diff --git a/extensions/positron-python/python_files/posit/pinned-test-requirements.txt b/extensions/positron-python/python_files/posit/pinned-test-requirements.txt index e3296784a2ee..de9a092548e1 100644 --- a/extensions/positron-python/python_files/posit/pinned-test-requirements.txt +++ b/extensions/positron-python/python_files/posit/pinned-test-requirements.txt @@ -44,6 +44,7 @@ pyarrow==21.0.0; python_version < '3.14' pytest==8.4.2 pytest-asyncio==1.2.0 pytest-mock==3.15.1 +redshift_connector==2.1.10; python_version < '3.14' syrupy==4.9.1; python_version == '3.9' syrupy==5.0.0; python_version >= '3.10' torch==2.8.0; python_version == '3.9' diff --git a/extensions/positron-python/python_files/posit/positron/connections.py b/extensions/positron-python/python_files/posit/positron/connections.py index 9931b2620468..8060c55d481e 100644 --- a/extensions/positron-python/python_files/posit/positron/connections.py +++ b/extensions/positron-python/python_files/posit/positron/connections.py @@ -301,7 +301,6 @@ def _wrap_connection(self, obj: Any) -> Connection: if not self.object_is_supported(obj): type_name = type(obj).__name__ raise UnsupportedConnectionError(f"Unsupported connection type {type_name}") - if safe_isinstance(obj, "sqlite3", "Connection"): return SQLite3Connection(obj) elif safe_isinstance(obj, "sqlalchemy", "Engine"): @@ -312,6 +311,8 @@ def _wrap_connection(self, obj: Any) -> Connection: return SnowflakeConnection(obj) elif safe_isinstance(obj, "databricks.sql.client", "Connection"): return DatabricksConnection(obj) + elif safe_isinstance(obj, "redshift_connector", "Connection"): + return RedshiftConnection(obj) else: type_name = type(obj).__name__ raise UnsupportedConnectionError(f"Unsupported connection type {type(obj)}") @@ -327,6 +328,7 @@ def object_is_supported(self, obj: Any) -> bool: or safe_isinstance(obj, "duckdb", "DuckDBPyConnection") or safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection") or safe_isinstance(obj, "databricks.sql.client", "Connection") + or safe_isinstance(obj, "redshift_connector", "Connection") ) except Exception as err: logger.error(f"Error checking supported {err}") @@ -1200,3 +1202,178 @@ def _make_code(self) -> str: ")\n" "%connection_show con\n" ) + + +class RedshiftConnection(Connection): + """Support for Redshift connections to databases.""" + + def __init__(self, conn: Any): + self.conn = conn + + try: + # Unfortunatelly there's no public API to get the host, so we access the protected member. + # to at least provide some info in the connection display name. + host, _ = conn._usock.getpeername() # noqa: SLF001 + except AttributeError: + host = "" + + self.host = str(host) + + self.display_name = f"Redshift ({self.host})" + self.type = "Redshift" + self.code = self._make_code() + + self.icon = "" + + def disconnect(self): + with contextlib.suppress(Exception): + self.conn.close() + + def list_object_types(self): + return { + "database": ConnectionObjectInfo({"contains": None, "icon": None}), + "schema": ConnectionObjectInfo({"contains": None, "icon": None}), + "table": ConnectionObjectInfo({"contains": "data", "icon": None}), + "view": ConnectionObjectInfo({"contains": "data", "icon": None}), + } + + def list_objects(self, path: list[ObjectSchema]): + if len(path) == 0: + rows = self._query("SHOW DATABASES;") + return [ + ConnectionObject({"name": row["database_name"], "kind": "database"}) for row in rows + ] + + if len(path) == 1: + database = path[0] + if database.kind != "database": + raise ValueError("Expected database on path position 0.", f"Path: {path}") + database_ident = self._qualify(database.name) + rows = self._query(f"SHOW SCHEMAS FROM DATABASE {database_ident};") + return [ + ConnectionObject( + { + "name": row["schema_name"], + "kind": "schema", + } + ) + for row in rows + ] + + if len(path) == 2: + database, schema = path + if database.kind != "database" or schema.kind != "schema": + raise ValueError( + "Expected database and schema objects at positions 0 and 1.", f"Path: {path}" + ) + location = f"{self._qualify(database.name)}.{self._qualify(schema.name)}" + tables = self._query(f"SHOW TABLES FROM SCHEMA {location};") + return [ + ConnectionObject( + { + "name": row["table_name"], + "kind": row["table_type"].lower(), + } + ) + for row in tables + ] + + raise ValueError(f"Path length must be at most 2, but got {len(path)}. Path: {path}") + + def list_fields(self, path: list[ObjectSchema]): + if len(path) != 3: + raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}") + + database, schema, table = path + if ( + database.kind != "database" + or schema.kind != "schema" + or table.kind not in ("table", "view") + ): + raise ValueError( + "Expected database, schema, and table/view kinds in the path.", + f"Path: {path}", + ) + + identifier = ".".join( + [self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)] + ) + rows = self._query(f"SHOW COLUMNS FROM TABLE {identifier};") + return [ + ConnectionObjectFields( + { + "name": row["column_name"], + "dtype": row["data_type"], + } + ) + for row in rows + ] + + def preview_object(self, path: list[ObjectSchema], var_name: str | None = None): + if len(path) != 3: + raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}") + + database, schema, table = path + if ( + database.kind != "database" + or schema.kind != "schema" + or table.kind not in ("table", "view") + ): + raise ValueError( + "Expected database, schema, and table/view kinds in the path.", + f"Path: {path}", + ) + + identifier = ".".join( + [self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)] + ) + sql = f"SELECT * FROM {identifier} LIMIT 1000;" + + with self.conn.cursor() as cursor: + try: + cursor.execute(sql) + frame = cursor.fetch_dataframe() + except Exception as e: + # Rollback on error to avoid transaction issues + # for subsequent queries + self.conn.rollback() + raise e + + var_name = var_name or "conn" + return frame, ( + f"with {var_name}.cursor() as cursor:\n" + f" cursor.execute({sql!r})\n" + f" {table.name} = cursor.fetch_dataframe()" + ) + + def _query(self, sql: str) -> list[dict[str, Any]]: + cursor = self.conn.cursor() + try: + cursor.execute(sql) + rows = cursor.fetchall() + description = cursor.description or [] + columns = [col[0] for col in description] + return [dict(zip(columns, row)) for row in rows] + except Exception as e: + # Rollback on error to avoid transaction issues + # for subsequent queries + self.conn.rollback() + raise e + finally: + cursor.close() + + def _qualify(self, identifier: str) -> str: + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + + def _make_code(self) -> str: + return ( + "# Requires redshift-connector package\n" + "# Authentication steps may be incomplete, adjust as needed.\n" + "import redshift_connector\n" + "con = redshift_connector.connect(\n" + f" iam = True,\n" + f" host = '{self.host}',\n" + ")\n" + "%connection_show con\n" + ) diff --git a/extensions/positron-python/python_files/posit/positron/tests/test_connections.py b/extensions/positron-python/python_files/posit/positron/tests/test_connections.py index 37ab01e5975e..e2ec028c0de9 100644 --- a/extensions/positron-python/python_files/posit/positron/tests/test_connections.py +++ b/extensions/positron-python/python_files/posit/positron/tests/test_connections.py @@ -43,6 +43,14 @@ except ImportError: HAS_DATABRICKS = False + +try: + import redshift_connector + + HAS_REDSHIFT = "REDSHIFT_HOST" in os.environ +except ImportError: + HAS_REDSHIFT = False + from positron.access_keys import encode_access_key from positron.connections import ConnectionsService @@ -706,3 +714,135 @@ def _view_in_connections_pane(self, variables_comm: DummyComm, path): assert variables_comm.messages == [json_rpc_response({})] variables_comm.messages.clear() return tuple(encoded_paths) + + +@pytest.mark.skipif(not HAS_REDSHIFT, reason="Redshift not available") +class TestRedshiftConnectionsService: + REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST") + REDSHIFT_PROFILE = os.environ.get("REDSHIFT_PROFILE", "default") + REDSHIFT_DATABASE = "dev" + REDSHIFT_SCHEMA = "public" + REDSHIFT_TABLE = "airlines" + + def _connect(self): + return redshift_connector.connect( + iam=True, + host=self.REDSHIFT_HOST, + database=self.REDSHIFT_DATABASE, + profile=self.REDSHIFT_PROFILE, + ) + + def _open_comm(self, connections_service: ConnectionsService): + con = self._connect() + comm_id = connections_service.register_connection(con) + dummy_comm = DummyComm(TARGET_NAME, comm_id=comm_id) + connections_service.on_comm_open(dummy_comm) + dummy_comm.messages.clear() + return dummy_comm, comm_id + + def _database_path(self): + return [{"kind": "database", "name": self.REDSHIFT_DATABASE}] + + def _schema_path(self): + return [*self._database_path(), {"kind": "schema", "name": self.REDSHIFT_SCHEMA}] + + def _table_path(self): + return [*self._schema_path(), {"kind": "table", "name": self.REDSHIFT_TABLE}] + + def _resolve_path(self, kind: str): + if kind == "root": + return [] + if kind == "database": + return self._database_path() + if kind == "schema": + return self._schema_path() + if kind == "table": + return self._table_path() + raise ValueError(f"Unknown path kind: {kind}") + + def test_register_connection(self, connections_service: ConnectionsService): + con = self._connect() + comm_id = connections_service.register_connection(con) + assert comm_id in connections_service.comms + + @pytest.mark.parametrize( + ("path_kind"), + [ + pytest.param("root", id="root"), + pytest.param("database", id="database"), + pytest.param("schema", id="schema"), + pytest.param("table", id="table"), + ], + ) + def test_contains_data(self, connections_service: ConnectionsService, path_kind: str): + dummy_comm, comm_id = self._open_comm(connections_service) + path = self._resolve_path(path_kind) + + msg = _make_msg(params={"path": path}, method="contains_data", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + assert result == (path_kind == "table") + + @pytest.mark.parametrize( + ("path_kind", "expected"), + [ + pytest.param("root", "data:image", id="root"), + pytest.param("database", "", id="database"), + pytest.param("schema", "", id="schema"), + pytest.param("table", "", id="table"), + ], + ) + def test_get_icon(self, connections_service: ConnectionsService, path_kind: str, expected: str): + dummy_comm, comm_id = self._open_comm(connections_service) + path = self._resolve_path(path_kind) + + msg = _make_msg(params={"path": path}, method="get_icon", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + if expected: + assert expected in result + else: + assert result == "" + + @pytest.mark.parametrize( + "path_kind", + [ + pytest.param("root", id="databases"), + pytest.param("database", id="schemas"), + pytest.param("schema", id="tables"), + ], + ) + def test_list_objects(self, connections_service: ConnectionsService, path_kind: str): + dummy_comm, comm_id = self._open_comm(connections_service) + path = self._resolve_path(path_kind) + expected = { + "root": self.REDSHIFT_DATABASE, + "database": self.REDSHIFT_SCHEMA, + "schema": self.REDSHIFT_TABLE, + }[path_kind] + + msg = _make_msg(params={"path": path}, method="list_objects", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + names = [item["name"] for item in result] + assert expected in names + + def test_list_fields(self, connections_service: ConnectionsService): + dummy_comm, comm_id = self._open_comm(connections_service) + path = self._table_path() + + msg = _make_msg(params={"path": path}, method="list_fields", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + field_names = {field["name"].lower() for field in result} + assert {"carrier", "name"}.issubset(field_names) + + def test_preview_object(self, connections_service: ConnectionsService): + dummy_comm, comm_id = self._open_comm(connections_service) + path = self._table_path() + + msg = _make_msg(params={"path": path}, method="preview_object", comm_id=comm_id) + dummy_comm.handle_msg(msg) + connections_service._kernel.data_explorer_service.shutdown() # noqa: SLF001 + result = dummy_comm.messages[0]["data"]["result"] + assert result is None diff --git a/extensions/positron-python/python_files/posit/test-requirements.txt b/extensions/positron-python/python_files/posit/test-requirements.txt index ce3a6a124995..82d225c0c2ed 100644 --- a/extensions/positron-python/python_files/posit/test-requirements.txt +++ b/extensions/positron-python/python_files/posit/test-requirements.txt @@ -23,6 +23,7 @@ pytest pytest-asyncio pytest-mock syrupy +redshift_connector torch scipy snowflake-connector-python; python_version < '3.14'