Skip to content

Commit

Permalink
minimal changes for finetuning support
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Jan 29, 2025
1 parent 7c3232f commit f4a26f7
Show file tree
Hide file tree
Showing 13 changed files with 177 additions and 21 deletions.
6 changes: 3 additions & 3 deletions eole/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ def _validate_data(self):
logger.info(f"Missing transforms field for {cname} data, " f"set to default: {default_transforms}.")
corpus.transforms = default_transforms
# Check path
if corpus.path_src is None:
if corpus.path_src is None and corpus.path_txt is None:
raise ValueError(
f"Corpus {cname} src path is required."
f"Corpus {cname} `path_src` or `path_tgt` is required."
"tgt path is also required for non language"
" modeling tasks."
)
else:
elif corpus.path_src is not None:
self.__class__._validate_file(corpus.path_src, info=f"{cname}/path_src")
if corpus.path_tgt is None:
logger.debug("path_tgt is None, it should be set unless the task" " is language modeling")
Expand Down
4 changes: 4 additions & 0 deletions eole/config/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class TrainConfig(LoggingConfig, MiscConfig, DataConfig, VocabConfig): # ModelC
def get_model_path(self):
return self.training.get_model_path()

@property
def data_type(self):
return self.training.data_type

@classmethod
def get_defaults(cls, architecture):
return cls(
Expand Down
4 changes: 3 additions & 1 deletion eole/encoders/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def __init__(self, model_config, running_config=None):
self.ln_pre = RMSNorm(model_config.hidden_size, eps=1e-5)
self.transformer_layers = torch.nn.ModuleList()
for _ in range(model_config.layers):
self.transformer_layers.append(TransformerEncoderLayer(model_config))
self.transformer_layers.append(
TransformerEncoderLayer(model_config, running_config=running_config)
)

head_dim = model_config.hidden_size // model_config.heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
Expand Down
5 changes: 4 additions & 1 deletion eole/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,10 @@ def __iter__(self):
"cid_line_number",
"left_pad",
]:
tensor_batch[key] = tensor_batch[key].to(self.device)
if isinstance(tensor_batch[key], list):
tensor_batch[key] = [t.to(self.device) for t in tensor_batch[key]]
else:
tensor_batch[key] = tensor_batch[key].to(self.device)
yield (tensor_batch, bucket_idx)


Expand Down
43 changes: 32 additions & 11 deletions eole/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,17 @@ class ImageTextCorpus(object):
```
"""

def __init__(self, name, data):
def __init__(self, name, data, is_train=False):
self.id = name
self.data = data
self.is_train = is_train

def load(self, offset=0, stride=1):
def make_ex(item):
example = {"text": item["text"], "images": item["images"]}
example = {
"text": item["text"],
"images": item.get("images", {})
}
return example

if isinstance(self.data, list):
Expand Down Expand Up @@ -253,25 +257,38 @@ def get_corpora(config, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
corpus_dict.path_sco,
corpus_dict.path_align,
)
else:
elif config.data_type == "text":
corpora_dict[corpus_id] = BlockwiseCorpus(
corpus_id,
corpus_dict.path_txt,
block_size=8192, # number of characters
)
elif config.data_type == "image":
corpora_dict[corpus_id] = ImageTextCorpus(
corpus_id,
corpus_dict.path_txt,
is_train=True,
)
elif task == CorpusTask.VALID:
if CorpusName.VALID in config.data.keys():
if config.data[CorpusName.VALID].path_tgt is None:
path_tgt = config.data[CorpusName.VALID].path_src
else:
path_tgt = config.data[CorpusName.VALID].path_tgt
corpora_dict[CorpusName.VALID] = ParallelCorpus(
CorpusName.VALID,
config.data[CorpusName.VALID].path_src,
path_tgt if tgt is None else None,
None,
config.data[CorpusName.VALID].path_align,
)
if config.data_type == "text":
corpora_dict[CorpusName.VALID] = ParallelCorpus(
CorpusName.VALID,
config.data[CorpusName.VALID].path_src,
path_tgt if tgt is None else None,
None,
config.data[CorpusName.VALID].path_align,
)
elif config.data_type == "image":
corpora_dict[CorpusName.VALID] = ImageTextCorpus(
CorpusName.VALID,
config.data[CorpusName.VALID].path_txt,
is_train=True,
)
else:
return None
else:
Expand All @@ -284,7 +301,7 @@ def get_corpora(config, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
)
elif config.data_type == "image":
corpora_dict[CorpusName.INFER] = ImageTextCorpus(
CorpusName.INFER, src # maybe homogenize to some better name
CorpusName.INFER, src, is_train=False # maybe homogenize to some better name
)
return corpora_dict

Expand Down Expand Up @@ -355,6 +372,7 @@ def __init__(
skip_empty_level="warning",
stride=1,
offset=0,
is_train=False,
):
self.cid = corpus.id
self.corpus = corpus
Expand All @@ -364,6 +382,7 @@ def __init__(
self.skip_empty_level = skip_empty_level
self.stride = stride
self.offset = offset
self.is_train = is_train

def _process(self, stream):
for i, example in enumerate(stream):
Expand All @@ -379,6 +398,7 @@ def _process(self, stream):
line_number = i * self.stride + self.offset
example = {
"src": text,
"tgt": text if self.is_train else None,
"images": {k: v["image"] for k, v in processed_images.items()},
"cid": self.cid,
"cid_line_number": line_number,
Expand Down Expand Up @@ -408,6 +428,7 @@ def build_corpora_iters(corpora, transforms, corpora_info, skip_empty_level="war
skip_empty_level=skip_empty_level,
stride=stride,
offset=offset,
is_train=corpus.is_train,
)
corpora_iters[c_id] = corpus_iter
return corpora_iters
Expand Down
2 changes: 2 additions & 0 deletions eole/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def clean_example(maybe_example):
maybe_example["tgt"] = {"tgt": " ".join(maybe_example["tgt"])}
if "align" in maybe_example:
maybe_example["align"] = " ".join(maybe_example["align"])
if "sco" not in maybe_example:
maybe_example["sco"] = 1
return maybe_example


Expand Down
8 changes: 7 additions & 1 deletion eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ class VisionEncoderDecoderModel(BaseModel):

def __init__(self, **kwargs):
super(VisionEncoderDecoderModel, self).__init__(**kwargs)
self.tgt_shift = 1
self.image_token_id = kwargs.get("image_token_id", None)
if self.encoder is None or self.decoder is None:
raise ValueError("A EncoderDecoderModel requires both an Encoder and a Decoder")
Expand Down Expand Up @@ -942,6 +943,8 @@ def embed_vision_language_features(self, src, images):
text_locations = src != self.image_token_id
image_locations = src == self.image_token_id
text_features = self.tgt_emb(src[text_locations].view(batch_size, -1))
if len(images) == 0:
return text_features
encoded_images = self.encoder(images)
image_features = self.adapter(encoded_images)

Expand All @@ -959,7 +962,8 @@ def embed_vision_language_features(self, src, images):
device=text_features.device,
)
combined_features[text_locations, :] = text_features
combined_features[image_locations, :] = image_features
if len(images) > 0:
combined_features[image_locations, :] = image_features

return combined_features

Expand All @@ -972,12 +976,14 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False, images=[]):
emb = self.embed_vision_language_features(src, images)
pad_idx = self.tgt_emb.word_padding_idx
pad_mask = src.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
position_embeddings = self.rope.update(emb.size(1), step=None)
dec_out, attns = self.decoder(
emb,
enc_out=None,
src_len=src_len,
with_align=with_align,
tgt_pad_mask=pad_mask,
position_embeddings=position_embeddings,
)

return dec_out, attns, None
Expand Down
1 change: 0 additions & 1 deletion eole/predict/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def _score_target(self, batch, enc_out, src_len):
log_probs, attn = self._decode_and_generate(
src,
None,
batch,
src_len=src_len,
)

Expand Down
5 changes: 4 additions & 1 deletion eole/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,13 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, repor
report_stats.n_src_words += src_len.sum().item()
total_stats.n_src_words += src_len.sum().item()
tgt = batch["tgt"]
kwargs = {}
if "images" in batch.keys():
kwargs["images"] = batch["images"]

try:
with get_autocast(enabled=self.optim.amp):
model_out, attns, estim = self.model(src, tgt, src_len, with_align=self.with_align)
model_out, attns, estim = self.model(src, tgt, src_len, with_align=self.with_align, **kwargs)
if self.zero_out_prompt_loss:
# The loss of the prompt will be set to zero.
batch = self.train_loss.ignore_prompt(batch)
Expand Down
5 changes: 4 additions & 1 deletion eole/transforms/tokenize_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
raise ValueError(f"Unsupported src type: {type(example['src'])}")
example["src_ids"] = src_tokens
if example.get("tgt", None) is not None:
tgt_tokens = self.tokenize_string(" ".join(example["tgt"]), side="tgt", is_train=is_train)
if isinstance(example["tgt"], str):
tgt_tokens = self.tokenize_string(example["tgt"], side="tgt", is_train=is_train)
elif isinstance(example["tgt"], list):
tgt_tokens = self.tokenize_string(" ".join(example["tgt"]), side="tgt", is_train=is_train)
example["tgt_ids"] = tgt_tokens
return example

Expand Down
6 changes: 5 additions & 1 deletion recipes/pixtral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ eole convert HF --model_dir mistral-community/pixtral-12b --output ./pixtral-12b
python3 test_inference.py
```

There are several examples in the test script (taken from pixtral blog posts). A single one is activated by default, but you can uncomment the others to test the various cases.
There are several examples in the test script (taken from pixtral blog posts). A single one is activated by default, but you can uncomment the others to test the various cases.

## Finetuning

Finetuning is untested for now. Feel free to try it out and fix any arising issues.
89 changes: 89 additions & 0 deletions recipes/pixtral/finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# General settings
seed: 1234
share_vocab: true
save_data: "./finetune/pixtral-finetune"
src_vocab: "./pixtral-12b/vocab.txt"
src_vocab_size: 132000
tgt_vocab_size: 132000

overwrite: true

report_every: 10

# datasets
data:
test_data:
path_txt: "./train_data.json"

# valid:
# path_src: "./data/valid.txt"

skip_empty_level: silent

transforms_configs:
huggingface_tokenize:
max_length: 4096

training:
data_type: "image"
# GPU dispatching
world_size: 1
gpu_ranks: [0]
# 2 GPU
# world_size: 2
# gpu_ranks: [0, 1]
# parallel_mode: tensor_parallel
dropout_steps: [0]
dropout: [0.0]
attention_dropout: [0.0]
# Batching
# bucket_size: 32768
bucket_size: 10
num_workers: 0
batch_type: "sents"
batch_size: 1
valid_batch_size: 1
batch_size_multiple: 1

# Optimization
# compute_dtype: "fp16"
# apex_opt_level: ""
# optim: "fusedadam"
compute_dtype: "bf16"
optim: "adam"
learning_rate: 0.0001
warmup_steps: 100
decay_method: "none"
#learning_rate_decay: 0.98
#start_decay_steps: 100
#decay_steps: 10
adam_beta2: 0.998
accum_count: [8]
accum_steps: [0]
max_grad_norm: 0
label_smoothing: 0.0
param_init_method: xavier_uniform
normalization: "tokens"

# folders
train_from: "./pixtral-12b"
model_path: "./finetune/pixtral-finetuned"
keep_checkpoint: 10
save_checkpoint_steps: 100

train_steps: 1000
valid_steps: 100

# 4/8bit
quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear']
quant_type: "bnb_NF4"

# LoRa
lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear']
lora_rank: 2
lora_dropout: 0.05
lora_alpha: 8
lora_embedding: false

# Chekpointing
#use_ckpting: ['ffn', 'lora']
20 changes: 20 additions & 0 deletions recipes/pixtral/train_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"text": "List the top 5 countries in Europe with the highest GDP\n{image1}\nGermany\nUnited Kingdom\nFrance\nItaly\nSpain",
"images": {
"image1": "./test_data/gdp.png"
}
},
{
"text": "When did things start to go wrong for dark dragon?\n{image1}\nAround step 10k, then worsened around step 20k.",
"images": {
"image1": "./test_data/loss_curve.jpg"
}
},
{
"text": "Is this person really big, or is this building just super small?\n{image1}\nNone, it's a trick of perspective.",
"images": {
"image1": "./test_data/pisa_2.jpg"
}
}
]

0 comments on commit f4a26f7

Please sign in to comment.