Skip to content

Commit 2de4c6f

Browse files
committed
Update logging_utils and log for zero rank
Signed-off-by: Mamta Singh <[email protected]>
1 parent 7bf2b24 commit 2de4c6f

File tree

4 files changed

+61
-89
lines changed

4 files changed

+61
-89
lines changed

QEfficient/cloud/finetune.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
)
3535
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3636
from QEfficient.utils._utils import login_and_download_hf_lm
37-
from QEfficient.utils.logging_utils import ft_logger as logger
37+
from QEfficient.utils.logging_utils import logger
3838

3939
# Try importing QAIC-specific module, proceed without it if unavailable
4040
try:
4141
import torch_qaic # noqa: F401
4242
except ImportError as e:
43-
logger.warning(f"{e}. Moving ahead without these qaic modules.")
43+
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
4444

4545
logger.setLevel(logging.INFO)
4646

@@ -121,7 +121,7 @@ def load_model_and_tokenizer(
121121
)
122122

123123
if not hasattr(model, "base_model_prefix"):
124-
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
124+
logger.raise_runtimeerror("Given huggingface model does not have 'base_model_prefix' attribute.")
125125

126126
for param in getattr(model, model.base_model_prefix).parameters():
127127
param.requires_grad = False
@@ -146,7 +146,7 @@ def load_model_and_tokenizer(
146146
# If there is a mismatch between tokenizer vocab size and embedding matrix,
147147
# throw a warning and then expand the embedding matrix
148148
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
149-
logger.warning("Resizing the embedding matrix to match the tokenizer vocab size.")
149+
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logger.WARNING)
150150
model.resize_token_embeddings(len(tokenizer))
151151

152152
# FIXME (Meet): Cover below line inside the logger once it is implemented.
@@ -162,7 +162,9 @@ def load_model_and_tokenizer(
162162
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
163163
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
164164
else:
165-
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
165+
logger.raise_runtimeerror(
166+
"Given model doesn't support gradient checkpointing. Please disable it and run it."
167+
)
166168

167169
model = apply_peft(model, train_config, peft_config_file, **kwargs)
168170

@@ -222,7 +224,7 @@ def setup_dataloaders(
222224
- Length of longest sequence in the dataset.
223225
224226
Raises:
225-
ValueError: If validation is enabled but the validation set is too small.
227+
RuntimeError: If validation is enabled but the validation set is too small.
226228
227229
Notes:
228230
- Applies a custom data collator if provided by get_custom_data_collator.
@@ -246,12 +248,12 @@ def setup_dataloaders(
246248
# )
247249
##
248250
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
249-
logger.info(f"length of dataset_train = {len(dataset_train)}")
251+
logger.log_rank_zero(f"Length of dataset_train = {len(dataset_train)}")
250252

251253
# FIXME (Meet): Add custom data collator registration from the outside by the user.
252254
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
253255
if custom_data_collator:
254-
logger.info("custom_data_collator is used")
256+
logger.log_rank_zero("Custom_data_collator is used")
255257
train_dl_kwargs["collate_fn"] = custom_data_collator
256258

257259
# Create DataLoaders for the training and validation dataset
@@ -261,7 +263,7 @@ def setup_dataloaders(
261263
pin_memory=True,
262264
**train_dl_kwargs,
263265
)
264-
logger.info(f"Num of Training Set Batches loaded = {len(train_dataloader)}")
266+
logger.log_rank_zero(f"Number of Training Set Batches loaded = {len(train_dataloader)}")
265267

266268
eval_dataloader = None
267269
if train_config.run_validation:
@@ -281,11 +283,11 @@ def setup_dataloaders(
281283
**val_dl_kwargs,
282284
)
283285
if len(eval_dataloader) == 0:
284-
raise ValueError(
286+
logger.raise_runtimeerror(
285287
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
286288
)
287289
else:
288-
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
290+
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")
289291

290292
longest_seq_length, _ = get_longest_seq_length(
291293
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
@@ -329,7 +331,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
329331

330332
# Create DataLoaders for the training and validation dataset
331333
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
332-
logger.info(
334+
logger.log_rank_zero(
333335
f"The longest sequence length in the train data is {longest_seq_length}, "
334336
f"passed context length is {train_config.context_length} and overall model's context length is "
335337
f"{model.config.max_position_embeddings}"

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import importlib
99
from pathlib import Path
1010

11-
from QEfficient.utils.logging_utils import ft_logger as logger
11+
from QEfficient.utils.logging_utils import logger
1212

1313

1414
def load_module_from_py_file(py_file: str) -> object:

QEfficient/finetune/utils/train_utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from QEfficient.finetune.configs.training import TrainConfig
22-
from QEfficient.utils.logging_utils import ft_logger as logger
22+
from QEfficient.utils.logging_utils import logger
2323

2424
try:
2525
import torch_qaic # noqa: F401
@@ -28,7 +28,7 @@
2828
import torch_qaic.utils as qaic_utils # noqa: F401
2929
from torch.qaic.amp import GradScaler as QAicGradScaler
3030
except ImportError as e:
31-
logger.warning(f"{e}. Moving ahead without these qaic modules.")
31+
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
3232

3333
from torch.amp import GradScaler
3434

@@ -111,21 +111,21 @@ def train(
111111
for epoch in range(train_config.num_epochs):
112112
if loss_0_counter.item() == train_config.convergence_counter:
113113
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
114-
logger.info(
114+
logger.log_rank_zero(
115115
f"Skipping epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
116116
)
117117
break
118118

119119
if train_config.use_peft and train_config.from_peft_checkpoint:
120120
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
121121
if epoch < intermediate_epoch:
122-
logger.info(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
122+
logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
123123
# to bring the count of train_step in sync with where it left off
124124
total_train_steps += len(train_dataloader)
125125
continue
126126

127-
logger.info(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
128-
logger.info(f"train_config.max_train_step: {train_config.max_train_step}")
127+
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
128+
logger.log_rank_zero(f"train_config.max_train_step: {train_config.max_train_step}")
129129
# stop when the maximum number of training steps is reached
130130
if max_steps_reached:
131131
break
@@ -152,7 +152,7 @@ def train(
152152
# to bring the count of train_step in sync with where it left off
153153
if epoch == intermediate_epoch and step == 0:
154154
total_train_steps += intermediate_step
155-
logger.info(
155+
logger.log_rank_zero(
156156
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
157157
)
158158
if epoch == intermediate_epoch and step < intermediate_step:
@@ -266,7 +266,7 @@ def train(
266266
)
267267
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
268268
if loss_0_counter.item() == train_config.convergence_counter:
269-
logger.info(
269+
logger.log_rank_zero(
270270
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps.Hence,stopping the fine tuning."
271271
)
272272
break
@@ -328,15 +328,15 @@ def train(
328328
if train_config.run_validation:
329329
if eval_epoch_loss < best_val_loss:
330330
best_val_loss = eval_epoch_loss
331-
logger.info(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
331+
logger.log_rank_zero(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
332332
val_loss.append(float(eval_epoch_loss))
333333
val_metric.append(float(eval_metric))
334334
if train_config.task_type == "seq_classification":
335-
logger.info(
335+
logger.log_rank_zero(
336336
f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
337337
)
338338
else:
339-
logger.info(
339+
logger.log_rank_zero(
340340
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
341341
)
342342

@@ -440,7 +440,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
440440
eval_metric = torch.exp(eval_epoch_loss)
441441

442442
# Print evaluation metrics
443-
logger.info(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
443+
logger.log_rank_zero(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
444444

445445
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
446446

@@ -469,10 +469,8 @@ def print_model_size(model, config) -> None:
469469
model: The PyTorch model.
470470
model_name (str): Name of the model.
471471
"""
472-
473-
logger.info(f"Model : {config.model_name}")
474472
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
475-
logger.info(f"{config.model_name} has {total_params / 1e6} Million params\n")
473+
logger.log_rank_zero(f"{config.model_name} has {total_params / 1e6} Million params\n")
476474

477475

478476
def save_to_json(

QEfficient/utils/logging_utils.py

Lines changed: 33 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
# -----------------------------------------------------------------------------
77

88
import logging
9+
import os
10+
from datetime import datetime
11+
12+
import torch.distributed as dist
13+
14+
from QEfficient.utils.constants import ROOT_DIR
915

1016

1117
class QEffFormatter(logging.Formatter):
@@ -38,82 +44,48 @@ def format(self, record):
3844
return formatter.format(record)
3945

4046

41-
def create_logger() -> logging.Logger:
47+
def create_logger(level=logging.INFO, dump_logs=True) -> logging.Logger:
4248
"""
4349
Creates a logger object with Colored QEffFormatter.
4450
"""
4551
logger = logging.getLogger("QEfficient")
4652

47-
# create console handler and set level to debug
53+
# create console handler and set level
4854
ch = logging.StreamHandler()
49-
ch.setLevel(logging.INFO)
50-
# define formatter
55+
ch.setLevel(level)
5156
ch.setFormatter(QEffFormatter())
52-
5357
logger.addHandler(ch)
54-
return logger
55-
56-
57-
# Define the logger object that can be used for logging purposes throughout the module.
58-
logger = create_logger()
59-
60-
61-
def create_ft_logger(log_file="finetune.log") -> logging.Logger:
62-
"""
63-
Creates a logger object with Colored QEffFormatter.
64-
"""
65-
logger = logging.getLogger("QEfficient")
6658

67-
# create console handler and set level to debug
68-
ch = logging.StreamHandler()
69-
ch.setLevel(logging.INFO)
70-
ch.setFormatter(QEffFormatter())
71-
logger.addHandler(ch)
59+
if dump_logs:
60+
logs_path = os.path.join(ROOT_DIR, "logs")
61+
if not os.path.exists(logs_path):
62+
os.makedirs(logs_path, exist_ok=True)
63+
file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
64+
log_file = os.path.join(logs_path, file_name)
7265

73-
# create file handler and set level to debug
74-
fh = logging.FileHandler(log_file)
75-
fh.setLevel(logging.INFO)
76-
fh.setFormatter(QEffFormatter())
77-
logger.addHandler(fh)
66+
# create file handler and set level
67+
fh = logging.FileHandler(log_file)
68+
fh.setLevel(level)
69+
formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s")
70+
fh.setFormatter(formatter)
71+
logger.addHandler(fh)
7872

7973
return logger
8074

8175

82-
# Define the logger object that can be used for logging purposes throughout the finetuning module.
83-
ft_logger = create_ft_logger()
84-
"""
85-
86-
class FT_Logger:
87-
def __init__(self, level=logging.INFO, log_file="finetune.log"):
88-
self.logger = logging.getLogger("QEfficient")
89-
self.logger.setLevel(level)
90-
self.level = level
91-
92-
# Create handlers
93-
self.file_handler = logging.FileHandler(log_file)
94-
self.console_handler = logging.StreamHandler()
95-
96-
self.file_handler.setFormatter(QEffFormatter())
97-
self.console_handler.setFormatter(QEffFormatter())
76+
class CustomLogger(logging.Logger):
77+
def raise_runtimeerror(self, message):
78+
self.error(message)
79+
raise RuntimeError(message)
9880

99-
# Add handlers to the logger
100-
self.logger.addHandler(self.file_handler)
101-
self.logger.addHandler(self.console_handler)
81+
def log_rank_zero(self, msg: str, level: int = logging.INFO) -> None:
82+
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
83+
if rank != 0:
84+
return
85+
self.log(level, msg, stacklevel=2)
10286

103-
def get_logger(self):
104-
return self.logger
105-
106-
def raise_valueerror(self, message):
107-
self.logger.error(message)
108-
raise ValueError(message)
10987

110-
def raise_runtimeerror(self, message):
111-
self.logger.error(message)
112-
raise RuntimeError(message)
113-
114-
def raise_filenotfounderror(self, message):
115-
self.logger.error(message)
116-
raise FileNotFoundError(message)
88+
logging.setLoggerClass(CustomLogger)
11789

118-
ft_logger = FT_Logger().get_logger()
119-
"""
90+
# Define the logger object that can be used for logging purposes throughout the module.
91+
logger = create_logger()

0 commit comments

Comments
 (0)