diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 5ccb115c..a60bb898 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -141,28 +141,41 @@ def ordered_union(x, y): class LazyTbl: def __init__( - self, source, tbl, ops = None, - group_by = tuple(), order_by = tuple(), funcs = None, + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), funcs = None, rm_attr = ('str', 'dt'), call_sub_attr = ('dt',) ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ # connection and dialect specific functions self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source self.funcs = get_dialect_funcs(self.source.dialect.name) if funcs is None else funcs - if isinstance(tbl, str): - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - self.tbl = sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(), - autoload_with = self.source, - schema = schema - ) - else: - self.tbl = tbl + self.tbl = self._create_table(tbl, columns, self.source) # important states the query can be in (e.g. grouped) self.ops = [sql.Select([self.tbl])] if ops is None else ops + self.group_by = group_by self.order_by = order_by @@ -170,6 +183,7 @@ def __init__( self.rm_attr = rm_attr self.call_sub_attr = call_sub_attr + def append_op(self, op, **kwargs): cpy = self.copy(**kwargs) cpy.ops = cpy.ops + [op] @@ -215,6 +229,37 @@ def get_ordered_col_names(self): def last_op(self): return self.ops[-1] if len(self.ops) else None + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sqlalchemy.Table): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + def _get_preview(self): # need to make prev op a cte, so we don't override any previous limit new_sel = sql.select([self.last_op.alias()]).limit(5) diff --git a/siuba/tests/test_sql_verbs.py b/siuba/tests/test_sql_verbs.py index 908c1f01..3fe32a65 100644 --- a/siuba/tests/test_sql_verbs.py +++ b/siuba/tests/test_sql_verbs.py @@ -41,6 +41,19 @@ def db(): conn.execute(ins, id=2, name='wendy', fullname='Wendy Williams') yield conn +# LazyTbl --------------------------------------------------------------------- + +def test_lazy_tbl_table_string(db): + tbl = LazyTbl(db, 'addresses') + tbl.tbl.columns.user_id + +def test_lazy_tbl_manual_columns(db): + tbl = LazyTbl(db, 'addresses', columns = ('user_id', 'wrong_name')) + tbl.tbl.columns.wrong_name + tbl.tbl.columns.user_id + + with pytest.raises(AttributeError): + tbl.tbl.columns.email_address # mutate ----------------------------------------------------------------------