Skip to content

Commit 99a3682

Browse files
authored
Merge pull request #476 from machow/misc-fixes
Misc fixes
2 parents 4099f48 + 0c4d103 commit 99a3682

File tree

11 files changed

+95
-18
lines changed

11 files changed

+95
-18
lines changed

siuba/dply/verbs.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2545,7 +2545,7 @@ def _extract_gdf(__data, *args, **kwargs):
25452545

25462546
# tbl ----
25472547

2548-
from siuba.siu._databackend import SqlaEngine
2548+
from siuba.siu._databackend import SqlaEngine, PlDataFrame, PdDataFrame
25492549

25502550
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
25512551
def tbl(src, *args, **kwargs):
@@ -2613,8 +2613,10 @@ def _tbl_sqla(src: SqlaEngine, table_name, columns=None):
26132613

26142614
# TODO: once we subclass LazyTbl per dialect (e.g. duckdb), we can move out
26152615
# this dialect specific logic.
2616-
if src.dialect.name == "duckdb" and isinstance(columns, pd.DataFrame):
2617-
src.execute("register", (table_name, columns))
2616+
if src.dialect.name == "duckdb" and isinstance(columns, (PdDataFrame, PlDataFrame)):
2617+
with src.begin() as conn:
2618+
conn.exec_driver_sql("register", (table_name, columns))
2619+
26182620
return LazyTbl(src, table_name)
26192621

26202622
return LazyTbl(src, table_name, columns=columns)

siuba/experimental/pivot/pivot_wide.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,9 @@ def pivot_wider_spec(
403403
# validate names and move id vars to columns ----
404404
# note: in pandas 1.5+ we can use the allow_duplicates option to reset, even
405405
# when index and column names overlap. for now, repair names, rename, then reset.
406-
unique_names = vec_as_names([*id_vars, *wide.columns], repair="unique")
407-
repaired_names = vec_as_names([*id_vars, *wide.columns], repair=names_repair)
406+
_all_raw_names = list(map(str, [*id_vars, *wide.columns]))
407+
unique_names = vec_as_names(_all_raw_names, repair="unique")
408+
repaired_names = vec_as_names(_all_raw_names, repair=names_repair)
408409

409410
uniq_id_vars = unique_names[:len(id_vars)]
410411
uniq_val_vars = unique_names[len(id_vars):]

siuba/experimental/pivot/sql_pivot_wide.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def _pivot_wider_spec(
131131

132132
wide_id_cols = [sel_cols[id_] for id_ in id_vars]
133133

134-
repaired_names = vec_as_names([*id_vars, *spec[".name"]], repair=names_repair)
134+
_all_raw_names = list(map(str, [*id_vars, *spec[".name"]]))
135+
repaired_names = vec_as_names(_all_raw_names, repair=names_repair)
135136
labeled_cols = [
136137
col.label(name) for name, col in
137138
zip(repaired_names, [*wide_id_cols, *wide_name_cols])

siuba/siu/_databackend.py

+4
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,9 @@ def __subclasshook__(cls, subclass):
4848
# Implementations -------------------------------------------------------------
4949

5050
class SqlaEngine(AbstractBackend): pass
51+
class PlDataFrame(AbstractBackend): pass
52+
class PdDataFrame(AbstractBackend): pass
5153

5254
SqlaEngine.register_backend("sqlalchemy.engine", "Connectable")
55+
PlDataFrame.register_backend("polars", "DataFrame")
56+
PdDataFrame.register_backend("pandas", "DataFrame")

siuba/siu/calls.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __rshift__(self, x):
201201
stripped = strip_symbolic(x)
202202

203203
if isinstance(stripped, Call):
204-
return self._construct_pipe(MetaArg("_"), self, x)
204+
return self._construct_pipe(self, x)
205205

206206
raise TypeError()
207207

@@ -308,9 +308,28 @@ def obj_name(self):
308308

309309
return None
310310

311-
@classmethod
312-
def _construct_pipe(cls, *args):
313-
return PipeCall(*args)
311+
@staticmethod
312+
def _construct_pipe(lhs, rhs):
313+
if isinstance(lhs, PipeCall):
314+
lh_args = lhs.args
315+
316+
# ensure we don't keep adding MetaArg to the left when
317+
# combining two pipes
318+
if lh_args and isinstance(lh_args[0], MetaArg):
319+
lh_args = lh_args[1:]
320+
else:
321+
lh_args = [lhs]
322+
323+
if isinstance(rhs, PipeCall):
324+
rh_args = rhs.args
325+
326+
# similar to above, but for rh args
327+
if rh_args and isinstance(rh_args[0], MetaArg):
328+
rh_args = rh_args[1:]
329+
else:
330+
rh_args = [rhs]
331+
332+
return PipeCall(MetaArg("_"), *lh_args, *rh_args)
314333

315334

316335
class Lazy(Call):
@@ -674,8 +693,15 @@ class PipeCall(Call):
674693
"""
675694

676695
def __init__(self, func, *args, **kwargs):
677-
self.func = "__siu_pipe_call__"
678-
self.args = (func, *args)
696+
if isinstance(func, str) and func == "__siu_pipe_call__":
697+
# it was a mistake to make func the first parameter to Call
698+
# but basically we need to catch when it is passed, so
699+
# we can ignore it
700+
self.func = func
701+
self.args = args
702+
else:
703+
self.func = "__siu_pipe_call__"
704+
self.args = (func, *args)
679705
if kwargs:
680706
raise ValueError("Keyword arguments are not allowed.")
681707
self.kwargs = {}

siuba/siu/symbolic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __rshift__(self, x):
8585

8686
if isinstance(stripped, Call):
8787
lhs_call = self.__source
88-
return Call._construct_pipe(MetaArg("_"), lhs_call, stripped)
88+
return self.__class__(Call._construct_pipe(lhs_call, stripped))
8989
# strip_symbolic(self)(x)
9090
# x is a symbolic
9191
raise NotImplementedError("Symbolic may only be used on right-hand side of >> operator.")

siuba/sql/across.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ def _across_sql_cols(
5151
lazy_tbl = ctx_verb_data.get()
5252
window = ctx_verb_window.get()
5353

54+
column_names = list(__data.keys())
55+
5456
name_template = _get_name_template(fns, names)
55-
selected_cols = var_select(__data, *var_create(cols), data=__data)
57+
selected_cols = var_select(column_names, *var_create(cols), data=__data)
5658

5759
fns_map = _across_setup_fns(fns)
5860

siuba/sql/dialects/base.py

+15
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,20 @@ def sql_func_capitalize(_, col):
127127
rest = fn.right(col, fn.length(col) - 1)
128128
return sql.functions.concat(first_char, rest)
129129

130+
def sql_str_cat(_, col, others=None, sep=None, na_rep=None, join=None):
131+
if sep is not None:
132+
raise NotImplementedError("sep argument not supported for sql cat")
133+
134+
if na_rep is not None:
135+
raise NotImplementedError("na_rep argument not supported for sql cat")
136+
137+
if join is not None:
138+
raise NotImplementedError("join argument not supported for sql cat")
139+
140+
if isinstance(others, (list, tuple)):
141+
raise NotImplementedError("others argument must be a single column for sql cat")
142+
143+
return sql.functions.concat(col, others)
130144

131145
# Numpy ufuncs ----------------------------------------------------------------
132146
# symbolic objects have a generic dispatch for when _.__array_ufunc__ is called,
@@ -252,6 +266,7 @@ def req_bool(f):
252266
**{
253267
# TODO: check generality of trim functions, since MYSQL overrides
254268
"str.capitalize" : sql_func_capitalize,
269+
"str.cat" : sql_str_cat,
255270
#"str.center" :,
256271
#"str.contains" :,
257272
#"str.count" :,

siuba/sql/dialects/duckdb.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
sql_not_impl,
1414
# wiring up translator
1515
extend_base,
16-
SqlTranslator
16+
SqlTranslator,
17+
convert_literal
1718
)
1819

1920
from .postgresql import (
@@ -44,6 +45,15 @@ def returns_int(func_names):
4445
f_annotated = wrap_annotate(f_concrete, result_type="int")
4546
generic.register(DuckdbColumn, f_annotated)
4647

48+
# Literal Conversions =========================================================
49+
50+
@convert_literal.register
51+
def _cl_duckdb(codata: DuckdbColumn, lit):
52+
from sqlalchemy.dialects.postgresql import array
53+
if isinstance(lit, list):
54+
return array(lit)
55+
56+
return sql.literal(lit)
4757

4858
# Translations ================================================================
4959

siuba/sql/translate.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def wrapper(*args, **kwargs):
252252
# Translator =================================================================
253253

254254
from siuba.ops.translate import create_pandas_translator
255+
from functools import singledispatch
255256

256257

257258
def extend_base(cls, **kwargs):
@@ -323,7 +324,7 @@ def shape_call(
323324
# verbs that can use strings as accessors, like group_by, or
324325
# arrange, need to convert those strings into a getitem call
325326
return str_to_getitem_call(call)
326-
elif isinstance(call, sql.elements.ColumnClause):
327+
elif isinstance(call, (sql.elements.ClauseElement)):
327328
return Lazy(call)
328329
elif callable(call):
329330
#TODO: should not happen here
@@ -332,7 +333,8 @@ def shape_call(
332333
else:
333334
# verbs that use literal strings, need to convert them to a call
334335
# that returns a sqlalchemy "literal" object
335-
return Lazy(sql.literal(call))
336+
_lit = convert_literal(self.window.dispatch_cls(), call)
337+
return Lazy(_lit)
336338

337339
# raise informative error message if missing translation
338340
try:
@@ -367,3 +369,7 @@ def from_mappings(WinCls, AggCls):
367369
aggregate = create_pandas_translator(ALL_OPS, AggCls, sql.elements.ClauseElement)
368370
)
369371

372+
373+
@singledispatch
374+
def convert_literal(codata, lit):
375+
return sql.literal(lit)

siuba/tests/test_siu_dispatchers.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
import pytest
22

3+
from siuba.siu.calls import PipeCall
34
from siuba.siu.dispatchers import call
4-
from siuba.siu import _
5+
from siuba.siu import _, strip_symbolic, Symbolic
56

67
# TODO: direct test of lazy elements
78
# TODO: NSECall - no map subcalls
89

10+
11+
def test_siu_pipe_call_is_flat():
12+
pipe_expr = _ >> _.a >> _.b
13+
pipe_call = strip_symbolic(pipe_expr)
14+
15+
assert isinstance(pipe_expr, Symbolic)
16+
assert isinstance(pipe_call, PipeCall)
17+
assert len(pipe_call.args) == 4
18+
919
def test_siu_call_no_args():
1020
assert 1 >> call(range) == range(1)
1121

0 commit comments

Comments
 (0)