Skip to content

Commit

Permalink
chore(datasets): Fix doctests (#488)
Browse files Browse the repository at this point in the history
Signed-off-by: Merel Theisen <[email protected]>
  • Loading branch information
merelcht authored Dec 19, 2023
1 parent f8f4a7d commit 6997b11
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 14 deletions.
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@ dataset-doctest%:
# TODO(deepyaman): Fix as many doctests as possible (so that they run).
cd kedro-datasets && pytest kedro_datasets --doctest-modules --doctest-continue-on-failure --no-cov \
--ignore kedro_datasets/databricks/managed_table_dataset.py \
--ignore kedro_datasets/pandas/deltatable_dataset.py \
--ignore kedro_datasets/pandas/gbq_dataset.py \
--ignore kedro_datasets/partitions/incremental_dataset.py \
--ignore kedro_datasets/partitions/partitioned_dataset.py \
--ignore kedro_datasets/polars/lazy_polars_dataset.py \
--ignore kedro_datasets/redis/redis_dataset.py \
--ignore kedro_datasets/snowflake/snowpark_dataset.py \
--ignore kedro_datasets/spark/deltatable_dataset.py \
--ignore kedro_datasets/spark/spark_hive_dataset.py \
--ignore kedro_datasets/spark/spark_jdbc_dataset.py \
--ignore kedro_datasets/tensorflow/tensorflow_model_dataset.py \
$(extra_pytest_arg${*})

test-sequential:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class ManagedTableDataset(AbstractVersionedDataset):
... )
>>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]
>>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)
>>> dataset = ManagedTableDataset(table="names_and_ages")
>>> dataset = ManagedTableDataset(table="names_and_ages", write_mode="overwrite")
>>> dataset.save(spark_df)
>>> reloaded = dataset.load()
>>> reloaded.take(4)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DeltaTableDataset(AbstractDataset):
>>>
>>> new_data = pd.DataFrame({"col1": [7, 8], "col2": [9, 10], "col3": [11, 12]})
>>> dataset.save(new_data)
>>> dataset.get_loaded_version()
>>> assert isinstance(dataset.get_loaded_version(), int)
"""

Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/gbq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GBQTableDataset(AbstractDataset[None, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = GBQTableDataset("dataset", "table_name", project="my-project")
>>> dataset = GBQTableDataset(dataset="dataset", table_name="table_name", project="my-project")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>>
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ class LazyPolarsDataset(AbstractVersionedDataset[pl.LazyFrame, PolarsFrame]):
>>>
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = LazyPolarsDataset(filepath=tmp_path / "test.csv")
>>> dataset = LazyPolarsDataset(filepath=tmp_path / "test.csv", file_format="csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
>>> assert data.frame_equal(reloaded.collect())
"""

Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SparkJDBCDataset(AbstractDataset[DataFrame, DataFrame]):
.. code-block:: pycon
>>> import pandas as pd
>>> from kedro_datasets.spark import SparkJBDCDataset
>>> from kedro_datasets.spark import SparkJDBCDataset
>>> from pyspark.sql import SparkSession
>>>
>>> spark = SparkSession.builder.getOrCreate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class TensorFlowModelDataset(AbstractVersionedDataset[tf.keras.Model, tf.keras.M
>>> import tensorflow as tf
>>> import numpy as np
>>>
>>> dataset = TensorFlowModelDataset(tmp_path / "data/06_models/tensorflow_model.h5")
>>> model = tf.keras.Model()
>>> predictions = model.predict([...])
>>> dataset = TensorFlowModelDataset(filepath=tmp_path / "data/06_models/tensorflow_model.h5")
>>> model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(3,)),tf.keras.layers.Softmax()])
>>>
>>> # x = tf.random.uniform((10, 3))
>>> # predictions = model.predict(x)
>>>
>>> dataset.save(model)
>>> loaded_model = dataset.load()
>>> new_predictions = loaded_model.predict([...])
>>> np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
"""

Expand Down

0 comments on commit 6997b11

Please sign in to comment.