Skip to content

Commit

Permalink
Keras 2 _*generator displaying warning always about semantic changes …
Browse files Browse the repository at this point in the history
…from Keras 1 (keras-team#7001)

* Warn always about semantic changes if having keras1 args in *_generator calls.

* modified api upgrade warning message to be more detailed

* minor fix to pep8 syntax
  • Loading branch information
holli authored and fchollet committed Jun 16, 2017
1 parent 5ca5699 commit d3c3361
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
19 changes: 13 additions & 6 deletions keras/legacy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,21 @@ def generator_methods_args_preprocessor(args, kwargs):
if hasattr(generator, 'batch_size'):
kwargs['steps_per_epoch'] = samples_per_epoch // generator.batch_size
else:
warnings.warn('The semantics of the Keras 2 argument '
' `steps_per_epoch` is not the same as the '
'Keras 1 argument `samples_per_epoch`. '
'`steps_per_epoch` is the number of batches '
'to draw from the generator at each epoch. '
'Update your method calls accordingly.', stacklevel=3)
kwargs['steps_per_epoch'] = samples_per_epoch
converted.append(('samples_per_epoch', 'steps_per_epoch'))

keras1_args = {'samples_per_epoch', 'val_samples', 'nb_epoch', 'nb_val_samples', 'nb_worker'}
if keras1_args.intersection(kwargs.keys()):
warnings.warn('The semantics of the Keras 2 argument '
'`steps_per_epoch` is not the same as the '
'Keras 1 argument `samples_per_epoch`. '
'`steps_per_epoch` is the number of batches '
'to draw from the generator at each epoch. '
'Basically steps_per_epoch = samples_per_epoch/batch_size. '
'Similarly `nb_val_samples`->`validation_steps` and '
'`val_samples`->`steps` arguments have changed. '
'Update your method calls accordingly.', stacklevel=3)

return args, kwargs, converted


Expand Down
5 changes: 5 additions & 0 deletions tests/keras/legacy/interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,11 @@ def pred_generator():
validation_data=val_generator(),
nb_val_samples=1,
nb_worker=1)
model.fit_generator(train_generator(),
10,
1,
nb_val_samples=1,
nb_worker=1)
model.evaluate_generator(generator=train_generator(),
val_samples=2,
nb_worker=1)
Expand Down

0 comments on commit d3c3361

Please sign in to comment.