diff --git a/dask_ml/model_selection/_split.py b/dask_ml/model_selection/_split.py index 26134da5b..90e342b19 100644 --- a/dask_ml/model_selection/_split.py +++ b/dask_ml/model_selection/_split.py @@ -424,9 +424,9 @@ def train_test_split( test_size = 0.1 if train_size is None and test_size is not None: - train_size = 1 - test_size + train_size = round(1 - test_size, 6) if test_size is None and train_size is not None: - test_size = 1 - train_size + test_size = round(1 - train_size, 6) if options: raise TypeError("Unexpected options {}".format(options))