Skip to content

Commit

Permalink
Merge pull request #36 from machow/feat-sql-tests
Browse files Browse the repository at this point in the history
[WIP] Feat sql tests
  • Loading branch information
machow authored Jun 1, 2019
2 parents 4921fd8 + 7712f38 commit 8949097
Show file tree
Hide file tree
Showing 27 changed files with 2,437 additions and 337 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ services:
- postgresql
env:
global:
- PGPORT=5432
- SB_TEST_PGPORT=5432
script:
- make test-travis
1 change: 0 additions & 1 deletion docs/.gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
*.ipynb filter=nbstripout

*.ipynb diff=ipynb
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
:hidden:

intro.Rmd
intro_sql.Rmd
intro_sql.ipynb

.. toctree::
:caption: Core One-table Verbs
Expand Down
119 changes: 106 additions & 13 deletions docs/intro_sql.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,131 @@ jupyter:
name: python3
---

```{python nbsphinx=hidden}
import matplotlib.cbook
import warnings
import plotnine
warnings.filterwarnings(module='plotnine*', action='ignore')
warnings.filterwarnings(module='matplotlib*', action='ignore')
# %matplotlib inline
```

# Using to query SQL


# Setting up

```{python}
from sqlalchemy import create_engine
from siuba.data import mtcars
import pandas as pd
from siuba.tests.helpers import copy_to_sql
from siuba import *
from siuba.dply.vector import lag, desc, row_number
from siuba.dply.string import str_c
engine = create_engine('sqlite:///:memory:', echo=False)
tv_ratings = pd.read_csv(
"https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2019/2019-01-08/IMDb_Economist_tv_ratings.csv",
parse_dates = ["date"]
)
# note that mtcars is a pandas DataFrame
mtcars.to_sql('mtcars', engine)
```

```{python}
from siuba import *
from siuba.sql import LazyTbl, show_query, collect
tbl_ratings = copy_to_sql(tv_ratings, "tv_ratings", "postgresql://postgres:@localhost:5433/postgres")
tbl_mtcars = LazyTbl(engine, 'mtcars')
```

```{python}
tbl_mtcars
tbl_ratings
```

## Inspecting a single show

```{python}
tbl_mtcars >> filter(_.hp > 250) >> collect()
buffy = (tbl_ratings
>> filter(_.title == "Buffy the Vampire Slayer")
>> collect()
)
buffy
```

```{python}
buffy >> summarize(avg_rating = _.av_rating.mean())
```

## Average rating per show, along with dates

```{python}
(tbl_mtcars
>> group_by(_.cyl)
>> summarize(avg_mpg = _.mpg.mean())
avg_ratings = (tbl_ratings
>> group_by(_.title)
>> summarize(
avg_rating = _.av_rating.mean(),
date_range = str_c(_.date.dt.year.max(), " - ", _.date.dt.year.min())
)
)
avg_ratings
```

## Biggest changes in ratings between two seasons

```{python}
top_4_shifts = (tbl_ratings
>> group_by(_.title)
>> mutate(rating_shift = _.av_rating - lag(_.av_rating))
>> summarize(
max_shift = _.rating_shift.max()
)
>> arrange(-_.max_shift)
>> head(4)
)
top_4_shifts
```

```{python}
big_shift_series = (top_4_shifts
>> select(_.title)
>> inner_join(_, tbl_ratings, "title")
>> collect()
)
from plotnine import *
(big_shift_series
>> ggplot(aes("seasonNumber", "av_rating"))
+ geom_point()
+ geom_line()
+ facet_wrap("~ title")
+ labs(
title = "Seasons with Biggest Shifts in Ratings",
y = "Average rating",
x = "Season"
)
)
```

## Do we have full data for each season?

```{python}
mismatches = (tbl_ratings
>> arrange(_.title, _.seasonNumber)
>> group_by(_.title)
>> mutate(
row = row_number(_),
mismatch = _.row != _.seasonNumber
)
>> filter(_.mismatch.any())
>> ungroup()
)
mismatches
```

```{python}
mismatches >> distinct(_.title) >> count() >> collect()
```
Loading

0 comments on commit 8949097

Please sign in to comment.