diff --git a/.travis.yml b/.travis.yml
index 1db80743..38f60163 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -14,6 +14,6 @@ services:
- postgresql
env:
global:
- - PGPORT=5432
+ - SB_TEST_PGPORT=5432
script:
- make test-travis
diff --git a/docs/.gitattributes b/docs/.gitattributes
index d9d68857..19c35d26 100644
--- a/docs/.gitattributes
+++ b/docs/.gitattributes
@@ -1,3 +1,2 @@
-*.ipynb filter=nbstripout
*.ipynb diff=ipynb
diff --git a/docs/index.rst b/docs/index.rst
index 7f27327c..6ef37a36 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -3,7 +3,7 @@
:hidden:
intro.Rmd
- intro_sql.Rmd
+ intro_sql.ipynb
.. toctree::
:caption: Core One-table Verbs
diff --git a/docs/intro_sql.Rmd b/docs/intro_sql.Rmd
index 8f4d3909..77c56131 100644
--- a/docs/intro_sql.Rmd
+++ b/docs/intro_sql.Rmd
@@ -12,38 +12,131 @@ jupyter:
name: python3
---
+```{python nbsphinx=hidden}
+import matplotlib.cbook
+
+import warnings
+import plotnine
+warnings.filterwarnings(module='plotnine*', action='ignore')
+warnings.filterwarnings(module='matplotlib*', action='ignore')
+
+# %matplotlib inline
+```
+
# Using to query SQL
+
+# Setting up
+
```{python}
-from sqlalchemy import create_engine
-from siuba.data import mtcars
import pandas as pd
+from siuba.tests.helpers import copy_to_sql
+from siuba import *
+from siuba.dply.vector import lag, desc, row_number
+from siuba.dply.string import str_c
-engine = create_engine('sqlite:///:memory:', echo=False)
+tv_ratings = pd.read_csv(
+ "https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2019/2019-01-08/IMDb_Economist_tv_ratings.csv",
+ parse_dates = ["date"]
+ )
-# note that mtcars is a pandas DataFrame
-mtcars.to_sql('mtcars', engine)
```
```{python}
-from siuba import *
-from siuba.sql import LazyTbl, show_query, collect
+tbl_ratings = copy_to_sql(tv_ratings, "tv_ratings", "postgresql://postgres:@localhost:5433/postgres")
-tbl_mtcars = LazyTbl(engine, 'mtcars')
```
```{python}
-tbl_mtcars
+tbl_ratings
+
```
+## Inspecting a single show
+
```{python}
-tbl_mtcars >> filter(_.hp > 250) >> collect()
+buffy = (tbl_ratings
+ >> filter(_.title == "Buffy the Vampire Slayer")
+ >> collect()
+ )
+
+buffy
+```
+
+```{python}
+buffy >> summarize(avg_rating = _.av_rating.mean())
```
+## Average rating per show, along with dates
+
```{python}
-(tbl_mtcars
- >> group_by(_.cyl)
- >> summarize(avg_mpg = _.mpg.mean())
+avg_ratings = (tbl_ratings
+ >> group_by(_.title)
+ >> summarize(
+ avg_rating = _.av_rating.mean(),
+ date_range = str_c(_.date.dt.year.max(), " - ", _.date.dt.year.min())
+ )
+ )
+
+avg_ratings
+```
+
+## Biggest changes in ratings between two seasons
+
+```{python}
+top_4_shifts = (tbl_ratings
+ >> group_by(_.title)
+ >> mutate(rating_shift = _.av_rating - lag(_.av_rating))
+ >> summarize(
+ max_shift = _.rating_shift.max()
+ )
+ >> arrange(-_.max_shift)
+ >> head(4)
+ )
+
+top_4_shifts
+```
+
+```{python}
+big_shift_series = (top_4_shifts
+ >> select(_.title)
+ >> inner_join(_, tbl_ratings, "title")
>> collect()
)
+
+from plotnine import *
+
+(big_shift_series
+ >> ggplot(aes("seasonNumber", "av_rating"))
+ + geom_point()
+ + geom_line()
+ + facet_wrap("~ title")
+ + labs(
+ title = "Seasons with Biggest Shifts in Ratings",
+ y = "Average rating",
+ x = "Season"
+ )
+ )
+```
+
+## Do we have full data for each season?
+
+```{python}
+mismatches = (tbl_ratings
+ >> arrange(_.title, _.seasonNumber)
+ >> group_by(_.title)
+ >> mutate(
+ row = row_number(_),
+ mismatch = _.row != _.seasonNumber
+ )
+ >> filter(_.mismatch.any())
+ >> ungroup()
+ )
+
+
+mismatches
+```
+
+```{python}
+mismatches >> distinct(_.title) >> count() >> collect()
```
diff --git a/docs/intro_sql.ipynb b/docs/intro_sql.ipynb
new file mode 100644
index 00000000..fb4712ea
--- /dev/null
+++ b/docs/intro_sql.ipynb
@@ -0,0 +1,877 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "nbsphinx": "hidden"
+ },
+ "outputs": [],
+ "source": [
+ "import matplotlib.cbook\n",
+ "\n",
+ "import warnings\n",
+ "import plotnine\n",
+ "warnings.filterwarnings(module='plotnine*', action='ignore')\n",
+ "warnings.filterwarnings(module='matplotlib*', action='ignore')\n",
+ "\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Using to query SQL"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Setting up"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "from siuba.tests.helpers import copy_to_sql\n",
+ "from siuba import *\n",
+ "from siuba.dply.vector import lag, desc, row_number\n",
+ "from siuba.dply.string import str_c\n",
+ "\n",
+ "tv_ratings = pd.read_csv(\n",
+ " \"https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2019/2019-01-08/IMDb_Economist_tv_ratings.csv\",\n",
+ " parse_dates = [\"date\"]\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tbl_ratings = copy_to_sql(tv_ratings, \"tv_ratings\", \"postgresql://postgres:@localhost:5433/postgres\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " titleId | \n",
+ " seasonNumber | \n",
+ " title | \n",
+ " date | \n",
+ " av_rating | \n",
+ " share | \n",
+ " genres | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " tt2879552 | \n",
+ " 1 | \n",
+ " 11.22.63 | \n",
+ " 2016-03-10 | \n",
+ " 8.4890 | \n",
+ " 0.51 | \n",
+ " Drama,Mystery,Sci-Fi | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " tt3148266 | \n",
+ " 1 | \n",
+ " 12 Monkeys | \n",
+ " 2015-02-27 | \n",
+ " 8.3407 | \n",
+ " 0.46 | \n",
+ " Adventure,Drama,Mystery | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " tt3148266 | \n",
+ " 2 | \n",
+ " 12 Monkeys | \n",
+ " 2016-05-30 | \n",
+ " 8.8196 | \n",
+ " 0.25 | \n",
+ " Adventure,Drama,Mystery | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " tt3148266 | \n",
+ " 3 | \n",
+ " 12 Monkeys | \n",
+ " 2017-05-19 | \n",
+ " 9.0369 | \n",
+ " 0.19 | \n",
+ " Adventure,Drama,Mystery | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " tt3148266 | \n",
+ " 4 | \n",
+ " 12 Monkeys | \n",
+ " 2018-06-26 | \n",
+ " 9.1363 | \n",
+ " 0.38 | \n",
+ " Adventure,Drama,Mystery | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
# .. may have more rows
"
+ ],
+ "text/plain": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ " titleId seasonNumber title date av_rating share \\\n",
+ "0 tt2879552 1 11.22.63 2016-03-10 8.4890 0.51 \n",
+ "1 tt3148266 1 12 Monkeys 2015-02-27 8.3407 0.46 \n",
+ "2 tt3148266 2 12 Monkeys 2016-05-30 8.8196 0.25 \n",
+ "3 tt3148266 3 12 Monkeys 2017-05-19 9.0369 0.19 \n",
+ "4 tt3148266 4 12 Monkeys 2018-06-26 9.1363 0.38 \n",
+ "\n",
+ " genres \n",
+ "0 Drama,Mystery,Sci-Fi \n",
+ "1 Adventure,Drama,Mystery \n",
+ "2 Adventure,Drama,Mystery \n",
+ "3 Adventure,Drama,Mystery \n",
+ "4 Adventure,Drama,Mystery \n",
+ "# .. may have more rows"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tbl_ratings\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Inspecting a single show"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " titleId | \n",
+ " seasonNumber | \n",
+ " title | \n",
+ " date | \n",
+ " av_rating | \n",
+ " share | \n",
+ " genres | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " tt0118276 | \n",
+ " 1 | \n",
+ " Buffy the Vampire Slayer | \n",
+ " 1997-04-14 | \n",
+ " 7.9629 | \n",
+ " 11.70 | \n",
+ " Action,Drama,Fantasy | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " tt0118276 | \n",
+ " 2 | \n",
+ " Buffy the Vampire Slayer | \n",
+ " 1997-12-31 | \n",
+ " 8.4191 | \n",
+ " 19.41 | \n",
+ " Action,Drama,Fantasy | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " tt0118276 | \n",
+ " 3 | \n",
+ " Buffy the Vampire Slayer | \n",
+ " 1999-01-29 | \n",
+ " 8.6233 | \n",
+ " 17.12 | \n",
+ " Action,Drama,Fantasy | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " tt0118276 | \n",
+ " 4 | \n",
+ " Buffy the Vampire Slayer | \n",
+ " 2000-01-19 | \n",
+ " 8.2205 | \n",
+ " 16.19 | \n",
+ " Action,Drama,Fantasy | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " tt0118276 | \n",
+ " 5 | \n",
+ " Buffy the Vampire Slayer | \n",
+ " 2001-01-12 | \n",
+ " 8.3028 | \n",
+ " 11.99 | \n",
+ " Action,Drama,Fantasy | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " tt0118276 | \n",
+ " 6 | \n",
+ " Buffy the Vampire Slayer | \n",
+ " 2002-01-29 | \n",
+ " 8.1008 | \n",
+ " 8.45 | \n",
+ " Action,Drama,Fantasy | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " tt0118276 | \n",
+ " 7 | \n",
+ " Buffy the Vampire Slayer | \n",
+ " 2003-01-18 | \n",
+ " 8.0460 | \n",
+ " 9.89 | \n",
+ " Action,Drama,Fantasy | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " titleId seasonNumber title date av_rating \\\n",
+ "0 tt0118276 1 Buffy the Vampire Slayer 1997-04-14 7.9629 \n",
+ "1 tt0118276 2 Buffy the Vampire Slayer 1997-12-31 8.4191 \n",
+ "2 tt0118276 3 Buffy the Vampire Slayer 1999-01-29 8.6233 \n",
+ "3 tt0118276 4 Buffy the Vampire Slayer 2000-01-19 8.2205 \n",
+ "4 tt0118276 5 Buffy the Vampire Slayer 2001-01-12 8.3028 \n",
+ "5 tt0118276 6 Buffy the Vampire Slayer 2002-01-29 8.1008 \n",
+ "6 tt0118276 7 Buffy the Vampire Slayer 2003-01-18 8.0460 \n",
+ "\n",
+ " share genres \n",
+ "0 11.70 Action,Drama,Fantasy \n",
+ "1 19.41 Action,Drama,Fantasy \n",
+ "2 17.12 Action,Drama,Fantasy \n",
+ "3 16.19 Action,Drama,Fantasy \n",
+ "4 11.99 Action,Drama,Fantasy \n",
+ "5 8.45 Action,Drama,Fantasy \n",
+ "6 9.89 Action,Drama,Fantasy "
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "buffy = (tbl_ratings\n",
+ " >> filter(_.title == \"Buffy the Vampire Slayer\")\n",
+ " >> collect()\n",
+ " )\n",
+ "\n",
+ "buffy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " avg_rating | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 8.239343 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " avg_rating\n",
+ "0 8.239343"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "buffy >> summarize(avg_rating = _.av_rating.mean())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Average rating per show, along with dates"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " title | \n",
+ " avg_rating | \n",
+ " date_range | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Friends from College | \n",
+ " 6.875100 | \n",
+ " 2017 - 2017 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Better Things | \n",
+ " 8.133150 | \n",
+ " 2017 - 2016 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " How to Get Away with Murder | \n",
+ " 8.762340 | \n",
+ " 2018 - 2014 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Dexter | \n",
+ " 8.582400 | \n",
+ " 2013 - 2006 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Queen of the South | \n",
+ " 8.574733 | \n",
+ " 2018 - 2016 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
# .. may have more rows
"
+ ],
+ "text/plain": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ " title avg_rating date_range\n",
+ "0 Friends from College 6.875100 2017 - 2017\n",
+ "1 Better Things 8.133150 2017 - 2016\n",
+ "2 How to Get Away with Murder 8.762340 2018 - 2014\n",
+ "3 Dexter 8.582400 2013 - 2006\n",
+ "4 Queen of the South 8.574733 2018 - 2016\n",
+ "# .. may have more rows"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "avg_ratings = (tbl_ratings \n",
+ " >> group_by(_.title)\n",
+ " >> summarize(\n",
+ " avg_rating = _.av_rating.mean(),\n",
+ " date_range = str_c(_.date.dt.year.max(), \" - \", _.date.dt.year.min())\n",
+ " )\n",
+ " )\n",
+ "\n",
+ "avg_ratings"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Biggest changes in ratings between two seasons"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " title | \n",
+ " max_shift | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Third Watch | \n",
+ " 4.8500 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Are You Afraid of the Dark? | \n",
+ " 2.3430 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Lethal Weapon | \n",
+ " 2.3070 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Law & Order: Special Victims Unit | \n",
+ " 2.0508 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
# .. may have more rows
"
+ ],
+ "text/plain": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ " title max_shift\n",
+ "0 Third Watch 4.8500\n",
+ "1 Are You Afraid of the Dark? 2.3430\n",
+ "2 Lethal Weapon 2.3070\n",
+ "3 Law & Order: Special Victims Unit 2.0508\n",
+ "# .. may have more rows"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "top_4_shifts = (tbl_ratings\n",
+ " >> group_by(_.title)\n",
+ " >> mutate(rating_shift = _.av_rating - lag(_.av_rating))\n",
+ " >> summarize(\n",
+ " max_shift = _.rating_shift.max()\n",
+ " )\n",
+ " >> arrange(-_.max_shift)\n",
+ " >> head(4)\n",
+ " )\n",
+ "\n",
+ "top_4_shifts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "big_shift_series = (top_4_shifts\n",
+ " >> select(_.title)\n",
+ " >> inner_join(_, tbl_ratings, \"title\")\n",
+ " >> collect()\n",
+ " )\n",
+ "\n",
+ "from plotnine import *\n",
+ "\n",
+ "(big_shift_series\n",
+ " >> ggplot(aes(\"seasonNumber\", \"av_rating\"))\n",
+ " + geom_point()\n",
+ " + geom_line()\n",
+ " + facet_wrap(\"~ title\")\n",
+ " + labs(\n",
+ " title = \"Seasons with Biggest Shifts in Ratings\",\n",
+ " y = \"Average rating\",\n",
+ " x = \"Season\"\n",
+ " )\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Do we have full data for each season?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " title | \n",
+ " titleId | \n",
+ " seasonNumber | \n",
+ " date | \n",
+ " av_rating | \n",
+ " share | \n",
+ " genres | \n",
+ " row | \n",
+ " mismatch | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 7th Heaven | \n",
+ " tt0115083 | \n",
+ " 1 | \n",
+ " 1996-08-26 | \n",
+ " 7.700 | \n",
+ " 0.10 | \n",
+ " Drama,Family,Romance | \n",
+ " 1 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 7th Heaven | \n",
+ " tt0115083 | \n",
+ " 10 | \n",
+ " 2006-05-08 | \n",
+ " 6.300 | \n",
+ " 0.01 | \n",
+ " Drama,Family,Romance | \n",
+ " 2 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ABC Afterschool Specials | \n",
+ " tt0202179 | \n",
+ " 25 | \n",
+ " 1996-09-12 | \n",
+ " 3.300 | \n",
+ " 0.10 | \n",
+ " Adventure,Comedy,Drama | \n",
+ " 1 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " American Gothic | \n",
+ " tt5257744 | \n",
+ " 1 | \n",
+ " 2016-08-05 | \n",
+ " 7.535 | \n",
+ " 0.07 | \n",
+ " Crime,Drama,Mystery | \n",
+ " 1 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " American Gothic | \n",
+ " tt0111880 | \n",
+ " 1 | \n",
+ " 1995-09-22 | \n",
+ " 7.800 | \n",
+ " 0.08 | \n",
+ " Drama,Horror,Thriller | \n",
+ " 2 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
# .. may have more rows
"
+ ],
+ "text/plain": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ " title titleId seasonNumber date av_rating \\\n",
+ "0 7th Heaven tt0115083 1 1996-08-26 7.700 \n",
+ "1 7th Heaven tt0115083 10 2006-05-08 6.300 \n",
+ "2 ABC Afterschool Specials tt0202179 25 1996-09-12 3.300 \n",
+ "3 American Gothic tt5257744 1 2016-08-05 7.535 \n",
+ "4 American Gothic tt0111880 1 1995-09-22 7.800 \n",
+ "\n",
+ " share genres row mismatch \n",
+ "0 0.10 Drama,Family,Romance 1 False \n",
+ "1 0.01 Drama,Family,Romance 2 True \n",
+ "2 0.10 Adventure,Comedy,Drama 1 True \n",
+ "3 0.07 Crime,Drama,Mystery 1 False \n",
+ "4 0.08 Drama,Horror,Thriller 2 True \n",
+ "# .. may have more rows"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mismatches = (tbl_ratings\n",
+ " >> arrange(_.title, _.seasonNumber)\n",
+ " >> group_by(_.title)\n",
+ " >> mutate(\n",
+ " row = row_number(_),\n",
+ " mismatch = _.row != _.seasonNumber\n",
+ " )\n",
+ " >> filter(_.mismatch.any())\n",
+ " >> ungroup()\n",
+ " )\n",
+ "\n",
+ "\n",
+ "mismatches"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " n | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 54 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " n\n",
+ "0 54"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mismatches >> distinct(_.title) >> count() >> collect()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.7"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/examples-postgres.ipynb b/examples/examples-postgres.ipynb
index bfb3e097..7ea34ac7 100644
--- a/examples/examples-postgres.ipynb
+++ b/examples/examples-postgres.ipynb
@@ -17,7 +17,7 @@
{
"data": {
"text/plain": [
- ""
+ ""
]
},
"execution_count": 1,
@@ -87,7 +87,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "SELECT anon_1.id, anon_1.user_id, anon_1.email_address, anon_1.num, anon_1.anon_2 \n",
+ "SELECT anon_1.user_id, anon_1.id, anon_1.email_address, anon_1.num \n",
"FROM (SELECT id, user_id, email_address, num, min(anon_3.id) OVER (PARTITION BY anon_3.user_id) AS anon_2 \n",
"FROM (SELECT id, user_id, email_address, dense_rank() OVER (PARTITION BY addresses.user_id ORDER BY addresses.id) AS num \n",
"FROM addresses) AS anon_3) AS anon_1 \n",
@@ -115,29 +115,27 @@
" \n",
" \n",
" | \n",
- " id | \n",
" user_id | \n",
+ " id | \n",
" email_address | \n",
" num | \n",
- " anon_2 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
- " 2 | \n",
" 1 | \n",
+ " 2 | \n",
" jack@msn.com | \n",
" 2 | \n",
- " 1 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " id user_id email_address num anon_2\n",
- "0 2 1 jack@msn.com 2 1"
+ " user_id id email_address num\n",
+ "0 1 2 jack@msn.com 2"
]
},
"execution_count": 2,
@@ -180,6 +178,26 @@
"#tbl_addresses >> group_by(_, \"user_id\") >> mutate(_, num = dense_rank(_.id)) >> show_query(_)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ">"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sql.functions.sum().over"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -189,7 +207,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -210,7 +228,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -232,7 +250,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -264,7 +282,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -288,14 +306,14 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "SELECT anon_1.id, anon_1.user_id, anon_1.email_address \n",
+ "SELECT anon_1.user_id, anon_1.id, anon_1.email_address \n",
"FROM (SELECT anon_2.id AS id, anon_2.user_id AS user_id, anon_2.email_address AS email_address \n",
"FROM (SELECT addresses.id AS id, addresses.user_id AS user_id, addresses.email_address AS email_address \n",
"FROM addresses) AS anon_2) AS anon_1 \n",
@@ -313,18 +331,18 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "SELECT anon_1.id, anon_1.user_id, anon_1.email_address, anon_1.anon_2 \n",
- "FROM (SELECT anon_3.id AS id, anon_3.user_id AS user_id, anon_3.email_address AS email_address, dense_rank() OVER (PARTITION BY anon_3.user_id ORDER BY anon_3.id) AS anon_2 \n",
+ "SELECT anon_1.user_id, anon_1.id, anon_1.email_address \n",
+ "FROM (SELECT anon_2.id AS id, anon_2.user_id AS user_id, anon_2.email_address AS email_address, dense_rank() OVER (PARTITION BY anon_2.user_id ORDER BY anon_2.id) AS anon_3 \n",
"FROM (SELECT addresses.id AS id, addresses.user_id AS user_id, addresses.email_address AS email_address \n",
- "FROM addresses) AS anon_3) AS anon_1 \n",
- "WHERE anon_1.anon_2 > 1\n"
+ "FROM addresses) AS anon_2) AS anon_1 \n",
+ "WHERE anon_1.anon_3 > 1\n"
]
},
{
@@ -348,38 +366,35 @@
" \n",
" \n",
" | \n",
- " id | \n",
" user_id | \n",
+ " id | \n",
" email_address | \n",
- " anon_2 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
- " 2 | \n",
" 1 | \n",
- " jack@msn.com | \n",
" 2 | \n",
+ " jack@msn.com | \n",
"
\n",
" \n",
" 1 | \n",
- " 4 | \n",
" 2 | \n",
+ " 4 | \n",
" wendy@aol.com | \n",
- " 2 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " id user_id email_address anon_2\n",
- "0 2 1 jack@msn.com 2\n",
- "1 4 2 wendy@aol.com 2"
+ " user_id id email_address\n",
+ "0 1 2 jack@msn.com\n",
+ "1 2 4 wendy@aol.com"
]
},
- "execution_count": 8,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -404,7 +419,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -461,7 +476,7 @@
"1 1 1.5"
]
},
- "execution_count": 9,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -479,15 +494,16 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "SELECT avg(addresses.id + 1) AS m_id \n",
- "FROM addresses\n"
+ "SELECT avg(anon_1.id2) AS m_id \n",
+ "FROM (SELECT addresses.id AS id, addresses.user_id AS user_id, addresses.email_address AS email_address, addresses.id + 1 AS id2 \n",
+ "FROM addresses) AS anon_1\n"
]
}
],
@@ -504,7 +520,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
@@ -566,7 +582,7 @@
"1 2 0 2"
]
},
- "execution_count": 11,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -591,16 +607,16 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "SELECT anon_1.id AS id_x, anon_1.user_id, anon_1.email_address, anon_2.id AS id_y, anon_2.name, anon_2.fullname \n",
+ "SELECT anon_1.id, anon_1.user_id, anon_1.email_address, anon_2.fullname, anon_2.name \n",
"FROM (SELECT addresses.id AS id, addresses.user_id AS user_id, addresses.email_address AS email_address \n",
- "FROM addresses) AS anon_1 JOIN (SELECT users.id AS id, users.name AS name, users.fullname AS fullname \n",
+ "FROM addresses) AS anon_1 LEFT OUTER JOIN (SELECT users.id AS id, users.name AS name, users.fullname AS fullname \n",
"FROM users) AS anon_2 ON anon_1.user_id = anon_2.id\n"
]
},
@@ -625,12 +641,11 @@
" \n",
" \n",
" | \n",
- " id_x | \n",
+ " id | \n",
" user_id | \n",
" email_address | \n",
- " id_y | \n",
- " name | \n",
" fullname | \n",
+ " name | \n",
"
\n",
" \n",
" \n",
@@ -639,50 +654,46 @@
" 1 | \n",
" 1 | \n",
" jack@yahoo.com | \n",
- " 1 | \n",
- " jack | \n",
" Jack Jones | \n",
+ " jack | \n",
" \n",
" \n",
" 1 | \n",
" 2 | \n",
" 1 | \n",
" jack@msn.com | \n",
- " 1 | \n",
- " jack | \n",
" Jack Jones | \n",
+ " jack | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 2 | \n",
" www@www.org | \n",
- " 2 | \n",
- " wendy | \n",
" Wendy Williams | \n",
+ " wendy | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 2 | \n",
" wendy@aol.com | \n",
- " 2 | \n",
- " wendy | \n",
" Wendy Williams | \n",
+ " wendy | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " id_x user_id email_address id_y name fullname\n",
- "0 1 1 jack@yahoo.com 1 jack Jack Jones\n",
- "1 2 1 jack@msn.com 1 jack Jack Jones\n",
- "2 3 2 www@www.org 2 wendy Wendy Williams\n",
- "3 4 2 wendy@aol.com 2 wendy Wendy Williams"
+ " id user_id email_address fullname name\n",
+ "0 1 1 jack@yahoo.com Jack Jones jack\n",
+ "1 2 1 jack@msn.com Jack Jones jack\n",
+ "2 3 2 www@www.org Wendy Williams wendy\n",
+ "3 4 2 wendy@aol.com Wendy Williams wendy"
]
},
- "execution_count": 12,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -708,7 +719,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -787,7 +798,7 @@
"3 4 2 wendy@aol.com 1"
]
},
- "execution_count": 13,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -811,7 +822,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -867,7 +878,7 @@
"0 1 1 jack@yahoo.com"
]
},
- "execution_count": 14,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -892,7 +903,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -971,7 +982,7 @@
"3 4 2 wendy@aol.com 0"
]
},
- "execution_count": 15,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -997,14 +1008,14 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"█─'__call__'\n",
- "├─\n",
+ "├─\n",
"├─_\n",
"└─█─''\n",
" └─█─'__call__'\n",
@@ -1012,7 +1023,7 @@
" └─{_.id > 1: 'yeah', True: 'no'}"
]
},
- "execution_count": 16,
+ "execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@@ -1031,7 +1042,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -1059,7 +1070,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -1127,7 +1138,7 @@
"2 3 2 www@www.org"
]
},
- "execution_count": 18,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -1149,7 +1160,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
@@ -1223,7 +1234,7 @@
"3 4 2 wendy@aol.com"
]
},
- "execution_count": 19,
+ "execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
@@ -1245,7 +1256,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -1302,7 +1313,7 @@
"1 1 2"
]
},
- "execution_count": 20,
+ "execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -1317,7 +1328,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
@@ -1386,7 +1397,7 @@
"3 wendy@aol.com 1"
]
},
- "execution_count": 21,
+ "execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -1416,7 +1427,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -1465,7 +1476,7 @@
"1 1 2"
]
},
- "execution_count": 22,
+ "execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
@@ -1495,7 +1506,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
@@ -1518,6 +1529,55 @@
"## SQL escapes"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Window functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "SELECT addresses.id, addresses.user_id, addresses.email_address, sum(addresses.user_id) OVER (ORDER BY addresses.id DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cumsum \n",
+ "FROM addresses ORDER BY addresses.id DESC, cumsum\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "# Source: lazy query\n",
+ "# DB Conn: Engine(postgresql://postgres:***@localhost:5433/postgres)\n",
+ "# Preview:\n",
+ " id user_id email_address cumsum\n",
+ "0 4 2 wendy@aol.com 2\n",
+ "1 3 2 www@www.org 4\n",
+ "2 2 1 jack@msn.com 5\n",
+ "3 1 1 jack@yahoo.com 6\n",
+ "# .. may have more rows"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from siuba.dply.vector import desc\n",
+ "(tbl_addresses\n",
+ " >> arrange(_.id.desc())\n",
+ " >> mutate(cumsum = _.user_id.cumsum())\n",
+ " >> arrange(_.cumsum)\n",
+ " >> show_query()\n",
+ " )"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -1534,7 +1594,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
@@ -1613,7 +1673,7 @@
"3 4 2 wendy@aol.com 4.0"
]
},
- "execution_count": 24,
+ "execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
@@ -1635,7 +1695,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 27,
"metadata": {},
"outputs": [
{
@@ -1698,7 +1758,7 @@
"1 2 wendy Wendy Williams 3"
]
},
- "execution_count": 25,
+ "execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
@@ -1727,7 +1787,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 28,
"metadata": {},
"outputs": [
{
@@ -1790,7 +1850,7 @@
"1 2 wendy Wendy Williams 3"
]
},
- "execution_count": 26,
+ "execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
@@ -1818,7 +1878,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 29,
"metadata": {},
"outputs": [
{
@@ -1833,7 +1893,7 @@
"# .. may have more rows"
]
},
- "execution_count": 27,
+ "execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/requirements.txt b/requirements.txt
index cb3bd17b..1507e4ec 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,10 +2,12 @@ numpy==1.16.1
pandas==0.24.1
python-dateutil==2.8.0
pytz==2018.9
+scipy==1.2.1
six==1.12.0
SQLAlchemy==1.2.17
nbval==0.9.1
# tests
+pytest==4.4.2
psycopg2==2.8.2
# only used for iris dataset
scikit-learn==0.20.2
@@ -13,3 +15,5 @@ scikit-learn==0.20.2
nbsphinx==0.4.2
jupytext==1.1.1
gapminder==0.1
+matplotlib==3.1.0
+plotnine==0.5.1
diff --git a/siuba/dply/string.py b/siuba/dply/string.py
new file mode 100644
index 00000000..b4099868
--- /dev/null
+++ b/siuba/dply/string.py
@@ -0,0 +1,33 @@
+import pandas as pd
+import numpy as np
+from functools import singledispatch
+import itertools
+
+from ..siu import Symbolic, create_sym_call,Call
+
+
+def register_symbolic(f):
+ # TODO: don't use singledispatch if it has already been done
+ f = singledispatch(f)
+ @f.register(Symbolic)
+ def _dispatch_symbol(__data, *args, **kwargs):
+ return create_sym_call(f, __data.source, *args, **kwargs)
+
+ return f
+
+def _coerce_to_str(x):
+ if isinstance(x, (pd.Series, np.ndarray)):
+ return x.astype(str)
+ elif not np.ndim(x) < 2:
+ raise ValueError("np.ndim must be less than 2, but is %s" %np.ndim(x))
+
+ return pd.Series(x, dtype = str)
+
+
+@register_symbolic
+def str_c(x, *args, sep = "", collapse = None):
+ all_args = itertools.chain([x], args)
+ strings = list(map(_coerce_to_str, all_args))
+
+ return np.sum(strings, axis = 0)
+
diff --git a/siuba/dply/vector.py b/siuba/dply/vector.py
index 54a07319..dd000e4c 100644
--- a/siuba/dply/vector.py
+++ b/siuba/dply/vector.py
@@ -32,7 +32,7 @@ def cummean(x):
@register_symbolic
def desc(x):
- NotImplementedError("Use minus sign in arrange instead (e.g. -_.somecol)")
+ return x.sort_values()
@register_symbolic
@@ -61,7 +61,7 @@ def row_number(x):
n = x.shape[0]
else:
n = len(x)
- return np.arange(n)
+ return np.arange(1, n + 1)
@register_symbolic
@@ -91,7 +91,7 @@ def lead(x, n = 1, default = None):
@register_symbolic
-def lag():
+def lag(x, n = 1, default = None):
res = x.shift(n)
if default is not None:
diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py
index 60e6787a..b5b2fde4 100644
--- a/siuba/dply/verbs.py
+++ b/siuba/dply/verbs.py
@@ -19,9 +19,10 @@
"nest", "unnest",
"expand", "complete",
# Joins ----
- "join", "inner_join", "left_join", "right_join", "semi_join", "full_join",
+ "join", "inner_join", "full_join", "left_join", "right_join", "semi_join", "anti_join",
# TODO: move to vectors
"if_else", "case_when",
+ "collect", "show_query"
)
__all__ = [*DPLY_FUNCTIONS, "Pipeable", "pipe"]
@@ -206,6 +207,20 @@ def raise_type_error(f):
types = ", ".join(map(str, f.registry.keys()))
))
+# Collect and show_query =========
+
+@pipe_no_args
+@singledispatch2((DataFrame, DataFrameGroupBy))
+def collect(__data, *args, **kwargs):
+ # simply return DataFrame, since requires no execution
+ return __data
+
+
+@pipe_no_args
+@singledispatch2((DataFrame, DataFrameGroupBy))
+def show_query(__data, simplify = False):
+ print("No query to show for a DataFrame")
+ return __data
# Mutate ======================================================================
@@ -538,6 +553,11 @@ def var_create(*args):
@singledispatch2(DataFrame)
def select(__data, *args, **kwargs):
+ if kwargs:
+ raise NotImplementedError(
+ "Using kwargs in select not currently supported. "
+ "Use _.newname == _.oldname instead"
+ )
var_list = var_create(*args)
od = var_select(__data.columns, *var_list)
@@ -557,13 +577,15 @@ def _select(__data, *args, **kwargs):
@singledispatch2(DataFrame)
def rename(__data, **kwargs):
# TODO: allow names with spaces, etc..
- col_names = {v:k for k,v in kwargs.items()}
+ col_names = {simple_varname(v):k for k,v in kwargs.items()}
+ if None in col_names:
+ raise ValueError("Rename needs column name (e.g. 'a' or _.a), but received %s"%col_names[None])
return __data.rename(columns = col_names)
@rename.register(DataFrameGroupBy)
def _rename(__data, **kwargs):
- raise Exception("Selecting columns of grouped DataFrame currently not allowed")
+ raise NotImplementedError("Selecting columns of grouped DataFrame currently not allowed")
@@ -616,6 +638,11 @@ def arrange(__data, *args):
.drop(tmp_colnames, axis = 1)
+@arrange.register(DataFrameGroupBy)
+def _arrange(__data, *args):
+ raise NotImplementedError("TODO: arrange with grouped DataFrame")
+
+
# Distinct ====================================================================
@@ -890,6 +917,9 @@ def semi_join(left, right = None, on = None):
return left.merge(right.loc[:,on_cols], how = 'inner', on = on_cols)
+@singledispatch2(pd.DataFrame)
+def anti_join(left, right = None, on = None):
+ raise NotImplementedError("anti_join not currently implemented")
left_join = partial(join, how = "left")
right_join = partial(join, how = "right")
diff --git a/siuba/siu.py b/siuba/siu.py
index 3dd4463b..9d5b38e8 100644
--- a/siuba/siu.py
+++ b/siuba/siu.py
@@ -502,6 +502,9 @@ def slice_to_call(x):
return strip_symbolic(x)
+def str_to_getitem_call(x):
+ return Call("__getitem__", MetaArg("_"), x)
+
def strip_symbolic(x):
if isinstance(x, Symbolic):
diff --git a/siuba/sql/dialects/postgresql.py b/siuba/sql/dialects/postgresql.py
index 124e3e01..31b01ed8 100644
--- a/siuba/sql/dialects/postgresql.py
+++ b/siuba/sql/dialects/postgresql.py
@@ -1,14 +1,42 @@
# sqlvariant, allow defining 3 namespaces to override defaults
-from ..translate import base_scalar, base_agg, base_win, SqlTranslator
+from ..translate import (
+ base_scalar, base_agg, base_win, SqlTranslator,
+ win_agg, sql_scalar
+ )
import sqlalchemy.sql.sqltypes as sa_types
from sqlalchemy import sql
+def sql_log(col, base = None):
+ if base is None:
+ return sql.func.ln(col)
+ return sql.func.log(col)
+
def sql_round(col, n):
return sql.func.round(sql.cast(col, sa_types.Numeric()), n)
+def sql_str_contains(col, pat, case, *args, **kwargs):
+ if args or kwargs:
+ raise NotImplementedError("Only pat and case arg of contains allowed.")
+
+ infix = "~" if case else "~*"
+
+ return col.op(infix, pat)
+
+# handle when others is a list?
+def sql_str_cat(col, others, sep, join = None):
+ if join is not None:
+ raise NotImplementedError("join argument of cat not supported")
+
+
scalar = SqlTranslator(
base_scalar,
- round = sql_round
+ log = sql_log,
+ round = sql_round,
+ contains = sql_str_contains,
+ year = lambda col: sql.func.extract('year', sql.cast(col, sql.sqltypes.Date)),
+ concat = sql.func.concat,
+ cat = sql.func.concat,
+ str_c = sql.func.concat
)
aggregate = SqlTranslator(
@@ -16,7 +44,10 @@ def sql_round(col, n):
)
window = SqlTranslator(
- base_win
+ base_win,
+ any = win_agg("bool_or"),
+ all = win_agg("bool_and"),
+ lag = win_agg("lag")
)
funcs = dict(scalar = scalar, aggregate = aggregate, window = window)
diff --git a/siuba/sql/dialects/sqlite.py b/siuba/sql/dialects/sqlite.py
index 910c971c..de30b0e2 100644
--- a/siuba/sql/dialects/sqlite.py
+++ b/siuba/sql/dialects/sqlite.py
@@ -1,5 +1,5 @@
# sqlvariant, allow defining 3 namespaces to override defaults
-from ..translate import base_scalar, base_agg, base_win, SqlTranslator, win_agg
+from ..translate import base_scalar, base_agg, base_nowin, SqlTranslator, win_agg
import sqlalchemy.sql.sqltypes as sa_types
from sqlalchemy import sql
@@ -13,7 +13,7 @@
window = SqlTranslator(
# TODO: should check sqlite version, since < 3.25 can't use windows
- base_win,
+ base_nowin,
sd = win_agg("stddev")
)
diff --git a/siuba/sql/translate.py b/siuba/sql/translate.py
index b36cb4ee..b6930336 100644
--- a/siuba/sql/translate.py
+++ b/siuba/sql/translate.py
@@ -1,7 +1,19 @@
+"""
+This module holds default translations from pandas syntax to sql for 3 kinds of operations...
+
+1. scalar - elementwise operations (e.g. array1 + array2)
+2. aggregation - operations that result in a single number (e.g. array1.mean())
+3. window - operations that do calculations across a window
+ (e.g. array1.lag() or array1.expanding().mean())
+
+
+"""
+
from sqlalchemy import sql
from sqlalchemy.sql import sqltypes as types
from functools import singledispatch
from .verbs import case_when, if_else
+import warnings
# TODO: must make these take both tbl, col as args, since hard to find window funcs
def sa_is_window(clause):
@@ -9,35 +21,80 @@ def sa_is_window(clause):
or isinstance(clause, sql.elements.WithinGroup)
-def sa_modify_window(clause, columns, group_by = None, order_by = None):
- cls = clause.__class__ if sa_is_window(clause) else getattr(clause, "over")
+def sa_modify_window(clause, group_by = None, order_by = None):
if group_by:
- partition_by = [columns[name] for name in group_by]
- return cls(**{**clause.__dict__, 'partition_by': partition_by})
+ group_cols = [columns[name] for name in group_by]
+ partition_by = sql.elements.ClauseList(*group_cols)
+ clone = clause._clone()
+ clone.partition_by = partition_by
+
+ return clone
return clause
+from sqlalchemy.sql.elements import Over
+# windowed agg (group by)
+# agg
+# windowed scalar
+# ordered set agg
+
+class CustomOverClause: pass
+
+class AggOver(Over, CustomOverClause):
+ def set_over(self, group_by, order_by = None):
+ self.partition_by = group_by
+ return self
+
+
+class RankOver(Over, CustomOverClause):
+ def set_over(self, group_by, order_by = None):
+ self.partition_by = group_by
+ return self
+
+
+class CumlOver(Over, CustomOverClause):
+ def set_over(self, group_by, order_by):
+ self.partition_by = group_by
+ self.order_by = order_by
+
+ if not len(order_by):
+ warnings.warn(
+ "No order by columns explicitly set in window function. SQL engine"
+ "does not guarantee a row ordering. Recommend using an arrange beforehand.",
+ RuntimeWarning
+ )
+ return self
+
+
+def win_absent(name):
+ def not_implemented(*args, **kwargs):
+ raise NotImplementedError("SQL dialect does not support {}.".format(name))
+
+ return not_implemented
def win_over(name):
sa_func = getattr(sql.func, name)
- return lambda col: sa_func().over(order_by = col)
+ return lambda col: RankOver(sa_func(), order_by = col)
+def win_cumul(name):
+ sa_func = getattr(sql.func, name)
+ return lambda col: CumlOver(sa_func(col), rows = (None,0))
def win_agg(name):
sa_func = getattr(sql.func, name)
- return lambda col: sa_func(col).over()
+ return lambda col: AggOver(sa_func(col))
def sql_agg(name):
sa_func = getattr(sql.func, name)
return lambda col: sa_func(col)
-def sql_scalar(name):
+def sql_scalar(name, *args):
sa_func = getattr(sql.func, name)
- return lambda col: sa_func(col)
+ return lambda col: sa_func(col, *args)
-def sql_colmeth(meth):
+def sql_colmeth(meth, *outerargs):
def f(col, *args):
- return getattr(col, meth)(*args)
+ return getattr(col, meth)(*outerargs, *args)
return f
def sql_astype(col, _type):
@@ -47,7 +104,10 @@ def sql_astype(col, _type):
float: types.Numeric,
bool: types.Boolean
}
- sa_type = mappings[_type]
+ try:
+ sa_type = mappings[_type]
+ except KeyError:
+ raise ValueError("sql astype currently only supports type objects: str, int, float, bool")
return sql.cast(col, sa_type)
base_scalar = dict(
@@ -70,9 +130,10 @@ def sql_astype(col, _type):
# TODO: I think these are postgres specific?
hour = lambda col: sql.func.date_trunc('hour', col),
week = lambda col: sql.func.date_trunc('week', col),
- isna = lambda col: col.is_(None),
- isnull = lambda col: col.is_(None),
+ isna = sql_colmeth("is_", None),
+ isnull = sql_colmeth("is_", None),
# dply.vector funcs ----
+ desc = lambda col: col.desc(),
# TODO: string methods
#str.len,
@@ -83,7 +144,6 @@ def sql_astype(col, _type):
#str_trim func to cut text off sides
# TODO: move to postgres specific
n = lambda col: sql.func.count(),
- sum = sql_scalar("sum"),
# TODO: this is to support a DictCall (e.g. used in case_when)
dict = dict,
# TODO: don't use singledispatch to add sql support to case_when
@@ -93,13 +153,16 @@ def sql_astype(col, _type):
base_agg = dict(
mean = sql_agg("avg"),
+ sum = sql_agg("sum"),
+ min = sql_agg("min"),
+ max = sql_agg("max"),
# TODO: generalize case where doesn't use col
# need better handeling of vector funcs
len = lambda col: sql.func.count()
)
base_win = dict(
- row_number = win_over("row_number"),
+ row_number = lambda col: CumlOver(sql.func.row_number()),
min_rank = win_over("rank"),
rank = win_over("rank"),
dense_rank = win_over("dense_rank"),
@@ -130,12 +193,47 @@ def sql_astype(col, _type):
# cumulative funcs ---
#avg("id") OVER (PARTITION BY "email" ORDER BY "id" ROWS UNBOUNDED PRECEDING)
#cummean = win_agg("
- #cumsum
+ cumsum = win_cumul("sum")
#cummin
#cummax
)
+# based on https://github.com/tidyverse/dbplyr/blob/master/R/backend-.R
+base_nowin = dict(
+ row_number = win_absent("ROW_NUMBER"),
+ min_rank = win_absent("RANK"),
+ rank = win_absent("RANK"),
+ dense_rank = win_absent("DENSE_RANK"),
+ percent_rank = win_absent("PERCENT_RANK"),
+ cume_dist = win_absent("CUME_DIST"),
+ ntile = win_absent("NTILE"),
+ mean = win_absent("AVG"),
+ sd = win_absent("SD"),
+ var = win_absent("VAR"),
+ cov = win_absent("COV"),
+ cor = win_absent("COR"),
+ sum = win_absent("SUM"),
+ min = win_absent("MIN"),
+ max = win_absent("MAX"),
+ median = win_absent("PERCENTILE_CONT"),
+ quantile = win_absent("PERCENTILE_CONT"),
+ n = win_absent("N"),
+ n_distinct = win_absent("N_DISTINCT"),
+ cummean = win_absent("MEAN"),
+ cumsum = win_absent("SUM"),
+ cummin = win_absent("MIN"),
+ cummax = win_absent("MAX"),
+ nth = win_absent("NTH_VALUE"),
+ first = win_absent("FIRST_VALUE"),
+ last = win_absent("LAST_VALUE"),
+ lead = win_absent("LEAD"),
+ lag = win_absent("LAG"),
+ order_by = win_absent("ORDER_BY"),
+ str_flatten = win_absent("STR_FLATTEN"),
+ count = win_absent("COUNT")
+ )
+
funcs = dict(scalar = base_scalar, aggregate = base_agg, window = base_win)
# MISC ===========================================================================
diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py
index 722e380c..c2685182 100644
--- a/siuba/sql/verbs.py
+++ b/siuba/sql/verbs.py
@@ -1,6 +1,6 @@
from siuba.dply.verbs import (
singledispatch2,
- pipe_no_args,
+ show_query, collect,
simple_varname,
select, VarList, var_select,
mutate,
@@ -10,18 +10,18 @@
count,
group_by, ungroup,
case_when,
- join, left_join, right_join, inner_join,
+ join, left_join, right_join, inner_join, semi_join, anti_join,
head,
rename,
distinct,
if_else
)
-from .translate import sa_modify_window, sa_is_window
+from .translate import sa_modify_window, sa_is_window, CustomOverClause
from .utils import get_dialect_funcs
from sqlalchemy import sql
import sqlalchemy
-from siuba.siu import Call, CallTreeLocal
+from siuba.siu import Call, CallTreeLocal, str_to_getitem_call, Lazy
# TODO: currently needed for select, but can we remove pandas?
from pandas import Series
import pandas as pd
@@ -57,17 +57,26 @@ class WindowReplacer(CallListener):
TODO: could replace with a sqlalchemy transformer
"""
- def __init__(self, columns, group_by, window_cte = None):
+ def __init__(self, columns, group_by, order_by, window_cte = None):
self.columns = columns
self.group_by = group_by
+ self.order_by = order_by
self.window_cte = window_cte
self.windows = []
def exit(self, node):
# evaluate
col_expr = node(self.columns)
- if sa_is_window(col_expr):
- label = sa_modify_window(col_expr, self.columns, self.group_by).label(None)
+ if isinstance(col_expr, CustomOverClause):
+ group_by = sql.elements.ClauseList(
+ *[self.columns[name] for name in self.group_by]
+ )
+ order_by = sql.elements.ClauseList(
+ *_create_order_by_clause(self.columns, *self.order_by)
+ )
+
+ label = col_expr.set_over(group_by, order_by).label(None)
+ #label = sa_modify_window(col_expr, self.columns, self.group_by).label(None)
self.windows.append(label)
@@ -81,8 +90,8 @@ def exit(self, node):
return col_expr
-def track_call_windows(call, columns, group_by, window_cte = None):
- listener = WindowReplacer(columns, group_by, window_cte)
+def track_call_windows(call, columns, group_by, order_by, window_cte = None):
+ listener = WindowReplacer(columns, group_by, order_by, window_cte)
col = listener.enter(call)
return col, listener.windows
@@ -93,15 +102,22 @@ def lift_inner_cols(tbl):
return sql.base.ImmutableColumnCollection(data, cols)
-def has_windows(clause):
- windows = []
- append_win = lambda col: windows.append(col)
+def col_expr_requires_cte(call, sel):
+ """Return whether a variable assignment needs a CTE"""
+
+ call_vars = set(call.op_vars(attr_calls = False))
- sql.util.visitors.traverse(clause, {}, {"over": append_win})
- if len(windows):
- return True
+ columns = lift_inner_cols(sel)
+ sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label))
+
+ return ( len(sel._group_by_clause)
+ or len(sel._order_by_clause)
+ or not sel_labs.isdisjoint(call_vars)
+ )
- return False
+def get_missing_columns(call, columns):
+ missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys())
+ return missing_cols
def compile_el(tbl, el):
compiled = el.compile(
@@ -110,6 +126,14 @@ def compile_el(tbl, el):
)
return compiled
+# Misc utilities --------------------------------------------------------------
+
+def ordered_union(x, y):
+ dx = {el: True for el in x}
+ dy = {el: True for el in y}
+
+ return tuple({**dx, **dy})
+
@@ -140,22 +164,25 @@ def __init__(
self.rm_attr = rm_attr
self.call_sub_attr = call_sub_attr
- def append_op(self, op):
- return self.__class__(
- self.source,
- self.tbl,
- self.ops + [op],
- self.group_by,
- self.order_by,
- self.funcs,
- self.rm_attr,
- self.call_sub_attr
- )
+ def append_op(self, op, **kwargs):
+ cpy = self.copy(**kwargs)
+ cpy.ops = cpy.ops + [op]
+ return cpy
def copy(self, **kwargs):
return self.__class__(**{**self.__dict__, **kwargs})
- def shape_call(self, call, window = True):
+ def shape_call(self, call, window = True, str_accessors = False):
+ # TODO: error if mutate receives a literal value?
+ if str_accessors and isinstance(call, str):
+ # verbs that can use strings as accessors, like group_by, or
+ # arrange, need to convert those strings into a getitem call
+ return str_to_get_item_call(call)
+ elif not isinstance(call, Call):
+ # verbs that use literal strings, need to convert them to a call
+ # that returns a sqlalchemy "literal" object
+ return Lazy(sql.literal(call))
+
f_dict1 = self.funcs['scalar']
f_dict2 = self.funcs['window' if window else 'aggregate']
@@ -172,25 +199,54 @@ def track_call_windows(self, call, columns = None, window_cte = None):
"""Returns tuple of (new column expression, list of window exprs)"""
columns = self.last_op.columns if columns is None else columns
- return track_call_windows(call, columns, self.group_by, window_cte)
+ return track_call_windows(call, columns, self.group_by, self.order_by, window_cte)
+
+ def get_ordered_col_names(self):
+ ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by]
+ return list(self.group_by) + ungrouped
@property
def last_op(self):
return self.ops[-1] if len(self.ops) else None
- def __repr__(self):
- tbl_small = self.append_op(self.last_op.limit(5))
-
- # makes sure to get engine, even if sqlalchemy connection obj
- engine = self.source.engine
+ def _get_preview(self):
+ # need to make prev op a cte, so we don't override any previous limit
+ new_sel = sql.select([self.last_op.alias()]).limit(5)
+ tbl_small = self.append_op(new_sel)
+ return collect(tbl_small)
- return ("# Source: lazy query\n"
+ def __repr__(self):
+ template = (
+ "# Source: lazy query\n"
"# DB Conn: {}\n"
"# Preview:\n{}\n"
"# .. may have more rows"
- .format(repr(engine), repr(collect(tbl_small)))
)
+ return template.format(repr(self.source.engine), repr(self._get_preview()))
+
+ def _repr_html_(self):
+ template = (
+ ""
+ "
"
+ "# Source: lazy query\n"
+ "# DB Conn: {}\n"
+ "# Preview:\n"
+ "
"
+ "{}"
+ "
# .. may have more rows
"
+ "
"
+ )
+
+ data = self._get_preview()
+ html_data = getattr(data, '_repr_html_', lambda: repr(data))()
+ return template.format(self.source.engine, html_data)
+
+
+def _repr_grouped_df_html_(self):
+ return "(grouped data frame)
" + self._selected_obj._repr_html_() + "
"
+
+
# Main Funcs
# =============================================================================
@@ -210,9 +266,8 @@ def use_simple_names():
finally:
deregister(sql.compiler._CompileLabel)
-@pipe_no_args
-@singledispatch2(LazyTbl)
-def show_query(tbl, simplify = False):
+@show_query.register(LazyTbl)
+def _show_query(tbl, simplify = False):
query = tbl.last_op #if not simplify else
compile_query = lambda: query.compile(
dialect = tbl.source.dialect,
@@ -231,9 +286,9 @@ def show_query(tbl, simplify = False):
return tbl
# collect ----------
-@pipe_no_args
-@singledispatch2(LazyTbl)
-def collect(__data, as_df = True):
+
+@collect.register(LazyTbl)
+def _collect(__data, as_df = True):
# TODO: maybe remove as_df options, always return dataframe
# normally can just pass the sql objects to execute, but for some reason
# psycopg2 completes about incomplete template.
@@ -248,15 +303,15 @@ def collect(__data, as_df = True):
return __data.source.execute(compiled).fetchall()
-@collect.register(pd.DataFrame)
-def _collect(__data, *args, **kwargs):
- # simply return DataFrame, since requires no execution
- return __data
-
@select.register(LazyTbl)
def _select(__data, *args, **kwargs):
# see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object
+ if kwargs:
+ raise NotImplementedError(
+ "Using kwargs in select not currently supported. "
+ "Use _.newname == _.oldname instead"
+ )
last_op = __data.last_op
columns = {c.key: c for c in last_op.inner_columns}
@@ -282,14 +337,13 @@ def _filter(__data, *args, **kwargs):
# 1 for window/aggs, and 1 for the where clause
sel = __data.last_op.alias()
win_sel = sql.select([sel], from_obj = sel)
- #fil_sel = sql.select([win_sel], from_obj = win_sel)
conds = []
windows = []
for arg in args:
if isinstance(arg, Call):
new_call = __data.shape_call(arg)
- var_cols = new_call.op_vars(attr_calls = False)
+ #var_cols = new_call.op_vars(attr_calls = False)
col_expr, win_cols = __data.track_call_windows(
new_call,
@@ -297,8 +351,6 @@ def _filter(__data, *args, **kwargs):
window_cte = win_sel
)
- #if sa_is_window(col_expr):
- # col_expr = sa_modify_window(col_expr, columns, __data.group_by)
conds.append(col_expr)
else:
conds.append(arg)
@@ -309,9 +361,10 @@ def _filter(__data, *args, **kwargs):
win_alias = win_sel.alias()
bool_clause = sql.util.ClauseAdapter(win_alias).traverse(bool_clause)
-
- sel = sql.select([win_alias], from_obj = win_alias, whereclause = bool_clause)
- return __data.append_op(sel)
+
+ orig_cols = [win_alias.columns[k] for k in __data.get_ordered_col_names()]
+ filt_sel = sql.select(orig_cols, from_obj = win_alias, whereclause = bool_clause)
+ return __data.append_op(filt_sel)
@mutate.register(LazyTbl)
@@ -342,17 +395,14 @@ def _mutate_select(sel, colname, func, labs, __data):
function handles whether to add a column to the existing select statement,
or to use it as a subquery.
"""
- #colname, func
- replace_col = colname in sel.columns
+ replace_col = False
# Call objects let us check whether column expr used a derived column
# e.g. SELECT a as b, b + 1 as c raises an error in SQL, so need subquery
call_vars = func.op_vars(attr_calls = False)
- if isinstance(func, Call) and labs.isdisjoint(call_vars):
+ if labs.isdisjoint(call_vars):
# New column may be able to modify existing select
+ replace_col = colname in sel.columns
columns = lift_inner_cols(sel)
- # replacing an existing column, so strip it from select statement
- if replace_col:
- sel = sel.with_only_columns([v for k,v in columns.items() if k != colname])
else:
# anything else requires a subquery
@@ -363,6 +413,12 @@ def _mutate_select(sel, colname, func, labs, __data):
# evaluate call expr on columns, making sure to use group vars
new_col, windows = __data.track_call_windows(func, columns)
+ # replacing an existing column, so strip it from select statement
+ if replace_col:
+ replaced = {**columns}
+ replaced[colname] = new_col.label(colname)
+ return sel.with_only_columns(list(replaced.values()))
+
return sel.column(new_col.label(colname))
@@ -371,20 +427,35 @@ def _arrange(__data, *args):
last_op = __data.last_op
cols = lift_inner_cols(last_op)
+ new_calls = tuple(
+ __data.shape_call(expr, window = False) if callable(expr) else expr
+ for expr in args
+ )
+
+ sort_cols = _create_order_by_clause(cols, *new_calls)
+
+ order_by = __data.order_by + new_calls
+ return __data.append_op(last_op.order_by(*sort_cols), order_by = order_by)
+
+
+# TODO: consolidate / pull expr handling funcs into own file?
+def _create_order_by_clause(columns, *args):
sort_cols = []
for arg in args:
# simple named column
if isinstance(arg, str):
- sort_cols.append(cols[arg])
+ sort_cols.append(columns[arg])
# an expression
elif callable(arg):
- f, asc = _call_strip_ascending(arg)
- col_op = f(cols) if asc else f(cols).desc()
+ #f, asc = _call_strip_ascending(arg)
+ #col_op = f(cols) if asc else f(cols).desc()
+ col_op = arg(columns)
sort_cols.append(col_op)
else:
raise NotImplementedError("Must be string or callable")
- return __data.append_op(last_op.order_by(*sort_cols))
+ return sort_cols
+
@count.register(LazyTbl)
@@ -440,18 +511,24 @@ def _summarize(__data, **kwargs):
# - filter is fine, since it uses a CTE
# - need to detect any window functions...
sel = __data.last_op._clone()
- labs = set(k for k,v in sel.columns.items() if isinstance(v, sql.elements.Label))
+
+ new_calls = {k: __data.shape_call(expr, window = False) for k, expr in kwargs.items()}
+ needs_cte = [col_expr_requires_cte(call, sel) for call in new_calls.values()]
# create select statement ----
- if len(sel._group_by_clause):
- # current select stmt has window functions, so need to make it subquery
+ if any(needs_cte):
+ # need a cte, due to alias cols or existing group by
+ # current select stmt has group by clause, so need to make it subquery
cte = sel.alias()
columns = cte.columns
sel = sql.select(from_obj = cte)
else:
# otherwise, can alter the existing select statement
columns = lift_inner_cols(sel)
+ old_froms = sel.froms
+
sel = sel.with_only_columns([])
+ sel.append_from(*old_froms)
# add group by columns ----
group_cols = [columns[k] for k in __data.group_by]
@@ -462,22 +539,34 @@ def _summarize(__data, **kwargs):
# add each aggregate column ----
# TODO: can't do summarize(b = mean(a), c = b + mean(a))
# since difficult for c to refer to agg and unagg cols in SQL
- for k, expr in kwargs.items():
- new_call = __data.shape_call(expr, window = False)
- col = new_call(columns).label(k)
+ for k, expr in new_calls.items():
+ missing_cols = get_missing_columns(expr, columns)
+ if missing_cols:
+ raise NotImplementedError(
+ "Summarize cannot find the following columns: %s. "
+ "Note that it cannot refer to variables defined earlier in the "
+ "same summarize call." % missing_cols
+ )
+
+ col = expr(columns).label(k)
sel.append_column(col)
- # TODO: is a simple method on __data for doing this...
- new_data = __data.append_op(sel)
- new_data.group_by = None
+ new_data = __data.append_op(sel, group_by = tuple(), order_by = tuple())
return new_data
@group_by.register(LazyTbl)
-def _group_by(__data, *args):
- cols = __data.last_op.columns
- groups = [simple_varname(arg) for arg in args]
+def _group_by(__data, *args, add = False, **kwargs):
+ if kwargs:
+ data = mutate(__data, **kwargs)
+ else:
+ data = __data
+
+ cols = data.last_op.columns
+
+ # put kwarg grouping vars last, so similar order to function call
+ groups = tuple(simple_varname(arg) for arg in args) + tuple(kwargs)
if None in groups:
raise NotImplementedError("Complex expressions not supported in sql group_by")
@@ -485,11 +574,15 @@ def _group_by(__data, *args):
if unmatched:
raise KeyError("group_by specifies columns missing from table: %s" %unmatched)
- return __data.copy(group_by = groups)
+ if add:
+ groups = ordered_union(data.group_by, groups)
+
+ return data.copy(group_by = groups)
+
@ungroup.register(LazyTbl)
def _ungroup(__data):
- return __data.copy(group_by = None)
+ return __data.copy(group_by = tuple())
@case_when.register(sql.base.ImmutableColumnCollection)
@@ -523,14 +616,28 @@ def _case_when(__data, cases):
from collections.abc import Mapping
-def _joined_cols(left_cols, right_cols, shared_keys):
+def _joined_cols(left_cols, right_cols, on_keys, full = False):
+ """Return labeled columns, according to selection rules for joins.
+
+ Rules:
+ 1. For join keys, keep left table's column
+ 2. When keys have the same labels, add suffix
+ """
# TODO: remove sets, so uses stable ordering
# when left and right cols have same name, suffix with _x / _y
- shared_labs = set(left_cols.keys()) \
- .intersection(right_cols.keys()) \
- .difference(shared_keys)
+ keep_right = set(right_cols.keys()) - set(on_keys.values())
+ shared_labs = set(left_cols.keys()).intersection(keep_right)
- right_cols_no_keys = {k: v for k, v in right_cols.items() if k not in shared_keys}
+ right_cols_no_keys = {k: right_cols[k] for k in keep_right}
+
+ # for an outer join, have key columns coalesce values
+ if full:
+ left_cols = {**left_cols}
+ for lk, rk in on_keys.items():
+ col = sql.functions.coalesce(left_cols[lk], right_cols[rk])
+ left_cols[lk] = col.label(lk)
+
+ # create labels ----
l_labs = _relabeled_cols(left_cols, shared_labs, "_x")
r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, "_y")
@@ -548,19 +655,102 @@ def _relabeled_cols(columns, keys, suffix):
@join.register(LazyTbl)
-def _join(left, right, on = None, how = None):
+def _join(left, right, on = None, how = "inner"):
# Needs to be on the table, not the select
left_sel = left.last_op.alias()
right_sel = right.last_op.alias()
+
+ # handle arguments ----
+ on = _validate_join_arg_on(on)
+ how = _validate_join_arg_how(how)
+ if how == "right":
+ # switch joins, since sqlalchemy doesn't have right join arg
+ # see https://stackoverflow.com/q/11400307/1144523
+ left_sel, right_sel = right_sel, left_sel
+
+ # create join conditions ----
+ bool_clause = _create_join_conds(left_sel, right_sel, on)
+
+ # create join ----
+ join = left_sel.join(
+ right_sel,
+ onclause = bool_clause,
+ isouter = how != "inner",
+ full = how == "full"
+ )
+
+ # note, shared_keys assumes on is a mapping...
+ shared_keys = [k for k,v in on.items() if k == v]
+ labeled_cols = _joined_cols(
+ left_sel.columns,
+ right_sel.columns,
+ on_keys = on,
+ full = how == "full"
+ )
+
+ sel = sql.select(labeled_cols, from_obj = join)
+ return left.append_op(sel)
+
+
+@semi_join.register(LazyTbl)
+def _semi_join(left, right = None, on = None):
+
+ left_sel = left.last_op.alias()
+ right_sel = right.last_op.alias()
+
+ # handle arguments ----
+ on = _validate_join_arg_on(on)
+
+ # create join conditions ----
+ bool_clause = _create_join_conds(left_sel, right_sel, on)
+
+ # create inner join ----
+ join = left_sel.join(right_sel, onclause = bool_clause)
+
+ # only keep left hand select's columns ----
+ sel = sql.select(left_sel.columns, from_obj = join)
+ return left.append_op(sel)
+
+
+@anti_join.register(LazyTbl)
+def _anti_join(left, right = None, on = None):
+ left_sel = left.last_op.alias()
+ right_sel = right.last_op.alias()
+
+ # handle arguments ----
+ on = _validate_join_arg_on(on)
+
+ # create join conditions ----
+ bool_clause = _create_join_conds(left_sel, right_sel, on)
+
+ # create inner join ----
+ not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause)
+ sel = sql.select(left_sel.columns, from_obj = left_sel).where(not_exists)
+ return left.append_op(sel)
+
+
+def _validate_join_arg_on(on):
if on is None:
raise NotImplementedError("on arg must currently be dict")
+ elif isinstance(on, str):
+ on = {on: on}
elif isinstance(on, (list, tuple)):
on = dict(zip(on, on))
if not isinstance(on, Mapping):
- raise Exception("on must be a Mapping (e.g. dict)")
+ raise TypeError("on must be a Mapping (e.g. dict)")
+
+ return on
+
+def _validate_join_arg_how(how):
+ how_options = ("inner", "left", "right", "full")
+ if how not in how_options:
+ raise ValueError("how argument needs to be one of %s" %how_options)
+
+ return how
+def _create_join_conds(left_sel, right_sel, on):
left_cols = left_sel.columns #lift_inner_cols(left_sel)
right_cols = right_sel.columns #lift_inner_cols(right_sel)
@@ -569,21 +759,8 @@ def _join(left, right, on = None, how = None):
col_expr = left_cols[l] == right_cols[r]
conds.append(col_expr)
-
- bool_clause = sql.and_(*conds)
- join = left_sel.join(right_sel, onclause = bool_clause)
+ return sql.and_(*conds)
- # note, shared_keys assumes on is a mapping...
- shared_keys = [k for k,v in on.items() if k == v]
- labeled_cols = _joined_cols(
- left_sel.columns,
- right_sel.columns,
- shared_keys = shared_keys
- )
-
- sel = sql.select(labeled_cols, from_obj = join)
- return left.append_op(sel)
-
# Head ------------------------------------------------------------------------
@@ -602,7 +779,12 @@ def _rename(__data, **kwargs):
columns = lift_inner_cols(sel)
# old_keys uses dict as ordered set
- old_to_new = {v:k for k,v in kwargs.items()}
+ old_to_new = {simple_varname(v):k for k,v in kwargs.items()}
+
+ if None in old_to_new:
+ raise KeyError("positional arguments must be simple column, "
+ "e.g. _.colname or _['colname']"
+ )
labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()]
@@ -617,7 +799,7 @@ def _rename(__data, **kwargs):
def _distinct(__data, *args, _keep_all = False, **kwargs):
if (args or kwargs) and _keep_all:
raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False")
-
+
inner_sel = mutate(__data, **kwargs).last_op if kwargs else __data.last_op
# TODO: this is copied from the df distinct version
@@ -626,16 +808,26 @@ def _distinct(__data, *args, _keep_all = False, **kwargs):
cols.update(kwargs)
if None in cols:
- raise Exception("positional arguments must be simple column, "
+ raise KeyError("positional arguments must be simple column, "
"e.g. _.colname or _['colname']"
)
- if not cols: cols = list(inner_sel.columns.keys())
+ # use all columns by default
+ if not cols:
+ cols = list(inner_sel.columns.keys())
- sel_cols = lift_inner_cols(inner_sel)
- distinct_cols = [sel_cols[k] for k in cols]
+ if not len(inner_sel._order_by_clause):
+ # select distinct has to include any columns in the order by clause,
+ # so can only safely modify existing statement when there's no order by
+ sel_cols = lift_inner_cols(inner_sel)
+ distinct_cols = [sel_cols[k] for k in cols]
+ sel = inner_sel.with_only_columns(distinct_cols).distinct()
+ else:
+ # fallback to cte
+ cte = inner_sel.alias()
+ distinct_cols = [cte.columns[k] for k in cols]
+ sel = sql.select(distinct_cols, from_obj = cte).distinct()
- sel = inner_sel.with_only_columns(distinct_cols).distinct()
return __data.append_op(sel)
diff --git a/siuba/tests/conftest.py b/siuba/tests/conftest.py
index 3fc39a01..e2247ee2 100644
--- a/siuba/tests/conftest.py
+++ b/siuba/tests/conftest.py
@@ -1,6 +1,25 @@
import pytest
+from .helpers import assert_equal_query, Backend, SqlBackend, data_frame
def pytest_addoption(parser):
parser.addoption(
"--dbs", action="store", default="sqlite", help="databases tested against (comma separated)"
)
+
+params_backend = [
+ pytest.param(lambda: SqlBackend("postgresql"), id = "postgresql", marks=pytest.mark.postgresql),
+ pytest.param(lambda: SqlBackend("sqlite"), id = "sqlite", marks=pytest.mark.sqlite),
+ pytest.param(lambda: Backend("pandas"), id = "pandas", marks=pytest.mark.pandas)
+ ]
+
+@pytest.fixture(params = params_backend, scope = "session")
+def backend(request):
+ return request.param()
+
+@pytest.fixture(autouse=True)
+def skip_backend(request, backend):
+ if request.node.get_closest_marker('skip_backend'):
+ mark_args = request.node.get_closest_marker('skip_backend').args
+ if backend.name in mark_args:
+ pytest.skip('skipped on backend: {}'.format(backend.name))
+
diff --git a/siuba/tests/helpers.py b/siuba/tests/helpers.py
index 10782f25..8168e7c6 100644
--- a/siuba/tests/helpers.py
+++ b/siuba/tests/helpers.py
@@ -1,46 +1,89 @@
from sqlalchemy import create_engine, types
from siuba.sql import LazyTbl, collect
+from siuba.dply.verbs import ungroup
from pandas.testing import assert_frame_equal
+import pandas as pd
+import os
+import numpy as np
-class DbConRegistry:
- table_name_indx = 0
+def data_frame(**kwargs):
+ fixed = {k: [v] if not np.ndim(v) else v for k,v in kwargs.items()}
+ return pd.DataFrame(fixed)
+
+BACKEND_CONFIG = {
+ "postgresql": {
+ "dialect": "postgresql",
+ "dbname": ["SB_TEST_PGDATABASE", "postgres"],
+ "port": ["SB_TEST_PGPORT", "5433"],
+ "user": ["SB_TEST_PGUSER", "postgres"],
+ "password": ["SB_TEST_PGPASSWORD", ""],
+ "host": ["SB_TEST_PGHOST", "localhost"],
+ },
+ "sqlite": {
+ "dialect": "sqlite",
+ "dbname": ":memory:",
+ "port": "0",
+ "user": "",
+ "password": "",
+ "host": ""
+ }
+ }
+
+class Backend:
+ def __init__(self, name):
+ self.name = name
+
+ def dispose(self):
+ pass
+
+ def load_df(self, df = None, **kwargs):
+ if df is None and kwargs:
+ df = pd.DataFrame(kwargs)
+ elif df is not None and kwargs:
+ raise ValueError("Cannot pass kwargs, and a DataFrame")
+
+ return df
+
+ def __repr__(self):
+ return "{0}({1})".format(self.__class__.__name__, repr(self.name))
- def __init__(self):
- self.connections = {}
- def register(self, name, engine):
- self.connections[name] = engine
+class SqlBackend(Backend):
+ table_name_indx = 0
+ sa_conn_fmt = "{dialect}://{user}:{password}@{host}:{port}/{dbname}"
+
+ def __init__(self, name):
+ cnfg = BACKEND_CONFIG[name]
+ params = {k: os.environ.get(*v) if isinstance(v, (list)) else v for k,v in cnfg.items()}
- def remove(self, name):
- con = self.connections[name]
- con.close()
- del self.connections[name]
+ self.name = name
+ self.engine = create_engine(self.sa_conn_fmt.format(**params))
- return con
+ def dispose(self):
+ self.engine.dispose()
@classmethod
def unique_table_name(cls):
cls.table_name_indx += 1
return "siuba_{0:03d}".format(cls.table_name_indx)
- def load_df(self, df):
- out = []
- for k, engine in self.connections.items():
- lazy_tbl = copy_to_sql(df, self.unique_table_name(), engine)
- out.append(lazy_tbl)
- return out
+ def load_df(self, df = None, **kwargs):
+ df = super().load_df(df, **kwargs)
+ return copy_to_sql(df, self.unique_table_name(), self.engine)
+
def assert_frame_sort_equal(a, b):
"""Tests that DataFrames are equal, even if rows are in different order"""
- sorted_a = a.sort_values(by = a.columns.tolist()).reset_index(drop = True)
- sorted_b = b.sort_values(by = b.columns.tolist()).reset_index(drop = True)
+ df_a = ungroup(a)
+ df_b = ungroup(b)
+ sorted_a = df_a.sort_values(by = df_a.columns.tolist()).reset_index(drop = True)
+ sorted_b = df_b.sort_values(by = df_b.columns.tolist()).reset_index(drop = True)
assert_frame_equal(sorted_a, sorted_b)
-def assert_equal_query(tbls, lazy_query, target):
- for tbl in tbls:
- out = collect(lazy_query(tbl))
- assert_frame_sort_equal(out, target)
+def assert_equal_query(tbl, lazy_query, target):
+ out = collect(lazy_query(tbl))
+ assert_frame_sort_equal(out, target)
PREFIX_TO_TYPE = {
@@ -61,7 +104,40 @@ def auto_types(df):
def copy_to_sql(df, name, engine):
+ if isinstance(engine, str):
+ engine = create_engine(engine)
+
df.to_sql(name, engine, dtype = auto_types(df), index = False, if_exists = "replace")
return LazyTbl(engine, name)
+
+
+from functools import wraps
+import pytest
-
+def backend_notimpl(*names):
+ def outer(f):
+ @wraps(f)
+ def wrapper(backend, *args, **kwargs):
+ if backend.name in names:
+ with pytest.raises(NotImplementedError):
+ f(backend, *args, **kwargs)
+ pytest.xfail("Not implemented!")
+ else:
+ return f(backend, *args, **kwargs)
+ return wrapper
+ return outer
+
+def backend_sql(msg):
+ # allow decorating without an extra call
+ if callable(msg):
+ return backend_sql(None)(msg)
+
+ def outer(f):
+ @wraps(f)
+ def wrapper(backend, *args, **kwargs):
+ if not isinstance(backend, SqlBackend):
+ pytest.skip(msg)
+ else:
+ return f(backend, *args, **kwargs)
+ return wrapper
+ return outer
diff --git a/siuba/tests/test_sql_verbs_distinct.py b/siuba/tests/test_sql_verbs_distinct.py
deleted file mode 100644
index f02d3586..00000000
--- a/siuba/tests/test_sql_verbs_distinct.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Note: this test file was heavily influenced by its dbplyr counterpart.
-
-https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-distinct.R
-"""
-
-from siuba.sql import LazyTbl, collect
-from siuba import _, distinct
-import pandas as pd
-import os
-
-import pytest
-from sqlalchemy import create_engine
-
-from .helpers import assert_equal_query, DbConRegistry
-
-DATA = pd.DataFrame({
- "x": [1,1,1,1],
- "y": [1,1,2,2],
- "z": [1,2,1,2]
- })
-
-@pytest.fixture(scope = "module")
-def dbs(request):
- dialects = set(request.config.getoption("--dbs").split(","))
- dbs = DbConRegistry()
-
- if "sqlite" in dialects:
- dbs.register("sqlite", create_engine("sqlite:///:memory:"))
- if "postgresql" in dialects:
- port = os.environ.get("PGPORT", "5433")
- dbs.register("postgresql", create_engine('postgresql://postgres:@localhost:%s/postgres'%port))
-
-
- yield dbs
-
- # cleanup
- for engine in dbs.connections.values():
- engine.dispose()
-
-@pytest.fixture(scope = "module")
-def dfs(dbs):
- yield dbs.load_df(DATA)
-
-def test_distinct_no_args(dfs):
- assert_equal_query(dfs, distinct(), DATA.drop_duplicates())
- assert_equal_query(dfs, distinct(), distinct(DATA))
-
-def test_distinct_one_arg(dfs):
- assert_equal_query(
- dfs,
- distinct(_.y),
- DATA.drop_duplicates(['y'])[['y']].reset_index(drop = True)
- )
-
- assert_equal_query(dfs, distinct(_.y), distinct(DATA, _.y))
-
-def test_distinct_keep_all_not_impl(dfs):
- # TODO: should just mock LazyTbl
- for tbl in dfs:
- with pytest.raises(NotImplementedError):
- distinct(tbl, _.y, _keep_all = True) >> collect()
-
-
-@pytest.mark.xfail
-def test_distinct_via_group_by(dfs):
- # NotImplemented
- assert False
-
-def test_distinct_kwargs(dfs):
- dst = DATA.drop_duplicates(['y', 'x']) \
- .rename(columns = {'x': 'a'}) \
- .reset_index(drop = True)[['y', 'a']]
-
- assert_equal_query(dfs, distinct(_.y, a = _.x), dst)
-
-
-
-
diff --git a/siuba/tests/test_verb_arrange.py b/siuba/tests/test_verb_arrange.py
new file mode 100644
index 00000000..4119f02c
--- /dev/null
+++ b/siuba/tests/test_verb_arrange.py
@@ -0,0 +1,30 @@
+from siuba.dply.verbs import simple_varname
+from siuba import _, filter, group_by, arrange, mutate
+from siuba.dply.vector import row_number, desc
+import pandas as pd
+
+import pytest
+
+from .helpers import assert_equal_query, data_frame, backend_notimpl, backend_sql
+
+
+@backend_sql
+@backend_notimpl("sqlite")
+def test_no_arrange_before_cuml_window_warning(backend):
+ data = data_frame(x = range(1, 5), g = [1,1,2,2])
+ dfs = backend.load_df(data)
+ with pytest.warns(RuntimeWarning):
+ dfs >> mutate(y = _.x.cumsum())
+
+@backend_sql
+def test_arranges_back_to_back(backend):
+ data = data_frame(x = range(1, 5), g = [1,1,2,2])
+ dfs = backend.load_df(data)
+
+ lazy_tbl = dfs >> arrange(_.x) >> arrange(_.g)
+ order_by_vars = tuple(simple_varname(call) for call in lazy_tbl.order_by)
+
+ assert order_by_vars == ("x", "g")
+ assert [c.name for c in lazy_tbl.last_op._order_by_clause] == ["x", "g"]
+
+
diff --git a/siuba/tests/test_verb_distinct.py b/siuba/tests/test_verb_distinct.py
new file mode 100644
index 00000000..b3bca08a
--- /dev/null
+++ b/siuba/tests/test_verb_distinct.py
@@ -0,0 +1,77 @@
+"""
+Note: this test file was heavily influenced by its dbplyr counterpart.
+
+https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-distinct.R
+"""
+
+from siuba.sql import LazyTbl, collect
+from siuba import _, distinct, group_by, summarize, arrange, mutate
+from .helpers import assert_equal_query, backend_sql
+import pandas as pd
+import os
+
+import pytest
+from sqlalchemy import create_engine
+
+DATA = pd.DataFrame({
+ "x": [1,2,3,4,5],
+ "y": [5,4,3,2,1]
+ })
+
+@pytest.fixture(scope = "module")
+def df(backend):
+ yield backend.load_df(DATA)
+
+def test_distinct_no_args(df):
+ assert_equal_query(df, distinct(), DATA.drop_duplicates())
+ assert_equal_query(df, distinct(), distinct(DATA))
+
+def test_distinct_one_arg(df):
+ assert_equal_query(
+ df,
+ distinct(_.y),
+ DATA.drop_duplicates(['y'])[['y']].reset_index(drop = True)
+ )
+
+ assert_equal_query(df, distinct(_.y), distinct(DATA, _.y))
+
+@backend_sql
+def test_distinct_keep_all_not_impl(backend, df):
+ # TODO: should just mock LazyTbl
+ with pytest.raises(NotImplementedError):
+ distinct(df, _.y, _keep_all = True) >> collect()
+
+
+@pytest.mark.xfail
+def test_distinct_via_group_by(df):
+ # NotImplemented
+ assert False
+
+
+def test_distinct_after_summarize(df):
+ query = group_by(g = _.x) >> summarize(z = (_.y - _.y).min()) >> distinct(_.z)
+
+ assert_equal_query(df, query, pd.DataFrame({'z': [0]}))
+
+def test_distinct_after_arrange(df):
+ query = arrange(_.x) >> distinct(_.y)
+
+ assert_equal_query(df, query, pd.DataFrame({'y': [5,4,3,2,1]}))
+
+
+def test_distinct_of_mutate_col(df):
+ query = mutate(z = _.x + 1) >> distinct(_.z)
+
+ assert_equal_query(df, query, pd.DataFrame({'z': [2,3,4,5,6]}))
+
+
+def test_distinct_kwargs(df):
+ dst = DATA.drop_duplicates(['y', 'x']) \
+ .rename(columns = {'x': 'a'}) \
+ .reset_index(drop = True)[['y', 'a']]
+
+ assert_equal_query(df, distinct(_.y, a = _.x), dst)
+
+
+
+
diff --git a/siuba/tests/test_verb_filter.py b/siuba/tests/test_verb_filter.py
new file mode 100644
index 00000000..50a9669b
--- /dev/null
+++ b/siuba/tests/test_verb_filter.py
@@ -0,0 +1,78 @@
+"""
+Note: this test file was heavily influenced by its dbplyr counterpart.
+
+https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-filter.R
+"""
+
+from siuba import _, filter, group_by, arrange
+from siuba.dply.vector import row_number, desc
+import pandas as pd
+
+import pytest
+
+from .helpers import assert_equal_query, data_frame, backend_notimpl, backend_sql
+
+DATA = pd.DataFrame({
+ "x": [1,1,1,1],
+ "y": [1,1,2,2],
+ "z": [1,2,1,2]
+ })
+
+
+def test_filter_basic(backend):
+ df = data_frame(x = [1,2,3,4,5], y = [5,4,3,2,1])
+ dfs = backend.load_df(df)
+
+ assert_equal_query(dfs, filter(_.x > 3), df[lambda _: _.x > 3])
+
+
+@backend_sql("TODO: pandas - grouped col should be first after mutate")
+@backend_notimpl("sqlite")
+def test_filter_via_group_by(backend):
+ df = data_frame(
+ x = range(1, 11),
+ g = [1]*5 + [2]*5
+ )
+
+ dfs = backend.load_df(df)
+
+ assert_equal_query(
+ dfs,
+ group_by(_.g) >> filter(row_number(_) < 3),
+ data_frame(g = [1,1,2,2], x = [1,2,6,7])
+ )
+
+
+@backend_sql("TODO: pandas - grouped col should be first after mutate")
+@backend_notimpl("sqlite")
+def test_filter_via_group_by_agg(backend):
+ dfs = backend.load_df(x = range(1,11), g = [1]*5 + [2]*5)
+
+ assert_equal_query(
+ dfs,
+ group_by(_.g) >> filter(_.x > _.x.mean()),
+ data_frame(g = [1, 1, 2, 2], x = [4, 5, 9, 10])
+ )
+
+@backend_sql("TODO: pandas - implement arrange over group by")
+@backend_notimpl("sqlite")
+def test_filter_via_group_by_arrange(backend):
+ dfs = backend.load_df(x = [3,2,1] + [2,3,4], g = [1]*3 + [2]*3)
+
+ assert_equal_query(
+ dfs,
+ group_by(_.g) >> arrange(_.x) >> filter(_.x.cumsum() > 3),
+ data_frame(g = [1, 2, 2], x = [3, 3, 4])
+ )
+
+@backend_sql("TODO: pandas - implement arrange over group by")
+@backend_notimpl("sqlite")
+def test_filter_via_group_by_desc_arrange(backend):
+ dfs = backend.load_df(x = [3,2,1] + [2,3,4], g = [1]*3 + [2]*3)
+
+ assert_equal_query(
+ dfs,
+ group_by(_.g) >> arrange(desc(_.x)) >> filter(_.x.cumsum() > 3),
+ data_frame(g = [1, 1, 2, 2, 2], x = [2, 1, 4, 3, 2])
+ )
+
diff --git a/siuba/tests/test_verb_group_by.py b/siuba/tests/test_verb_group_by.py
new file mode 100644
index 00000000..0df83bf4
--- /dev/null
+++ b/siuba/tests/test_verb_group_by.py
@@ -0,0 +1,54 @@
+"""
+Note: this test file was heavily influenced by its dbplyr counterpart.
+
+https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-group_by.R
+"""
+
+from siuba import _, group_by, ungroup, summarize
+from siuba.dply.vector import row_number, n
+
+import pytest
+from .helpers import assert_equal_query, data_frame, backend_notimpl, SqlBackend
+from string import ascii_lowercase
+
+DATA = data_frame(x = [1,2,3], y = [9,8,7], g = ['a', 'a', 'b'])
+
+@pytest.fixture(scope = "module")
+def df(backend):
+ if not isinstance(backend, SqlBackend):
+ pytest.skip("TODO: generalize tests to pandas")
+ return backend.load_df(DATA)
+
+
+def test_group_by_no_add(df):
+ gdf = group_by(df, _.x, _.y)
+ assert gdf.group_by == ("x", "y")
+
+def test_group_by_override(df):
+ gdf = df >> group_by(_.x, _.y) >> group_by(_.g)
+ assert gdf.group_by == ("g",)
+
+def test_group_by_add(df):
+ gdf = group_by(df, _.x) >> group_by(_.y, add = True)
+
+ assert gdf.group_by == ("x", "y")
+
+def test_group_by_ungroup(df):
+ q1 = df >> group_by(_.g)
+ assert q1.group_by == ("g",)
+
+ q2 = q1 >> ungroup()
+ assert q2.group_by == tuple()
+
+
+@pytest.mark.skip("TODO: need to test / validate joins first")
+def test_group_by_before_joins(df):
+ assert False
+
+def test_group_by_performs_mutate(df):
+ assert_equal_query(
+ df,
+ group_by(z = _.x + _.y) >> summarize(n = n(_)),
+ data_frame(z = 10, n = 3)
+ )
+
diff --git a/siuba/tests/test_verb_join.py b/siuba/tests/test_verb_join.py
new file mode 100644
index 00000000..9ffc3a51
--- /dev/null
+++ b/siuba/tests/test_verb_join.py
@@ -0,0 +1,137 @@
+"""
+Note: this test file was heavily influenced by its dbplyr counterpart.
+
+https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-group_by.R
+"""
+
+from siuba import (
+ _, group_by,
+ join, inner_join, left_join, right_join, full_join,
+ semi_join, anti_join
+ )
+from siuba.dply.vector import row_number, n
+from siuba.sql.verbs import collect
+
+import pytest
+from .helpers import assert_equal_query, assert_frame_sort_equal, data_frame, backend_notimpl, backend_sql
+
+
+DF1 = data_frame(
+ ii = [1,2,3,4],
+ x = ["a", "b", "c", "d"]
+ )
+
+DF2 = data_frame(
+ ii = [1,2,26],
+ y = ["a", "b", "z"]
+ )
+
+DF3 = data_frame(
+ ii = [26],
+ z = ["z"]
+ )
+
+@pytest.fixture(scope = "module")
+def df1(backend):
+ return backend.load_df(DF1)
+
+@pytest.fixture(scope = "module")
+def df2(backend):
+ return backend.load_df(DF2)
+
+@pytest.fixture(scope = "module")
+def df2_jj(backend):
+ return backend.load_df(DF2.rename(columns = {"ii": "jj"}))
+
+@pytest.fixture(scope = "module")
+def df3(backend):
+ return backend.load_df(DF3)
+
+
+
+@backend_sql("TODO: pandas")
+def test_join_diff_vars_keeps_left(backend, df1, df2_jj):
+ out = inner_join(df1, df2_jj, {"ii": "jj"}) >> collect()
+
+ assert out.columns.tolist() == ["ii", "x", "y"]
+
+def test_join_on_str_arg(df1, df2):
+ out = inner_join(df1, df2, "ii") >> collect()
+
+ target = DF1.iloc[:2,].assign(y = ["a", "b"])
+ assert_frame_sort_equal(out, target)
+
+def test_join_on_list_arg(backend):
+ # TODO: how to validate how cols are being matched up?
+ data = DF1.assign(jj = lambda d: d.ii)
+ df_a = backend.load_df(data)
+ df_b = backend.load_df(DF2.assign(jj = lambda d: d.ii))
+ out = inner_join(df_a, df_b, ["ii", "jj"]) >> collect()
+
+ assert_frame_sort_equal(out, data.iloc[:2, :].assign(y = ["a", "b"]))
+
+@pytest.mark.skip("TODO: note, unsure of this syntax")
+def test_join_on_same_col_multiple_times():
+ data = data_frame(ii = [1,2,3], jj = [1,2, 9])
+ df_a = backend.load_df(data)
+ df_b = backend.load_df(data_frame(ii = [1,2,3]))
+
+ out = inner_join(df_a, df_b, {("ii", "jj"): "ii"}) >> collect()
+ # keeps all but last row
+ assert_frame_sort_equal(out, data.iloc[:2,])
+
+def test_join_on_missing_col(df1, df2):
+ with pytest.raises(KeyError):
+ inner_join(df1, df2, {"ABCDEF": "ii"})
+
+ with pytest.raises(KeyError):
+ inner_join(df1, df2, {"ii": "ABCDEF"})
+
+def test_join_suffixes_dupe_names(df1):
+ out = inner_join(df1, df1, {"ii": "ii"}) >> collect()
+ non_index_cols = DF1.columns[DF1.columns != "ii"]
+ assert all((non_index_cols + "_x").isin(out))
+ assert all((non_index_cols + "_y").isin(out))
+
+
+
+# Test basic join types -------------------------------------------------------
+
+def test_basic_left_join(df1, df2):
+ out = left_join(df1, df2, {"ii": "ii"}) >> collect()
+ target = DF1.assign(y = ["a", "b", None, None])
+ assert_frame_sort_equal(out, target)
+
+@backend_sql("TODO: pandas returns columns in rev name order")
+def test_basic_right_join(backend, df1, df2):
+ # same as left join, but flip df arguments
+ out = right_join(df2, df1, {"ii": "ii"}) >> collect()
+ target = DF1.assign(y = ["a", "b", None, None])
+ assert_frame_sort_equal(out, target)
+
+def test_basic_inner_join(df1, df2):
+ out = inner_join(df1, df2, {"ii": "ii"}) >> collect()
+ target = DF1.iloc[:2,:].assign(y = ["a", "b"])
+ assert_frame_sort_equal(out, target)
+
+@backend_sql("TODO: pandas - full should be converted to 'outer'")
+@pytest.mark.skip_backend("sqlite")
+def test_basic_full_join(backend, df1, df2):
+ out = full_join(df1, df2, {"ii": "ii"}) >> collect()
+ target = DF1.merge(DF2, on = "ii", how = "outer")
+ assert_frame_sort_equal(out, target)
+
+@backend_sql("TODO: pandas - key error?")
+def test_basic_semi_join(backend, df1, df2):
+ assert_frame_sort_equal(
+ semi_join(df1, df2, {"ii": "ii"}) >> collect(),
+ DF1.iloc[:2,]
+ )
+
+@backend_sql("TODO: pandas - implement anti join")
+def test_basic_anti_join(backend, df1, df2):
+ assert_frame_sort_equal(
+ anti_join(df1, df2, on = {"ii": "ii"}) >> collect(),
+ DF1.iloc[2:,]
+ )
+
diff --git a/siuba/tests/test_verb_mutate.py b/siuba/tests/test_verb_mutate.py
new file mode 100644
index 00000000..da92bbb7
--- /dev/null
+++ b/siuba/tests/test_verb_mutate.py
@@ -0,0 +1,112 @@
+"""
+Note: this test file was heavily influenced by its dbplyr counterpart.
+
+https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R
+"""
+
+from siuba import _, mutate, select, group_by, summarize, filter
+from siuba.dply.vector import row_number
+
+import pytest
+from .helpers import assert_equal_query, data_frame, backend_notimpl, backend_sql
+from string import ascii_lowercase
+
+DATA = data_frame(a = [1,2,3], b = [9,8,7])
+
+@pytest.fixture(scope = "module")
+def dfs(backend):
+ return backend.load_df(DATA)
+
+@pytest.mark.parametrize("query, output", [
+ (mutate(x = _.a + _.b), DATA.assign(x = [10, 10, 10])),
+ pytest.param( mutate(x = _.a + _.b) >> summarize(ttl = _.x.sum()), data_frame(ttl = 30.0), marks = pytest.mark.skip("TODO: failing sqlite?")),
+ (mutate(x = _.a + 1, y = _.b - 1), DATA.assign(x = [2,3,4], y = [8,7,6])),
+ (mutate(x = _.a + 1) >> mutate(y = _.b - 1), DATA.assign(x = [2,3,4], y = [8,7,6])),
+ (mutate(x = _.a + 1, y = _.x + 1), DATA.assign(x = [2,3,4], y = [3,4,5]))
+ ])
+def test_mutate_basic(dfs, query, output):
+ assert_equal_query(dfs, query, output)
+
+@pytest.mark.parametrize("query, output", [
+ (mutate(x = 1), DATA.assign(x = 1)),
+ (mutate(x = "a"), DATA.assign(x = "a")),
+ (mutate(x = 1.2), DATA.assign(x = 1.2))
+ ])
+def test_mutate_literal(dfs, query, output):
+ assert_equal_query(dfs, query, output)
+
+
+def test_select_mutate_filter(dfs):
+ assert_equal_query(
+ dfs,
+ select(_.x == _.a) >> mutate(y = _.x * 2) >> filter(_.y == 2),
+ data_frame(x = 1, y = 2)
+ )
+
+@pytest.mark.skip("TODO: check most recent vars for efficient mutate (#41)")
+def test_mutate_smart_nesting(dfs):
+ # y and z both use x, so should create only 1 extra query
+ lazy_tbl = dfs >> mutate(x = _.a + 1, y = _.x + 1, z = _.x + 1)
+
+ query = lazy_tbl.last_op.fromclause
+
+ assert query is lazy_tbl.ops[0]
+ assert isinstance(query.fromclause, sqlalchemy.Table )
+
+
+@pytest.mark.skip("TODO: does pandas backend preserve order? (#42)")
+def test_mutate_reassign_column_ordering(dfs):
+ assert_equal_query(
+ dfs,
+ mutate(c = 3, a = 1, b = 2),
+ data_frame(a = 1, b = 2, c = 3)
+ )
+
+
+@backend_sql
+@backend_notimpl("sqlite")
+def test_mutate_window_funcs(backend):
+ data = data_frame(x = range(1, 5), g = [1,1,2,2])
+ dfs = backend.load_df(data)
+ assert_equal_query(
+ dfs,
+ group_by(_.g) >> mutate(row_num = row_number(_).astype(float)),
+ data.assign(row_num = [1.0, 2, 1, 2])
+ )
+
+
+@backend_notimpl("sqlite")
+def test_mutate_using_agg_expr(backend):
+ data = data_frame(x = range(1, 5), g = [1,1,2,2])
+ dfs = backend.load_df(data)
+ assert_equal_query(
+ dfs,
+ group_by(_.g) >> mutate(y = _.x - _.x.mean()),
+ data.assign(y = [-.5, .5, -.5, .5])
+ )
+
+@backend_sql # TODO: pandas outputs a int column
+@backend_notimpl("sqlite")
+def test_mutate_using_cuml_agg(backend):
+ data = data_frame(x = range(1, 5), g = [1,1,2,2])
+ dfs = backend.load_df(data)
+
+ # cuml window without arrange before generates warning
+ with pytest.warns(None):
+ assert_equal_query(
+ dfs,
+ group_by(_.g) >> mutate(y = _.x.cumsum()),
+ data.assign(y = [1.0, 3, 3, 7])
+ )
+
+def test_mutate_overwrites_prev(backend):
+ # TODO: check that query doesn't generate a CTE
+ dfs = backend.load_df(data_frame(x = range(1, 5), g = [1,1,2,2]))
+ assert_equal_query(
+ dfs,
+ mutate(x = _.x + 1) >> mutate(x = _.x + 1),
+ data_frame(x = [3,4,5,6], g = [1,1,2,2])
+ )
+
+
+
diff --git a/siuba/tests/test_verb_select.py b/siuba/tests/test_verb_select.py
new file mode 100644
index 00000000..91b63889
--- /dev/null
+++ b/siuba/tests/test_verb_select.py
@@ -0,0 +1,56 @@
+"""
+Note: this test file was heavily influenced by its dbplyr counterpart.
+
+https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-select.R
+"""
+
+from siuba import _, mutate, select, group_by, rename
+
+import pytest
+from .helpers import assert_equal_query, data_frame, backend_notimpl, backend_sql
+from string import ascii_lowercase
+
+DATA = data_frame(a = 1, b = 2, c = 3)
+
+@pytest.fixture(scope = "module")
+def dfs(backend):
+ return backend.load_df(DATA)
+
+@pytest.mark.parametrize("query, output", [
+ ( select(_.c), data_frame(c = 3) ),
+ ( select(_.b == _.c), data_frame(b = 3) ),
+ ( select(_["a":"c"]), data_frame(a = 1, b = 2, c = 3) ),
+ ( select(_[_.a:_.c]), data_frame(a = 1, b = 2, c = 3) ),
+ ( select(_.a, _.b) >> select(_.b), data_frame(b = 2) ),
+ ( mutate(a = _.b + _.c) >> select(_.a), data_frame(a = 5) ),
+ pytest.param( group_by(_.a) >> select(_.b), data_frame(b = 2, a = 1), marks = pytest.mark.xfail),
+ ])
+def test_select_siu(dfs, query, output):
+ assert_equal_query(dfs, query, output)
+
+
+@pytest.mark.skip("TODO: #63")
+def test_select_kwargs(dfs):
+ assert_equal_query(dfs, select(x = _.a), data_frame(x = 1))
+
+
+# Rename ----------------------------------------------------------------------
+
+@pytest.mark.parametrize("query, output", [
+ ( rename(A = _.a), data_frame(A = 1, b = 2, c = 3) ),
+ ( rename(A = "a"), data_frame(A = 1, b = 2, c = 3) ),
+ ( rename(A = _.a, B = _.c), data_frame(A = 1, b = 2, B = 3) ),
+ ( rename(A = "a", B = "c"), data_frame(A = 1, b = 2, B = 3) )
+ ])
+def test_rename_siu(dfs, query, output):
+ assert_equal_query(dfs, query, output)
+
+
+@backend_sql("TODO: pandas - grouped df rename")
+@pytest.mark.parametrize("query, output", [
+ ( group_by(_.a) >> rename(z = _.a), data_frame(z = 1, b = 2, c = 3) ),
+ ( group_by(_.a) >> rename(z = "a"), data_frame(z = 1, b = 2, c = 3) )
+ ])
+def test_grouped_rename_siu(backend, dfs, query, output):
+ assert_equal_query(dfs, query, output)
+
diff --git a/siuba/tests/test_verb_summarize.py b/siuba/tests/test_verb_summarize.py
new file mode 100644
index 00000000..91062443
--- /dev/null
+++ b/siuba/tests/test_verb_summarize.py
@@ -0,0 +1,100 @@
+"""
+Note: this test file was heavily influenced by its dbplyr counterpart.
+
+https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R
+"""
+
+from siuba import _, mutate, select, group_by, summarize, filter
+from siuba.dply.vector import row_number, n
+
+import pytest
+from .helpers import assert_equal_query, data_frame, backend_notimpl, backend_sql
+from string import ascii_lowercase
+
+DATA = data_frame(x = [1,2,3,4], g = ['a', 'a', 'b', 'b'])
+
+@pytest.fixture(scope = "module")
+def df(backend):
+ return backend.load_df(DATA)
+
+@pytest.fixture(scope = "module")
+def df_float(backend):
+ return backend.load_df(DATA.assign(x = lambda d: d.x.astype(float)))
+
+@pytest.fixture(scope = "module")
+def gdf(df):
+ return df >> group_by(_.g)
+
+
+@pytest.mark.parametrize("query, output", [
+ (summarize(y = n(_)), data_frame(y = 4)),
+ (summarize(y = _.x.min()), data_frame(y = 1)),
+ ])
+def test_summarize_ungrouped(df, query, output):
+ assert_equal_query(df, query, output)
+
+
+@pytest.mark.skip("TODO: should return 1 row (#63)")
+def test_ungrouped_summarize_literal(df, query, output):
+ assert_equal_query(df, summarize(y = 1), data_frame(y = 1))
+
+
+@backend_notimpl("sqlite")
+def test_summarize_after_mutate_cuml_win(backend, df_float):
+ assert_equal_query(
+ df_float,
+ mutate(y = _.x.cumsum()) >> summarize(z = _.y.max()),
+ data_frame(z = [10.])
+ )
+
+
+@backend_sql
+def test_summarize_keeps_group_vars(backend, gdf):
+ q = gdf >> summarize(n = n(_))
+ assert list(q.last_op.c.keys()) == ["g", "n"]
+
+
+@pytest.mark.parametrize("query, output", [
+ (summarize(y = 1), data_frame(g = ['a', 'b'], y = [1, 1])),
+ (summarize(y = n(_)), data_frame(g = ['a', 'b'], y = [2,2])),
+ (summarize(y = _.x.min()), data_frame(g = ['a', 'b'], y = [1, 3])),
+ # TODO: same issue as above
+ #(mutate(y = _.x.cumsum()) >> summarize(z = _.y.max()), data_frame(y = [3, 7]))
+ ])
+def test_summarize_grouped(gdf, query, output):
+ assert_equal_query(gdf, query, output)
+
+
+@pytest.mark.skip("TODO: (#48)")
+def test_summarize_removes_1_grouping(backend):
+ data = data_frame(a = 1, b = 2, c = 3)
+ df = backend.load_df(data)
+
+ q1 = df >> group_by(_.a, _.b) >> summarize(n = n(_))
+ assert q1.group_by == ("a")
+
+ q2 = q1 >> summarize(n = n(_))
+ assert not len(q2.group_by)
+
+
+@backend_sql("TODO: pandas - need to implement or raise this warning")
+def test_summarize_no_same_call_var_refs(backend, df):
+ with pytest.raises(NotImplementedError):
+ df >> summarize(y = _.x.min(), z = _.y + 1)
+
+
+@backend_sql
+def test_summarize_removes_order_vars(backend, df):
+ lazy_tbl = df >> summarize(n = n(_))
+
+ assert not len(lazy_tbl.order_by)
+
+
+@pytest.mark.skip("TODO (see #50)")
+def test_summarize_unnamed_args(df):
+ assert_equal_query(
+ df,
+ summarize(n(_)),
+ pd.DataFrame({'n(_)': 4})
+ )
+
diff --git a/siuba/tests/test_verb_utils.py b/siuba/tests/test_verb_utils.py
new file mode 100644
index 00000000..3ec00c4b
--- /dev/null
+++ b/siuba/tests/test_verb_utils.py
@@ -0,0 +1,20 @@
+from siuba.sql.verbs import collect, show_query, LazyTbl
+from siuba.dply.verbs import Pipeable
+from .helpers import data_frame
+import pandas as pd
+
+import pytest
+
+@pytest.fixture(scope = "module")
+def df(backend):
+ return backend.load_df(data_frame(x = [1,2,3]))
+
+def test_show_query(df):
+ assert isinstance(show_query(df), df.__class__)
+ assert isinstance(df >> show_query(), df.__class__)
+ assert isinstance(show_query(), Pipeable)
+
+def test_collect(df):
+ assert isinstance(collect(df), pd.DataFrame)
+ assert isinstance(df >> collect(), pd.DataFrame)
+ assert isinstance(collect(), Pipeable)