diff --git a/pypika/terms.py b/pypika/terms.py index 90ceb89..44c556f 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -7,12 +7,14 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Iterable, Iterator, List, Optional, Sequence, Set, + Tuple, Type, TypeVar, Union, @@ -283,57 +285,115 @@ def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() +def idx_placeholder_gen(idx: int) -> str: + return str(idx + 1) + + +def named_placeholder_gen(idx: int) -> str: + return f"param{idx + 1}" + + class Parameter(Term): is_aggregate = None def __init__(self, placeholder: Union[str, int]) -> None: super().__init__() - self.placeholder = placeholder + self._placeholder = placeholder + + @property + def placeholder(self): + return self._placeholder def get_sql(self, **kwargs: Any) -> str: return str(self.placeholder) + def update_parameters(self, param_key: Any, param_value: Any, **kwargs): + pass + + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder -class QmarkParameter(Parameter): - """Question mark style, e.g. ...WHERE name=?""" - def __init__(self) -> None: - pass +class ListParameter(Parameter): + def __init__( + self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen + ) -> None: + super().__init__(placeholder=placeholder) + self._parameters = list() - def get_sql(self, **kwargs: Any) -> str: + @property + def placeholder(self) -> str: + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) + + return str(self._placeholder) + + def get_parameters(self, **kwargs): + return self._parameters + + def update_parameters(self, value: Any, **kwargs): + self._parameters.append(value) + + +class DictParameter(Parameter): + def __init__( + self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen + ) -> None: + super().__init__(placeholder=placeholder) + self._parameters = dict() + + @property + def placeholder(self) -> str: + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) + + return str(self._placeholder) + + def get_parameters(self, **kwargs): + return self._parameters + + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[1:] + + def update_parameters(self, param_key: Any, value: Any, **kwargs): + self._parameters[param_key] = value + + +class QmarkParameter(ListParameter): + def get_sql(self, **kwargs): return "?" -class NumericParameter(Parameter): +class NumericParameter(ListParameter): """Numeric, positional style, e.g. ...WHERE name=:1""" def get_sql(self, **kwargs: Any) -> str: return ":{placeholder}".format(placeholder=self.placeholder) -class NamedParameter(Parameter): - """Named style, e.g. ...WHERE name=:name""" +class FormatParameter(ListParameter): + """ANSI C printf format codes, e.g. ...WHERE name=%s""" def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) - + return "%s" -class FormatParameter(Parameter): - """ANSI C printf format codes, e.g. ...WHERE name=%s""" - def __init__(self) -> None: - pass +class NamedParameter(DictParameter): + """Named style, e.g. ...WHERE name=:name""" def get_sql(self, **kwargs: Any) -> str: - return "%s" + return ":{placeholder}".format(placeholder=self.placeholder) -class PyformatParameter(Parameter): +class PyformatParameter(DictParameter): """Python extended format codes, e.g. ...WHERE name=%(name)s""" def get_sql(self, **kwargs: Any) -> str: return "%({placeholder})s".format(placeholder=self.placeholder) + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[2:-2] + class Negative(Term): def __init__(self, term: Term) -> None: @@ -382,13 +442,46 @@ def get_value_sql(self, **kwargs: Any) -> str: return "null" return str(self.value) + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = parameter.get_sql(**kwargs) + param_key = parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key + def get_sql( - self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any + self, + quote_char: Optional[str] = None, + secondary_quote_char: str = "'", + parameter: Parameter = None, + **kwargs: Any, ) -> str: - sql = self.get_value_sql( - quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs - ) - return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + if parameter is None: + sql = self.get_value_sql( + quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs + ) + return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + + # Don't stringify numbers when using a parameter + if isinstance(self.value, (int, float)): + value_sql = self.value + else: + value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) + param_sql, param_key = self._get_param_data(parameter, **kwargs) + parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) + + return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) + + +class ParameterValueWrapper(ValueWrapper): + def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None: + super().__init__(value, alias) + self._parameter = parameter + + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = self._parameter.get_sql(**kwargs) + param_key = self._parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key class JSON(Term): diff --git a/tests/test_parameter.py b/tests/test_parameter.py index fb53bcf..59b6bc3 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,4 +1,5 @@ import unittest +from datetime import date from pypika import ( FormatParameter, @@ -10,6 +11,7 @@ Query, Tables, ) +from pypika.terms import ListParameter, ParameterValueWrapper class ParametrizedTests(unittest.TestCase): @@ -98,3 +100,115 @@ def test_format_parameter(self): def test_pyformat_parameter(self): self.assertEqual("%(buz)s", PyformatParameter("buz").get_sql()) + + +class ParametrizedTestsWithValues(unittest.TestCase): + table_abc, table_efg = Tables("abc", "efg") + + def test_param_insert(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo") + + parameter = QmarkParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql) + self.assertEqual([1, 2.2, "foo"], parameter.get_parameters()) + + def test_param_select_join(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == "foobar") + .join(self.table_efg) + .on(self.table_abc.id == self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + .limit(10) + ) + + parameter = FormatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10', + sql, + ) + self.assertEqual(["foobar", "2024-02-22"], parameter.get_parameters()) + + def test_param_select_subquery(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == "foobar") + .where( + self.table_abc.id.isin( + Query.from_(self.table_efg) + .select(self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + ) + ) + .limit(10) + ) + + parameter = ListParameter(placeholder=lambda idx: f"&{idx+1}") + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10', + sql, + ) + self.assertEqual(["foobar", "2024-02-22"], parameter.get_parameters()) + + def test_join(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == "buz") + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == "bar") + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2', + sql, + ) + self.assertEqual({"param1": "buz", "param2": "bar"}, parameter.get_parameters()) + + def test_join_with_parameter_value_wrapper(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == ParameterValueWrapper(Parameter(":buz"), "buz")) + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == ParameterValueWrapper(NamedParameter("bar"), "bar")) + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar', + sql, + ) + self.assertEqual({":buz": "buz", "bar": "bar"}, parameter.get_parameters()) + + def test_pyformat_parameter(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo") + + parameter = PyformatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql + ) + self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters())