Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make output_dir Optional in TrainingArguments #27866 #35735

Merged
merged 9 commits into from
Feb 11, 2025
Next Next commit
make output_dir optional
sambhavnoobcoder committed Jan 16, 2025
commit da42052dad5e85fac6b1dc4ed578d675e20ac44a
13 changes: 11 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
@@ -228,7 +228,7 @@ class TrainingArguments:
command line.

Parameters:
output_dir (`str`):
output_dir (`str`, *optional*):
The output directory where the model predictions and checkpoints will be written.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the description to include the default value for output_dir. We can maybe call it trainer_output WDYT ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i agree , trainer_output sounds more intitutive .I've made that change as well as updated the descriptions in 7831474 commit .

Also , the PR is now clean and ready for merging .

overwrite_output_dir (`bool`, *optional*, defaults to `False`):
If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
@@ -813,7 +813,8 @@ class TrainingArguments:
"""

framework = "pt"
output_dir: str = field(
output_dir: Optional[str] = field(
default=None,
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
@@ -1547,6 +1548,14 @@ class TrainingArguments:
)

def __post_init__(self):
# Set default output_dir if not provided
if self.output_dir is None:
self.output_dir = "tmp_trainer"
logger.info(
"No output directory specified, defaulting to 'tmp_trainer'. "
"To change this behavior, specify --output_dir when creating TrainingArguments."
)

# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
passed_value = getattr(self, field)