diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a7b2ba0db3a7..bfd5b08da121 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -228,7 +228,7 @@ class TrainingArguments: command line. Parameters: - output_dir (`str`): + output_dir (`str`, *optional*, defaults to `"trainer_output"`): The output directory where the model predictions and checkpoints will be written. overwrite_output_dir (`bool`, *optional*, defaults to `False`): If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` @@ -813,8 +813,11 @@ class TrainingArguments: """ framework = "pt" - output_dir: str = field( - metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + output_dir: Optional[str] = field( + default=None, + metadata={ + "help": "The output directory where the model predictions and checkpoints will be written. Defaults to 'trainer_output' if not provided." + }, ) overwrite_output_dir: bool = field( default=False, @@ -1547,6 +1550,14 @@ class TrainingArguments: ) def __post_init__(self): + # Set default output_dir if not provided + if self.output_dir is None: + self.output_dir = "trainer_output" + logger.info( + "No output directory specified, defaulting to 'trainer_output'. " + "To change this behavior, specify --output_dir when creating TrainingArguments." + ) + # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: passed_value = getattr(self, field) diff --git a/tests/test_training_args.py b/tests/test_training_args.py new file mode 100644 index 000000000000..7b1daabe1634 --- /dev/null +++ b/tests/test_training_args.py @@ -0,0 +1,42 @@ +import os +import tempfile +import unittest + +from transformers import TrainingArguments + + +class TestTrainingArguments(unittest.TestCase): + def test_default_output_dir(self): + """Test that output_dir defaults to 'tmp_trainer' when not specified.""" + args = TrainingArguments(output_dir=None) + self.assertEqual(args.output_dir, "tmp_trainer") + + def test_custom_output_dir(self): + """Test that output_dir is respected when specified.""" + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments(output_dir=tmp_dir) + self.assertEqual(args.output_dir, tmp_dir) + + def test_output_dir_creation(self): + """Test that output_dir is created only when needed.""" + with tempfile.TemporaryDirectory() as tmp_dir: + output_dir = os.path.join(tmp_dir, "test_output") + + # Directory should not exist before creating args + self.assertFalse(os.path.exists(output_dir)) + + # Create args with save_strategy="no" - should not create directory + args = TrainingArguments( + output_dir=output_dir, + do_train=True, + save_strategy="no", + report_to=None, + ) + self.assertFalse(os.path.exists(output_dir)) + + # Now set save_strategy="steps" - should create directory when needed + args.save_strategy = "steps" + args.save_steps = 1 + self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist + + # Directory should be created when actually needed (e.g. in Trainer)