@@ -67,7 +67,6 @@ def __init__(self, config):
67
67
tp .tril (tp .ones ((config .block_size , config .block_size ), dtype = config .dtype )),
68
68
(1 , 1 , config .block_size , config .block_size ),
69
69
)
70
- self .zeros = tp .zeros ((1 , 1 , self .seq_len , self .seq_len ), dtype = config .dtype )
71
70
72
71
def __call__ (self , x : tp .Tensor ):
73
72
B , T = x .shape [0 :2 ]
@@ -87,11 +86,7 @@ def __call__(self, x: tp.Tensor):
87
86
88
87
k_t = tp .transpose (k , - 2 , - 1 )
89
88
att = (q @ k_t ) * (1.0 / math .sqrt (self .embedding_size // self .num_heads ))
90
- att = tp .masked_fill (
91
- att ,
92
- self .bias [:, :, :T , :T ] == self .zeros [:, :, :T , :T ],
93
- float ("-inf" ),
94
- )
89
+ att = tp .masked_fill (att , self .bias [:, :, :T , :T ] == 0 , float ("-inf" ))
95
90
96
91
att = tp .softmax (att , dim = - 1 )
97
92
@@ -135,18 +130,18 @@ def __call__(self, x):
135
130
class Transformer (tp .Module ):
136
131
def __init__ (self , config ):
137
132
super ().__init__ ()
133
+ self .seq_len = config .seq_len
138
134
self .wte = tp .Embedding (config .vocab_size , config .embedding_size , dtype = config .dtype )
139
135
self .wpe = tp .Embedding (config .block_size , config .embedding_size , dtype = config .dtype )
140
- self .h = [Block (config ) for _ in range (config .num_layers )]
136
+ self .h = tp . Sequential ( * [Block (config ) for _ in range (config .num_layers )])
141
137
self .ln_f = tp .LayerNorm (config .embedding_size )
142
- self .pos = tp .reshape (tp .arange (0 , config .seq_len , dtype = tp .int32 ), (1 , config .seq_len ))
143
138
144
139
def __call__ (self , idx ):
145
140
tok_emb = self .wte (idx ) # token embeddings of shape (batch_size, seq_len, embedding_size)
146
- pos_emb = self .wpe (self .pos [:, : idx .shape [1 ]]) # position embeddings of shape (seq_len, embedding_size)
141
+ pos = tp .unsqueeze (tp .arange (self .seq_len , dtype = tp .int32 )[: idx .shape [1 ]], 0 )
142
+ pos_emb = self .wpe (pos ) # position embeddings of shape (seq_len, embedding_size)
147
143
x = tok_emb + pos_emb # (batch_size, seq_len, embedding_size)
148
- for block in self .h :
149
- x = block (x )
144
+ x = self .h (x )
150
145
x = tp .cast (self .ln_f (tp .cast (x , self .ln_f .dtype )), x .dtype )
151
146
return x
152
147
0 commit comments