diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index 64a85ef434..3b1eb95adb 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2025-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional, Union import polars as pl from haystack import component, default_from_dict, default_to_dict, logging @@ -20,8 +20,9 @@ class SnowflakeTableRetriever: For more information, see [Polars documentation](https://docs.pola.rs/api/python/dev/reference/api/polars.read_database_uri.html). and [ADBC documentation](https://arrow.apache.org/adbc/main/driver/snowflake.html). - ### Usage example: + ### Usage examples: + #### Password Authentication (default): ```python executor = SnowflakeTableRetriever( user="", @@ -31,7 +32,39 @@ class SnowflakeTableRetriever: db_schema="", warehouse="", ) + ``` + #### Key-pair Authentication (MFA): + ```python + executor = SnowflakeTableRetriever( + user="", + account="", + authenticator="SNOWFLAKE_JWT", + private_key_file=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE"), + private_key_file_pwd=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD"), + database="", + db_schema="", + warehouse="", + ) + ``` + + #### OAuth Authentication (MFA): + ```python + executor = SnowflakeTableRetriever( + user="", + account="", + authenticator="OAUTH", + oauth_client_id=Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID"), + oauth_client_secret=Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET"), + oauth_token_request_url="", + database="", + db_schema="", + warehouse="", + ) + ``` + + #### Running queries: + ```python query = "SELECT * FROM table_name" results = executor.run(query=query) @@ -58,22 +91,41 @@ def __init__( self, user: str, account: str, - api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 + api_key: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 database: Optional[str] = None, db_schema: Optional[str] = None, warehouse: Optional[str] = None, login_timeout: Optional[int] = 60, return_markdown: bool = True, + authenticator: Optional[Literal["SNOWFLAKE", "SNOWFLAKE_JWT", "OAUTH"]] = None, + private_key_file: Optional[Union[str, Secret]] = None, + private_key_file_pwd: Optional[Union[str, Secret]] = None, + oauth_client_id: Optional[Union[str, Secret]] = None, + oauth_client_secret: Optional[Union[str, Secret]] = None, + oauth_token_request_url: Optional[str] = None, + oauth_authorization_url: Optional[str] = None, ) -> None: """ :param user: User's login. :param account: Snowflake account identifier. - :param api_key: Snowflake account password. + :param api_key: Snowflake account password. Required for default password authentication. :param database: Name of the database to use. :param db_schema: Name of the schema to use. :param warehouse: Name of the warehouse to use. :param login_timeout: Timeout in seconds for login. :param return_markdown: Whether to return a Markdown-formatted string of the DataFrame. + :param authenticator: Authentication method. Options: "SNOWFLAKE" (default password), + "SNOWFLAKE_JWT" (key-pair), or "OAUTH". + :param private_key_file: Path to private key file or Secret containing the path. + Required for SNOWFLAKE_JWT authentication. + :param private_key_file_pwd: Passphrase for private key file or Secret containing the passphrase. + Required for SNOWFLAKE_JWT authentication. + :param oauth_client_id: OAuth client ID or Secret containing the client ID. + Required for OAUTH authentication. + :param oauth_client_secret: OAuth client secret or Secret containing the client secret. + Required for OAUTH authentication. + :param oauth_token_request_url: OAuth token request URL for Client Credentials flow. + :param oauth_authorization_url: OAuth authorization URL for Authorization Code flow. """ self.user = user @@ -85,6 +137,66 @@ def __init__( self.login_timeout = login_timeout or 60 self.return_markdown = return_markdown + # Authentication parameters + self.authenticator = authenticator or "SNOWFLAKE" + self.private_key_file = private_key_file + self.private_key_file_pwd = private_key_file_pwd + self.oauth_client_id = oauth_client_id + self.oauth_client_secret = oauth_client_secret + self.oauth_token_request_url = oauth_token_request_url + self.oauth_authorization_url = oauth_authorization_url + + # Validate authentication parameters + self._validate_auth_params() + + def _validate_auth_params(self) -> None: + """ + Validates authentication parameters based on the chosen authentication method. + + :raises ValueError: If required parameters are missing for the selected authentication method. + """ + if self.authenticator == "SNOWFLAKE_JWT": + if not self.private_key_file: + msg = "private_key_file is required for SNOWFLAKE_JWT authentication" + raise ValueError(msg) + if not self.private_key_file_pwd: + msg = "private_key_file_pwd is required for SNOWFLAKE_JWT authentication" + raise ValueError(msg) + elif self.authenticator == "OAUTH": + if not self.oauth_client_id: + msg = "oauth_client_id is required for OAUTH authentication" + raise ValueError(msg) + if not self.oauth_client_secret: + msg = "oauth_client_secret is required for OAUTH authentication" + raise ValueError(msg) + elif self.authenticator == "SNOWFLAKE": + if not self.api_key: + msg = "api_key is required for SNOWFLAKE (password) authentication" + raise ValueError(msg) + try: + api_key_value = self.api_key.resolve_value() + if not api_key_value: + msg = "api_key is required for SNOWFLAKE (password) authentication" + raise ValueError(msg) + except ValueError as e: + if "authentication environment variables are set" in str(e): + msg = "api_key is required for SNOWFLAKE (password) authentication" + raise ValueError(msg) from e + raise + + def _resolve_secret_value(self, value: Optional[Union[str, Secret]]) -> Optional[str]: + """ + Resolves a Secret value or returns the string value. + + :param value: String or Secret to resolve. + :returns: Resolved string value or None. + """ + if value is None: + return None + if isinstance(value, Secret): + return value.resolve_value() + return value + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -92,17 +204,44 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( # type: ignore - self, - user=self.user, - account=self.account, - api_key=self.api_key.to_dict(), - database=self.database, - db_schema=self.db_schema, - warehouse=self.warehouse, - login_timeout=self.login_timeout, - return_markdown=self.return_markdown, - ) + data: Dict[str, Any] = { + "user": self.user, + "account": self.account, + "database": self.database, + "db_schema": self.db_schema, + "warehouse": self.warehouse, + "login_timeout": self.login_timeout, + "return_markdown": self.return_markdown, + "authenticator": self.authenticator, + "oauth_token_request_url": self.oauth_token_request_url, + "oauth_authorization_url": self.oauth_authorization_url, + } + + # Handle Secret fields + if self.api_key: + data["api_key"] = self.api_key.to_dict() + if self.private_key_file: + data["private_key_file"] = ( + self.private_key_file.to_dict() if isinstance(self.private_key_file, Secret) else self.private_key_file + ) + if self.private_key_file_pwd: + data["private_key_file_pwd"] = ( + self.private_key_file_pwd.to_dict() + if isinstance(self.private_key_file_pwd, Secret) + else self.private_key_file_pwd + ) + if self.oauth_client_id: + data["oauth_client_id"] = ( + self.oauth_client_id.to_dict() if isinstance(self.oauth_client_id, Secret) else self.oauth_client_id + ) + if self.oauth_client_secret: + data["oauth_client_secret"] = ( + self.oauth_client_secret.to_dict() + if isinstance(self.oauth_client_secret, Secret) + else self.oauth_client_secret + ) + + return default_to_dict(self, **data) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeTableRetriever": @@ -115,40 +254,110 @@ def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeTableRetriever": Deserialized component. """ init_params = data.get("init_parameters", {}) - deserialize_secrets_inplace(init_params, ["api_key"]) + secret_fields = [ + "api_key", + "private_key_file", + "private_key_file_pwd", + "oauth_client_id", + "oauth_client_secret", + ] + deserialize_secrets_inplace(init_params, secret_fields) return default_from_dict(cls, data) def _snowflake_uri_constructor(self) -> str: """ - Constructs the Snowflake connection URI. + Constructs the Snowflake connection URI based on the authentication method. - Format: "snowflake://user:password@account/database/schema?warehouse=warehouse" + Formats: + - Password: "snowflake://user:password@account/database/schema?warehouse=warehouse" + - Key-pair JWT: "snowflake://user@account/database/schema?warehouse=warehouse&authenticator=SNOWFLAKE_JWT&private_key_file=path&private_key_file_pwd=pwd" + - OAuth: "snowflake://user@account/database/schema?warehouse=warehouse&authenticator=OAUTH&oauth_client_id=id&oauth_client_secret=secret" - :raises ValueError: If required credentials (`user` or `account`) are missing. + :raises ValueError: If required credentials are missing. :returns: A formatted Snowflake connection URI. """ if not self.user or not self.account: msg = "Missing required Snowflake connection parameters: user and account." raise ValueError(msg) - uri = f"snowflake://{self.user}:{self.api_key.resolve_value()}@{self.account}" + # Base URI construction + if self.authenticator == "SNOWFLAKE" and self.api_key: + # Traditional password authentication + uri = f"snowflake://{self.user}:{self.api_key.resolve_value()}@{self.account}" + else: + # MFA authentication methods (no password in URI) + uri = f"snowflake://{self.user}@{self.account}" + + # Add database and schema if self.database: uri += f"/{self.database}" if self.db_schema: uri += f"/{self.db_schema}" - uri += "?" - if self.warehouse: - uri += f"warehouse={self.warehouse}&" - uri += f"login_timeout={self.login_timeout}&" - uri = uri.rstrip("&?") - # Logging placeholder for the actual password - masked_uri = uri - if resolved_api_key := self.api_key.resolve_value(): - masked_uri = uri.replace(resolved_api_key, "***REDACTED***") + # Add query parameters + params = [] + if self.warehouse: + params.append(f"warehouse={self.warehouse}") + params.append(f"login_timeout={self.login_timeout}") + + # Add authentication-specific parameters + if self.authenticator == "SNOWFLAKE_JWT": + params.append(f"authenticator={self.authenticator}") + if self.private_key_file: + private_key_path = self._resolve_secret_value(self.private_key_file) + params.append(f"private_key_file={private_key_path}") + if self.private_key_file_pwd: + private_key_pwd = self._resolve_secret_value(self.private_key_file_pwd) + params.append(f"private_key_file_pwd={private_key_pwd}") + elif self.authenticator == "OAUTH": + params.append(f"authenticator={self.authenticator}") + if self.oauth_client_id: + client_id = self._resolve_secret_value(self.oauth_client_id) + params.append(f"oauth_client_id={client_id}") + if self.oauth_client_secret: + client_secret = self._resolve_secret_value(self.oauth_client_secret) + params.append(f"oauth_client_secret={client_secret}") + if self.oauth_token_request_url: + params.append(f"oauth_token_request_url={self.oauth_token_request_url}") + if self.oauth_authorization_url: + params.append(f"oauth_authorization_url={self.oauth_authorization_url}") + + if params: + uri += "?" + "&".join(params) + + # Create masked URI for logging + masked_uri = self._create_masked_uri(uri) logger.info("Constructed Snowflake URI: {masked_uri}", masked_uri=masked_uri) return uri + def _create_masked_uri(self, uri: str) -> str: + """ + Creates a masked version of the URI for safe logging. + + :param uri: Original URI. + :returns: URI with sensitive information masked. + """ + masked_uri = uri + + # Mask password if present + if self.api_key and self.authenticator == "SNOWFLAKE": + if resolved_api_key := self.api_key.resolve_value(): + masked_uri = masked_uri.replace(resolved_api_key, "***REDACTED***") + + # Mask private key password + if self.private_key_file_pwd: + private_key_pwd = self._resolve_secret_value(self.private_key_file_pwd) + if private_key_pwd: + masked_uri = masked_uri.replace(private_key_pwd, "***REDACTED***") + + # Mask OAuth client secret + if self.oauth_client_secret: + client_secret = self._resolve_secret_value(self.oauth_client_secret) + if client_secret: + masked_uri = masked_uri.replace(client_secret, "***REDACTED***") + + return masked_uri + @staticmethod def _polars_to_md(data: pl.DataFrame) -> str: """ diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index 744d6ca163..d61b98cd83 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -49,6 +49,44 @@ def expected_markdown() -> str: return "shape: (3, 2)\n| col1 | col2 |\n|------|------|\n| 1 | A |\n| 2 | B |\n| 3 | C |" +@pytest.fixture +def jwt_retriever(mocker: Mock) -> SnowflakeTableRetriever: + mocker.patch.dict( + os.environ, {"SNOWFLAKE_PRIVATE_KEY_FILE": "/path/to/key.pem", "SNOWFLAKE_PRIVATE_KEY_PWD": "test_password"} + ) + return SnowflakeTableRetriever( + user="test_user", + account="test_account", + authenticator="SNOWFLAKE_JWT", + private_key_file=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE"), + private_key_file_pwd=Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD"), + database="test_db", + db_schema="test_schema", + warehouse="test_warehouse", + return_markdown=True, + ) + + +@pytest.fixture +def oauth_retriever(mocker: Mock) -> SnowflakeTableRetriever: + mocker.patch.dict( + os.environ, + {"SNOWFLAKE_OAUTH_CLIENT_ID": "test_client_id", "SNOWFLAKE_OAUTH_CLIENT_SECRET": "test_client_secret"}, + ) + return SnowflakeTableRetriever( + user="test_user", + account="test_account", + authenticator="OAUTH", + oauth_client_id=Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID"), + oauth_client_secret=Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET"), + oauth_token_request_url="https://test.snowflakecomputing.com/oauth/token-request", + database="test_db", + db_schema="test_schema", + warehouse="test_warehouse", + return_markdown=True, + ) + + class TestSnowflakeTableRetriever: def test_init_and_serialization(self, retriever: SnowflakeTableRetriever) -> None: serialized = retriever.to_dict() @@ -345,3 +383,140 @@ def test_custom_login_timeout(self, mocker: Mock) -> None: uri = retriever._snowflake_uri_constructor() expected_uri = f"snowflake://test_user:test_api_key@test_account/test_db?login_timeout={custom_timeout}" assert uri == expected_uri + + def test_jwt_authentication_serialization(self, jwt_retriever: SnowflakeTableRetriever) -> None: + serialized = jwt_retriever.to_dict() + init_params = serialized["init_parameters"] + + assert init_params["authenticator"] == "SNOWFLAKE_JWT" + assert "private_key_file" in init_params + assert "private_key_file_pwd" in init_params + + deserialized = SnowflakeTableRetriever.from_dict(serialized) + assert isinstance(deserialized, SnowflakeTableRetriever) + assert deserialized.authenticator == "SNOWFLAKE_JWT" + + def test_oauth_authentication_serialization(self, oauth_retriever: SnowflakeTableRetriever) -> None: + serialized = oauth_retriever.to_dict() + init_params = serialized["init_parameters"] + + assert init_params["authenticator"] == "OAUTH" + assert "oauth_client_id" in init_params + assert "oauth_client_secret" in init_params + assert init_params["oauth_token_request_url"] == "https://test.snowflakecomputing.com/oauth/token-request" + + deserialized = SnowflakeTableRetriever.from_dict(serialized) + assert isinstance(deserialized, SnowflakeTableRetriever) + assert deserialized.authenticator == "OAUTH" + + def test_jwt_uri_construction(self, jwt_retriever: SnowflakeTableRetriever) -> None: + uri = jwt_retriever._snowflake_uri_constructor() + expected_uri = "snowflake://test_user@test_account/test_db/test_schema?warehouse=test_warehouse&login_timeout=60&authenticator=SNOWFLAKE_JWT&private_key_file=/path/to/key.pem&private_key_file_pwd=test_password" + assert uri == expected_uri + + def test_oauth_uri_construction(self, oauth_retriever: SnowflakeTableRetriever) -> None: + uri = oauth_retriever._snowflake_uri_constructor() + expected_uri = "snowflake://test_user@test_account/test_db/test_schema?warehouse=test_warehouse&login_timeout=60&authenticator=OAUTH&oauth_client_id=test_client_id&oauth_client_secret=test_client_secret&oauth_token_request_url=https://test.snowflakecomputing.com/oauth/token-request" + assert uri == expected_uri + + def test_masked_uri_logging_jwt(self, jwt_retriever: SnowflakeTableRetriever) -> None: + uri = jwt_retriever._snowflake_uri_constructor() + masked_uri = jwt_retriever._create_masked_uri(uri) + + assert "test_password" not in masked_uri + assert "***REDACTED***" in masked_uri + + def test_masked_uri_logging_oauth(self, oauth_retriever: SnowflakeTableRetriever) -> None: + uri = oauth_retriever._snowflake_uri_constructor() + masked_uri = oauth_retriever._create_masked_uri(uri) + + assert "test_client_secret" not in masked_uri + assert "***REDACTED***" in masked_uri + + @pytest.mark.parametrize( + "authenticator, missing_param, expected_error", + [ + ("SNOWFLAKE_JWT", "private_key_file", "private_key_file is required for SNOWFLAKE_JWT authentication"), + ( + "SNOWFLAKE_JWT", + "private_key_file_pwd", + "private_key_file_pwd is required for SNOWFLAKE_JWT authentication", + ), + ("OAUTH", "oauth_client_id", "oauth_client_id is required for OAUTH authentication"), + ("OAUTH", "oauth_client_secret", "oauth_client_secret is required for OAUTH authentication"), + ("SNOWFLAKE", "api_key", "api_key is required for SNOWFLAKE \\(password\\) authentication"), + ], + ) + def test_authentication_validation_errors( + self, mocker: Mock, authenticator: str, missing_param: str, expected_error: str + ) -> None: + # Set up environment variables, excluding the one being tested as missing + env_vars = { + "SNOWFLAKE_PRIVATE_KEY_FILE": "/path/to/key.pem", + "SNOWFLAKE_PRIVATE_KEY_PWD": "test_password", + "SNOWFLAKE_OAUTH_CLIENT_ID": "test_client_id", + "SNOWFLAKE_OAUTH_CLIENT_SECRET": "test_client_secret", + } + + # Only set SNOWFLAKE_API_KEY if we're not testing its absence + if not (authenticator == "SNOWFLAKE" and missing_param == "api_key"): + env_vars["SNOWFLAKE_API_KEY"] = "test_api_key" + + mocker.patch.dict(os.environ, env_vars, clear=True) + + kwargs = { + "user": "test_user", + "account": "test_account", + "authenticator": authenticator, + } + + if authenticator == "SNOWFLAKE_JWT": + if missing_param != "private_key_file": + kwargs["private_key_file"] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE") + if missing_param != "private_key_file_pwd": + kwargs["private_key_file_pwd"] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD") + elif authenticator == "OAUTH": + if missing_param != "oauth_client_id": + kwargs["oauth_client_id"] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID") + if missing_param != "oauth_client_secret": + kwargs["oauth_client_secret"] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET") + elif authenticator == "SNOWFLAKE": + if missing_param != "api_key": + kwargs["api_key"] = Secret.from_env_var("SNOWFLAKE_API_KEY") + + with pytest.raises(ValueError, match=expected_error): + SnowflakeTableRetriever(**kwargs) + + def test_jwt_authentication_happy_path( + self, + jwt_retriever: SnowflakeTableRetriever, + mocker: Mock, + toy_polars_df: pl.DataFrame, + toy_pandas_df: pd.DataFrame, + expected_markdown: str, + ) -> None: + mocker.patch("polars.read_database_uri", return_value=toy_polars_df) + mocker.patch.object(toy_polars_df, "to_pandas", return_value=toy_pandas_df) + mocker.patch.object(SnowflakeTableRetriever, "_polars_to_md", return_value=expected_markdown) + + result = jwt_retriever.run(query="SELECT * FROM table_name") + + assert result["dataframe"].equals(toy_pandas_df) + assert result["table"] == expected_markdown + + def test_oauth_authentication_happy_path( + self, + oauth_retriever: SnowflakeTableRetriever, + mocker: Mock, + toy_polars_df: pl.DataFrame, + toy_pandas_df: pd.DataFrame, + expected_markdown: str, + ) -> None: + mocker.patch("polars.read_database_uri", return_value=toy_polars_df) + mocker.patch.object(toy_polars_df, "to_pandas", return_value=toy_pandas_df) + mocker.patch.object(SnowflakeTableRetriever, "_polars_to_md", return_value=expected_markdown) + + result = oauth_retriever.run(query="SELECT * FROM table_name") + + assert result["dataframe"].equals(toy_pandas_df) + assert result["table"] == expected_markdown