Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 115 additions & 22 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -283,57 +285,115 @@ def get_sql(self, **kwargs: Any) -> str:
raise NotImplementedError()


def idx_placeholder_gen(idx: int) -> str:
return str(idx + 1)


def named_placeholder_gen(idx: int) -> str:
return f"param{idx + 1}"


class Parameter(Term):
is_aggregate = None

def __init__(self, placeholder: Union[str, int]) -> None:
super().__init__()
self.placeholder = placeholder
self._placeholder = placeholder

@property
def placeholder(self):
return self._placeholder

def get_sql(self, **kwargs: Any) -> str:
return str(self.placeholder)

def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
pass

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder

class QmarkParameter(Parameter):
"""Question mark style, e.g. ...WHERE name=?"""

def __init__(self) -> None:
pass
class ListParameter(Parameter):
def __init__(
self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen
) -> None:
super().__init__(placeholder=placeholder)
self._parameters = list()

def get_sql(self, **kwargs: Any) -> str:
@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

return str(self._placeholder)

def get_parameters(self, **kwargs):
return self._parameters

def update_parameters(self, value: Any, **kwargs):
self._parameters.append(value)


class DictParameter(Parameter):
def __init__(
self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen
) -> None:
super().__init__(placeholder=placeholder)
self._parameters = dict()

@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

return str(self._placeholder)

def get_parameters(self, **kwargs):
return self._parameters

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[1:]

def update_parameters(self, param_key: Any, value: Any, **kwargs):
self._parameters[param_key] = value


class QmarkParameter(ListParameter):
def get_sql(self, **kwargs):
return "?"


class NumericParameter(Parameter):
class NumericParameter(ListParameter):
"""Numeric, positional style, e.g. ...WHERE name=:1"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class NamedParameter(Parameter):
"""Named style, e.g. ...WHERE name=:name"""
class FormatParameter(ListParameter):
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)

return "%s"

class FormatParameter(Parameter):
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""

def __init__(self) -> None:
pass
class NamedParameter(DictParameter):
"""Named style, e.g. ...WHERE name=:name"""

def get_sql(self, **kwargs: Any) -> str:
return "%s"
return ":{placeholder}".format(placeholder=self.placeholder)


class PyformatParameter(Parameter):
class PyformatParameter(DictParameter):
"""Python extended format codes, e.g. ...WHERE name=%(name)s"""

def get_sql(self, **kwargs: Any) -> str:
return "%({placeholder})s".format(placeholder=self.placeholder)

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[2:-2]


class Negative(Term):
def __init__(self, term: Term) -> None:
Expand Down Expand Up @@ -382,13 +442,46 @@ def get_value_sql(self, **kwargs: Any) -> str:
return "null"
return str(self.value)

def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = parameter.get_sql(**kwargs)
param_key = parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key

def get_sql(
self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any
self,
quote_char: Optional[str] = None,
secondary_quote_char: str = "'",
parameter: Parameter = None,
**kwargs: Any,
) -> str:
sql = self.get_value_sql(
quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs
)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
if parameter is None:
sql = self.get_value_sql(
quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs
)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)

# Don't stringify numbers when using a parameter
if isinstance(self.value, (int, float)):
value_sql = self.value
else:
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
param_sql, param_key = self._get_param_data(parameter, **kwargs)
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)

return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)


class ParameterValueWrapper(ValueWrapper):
def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None:
super().__init__(value, alias)
self._parameter = parameter

def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = self._parameter.get_sql(**kwargs)
param_key = self._parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key


class JSON(Term):
Expand Down
114 changes: 114 additions & 0 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from datetime import date

from pypika import (
FormatParameter,
Expand All @@ -10,6 +11,7 @@
Query,
Tables,
)
from pypika.terms import ListParameter, ParameterValueWrapper


class ParametrizedTests(unittest.TestCase):
Expand Down Expand Up @@ -98,3 +100,115 @@ def test_format_parameter(self):

def test_pyformat_parameter(self):
self.assertEqual("%(buz)s", PyformatParameter("buz").get_sql())


class ParametrizedTestsWithValues(unittest.TestCase):
table_abc, table_efg = Tables("abc", "efg")

def test_param_insert(self):
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo")

parameter = QmarkParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql)
self.assertEqual([1, 2.2, "foo"], parameter.get_parameters())

def test_param_select_join(self):
q = (
Query.from_(self.table_abc)
.select("*")
.where(self.table_abc.category == "foobar")
.join(self.table_efg)
.on(self.table_abc.id == self.table_efg.abc_id)
.where(self.table_efg.date >= date(2024, 2, 22))
.limit(10)
)

parameter = FormatParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10',
sql,
)
self.assertEqual(["foobar", "2024-02-22"], parameter.get_parameters())

def test_param_select_subquery(self):
q = (
Query.from_(self.table_abc)
.select("*")
.where(self.table_abc.category == "foobar")
.where(
self.table_abc.id.isin(
Query.from_(self.table_efg)
.select(self.table_efg.abc_id)
.where(self.table_efg.date >= date(2024, 2, 22))
)
)
.limit(10)
)

parameter = ListParameter(placeholder=lambda idx: f"&{idx+1}")
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10',
sql,
)
self.assertEqual(["foobar", "2024-02-22"], parameter.get_parameters())

def test_join(self):
subquery = (
Query.from_(self.table_efg)
.select(self.table_efg.fiz, self.table_efg.buz)
.where(self.table_efg.buz == "buz")
)

q = (
Query.from_(self.table_abc)
.join(subquery)
.on(self.table_abc.bar == subquery.buz)
.select(self.table_abc.foo, subquery.fiz)
.where(self.table_abc.bar == "bar")
)

parameter = NamedParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)'
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2',
sql,
)
self.assertEqual({"param1": "buz", "param2": "bar"}, parameter.get_parameters())

def test_join_with_parameter_value_wrapper(self):
subquery = (
Query.from_(self.table_efg)
.select(self.table_efg.fiz, self.table_efg.buz)
.where(self.table_efg.buz == ParameterValueWrapper(Parameter(":buz"), "buz"))
)

q = (
Query.from_(self.table_abc)
.join(subquery)
.on(self.table_abc.bar == subquery.buz)
.select(self.table_abc.foo, subquery.fiz)
.where(self.table_abc.bar == ParameterValueWrapper(NamedParameter("bar"), "bar"))
)

parameter = NamedParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)'
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar',
sql,
)
self.assertEqual({":buz": "buz", "bar": "bar"}, parameter.get_parameters())

def test_pyformat_parameter(self):
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo")

parameter = PyformatParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql
)
self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters())
Loading