Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent write operations in the SQL query #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flux_sdk/etl/data_models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SQLQuery:

def __post_init__(self):
"""Perform validation."""
check_field(self, "text", str, required=True)
check_field(self, "text", str, required=True, block_list_regex=r'\b(?:INSERT\s+INTO|UPDATE|DELETE\s+FROM|DROP|ALTER|CREATE)\b')
check_field(self, "args", dict[str, SQLQueryArg])


Expand Down
5 changes: 5 additions & 0 deletions flux_sdk/etl/data_models/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def test_validate_args_wrong_type(self):
with self.assertRaises(TypeError):
SQLQuery(text="select column from table", args=value)

def test_validate_text_contains_write_operation(self):
for value in ["insert into table", "update table", "delete from table", "drop table", "alter table", "create table"]:
with self.assertRaises(ValueError):
SQLQuery(text=value)

def test_validate_success_minimal(self):
SQLQuery(text="select column from table")

Expand Down
6 changes: 5 additions & 1 deletion flux_sdk/flux_core/validation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from typing import Any, Type, Union, get_args, get_origin


def check_field(obj: Any, attr: str, desired_type: Type, required: bool = False):
def check_field(obj: Any, attr: str, desired_type: Type, required: bool = False, block_list_regex: str = None):
value = getattr(obj, attr)

if required and not value:
Expand All @@ -10,6 +11,9 @@ def check_field(obj: Any, attr: str, desired_type: Type, required: bool = False)
if value:
_check_type(value, attr, desired_type)

if block_list_regex and value and re.search(block_list_regex, value, re.IGNORECASE):
raise ValueError(f"{attr} is not allowed")


def _check_type(value: Any, attr: str, desired_type: Type):
origin = get_origin(desired_type)
Expand Down
Loading