-
Notifications
You must be signed in to change notification settings - Fork 128
Connections Pane: AWS Redshift Support #10639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6d3fbbe
3163088
c64bfaa
034dab9
48e90be
463e90f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -301,7 +301,6 @@ def _wrap_connection(self, obj: Any) -> Connection: | |||||
| if not self.object_is_supported(obj): | ||||||
| type_name = type(obj).__name__ | ||||||
| raise UnsupportedConnectionError(f"Unsupported connection type {type_name}") | ||||||
|
|
||||||
| if safe_isinstance(obj, "sqlite3", "Connection"): | ||||||
| return SQLite3Connection(obj) | ||||||
| elif safe_isinstance(obj, "sqlalchemy", "Engine"): | ||||||
|
|
@@ -312,6 +311,8 @@ def _wrap_connection(self, obj: Any) -> Connection: | |||||
| return SnowflakeConnection(obj) | ||||||
| elif safe_isinstance(obj, "databricks.sql.client", "Connection"): | ||||||
| return DatabricksConnection(obj) | ||||||
| elif safe_isinstance(obj, "redshift_connector", "Connection"): | ||||||
| return RedshiftConnection(obj) | ||||||
| else: | ||||||
| type_name = type(obj).__name__ | ||||||
| raise UnsupportedConnectionError(f"Unsupported connection type {type(obj)}") | ||||||
|
|
@@ -327,6 +328,7 @@ def object_is_supported(self, obj: Any) -> bool: | |||||
| or safe_isinstance(obj, "duckdb", "DuckDBPyConnection") | ||||||
| or safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection") | ||||||
| or safe_isinstance(obj, "databricks.sql.client", "Connection") | ||||||
| or safe_isinstance(obj, "redshift_connector", "Connection") | ||||||
| ) | ||||||
| except Exception as err: | ||||||
| logger.error(f"Error checking supported {err}") | ||||||
|
|
@@ -1200,3 +1202,178 @@ def _make_code(self) -> str: | |||||
| ")\n" | ||||||
| "%connection_show con\n" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class RedshiftConnection(Connection): | ||||||
| """Support for Redshift connections to databases.""" | ||||||
|
|
||||||
| def __init__(self, conn: Any): | ||||||
| self.conn = conn | ||||||
|
|
||||||
| try: | ||||||
| # Unfortunatelly there's no public API to get the host, so we access the protected member. | ||||||
| # to at least provide some info in the connection display name. | ||||||
| host, _ = conn._usock.getpeername() # noqa: SLF001 | ||||||
| except AttributeError: | ||||||
| host = "<unknown>" | ||||||
|
|
||||||
| self.host = str(host) | ||||||
|
|
||||||
| self.display_name = f"Redshift ({self.host})" | ||||||
| self.type = "Redshift" | ||||||
| self.code = self._make_code() | ||||||
|
|
||||||
| self.icon = "" | ||||||
|
|
||||||
| def disconnect(self): | ||||||
| with contextlib.suppress(Exception): | ||||||
| self.conn.close() | ||||||
|
|
||||||
| def list_object_types(self): | ||||||
| return { | ||||||
| "database": ConnectionObjectInfo({"contains": None, "icon": None}), | ||||||
| "schema": ConnectionObjectInfo({"contains": None, "icon": None}), | ||||||
| "table": ConnectionObjectInfo({"contains": "data", "icon": None}), | ||||||
| "view": ConnectionObjectInfo({"contains": "data", "icon": None}), | ||||||
| } | ||||||
|
|
||||||
| def list_objects(self, path: list[ObjectSchema]): | ||||||
| if len(path) == 0: | ||||||
| rows = self._query("SHOW DATABASES;") | ||||||
| return [ | ||||||
| ConnectionObject({"name": row["database_name"], "kind": "database"}) for row in rows | ||||||
| ] | ||||||
|
|
||||||
| if len(path) == 1: | ||||||
| database = path[0] | ||||||
| if database.kind != "database": | ||||||
| raise ValueError("Expected database on path position 0.", f"Path: {path}") | ||||||
| database_ident = self._qualify(database.name) | ||||||
| rows = self._query(f"SHOW SCHEMAS FROM DATABASE {database_ident};") | ||||||
| return [ | ||||||
| ConnectionObject( | ||||||
| { | ||||||
| "name": row["schema_name"], | ||||||
| "kind": "schema", | ||||||
| } | ||||||
| ) | ||||||
| for row in rows | ||||||
| ] | ||||||
|
|
||||||
| if len(path) == 2: | ||||||
| database, schema = path | ||||||
| if database.kind != "database" or schema.kind != "schema": | ||||||
| raise ValueError( | ||||||
| "Expected database and schema objects at positions 0 and 1.", f"Path: {path}" | ||||||
| ) | ||||||
| location = f"{self._qualify(database.name)}.{self._qualify(schema.name)}" | ||||||
| tables = self._query(f"SHOW TABLES FROM SCHEMA {location};") | ||||||
| return [ | ||||||
| ConnectionObject( | ||||||
| { | ||||||
| "name": row["table_name"], | ||||||
| "kind": row["table_type"].lower(), | ||||||
| } | ||||||
| ) | ||||||
| for row in tables | ||||||
| ] | ||||||
|
|
||||||
| raise ValueError(f"Path length must be at most 2, but got {len(path)}. Path: {path}") | ||||||
|
|
||||||
| def list_fields(self, path: list[ObjectSchema]): | ||||||
| if len(path) != 3: | ||||||
| raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}") | ||||||
|
|
||||||
| database, schema, table = path | ||||||
| if ( | ||||||
| database.kind != "database" | ||||||
| or schema.kind != "schema" | ||||||
| or table.kind not in ("table", "view") | ||||||
| ): | ||||||
| raise ValueError( | ||||||
| "Expected database, schema, and table/view kinds in the path.", | ||||||
| f"Path: {path}", | ||||||
| ) | ||||||
|
|
||||||
| identifier = ".".join( | ||||||
| [self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)] | ||||||
| ) | ||||||
| rows = self._query(f"SHOW COLUMNS FROM TABLE {identifier};") | ||||||
| return [ | ||||||
| ConnectionObjectFields( | ||||||
| { | ||||||
| "name": row["column_name"], | ||||||
| "dtype": row["data_type"], | ||||||
| } | ||||||
| ) | ||||||
| for row in rows | ||||||
| ] | ||||||
|
|
||||||
| def preview_object(self, path: list[ObjectSchema], var_name: str | None = None): | ||||||
| if len(path) != 3: | ||||||
| raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}") | ||||||
|
|
||||||
| database, schema, table = path | ||||||
| if ( | ||||||
| database.kind != "database" | ||||||
| or schema.kind != "schema" | ||||||
| or table.kind not in ("table", "view") | ||||||
| ): | ||||||
| raise ValueError( | ||||||
| "Expected database, schema, and table/view kinds in the path.", | ||||||
| f"Path: {path}", | ||||||
| ) | ||||||
|
|
||||||
| identifier = ".".join( | ||||||
| [self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)] | ||||||
| ) | ||||||
| sql = f"SELECT * FROM {identifier} LIMIT 1000;" | ||||||
|
|
||||||
| with self.conn.cursor() as cursor: | ||||||
| try: | ||||||
| cursor.execute(sql) | ||||||
| frame = cursor.fetch_dataframe() | ||||||
| except Exception as e: | ||||||
| # Rollback on error to avoid transaction issues | ||||||
| # for subsequent queries | ||||||
| self.conn.rollback() | ||||||
| raise e | ||||||
|
||||||
|
|
||||||
| var_name = var_name or "conn" | ||||||
| return frame, ( | ||||||
| f"with {var_name}.cursor() as cursor:\n" | ||||||
| f" cursor.execute({sql!r})\n" | ||||||
| f" {table.name} = cursor.fetch_dataframe()" | ||||||
| ) | ||||||
|
|
||||||
| def _query(self, sql: str) -> list[dict[str, Any]]: | ||||||
| cursor = self.conn.cursor() | ||||||
| try: | ||||||
| cursor.execute(sql) | ||||||
| rows = cursor.fetchall() | ||||||
| description = cursor.description or [] | ||||||
| columns = [col[0] for col in description] | ||||||
| return [dict(zip(columns, row)) for row in rows] | ||||||
| except Exception as e: | ||||||
| # Rollback on error to avoid transaction issues | ||||||
| # for subsequent queries | ||||||
| self.conn.rollback() | ||||||
| raise e | ||||||
|
||||||
| finally: | ||||||
| cursor.close() | ||||||
|
|
||||||
| def _qualify(self, identifier: str) -> str: | ||||||
| escaped = identifier.replace('"', '""') | ||||||
| return f'"{escaped}"' | ||||||
|
|
||||||
| def _make_code(self) -> str: | ||||||
| return ( | ||||||
| "# Requires redshift-connector package\n" | ||||||
| "# Authentication steps may be incomplete, adjust as needed.\n" | ||||||
| "import redshift_connector\n" | ||||||
| "con = redshift_connector.connect(\n" | ||||||
| f" iam = True,\n" | ||||||
|
||||||
| f" iam = True,\n" | |
| " iam = True,\n" |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -43,6 +43,14 @@ | |||||
| except ImportError: | ||||||
| HAS_DATABRICKS = False | ||||||
|
|
||||||
|
|
||||||
| try: | ||||||
| import redshift_connector | ||||||
|
|
||||||
| HAS_REDSHIFT = "REDSHIFT_HOST" in os.environ | ||||||
| except ImportError: | ||||||
| HAS_REDSHIFT = False | ||||||
|
|
||||||
| from positron.access_keys import encode_access_key | ||||||
| from positron.connections import ConnectionsService | ||||||
|
|
||||||
|
|
@@ -706,3 +714,135 @@ def _view_in_connections_pane(self, variables_comm: DummyComm, path): | |||||
| assert variables_comm.messages == [json_rpc_response({})] | ||||||
| variables_comm.messages.clear() | ||||||
| return tuple(encoded_paths) | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif(not HAS_REDSHIFT, reason="Redshift not available") | ||||||
| class TestRedshiftConnectionsService: | ||||||
| REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST") | ||||||
| REDSHIFT_PROFILE = os.environ.get("REDSHIFT_PROFILE", "default") | ||||||
| REDSHIFT_DATABASE = "dev" | ||||||
| REDSHIFT_SCHEMA = "public" | ||||||
| REDSHIFT_TABLE = "airlines" | ||||||
|
|
||||||
| def _connect(self): | ||||||
| return redshift_connector.connect( | ||||||
| iam=True, | ||||||
| host=self.REDSHIFT_HOST, | ||||||
| database=self.REDSHIFT_DATABASE, | ||||||
| profile=self.REDSHIFT_PROFILE, | ||||||
| ) | ||||||
|
|
||||||
| def _open_comm(self, connections_service: ConnectionsService): | ||||||
| con = self._connect() | ||||||
| comm_id = connections_service.register_connection(con) | ||||||
| dummy_comm = DummyComm(TARGET_NAME, comm_id=comm_id) | ||||||
| connections_service.on_comm_open(dummy_comm) | ||||||
| dummy_comm.messages.clear() | ||||||
| return dummy_comm, comm_id | ||||||
|
|
||||||
| def _database_path(self): | ||||||
| return [{"kind": "database", "name": self.REDSHIFT_DATABASE}] | ||||||
|
|
||||||
| def _schema_path(self): | ||||||
| return [*self._database_path(), {"kind": "schema", "name": self.REDSHIFT_SCHEMA}] | ||||||
|
|
||||||
| def _table_path(self): | ||||||
| return [*self._schema_path(), {"kind": "table", "name": self.REDSHIFT_TABLE}] | ||||||
|
|
||||||
| def _resolve_path(self, kind: str): | ||||||
| if kind == "root": | ||||||
| return [] | ||||||
| if kind == "database": | ||||||
| return self._database_path() | ||||||
| if kind == "schema": | ||||||
| return self._schema_path() | ||||||
| if kind == "table": | ||||||
| return self._table_path() | ||||||
| raise ValueError(f"Unknown path kind: {kind}") | ||||||
|
|
||||||
| def test_register_connection(self, connections_service: ConnectionsService): | ||||||
| con = self._connect() | ||||||
| comm_id = connections_service.register_connection(con) | ||||||
| assert comm_id in connections_service.comms | ||||||
|
|
||||||
| @pytest.mark.parametrize( | ||||||
| ("path_kind"), | ||||||
|
||||||
| ("path_kind"), | |
| "path_kind", |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Inconsistent boolean assertion. This test uses assert result == (path_kind == "table") while the equivalent Databricks test at line 397 uses assert result is (path_kind == "table"). For consistency with the existing codebase, consider using is instead of == when comparing boolean values.
| assert result == (path_kind == "table") | |
| assert result is (path_kind == "table") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spelling error: "Unfortunatelly" should be "Unfortunately".