Skip to content

Commit 6cf697e

Browse files
authored
[Update] Create dinov2.py (#210)
DINOv2
1 parent 0833481 commit 6cf697e

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

Diff for: semilearn/nets/dinov2.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import Dinov2Model, Dinov2PreTrainedModel
4+
import os
5+
6+
class CustomDINONormModel(nn.Module):
7+
def __init__(self, name, num_classes=8):
8+
super(CustomDINONormModel, self).__init__()
9+
self.dino_model = Dinov2Model.from_pretrained(name)
10+
self.classifier = nn.Sequential(*[
11+
nn.Linear(1024, 256),
12+
nn.LayerNorm(256),
13+
nn.Linear(256, 128),
14+
nn.ReLU(),
15+
nn.Linear(128, num_classes),
16+
])
17+
18+
def forward(self, x, only_fc=False, only_feat=False, return_embed=False, **kwargs):
19+
"""
20+
Args:
21+
x: input tensor, depends on only_fc and only_feat flag
22+
only_fc: only use classifier, input should be features before classifier
23+
only_feat: only return pooled features
24+
return_embed: return word embedding, used for vat
25+
"""
26+
# Extract features using DinoV2 model
27+
if return_embed:
28+
embed = self.dino_model(x)
29+
return embed
30+
31+
out_dict = self.dino_model(x, output_hidden_states=True, return_dict=True)
32+
last_hidden_state = out_dict['last_hidden_state']
33+
pooled_output = torch.mean(last_hidden_state, 1) # Perform mean pooling
34+
35+
if only_fc:
36+
logits = self.classifier(pooled_output)
37+
return logits
38+
39+
if only_feat:
40+
return pooled_output
41+
42+
logits = self.classifier(pooled_output)
43+
result_dict = {'logits': logits, 'feat': pooled_output}
44+
return result_dict
45+
46+
47+
def group_matcher(self, coarse=False, prefix=''):
48+
matcher = dict(stem=r'^{}dino_model.embeddings'.format(prefix), blocks=r'^{}dino_model.encoder.layer.(\d+)'.format(prefix))
49+
return matcher
50+
51+
def no_weight_decay(self):
52+
return []
53+
54+
55+
56+
def dinov2_vitl14(pretrained=True, pretrained_path=None, **kwargs):
57+
model = CustomDINONormModel(name='facebookresearch/dinov2_vitl14', **kwargs)
58+
return model
59+
60+
61+
def dinov2_vitb14(pretrained=True, pretrained_path=None, **kwargs):
62+
model = CustomDINONormModel(name='facebookresearch/dinov2_vitb14', **kwargs)
63+
return model

0 commit comments

Comments
 (0)