Skip to content

Commit c045834

Browse files
authored
Merge pull request #581 from ftgreat/master
Update AltCLIP and vit_cifar100
2 parents f138a08 + 2f61336 commit c045834

File tree

5 files changed

+16
-4
lines changed

5 files changed

+16
-4
lines changed

examples/AltCLIP/altclip_finetuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1313

14-
dataset_root = "./clip_benchmark_datasets"
14+
dataset_root = "./data"
1515
dataset_name = "cifar10"
1616

1717
batch_size = 4
@@ -62,4 +62,4 @@ def cifar10_collate_fn(batch):
6262
}
6363

6464
if __name__ == "__main__":
65-
trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
65+
trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pillow
2+
antlr4
3+
pytorch-lightning==1.9.0
4+
taming-transformers==0.0.6
5+
transformers==4.30.0
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pillow
2+
antlr4
3+
pytorch-lightning==1.9.0
4+
taming-transformers==0.0.6
5+
transformers==4.30.0

examples/vit_cifar100/train_single_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
weight_decay=1e-5,
2222
epochs=n_epochs,
2323
log_interval=100,
24-
eval_interval=1000,
24+
eval_interval=10000,
2525
load_dir=None,
2626
pytorch_device=device,
2727
save_dir="checkpoints_vit_cifar100_single_gpu",

flagai/model/vision/vit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,12 @@ def __init__(
225225
norm_layer: (nn.Module): normalization layer
226226
act_layer: (nn.Module): MLP activation layer
227227
"""
228+
config = config.json_config
228229
super().__init__(config)
229230
embed_layer=PatchEmbed
230231
block_fn=Block
231-
config = config.json_config
232+
if 'use_cache' in config:
233+
del config['use_cache']
232234
vit_config = VitConfig(**config)
233235
vit_config.num_classes = num_classes
234236
# config = vit_config

0 commit comments

Comments
 (0)