Skip to content

Commit 669cf98

Browse files
committed
Address comments
Signed-off-by: Mamta Singh <[email protected]>
1 parent 4753d9e commit 669cf98

File tree

7 files changed

+84
-39
lines changed

7 files changed

+84
-39
lines changed

QEfficient/cloud/finetune.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import logging
89
import random
910
import warnings
1011
from typing import Any, Dict, Optional, Union
@@ -18,7 +19,7 @@
1819
import torch.utils.data
1920
from peft import PeftModel, get_peft_model
2021
from torch.optim.lr_scheduler import StepLR
21-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
22+
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
2223

2324
from QEfficient.finetune.configs.training import TrainConfig
2425
from QEfficient.finetune.utils.config_utils import (
@@ -33,16 +34,16 @@
3334
)
3435
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3536
from QEfficient.utils._utils import login_and_download_hf_lm
36-
from QEfficient.utils.logging_utils import logger
37+
from QEfficient.utils.logging_utils import ft_logger as logger
3738

3839
# Try importing QAIC-specific module, proceed without it if unavailable
3940
try:
4041
import torch_qaic # noqa: F401
4142
except ImportError as e:
4243
logger.warning(f"{e}. Moving ahead without these qaic modules.")
4344

45+
logger.setLevel(logging.INFO)
4446

45-
from transformers import AutoModelForSequenceClassification
4647

4748
# Suppress all warnings
4849
warnings.filterwarnings("ignore")
@@ -245,7 +246,7 @@ def setup_dataloaders(
245246
# )
246247
##
247248
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
248-
logger.info("length of dataset_train", len(dataset_train))
249+
logger.info(f"length of dataset_train = {len(dataset_train)}")
249250

250251
# FIXME (Meet): Add custom data collator registration from the outside by the user.
251252
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
@@ -260,7 +261,7 @@ def setup_dataloaders(
260261
pin_memory=True,
261262
**train_dl_kwargs,
262263
)
263-
logger.info(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
264+
logger.info(f"Num of Training Set Batches loaded = {len(train_dataloader)}")
264265

265266
eval_dataloader = None
266267
if train_config.run_validation:
@@ -284,7 +285,7 @@ def setup_dataloaders(
284285
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
285286
)
286287
else:
287-
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
288+
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
288289

289290
longest_seq_length, _ = get_longest_seq_length(
290291
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import importlib
99
from pathlib import Path
1010

11-
from QEfficient.utils.logging_utils import logger
11+
from QEfficient.utils.logging_utils import ft_logger as logger
1212

1313

1414
def load_module_from_py_file(py_file: str) -> object:

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
2323
)
2424
except Exception as e:
2525
logger.error(
26-
"Loading of grammar dataset failed! Please see [here](https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
26+
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
2727
)
2828
raise e
2929

QEfficient/finetune/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ def main(**kwargs):
109109
pin_memory=True,
110110
**val_dl_kwargs,
111111
)
112-
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
112+
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
113113
if len(eval_dataloader) == 0:
114114
raise ValueError(
115115
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
116116
)
117117
else:
118-
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
118+
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
119119

120120
model.to(device)
121121
_ = evaluation(model, train_config, eval_dataloader, None, tokenizer, device)

QEfficient/finetune/utils/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def update_config(config, **kwargs):
5454
raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
5555
else:
5656
config_type = type(config).__name__
57-
logger.warning(f"Unknown parameter '{k}' for config type '{config_type}'")
57+
logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")
5858

5959

6060
def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any:

QEfficient/finetune/utils/train_utils.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from QEfficient.finetune.configs.training import TrainConfig
22-
from QEfficient.utils.logging_utils import logger
22+
from QEfficient.utils.logging_utils import ft_logger as logger
2323

2424
try:
2525
import torch_qaic # noqa: F401
@@ -85,10 +85,7 @@ def train(
8585
device_type = device.split(":")[0]
8686

8787
tensorboard_updates = None
88-
if train_config.enable_ddp:
89-
if local_rank == 0:
90-
tensorboard_updates = SummaryWriter()
91-
else:
88+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
9289
tensorboard_updates = SummaryWriter()
9390

9491
if train_config.grad_scaler:
@@ -113,14 +110,9 @@ def train(
113110
# Start the training loop
114111
for epoch in range(train_config.num_epochs):
115112
if loss_0_counter.item() == train_config.convergence_counter:
116-
if train_config.enable_ddp:
117-
logger.info(
118-
f"Not proceeding with epoch {epoch + 1} on device {local_rank} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
119-
)
120-
break
121-
else:
113+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
122114
logger.info(
123-
f"Not proceeding with epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
115+
f"Skipping epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
124116
)
125117
break
126118

@@ -161,7 +153,7 @@ def train(
161153
if epoch == intermediate_epoch and step == 0:
162154
total_train_steps += intermediate_step
163155
logger.info(
164-
f"skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for them."
156+
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
165157
)
166158
if epoch == intermediate_epoch and step < intermediate_step:
167159
total_train_steps += 1
@@ -221,10 +213,7 @@ def train(
221213
else:
222214
loss_0_counter = torch.tensor([0]).to(device)
223215

224-
if train_config.enable_ddp:
225-
if local_rank == 0:
226-
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
227-
else:
216+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
228217
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
229218

230219
if train_config.save_metrics:
@@ -275,16 +264,10 @@ def train(
275264
val_step_metric,
276265
val_metric,
277266
)
278-
if train_config.enable_ddp:
279-
if loss_0_counter.item() == train_config.convergence_counter:
280-
logger.info(
281-
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning on device {local_rank}."
282-
)
283-
break
284-
else:
267+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
285268
if loss_0_counter.item() == train_config.convergence_counter:
286269
logger.info(
287-
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning."
270+
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps.Hence,stopping the fine tuning."
288271
)
289272
break
290273

@@ -457,7 +440,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
457440
eval_metric = torch.exp(eval_epoch_loss)
458441

459442
# Print evaluation metrics
460-
logger.info(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
443+
logger.info(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
461444

462445
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
463446

@@ -487,9 +470,9 @@ def print_model_size(model, config) -> None:
487470
model_name (str): Name of the model.
488471
"""
489472

490-
logger.info(f"--> Model {config.model_name}")
473+
logger.info(f"Model : {config.model_name}")
491474
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
492-
logger.info(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
475+
logger.info(f"{config.model_name} has {total_params / 1e6} Million params\n")
493476

494477

495478
def save_to_json(

QEfficient/utils/logging_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,64 @@ def create_logger() -> logging.Logger:
5656

5757
# Define the logger object that can be used for logging purposes throughout the module.
5858
logger = create_logger()
59+
60+
61+
def create_ft_logger(log_file="finetune.log") -> logging.Logger:
62+
"""
63+
Creates a logger object with Colored QEffFormatter.
64+
"""
65+
logger = logging.getLogger("QEfficient")
66+
67+
# create console handler and set level to debug
68+
ch = logging.StreamHandler()
69+
ch.setLevel(logging.INFO)
70+
ch.setFormatter(QEffFormatter())
71+
logger.addHandler(ch)
72+
73+
# create file handler and set level to debug
74+
fh = logging.FileHandler(log_file)
75+
fh.setLevel(logging.INFO)
76+
fh.setFormatter(QEffFormatter())
77+
logger.addHandler(fh)
78+
79+
return logger
80+
81+
82+
# Define the logger object that can be used for logging purposes throughout the finetuning module.
83+
ft_logger = create_ft_logger()
84+
"""
85+
86+
class FT_Logger:
87+
def __init__(self, level=logging.INFO, log_file="finetune.log"):
88+
self.logger = logging.getLogger("QEfficient")
89+
self.logger.setLevel(level)
90+
self.level = level
91+
92+
# Create handlers
93+
self.file_handler = logging.FileHandler(log_file)
94+
self.console_handler = logging.StreamHandler()
95+
96+
self.file_handler.setFormatter(QEffFormatter())
97+
self.console_handler.setFormatter(QEffFormatter())
98+
99+
# Add handlers to the logger
100+
self.logger.addHandler(self.file_handler)
101+
self.logger.addHandler(self.console_handler)
102+
103+
def get_logger(self):
104+
return self.logger
105+
106+
def raise_valueerror(self, message):
107+
self.logger.error(message)
108+
raise ValueError(message)
109+
110+
def raise_runtimeerror(self, message):
111+
self.logger.error(message)
112+
raise RuntimeError(message)
113+
114+
def raise_filenotfounderror(self, message):
115+
self.logger.error(message)
116+
raise FileNotFoundError(message)
117+
118+
ft_logger = FT_Logger().get_logger()
119+
"""

0 commit comments

Comments
 (0)