diff --git a/classification/README.md b/classification/README.md index de75387..3cb7fd8 100644 --- a/classification/README.md +++ b/classification/README.md @@ -68,10 +68,11 @@ smaller models are Resnet 26, 20, 14 and 8, e.g. ## ResNet-101 on ImageNet-1K +Note: Specify the path of your imagenet dataset using the `--dataset-root` argument, or change the path in the `main_imagenet.y` file. ### Evaluate a pretrained sparse DynConv network - python main_imagenet.py --batchsize 64 -r exp/imagenet/resnet101/sparse03/checkpoint_best.pth --budget 0.3 -e + python main_imagenet.py -r exp/imagenet/resnet101/sparse03/checkpoint_best.pth --budget 0.3 -e should result in @@ -79,19 +80,19 @@ should result in >\* FLOPS (multiply-accumulates, MACs) per image: 2997.180928 MMac use the `--plot_ponder` flag to visualize the ponder cost maps (computation heatmaps): - python main_imagenet.py --batchsize 64 -r exp/imagenet/resnet101/sparse03/checkpoint_best.pth --budget 0.3 -e --plot_ponder + python main_imagenet.py -r exp/imagenet/resnet101/sparse03/checkpoint_best.pth --budget 0.3 -e --plot_ponder Likewise: - python main_imagenet.py --batchsize 64 -r exp/imagenet/resnet101/sparse05/checkpoint_best.pth --budget 0.5 -e + python main_imagenet.py -r exp/imagenet/resnet101/sparse05/checkpoint_best.pth --budget 0.5 -e - python main_imagenet.py --batchsize 64 -r exp/imagenet/resnet101/sparse07/checkpoint_best.pth --budget 0.7 -e + python main_imagenet.py -r exp/imagenet/resnet101/sparse07/checkpoint_best.pth --budget 0.7 -e - python main_imagenet.py --batchsize 64 -r exp/imagenet/resnet101/sparse08/checkpoint_best.pth --budget 0.8 -e + python main_imagenet.py -r exp/imagenet/resnet101/sparse08/checkpoint_best.pth --budget 0.8 -e ### Evaluate a pretrained baseline - python main_imagenet.py --batchsize 64 -r exp/imagenet/resnet101/base/checkpoint_best.pth -e + python main_imagenet.py -r exp/imagenet/resnet101/base/checkpoint_best.pth -e diff --git a/classification/main_imagenet.py b/classification/main_imagenet.py index 798043a..6ffb8a6 100644 --- a/classification/main_imagenet.py +++ b/classification/main_imagenet.py @@ -24,7 +24,7 @@ def main(): parser = argparse.ArgumentParser(description='PyTorch ImageNet Training with sparse masks') - parser.add_argument('--lr', default=0.1, type=float, help='learning rate') + parser.add_argument('--lr', default=0.025, type=float, help='learning rate') parser.add_argument('--lr_decay', default=[30,60,90], nargs='+', type=int, help='learning rate decay epochs') parser.add_argument('--momentum', default=0.9, type=float, help='momentum') parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')