Skip to content

Commit 08901d5

Browse files
feat: expose DataFrame.parse_sql_expr (apache#1274)
* feat: expose DataFrame.parse_sql_expr to python * Update python/tests/test_dataframe.py Co-authored-by: Tim Saucer <[email protected]> --------- Co-authored-by: Tim Saucer <[email protected]>
1 parent 16d4c03 commit 08901d5

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

python/datafusion/dataframe.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,28 @@ def filter(self, *predicates: Expr) -> DataFrame:
482482
df = df.filter(ensure_expr(p))
483483
return DataFrame(df)
484484

485+
def parse_sql_expr(self, expr: str) -> Expr:
486+
"""Creates logical expression from a SQL query text.
487+
488+
The expression is created and processed against the current schema.
489+
490+
Example::
491+
492+
from datafusion import col, lit
493+
df.parse_sql_expr("a > 1")
494+
495+
should produce:
496+
497+
col("a") > lit(1)
498+
499+
Args:
500+
expr: Expression string to be converted to datafusion expression
501+
502+
Returns:
503+
Logical expression .
504+
"""
505+
return Expr(self.df.parse_sql_expr(expr))
506+
485507
def with_column(self, name: str, expr: Expr) -> DataFrame:
486508
"""Add an additional column to the DataFrame.
487509

python/tests/test_dataframe.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,36 @@ def test_filter(df):
274274
assert result.column(2) == pa.array([5])
275275

276276

277+
def test_parse_sql_expr(df):
278+
plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan()
279+
plan2 = df.filter(column("a") > literal(2)).logical_plan()
280+
# object equality not implemented but string representation should match
281+
assert str(plan1) == str(plan2)
282+
283+
df1 = df.filter(df.parse_sql_expr("a > 2")).select(
284+
column("a") + column("b"),
285+
column("a") - column("b"),
286+
)
287+
288+
# execute and collect the first (and only) batch
289+
result = df1.collect()[0]
290+
291+
assert result.column(0) == pa.array([9])
292+
assert result.column(1) == pa.array([-3])
293+
294+
df.show()
295+
# verify that if there is no filter applied, internal dataframe is unchanged
296+
df2 = df.filter()
297+
assert df.df == df2.df
298+
299+
df3 = df.filter(df.parse_sql_expr("a > 1"), df.parse_sql_expr("b != 6"))
300+
result = df3.collect()[0]
301+
302+
assert result.column(0) == pa.array([2])
303+
assert result.column(1) == pa.array([5])
304+
assert result.column(2) == pa.array([5])
305+
306+
277307
def test_show_empty(df, capsys):
278308
df_empty = df.filter(column("a") > literal(3))
279309
df_empty.show()

src/dataframe.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,14 @@ impl PyDataFrame {
454454
Ok(Self::new(df))
455455
}
456456

457+
fn parse_sql_expr(&self, expr: PyBackedStr) -> PyDataFusionResult<PyExpr> {
458+
self.df
459+
.as_ref()
460+
.parse_sql_expr(&expr)
461+
.map(|e| PyExpr::from(e))
462+
.map_err(PyDataFusionError::from)
463+
}
464+
457465
fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
458466
let df = self.df.as_ref().clone().with_column(name, expr.into())?;
459467
Ok(Self::new(df))

0 commit comments

Comments
 (0)