forked from aioz-ai/CFR_VQA
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathclassifier.py
30 lines (28 loc) · 1023 Bytes
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
This code is from Hengyuan Hu's repository.
https://github.com/hengyuan-hu/bottom-up-attention-vqa
"""
import torch
import numpy as np
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
import src.activation as act
from lxrt.modeling import BertLayerNorm, GeLU
class SimpleClassifier(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, args):
super(SimpleClassifier, self).__init__()
activation_dict = {'relu': nn.ReLU(), 'swish': act.Swish()}
try:
activation_func = activation_dict[args.activation]
except:
raise AssertionError(args.activation + " is not supported yet!")
layers = [
weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
activation_func,
nn.Dropout(args.dropout, inplace=True),
weight_norm(nn.Linear(hid_dim, out_dim), dim=None)
]
self.main = nn.Sequential(*layers)
def forward(self, x):
logits = self.main(x)
return logits