Skip to content

Commit

Permalink
chore(datasets): replace more "data_set" instances (#476)
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <[email protected]>
Co-authored-by: Juan Luis Cano Rodríguez <[email protected]>
  • Loading branch information
deepyaman and astrojuanlu authored Dec 12, 2023
1 parent abae9f8 commit f8f4a7d
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 45 deletions.
6 changes: 3 additions & 3 deletions kedro-datasets/kedro_datasets/geopandas/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ data = gpd.GeoDataFrame(
{"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]},
geometry=[Point(1, 1), Point(2, 4)],
)
data_set = GeoJSONDataset(filepath="test.geojson")
data_set.save(data)
reloaded = data_set.load()
dataset = GeoJSONDataset(filepath="test.geojson")
dataset.save(data)
reloaded = dataset.load()
assert data.equals(reloaded)
```

Expand Down
6 changes: 3 additions & 3 deletions kedro-datasets/kedro_datasets/tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ import tensorflow as tf

from kedro_datasets.tensorflow import TensorFlowModelDataset

data_set = TensorFlowModelDataset("tf_model_dirname")
dataset = TensorFlowModelDataset("tf_model_dirname")

model = tf.keras.Model()
predictions = model.predict([...])

data_set.save(model)
loaded_model = data_set.load()
dataset.save(model)
loaded_model = dataset.load()

new_predictions = loaded_model.predict([...])
np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
Expand Down
32 changes: 16 additions & 16 deletions kedro-datasets/tests/polars/test_lazy_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def dummy_dataframe():


@pytest.fixture
def csv_data_set(filepath_csv, load_args, save_args, fs_args):
def csv_dataset(filepath_csv, load_args, save_args, fs_args):
return LazyPolarsDataset(
filepath=filepath_csv,
file_format="csv",
Expand Down Expand Up @@ -109,17 +109,17 @@ def mocked_csv_in_s3(mocked_s3_bucket, dummy_dataframe):
class TestLazyCSVDataset:
"""Test class for LazyPolarsDataset csv functionality"""

def test_exists(self, csv_data_set, dummy_dataframe):
def test_exists(self, csv_dataset, dummy_dataframe):
"""Test `exists` method invocation for both existing and
nonexistent data set.
"""
assert not csv_data_set.exists()
csv_data_set.save(dummy_dataframe)
assert csv_data_set.exists()
assert not csv_dataset.exists()
csv_dataset.save(dummy_dataframe)
assert csv_dataset.exists()

def test_load(self, dummy_dataframe, csv_data_set, filepath_csv):
def test_load(self, dummy_dataframe, csv_dataset, filepath_csv):
dummy_dataframe.write_csv(filepath_csv)
df = csv_data_set.load()
df = csv_dataset.load()
assert df.collect().shape == (2, 3)

def test_load_s3(self, dummy_dataframe, mocked_csv_in_s3):
Expand All @@ -130,32 +130,32 @@ def test_load_s3(self, dummy_dataframe, mocked_csv_in_s3):
loaded_df = ds.load().collect()
assert_frame_equal(loaded_df, dummy_dataframe)

def test_save_and_load(self, csv_data_set, dummy_dataframe):
csv_data_set.save(dummy_dataframe)
reloaded_df = csv_data_set.load().collect()
def test_save_and_load(self, csv_dataset, dummy_dataframe):
csv_dataset.save(dummy_dataframe)
reloaded_df = csv_dataset.load().collect()
assert_frame_equal(dummy_dataframe, reloaded_df)

def test_load_missing_file(self, csv_data_set):
def test_load_missing_file(self, csv_dataset):
"""Check the error when trying to load missing file."""
pattern = r"Failed while loading data from data set LazyPolarsDataset\(.*\)"
with pytest.raises(DatasetError, match=pattern):
csv_data_set.load()
csv_dataset.load()

@pytest.mark.parametrize(
"load_args", [{"k1": "v1", "index": "value"}], indirect=True
)
def test_load_extra_params(self, csv_data_set, load_args):
def test_load_extra_params(self, csv_dataset, load_args):
"""Test overriding the default load arguments."""
for key, value in load_args.items():
assert csv_data_set._load_args[key] == value
assert csv_dataset._load_args[key] == value

@pytest.mark.parametrize(
"save_args", [{"k1": "v1", "index": "value"}], indirect=True
)
def test_save_extra_params(self, csv_data_set, save_args):
def test_save_extra_params(self, csv_dataset, save_args):
"""Test overriding the default save arguments."""
for key, value in save_args.items():
assert csv_data_set._save_args[key] == value
assert csv_dataset._save_args[key] == value

@pytest.mark.parametrize(
"load_args,save_args",
Expand Down
42 changes: 21 additions & 21 deletions kedro-datasets/tests/redis/test_redis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def serialised_dummy_object(backend, dummy_object, save_args):


@pytest.fixture
def pickle_data_set(mocker, key, backend, load_args, save_args, redis_args):
def pickle_dataset(mocker, key, backend, load_args, save_args, redis_args):
mocker.patch(
"redis.StrictRedis.from_url", return_value=redis.Redis.from_url("redis://")
)
Expand All @@ -70,7 +70,7 @@ class TestPickleDataset:
)
def test_save_and_load(
self,
pickle_data_set,
pickle_dataset,
mocker,
dummy_object,
serialised_dummy_object,
Expand All @@ -81,64 +81,64 @@ def test_save_and_load(
get_mocker = mocker.patch(
"redis.StrictRedis.get", return_value=serialised_dummy_object
)
pickle_data_set.save(dummy_object)
pickle_dataset.save(dummy_object)
mocker.patch("redis.StrictRedis.exists", return_value=True)
loaded_dummy_object = pickle_data_set.load()
loaded_dummy_object = pickle_dataset.load()
set_mocker.assert_called_once_with(
key,
serialised_dummy_object,
)
get_mocker.assert_called_once_with(key)
assert_frame_equal(loaded_dummy_object, dummy_object)

def test_exists(self, mocker, pickle_data_set, dummy_object, key):
def test_exists(self, mocker, pickle_dataset, dummy_object, key):
"""Test `exists` method invocation for both existing and
nonexistent data set."""
mocker.patch("redis.StrictRedis.exists", return_value=False)
assert not pickle_data_set.exists()
assert not pickle_dataset.exists()
mocker.patch("redis.StrictRedis.set")
pickle_data_set.save(dummy_object)
pickle_dataset.save(dummy_object)
exists_mocker = mocker.patch("redis.StrictRedis.exists", return_value=True)
assert pickle_data_set.exists()
assert pickle_dataset.exists()
exists_mocker.assert_called_once_with(key)

def test_exists_raises_error(self, pickle_data_set):
def test_exists_raises_error(self, pickle_dataset):
"""Check the error when trying to assert existence with no redis server."""
pattern = r"The existence of key "
with pytest.raises(DatasetError, match=pattern):
pickle_data_set.exists()
pickle_dataset.exists()

@pytest.mark.parametrize(
"load_args", [{"k1": "v1", "errors": "strict"}], indirect=True
)
def test_load_extra_params(self, pickle_data_set, load_args):
def test_load_extra_params(self, pickle_dataset, load_args):
"""Test overriding the default load arguments."""
for key, value in load_args.items():
assert pickle_data_set._load_args[key] == value
assert pickle_dataset._load_args[key] == value

@pytest.mark.parametrize("save_args", [{"k1": "v1", "protocol": 2}], indirect=True)
def test_save_extra_params(self, pickle_data_set, save_args):
def test_save_extra_params(self, pickle_dataset, save_args):
"""Test overriding the default save arguments."""
for key, value in save_args.items():
assert pickle_data_set._save_args[key] == value
assert pickle_dataset._save_args[key] == value

def test_redis_extra_args(self, pickle_data_set, redis_args):
assert pickle_data_set._redis_from_url_args == redis_args["from_url_args"]
assert pickle_data_set._redis_set_args == {} # default unchanged
def test_redis_extra_args(self, pickle_dataset, redis_args):
assert pickle_dataset._redis_from_url_args == redis_args["from_url_args"]
assert pickle_dataset._redis_set_args == {} # default unchanged

def test_load_missing_key(self, mocker, pickle_data_set):
def test_load_missing_key(self, mocker, pickle_dataset):
"""Check the error when trying to load missing file."""
pattern = r"The provided key "
mocker.patch("redis.StrictRedis.exists", return_value=False)
with pytest.raises(DatasetError, match=pattern):
pickle_data_set.load()
pickle_dataset.load()

def test_unserialisable_data(self, pickle_data_set, dummy_object, mocker):
def test_unserialisable_data(self, pickle_dataset, dummy_object, mocker):
mocker.patch("pickle.dumps", side_effect=pickle.PickleError)
pattern = r".+ was not serialised due to:.*"

with pytest.raises(DatasetError, match=pattern):
pickle_data_set.save(dummy_object)
pickle_dataset.save(dummy_object)

def test_invalid_backend(self, mocker):
pattern = (
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/tests/spark/test_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def test_spark_load_save(self, is_async, data_catalog):
pipeline = modular_pipeline([node(identity, "spark_in", "spark_out")])
SequentialRunner(is_async=is_async).run(pipeline, data_catalog)

save_path = Path(data_catalog._data_sets["spark_out"]._filepath.as_posix())
save_path = Path(data_catalog._datasets["spark_out"]._filepath.as_posix())
files = list(save_path.glob("*.parquet"))
assert len(files) > 0

Expand All @@ -1016,6 +1016,6 @@ def test_spark_memory_spark(self, is_async, data_catalog):
)
SequentialRunner(is_async=is_async).run(pipeline, data_catalog)

save_path = Path(data_catalog._data_sets["spark_out"]._filepath.as_posix())
save_path = Path(data_catalog._datasets["spark_out"]._filepath.as_posix())
files = list(save_path.glob("*.parquet"))
assert len(files) > 0

0 comments on commit f8f4a7d

Please sign in to comment.