Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DISC] Having IbisML closer to sklearn #141

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
397 changes: 397 additions & 0 deletions examples/Example estimators.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,397 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4377a22e-f751-4a12-b07a-c6b31e3c48d9",
"metadata": {},
"source": [
"## The New York City flight data\n",
"\n",
"Let's use the [nycflights13 data](https://github.com/hadley/nycflights13) to predict whether a plane arrives more than 30 minutes late. This dataset contains information on 325,819 flights departing near New York City in 2013. Let's start by loading the data and making a few changes to the variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be74f0f2-bccd-424e-984d-35d5d6d0bc53",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"import ibis\n",
"\n",
"con = ibis.connect(\"duckdb://nycflights13.ddb\")\n",
"con.create_table(\n",
" \"flights\", ibis.examples.nycflights13_flights.fetch().to_pyarrow(), overwrite=True\n",
")\n",
"con.create_table(\n",
" \"weather\", ibis.examples.nycflights13_weather.fetch().to_pyarrow(), overwrite=True\n",
")"
]
},
{
"cell_type": "markdown",
"id": "38c7ea7f-10cd-4c6b-8d55-cb98b10c317e",
"metadata": {},
"source": [
"You can now see the example dataset copied over to the database:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d0fca09-e4c4-4c89-bf26-a025a2ec3b49",
"metadata": {},
"outputs": [],
"source": [
"con = ibis.connect(\"duckdb://nycflights13.ddb\")\n",
"con.list_tables()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a102d7ca-ee8c-4212-9f12-1cb16eee7d72",
"metadata": {},
"outputs": [],
"source": [
"flights = con.table(\"flights\")\n",
"flights = flights.mutate(\n",
" dep_time=(\n",
" flights.dep_time.lpad(4, \"0\").substr(0, 2)\n",
" + \":\"\n",
" + flights.dep_time.substr(-2, 2)\n",
" + \":00\"\n",
" ).try_cast(\"time\"),\n",
" arr_delay=flights.arr_delay.try_cast(int),\n",
" air_time=flights.air_time.try_cast(int),\n",
")\n",
"flights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77d73bab-3ad3-489e-bbb2-a7175693a01e",
"metadata": {},
"outputs": [],
"source": [
"weather = con.table(\"weather\")\n",
"weather"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc81d493-11cf-47a1-b808-c6cda41a8739",
"metadata": {},
"outputs": [],
"source": [
"flight_data = (\n",
" flights.mutate(\n",
" # Convert the arrival delay to a factor\n",
" # By default, PyTorch expects the target to have a Long datatype\n",
" arr_delay=ibis.ifelse(flights.arr_delay >= 30, 1, 0).cast(\"int64\"),\n",
" # We will use the date (not date-time) in the recipe below\n",
" date=flights.time_hour.date(),\n",
" )\n",
" # Include the weather data\n",
" .inner_join(weather, [\"origin\", \"time_hour\"])\n",
" # Only retain the specific columns we will use\n",
" .select(\n",
" \"dep_time\",\n",
" \"flight\",\n",
" \"origin\",\n",
" \"dest\",\n",
" \"air_time\",\n",
" \"distance\",\n",
" \"carrier\",\n",
" \"date\",\n",
" \"arr_delay\",\n",
" \"time_hour\",\n",
" )\n",
" # Exclude missing data\n",
" .dropna()\n",
")\n",
"flight_data"
]
},
{
"cell_type": "markdown",
"id": "722b2213-3b84-4f03-9006-59bf72591613",
"metadata": {},
"source": [
"We can see that about 16% of the flights in this dataset arrived more than 30 minutes late."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82601a08-ddd0-42db-949a-c8a423af6a1b",
"metadata": {},
"outputs": [],
"source": [
"flight_data.arr_delay.value_counts().rename(n=\"arr_delay_count\").mutate(\n",
" prop=ibis._.n / ibis._.n.sum()\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5e9eaa27-b1de-46eb-a910-5d94a67e07a1",
"metadata": {},
"source": [
"## Data splitting\n",
"\n",
"To get started, let's split this single dataset into two: a _training_ set and a _testing_ set. We'll keep most of the rows in the original dataset (subset chosen randomly) in the _training_ set. The training data will be used to _fit_ the model, and the _testing_ set will be used to measure model performance.\n",
"\n",
"Because the order of rows in an Ibis table is undefined, we need a unique key to split the data reproducibly. [It is permissible for airlines to use the same flight number for different routes, as long as the flights do not operate on the same day. This means that the combination of the flight number and the date of travel is always unique.](https://www.euclaim.com/blog/flight-numbers-explained#:~:text=Can%20flight%20numbers%20be%20reused,of%20travel%20is%20always%20unique.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "732624f4-a2af-4c6e-b29d-4fb7cb5fc99e",
"metadata": {},
"outputs": [],
"source": [
"flight_data_with_unique_key = flight_data.mutate(\n",
" unique_key=ibis.literal(\",\").join(\n",
" [flight_data.carrier, flight_data.flight.cast(str), flight_data.date.cast(str)]\n",
" )\n",
")\n",
"flight_data_with_unique_key"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9cd58ce-dc2d-4e4e-8b4a-51100fe1182c",
"metadata": {},
"outputs": [],
"source": [
"# FIXME(deepyaman): Proposed key isn't unique for actual departure date.\n",
"flight_data_with_unique_key.group_by(\"unique_key\").mutate(\n",
" cnt=flight_data_with_unique_key.count()\n",
")[ibis._.cnt > 1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6be459de-73cd-4d6e-a195-41b9e5c481a6",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"# Fix the random numbers by setting the seed\n",
"# This enables the analysis to be reproducible when random numbers are used\n",
"random.seed(222)\n",
"\n",
"# Put 3/4 of the data into the training set\n",
"random_key = str(random.getrandbits(256))\n",
"data_split = flight_data_with_unique_key.mutate(\n",
" train=(flight_data_with_unique_key.unique_key + random_key).hash().abs() % 4 < 3\n",
")\n",
"\n",
"# Create data frames for the two sets:\n",
"train_data = data_split[data_split.train].drop(\"unique_key\", \"train\")\n",
"test_data = data_split[~data_split.train].drop(\"unique_key\", \"train\")"
]
},
{
"cell_type": "markdown",
"id": "e18eeb66-b2fc-4da1-933d-58d78beebc2b",
"metadata": {},
"source": [
"## Create features"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a223b57d-31b7-4ad1-88fd-a216de7da01a",
"metadata": {},
"outputs": [],
"source": [
"import ibis_ml as ml\n",
"\n",
"flights_rec = ml.Recipe(\n",
" ml.ExpandDate(\"date\", components=[\"dow\", \"month\"]),\n",
" ml.Drop(\"date\"),\n",
" ml.TargetEncode(ml.nominal()),\n",
" ml.DropZeroVariance(ml.everything()),\n",
" ml.MutateAt(\"dep_time\", ibis._.hour() * 60 + ibis._.minute()),\n",
" ml.MutateAt(ml.timestamp(), ibis._.epoch_seconds()),\n",
" # By default, PyTorch requires that the type of `X` is `np.float32`.\n",
" # https://discuss.pytorch.org/t/mat1-and-mat2-must-have-the-same-dtype-but-got-double-and-float/197555/2\n",
" ml.Cast(ml.numeric(), \"float32\"),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5e7dbe5",
"metadata": {},
"outputs": [],
"source": [
"ml_transformer = ml.PolynomialFeatures(columns=ml.numeric(), degree=2)"
]
},
{
"cell_type": "markdown",
"id": "373a0cb6-1bca-4add-8807-b9105b27fa5d",
"metadata": {},
"source": [
"## Fit a model with a recipe\n",
"\n",
"Let's model the flight data. We can use any scikit-learn-compatible estimator."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc04f24e-c8cb-4580-b502-a9410c64a126",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"mod = LogisticRegression()\n",
"\n",
"pipe_trans = Pipeline([(\"flights_rec\", flights_rec), (\"poly\", ml_transformer)])"
]
},
{
"cell_type": "markdown",
"id": "f666c1b3",
"metadata": {},
"source": [
"get_params works as expected for our ml_transformer as we have:\n",
"{\n",
"...\n",
" 'poly__columns': numeric(),\n",
" 'poly__degree': 2,\n",
"...\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d89aa805",
"metadata": {},
"outputs": [],
"source": [
"pipe_trans.get_params()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98d5d60d",
"metadata": {},
"outputs": [],
"source": [
"X_train = train_data.drop(\"arr_delay\")\n",
"y_train = train_data.arr_delay"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7736ac0",
"metadata": {},
"outputs": [],
"source": [
"pipe_trans.fit_transform(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f43b1c0",
"metadata": {},
"outputs": [],
"source": [
"pipe = Pipeline([(\"trans\", pipe_trans), (\"mod\", mod)])"
]
},
{
"cell_type": "markdown",
"id": "412922b1-bcb8-48c7-ad9f-3f37ed969cd8",
"metadata": {},
"source": [
"Now, there is a single function that can be used to prepare the recipe and train the model from the resulting predictors:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42ac1426-0561-4a8b-a949-127b2b0c4f01",
"metadata": {},
"outputs": [],
"source": [
"X_train = train_data.drop(\"arr_delay\")\n",
"y_train = train_data.arr_delay\n",
"pipe.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"id": "21d9774a-543d-4272-9310-7743201cf9b6",
"metadata": {},
"source": [
"## Use a trained workflow to predict\n",
"\n",
"..."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be3ff129-d56f-4441-acbc-da7d6cd93d19",
"metadata": {},
"outputs": [],
"source": [
"X_test = test_data.drop(\"arr_delay\")\n",
"y_test = test_data.arr_delay\n",
"pipe.score(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"id": "cc21b842-b85c-4ed9-af03-1feace909172",
"metadata": {},
"source": [
"## Acknowledgments\n",
"\n",
"This tutorial is derived from the [tidymodels article of the same name](https://www.tidymodels.org/start/recipes/). The transformation logic is very similar, and much of the text is copied verbatim."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading