From 64607147637d7b4e00d4bf9eb02ac0ba35cc46d8 Mon Sep 17 00:00:00 2001 From: Shimao Zheng Date: Thu, 21 Mar 2024 10:02:00 -0700 Subject: [PATCH] Prevent write operations in the SQL query --- flux_sdk/etl/data_models/query.py | 2 +- flux_sdk/etl/data_models/tests/test_query.py | 5 +++++ flux_sdk/flux_core/validation.py | 6 +++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/flux_sdk/etl/data_models/query.py b/flux_sdk/etl/data_models/query.py index 843e0611..0f783d97 100644 --- a/flux_sdk/etl/data_models/query.py +++ b/flux_sdk/etl/data_models/query.py @@ -34,7 +34,7 @@ class SQLQuery: def __post_init__(self): """Perform validation.""" - check_field(self, "text", str, required=True) + check_field(self, "text", str, required=True, block_list_regex=r'\b(?:INSERT\s+INTO|UPDATE|DELETE\s+FROM|DROP|ALTER|CREATE)\b') check_field(self, "args", dict[str, SQLQueryArg]) diff --git a/flux_sdk/etl/data_models/tests/test_query.py b/flux_sdk/etl/data_models/tests/test_query.py index 80f7a2ec..28c70970 100644 --- a/flux_sdk/etl/data_models/tests/test_query.py +++ b/flux_sdk/etl/data_models/tests/test_query.py @@ -24,6 +24,11 @@ def test_validate_args_wrong_type(self): with self.assertRaises(TypeError): SQLQuery(text="select column from table", args=value) + def test_validate_text_contains_write_operation(self): + for value in ["insert into table", "update table", "delete from table", "drop table", "alter table", "create table"]: + with self.assertRaises(ValueError): + SQLQuery(text=value) + def test_validate_success_minimal(self): SQLQuery(text="select column from table") diff --git a/flux_sdk/flux_core/validation.py b/flux_sdk/flux_core/validation.py index a48ae6e5..4a3676f0 100644 --- a/flux_sdk/flux_core/validation.py +++ b/flux_sdk/flux_core/validation.py @@ -1,7 +1,8 @@ +import re from typing import Any, Type, Union, get_args, get_origin -def check_field(obj: Any, attr: str, desired_type: Type, required: bool = False): +def check_field(obj: Any, attr: str, desired_type: Type, required: bool = False, block_list_regex: str = None): value = getattr(obj, attr) if required and not value: @@ -10,6 +11,9 @@ def check_field(obj: Any, attr: str, desired_type: Type, required: bool = False) if value: _check_type(value, attr, desired_type) + if block_list_regex and value and re.search(block_list_regex, value, re.IGNORECASE): + raise ValueError(f"{attr} is not allowed") + def _check_type(value: Any, attr: str, desired_type: Type): origin = get_origin(desired_type)