Skip to content

Commit

Permalink
misc fixes (#192)
Browse files Browse the repository at this point in the history
* misc fixes
* add datasets dependency
  • Loading branch information
vince62s authored Jan 21, 2025
1 parent d2992b7 commit 4e50ffb
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
5 changes: 2 additions & 3 deletions eole/bin/model/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def run(cls, args):
# use compute_dtype from lora finetuning
new_config.training.compute_dtype = config.training.compute_dtype
elif args.action == "concat":
model.half() # We keep FP16 for all
optim = lora_checkpoint["optim"]
model_state_dict = model.state_dict()
new_config = config
Expand All @@ -108,11 +107,11 @@ def run(cls, args):
shard_dict = {}
f.append(safe_open(shard, framework="pt", device="cpu"))
for key in f[i].keys():
shard_dict[key] = model_state_dict[key]
shard_dict[key] = model_state_dict[key].to(running_config.storage_dtype)
if i == 0:
for key in model_state_dict.keys():
if "estimator" in key:
shard_dict[key] = model_state_dict[key]
shard_dict[key] = model_state_dict[key].to(running_config.storage_dtype)
logger.info("saving shard" + args.output + "/model.{:02d}.safetensors".format(i))
save_file(
shard_dict,
Expand Down
9 changes: 5 additions & 4 deletions eole/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def forward(
not self.flash
or self.position_encoding_type in [PositionEncodingType.Relative, PositionEncodingType.Alibi]
or query.dtype not in [torch.float16, torch.bfloat16] # to match with flash
or query.device == torch.device("cpu")
):
key, value, query = self._prepare_inputs_w_cache(
query,
Expand All @@ -419,8 +420,8 @@ def forward(
cache_len = self._expand_cache(32, step, key)
if position_embeddings is not None:
rotdim = self.rotary_dim // 2
cos = position_embeddings[0][:, :rotdim].to(query.dtype)
sin = position_embeddings[1][:, :rotdim].to(query.dtype)
cos = position_embeddings[0][:, :rotdim].to(query.dtype).contiguous()
sin = position_embeddings[1][:, :rotdim].to(query.dtype).contiguous()
else:
cos, sin = None, None
# restore initial tgt_pad_mask - migth be better to store it instead.
Expand All @@ -435,8 +436,8 @@ def forward(
self.layer_cache[1]["values"].transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
rotary_cos=cos.contiguous(),
rotary_sin=sin.contiguous(),
rotary_cos=cos,
rotary_sin=sin,
cache_seqlens=cache_len - 1,
cache_leftpad=cache_leftpad,
causal=step == 0,
Expand Down
3 changes: 1 addition & 2 deletions requirements.opt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ scipy
bitsandbytes>=0.41.2
spacy
gradio
autoawq
pyarrow
pyarrow
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"fastapi",
"fasttext-wheel",
"huggingface_hub",
"datasets",
"numpy<2.0",
"pandas",
"protobuf==3.20.1",
Expand Down

0 comments on commit 4e50ffb

Please sign in to comment.