Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pyarrow==21.0.0; python_version < '3.14'
pytest==8.4.2
pytest-asyncio==1.2.0
pytest-mock==3.15.1
redshift_connector==2.1.10; python_version < '3.14'
syrupy==4.9.1; python_version == '3.9'
syrupy==5.0.0; python_version >= '3.10'
torch==2.8.0; python_version == '3.9'
Expand Down
179 changes: 178 additions & 1 deletion extensions/positron-python/python_files/posit/positron/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def _wrap_connection(self, obj: Any) -> Connection:
if not self.object_is_supported(obj):
type_name = type(obj).__name__
raise UnsupportedConnectionError(f"Unsupported connection type {type_name}")

if safe_isinstance(obj, "sqlite3", "Connection"):
return SQLite3Connection(obj)
elif safe_isinstance(obj, "sqlalchemy", "Engine"):
Expand All @@ -312,6 +311,8 @@ def _wrap_connection(self, obj: Any) -> Connection:
return SnowflakeConnection(obj)
elif safe_isinstance(obj, "databricks.sql.client", "Connection"):
return DatabricksConnection(obj)
elif safe_isinstance(obj, "redshift_connector", "Connection"):
return RedshiftConnection(obj)
else:
type_name = type(obj).__name__
raise UnsupportedConnectionError(f"Unsupported connection type {type(obj)}")
Expand All @@ -327,6 +328,7 @@ def object_is_supported(self, obj: Any) -> bool:
or safe_isinstance(obj, "duckdb", "DuckDBPyConnection")
or safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection")
or safe_isinstance(obj, "databricks.sql.client", "Connection")
or safe_isinstance(obj, "redshift_connector", "Connection")
)
except Exception as err:
logger.error(f"Error checking supported {err}")
Expand Down Expand Up @@ -1200,3 +1202,178 @@ def _make_code(self) -> str:
")\n"
"%connection_show con\n"
)


class RedshiftConnection(Connection):
"""Support for Redshift connections to databases."""

def __init__(self, conn: Any):
self.conn = conn

try:
# Unfortunatelly there's no public API to get the host, so we access the protected member.
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling error: "Unfortunatelly" should be "Unfortunately".

Suggested change
# Unfortunatelly there's no public API to get the host, so we access the protected member.
# Unfortunately there's no public API to get the host, so we access the protected member.

Copilot uses AI. Check for mistakes.
# to at least provide some info in the connection display name.
host, _ = conn._usock.getpeername() # noqa: SLF001
except AttributeError:
host = "<unknown>"

self.host = str(host)

self.display_name = f"Redshift ({self.host})"
self.type = "Redshift"
self.code = self._make_code()

self.icon = ""

def disconnect(self):
with contextlib.suppress(Exception):
self.conn.close()

def list_object_types(self):
return {
"database": ConnectionObjectInfo({"contains": None, "icon": None}),
"schema": ConnectionObjectInfo({"contains": None, "icon": None}),
"table": ConnectionObjectInfo({"contains": "data", "icon": None}),
"view": ConnectionObjectInfo({"contains": "data", "icon": None}),
}

def list_objects(self, path: list[ObjectSchema]):
if len(path) == 0:
rows = self._query("SHOW DATABASES;")
return [
ConnectionObject({"name": row["database_name"], "kind": "database"}) for row in rows
]

if len(path) == 1:
database = path[0]
if database.kind != "database":
raise ValueError("Expected database on path position 0.", f"Path: {path}")
database_ident = self._qualify(database.name)
rows = self._query(f"SHOW SCHEMAS FROM DATABASE {database_ident};")
return [
ConnectionObject(
{
"name": row["schema_name"],
"kind": "schema",
}
)
for row in rows
]

if len(path) == 2:
database, schema = path
if database.kind != "database" or schema.kind != "schema":
raise ValueError(
"Expected database and schema objects at positions 0 and 1.", f"Path: {path}"
)
location = f"{self._qualify(database.name)}.{self._qualify(schema.name)}"
tables = self._query(f"SHOW TABLES FROM SCHEMA {location};")
return [
ConnectionObject(
{
"name": row["table_name"],
"kind": row["table_type"].lower(),
}
)
for row in tables
]

raise ValueError(f"Path length must be at most 2, but got {len(path)}. Path: {path}")

def list_fields(self, path: list[ObjectSchema]):
if len(path) != 3:
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")

database, schema, table = path
if (
database.kind != "database"
or schema.kind != "schema"
or table.kind not in ("table", "view")
):
raise ValueError(
"Expected database, schema, and table/view kinds in the path.",
f"Path: {path}",
)

identifier = ".".join(
[self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)]
)
rows = self._query(f"SHOW COLUMNS FROM TABLE {identifier};")
return [
ConnectionObjectFields(
{
"name": row["column_name"],
"dtype": row["data_type"],
}
)
for row in rows
]

def preview_object(self, path: list[ObjectSchema], var_name: str | None = None):
if len(path) != 3:
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")

database, schema, table = path
if (
database.kind != "database"
or schema.kind != "schema"
or table.kind not in ("table", "view")
):
raise ValueError(
"Expected database, schema, and table/view kinds in the path.",
f"Path: {path}",
)

identifier = ".".join(
[self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)]
)
sql = f"SELECT * FROM {identifier} LIMIT 1000;"

with self.conn.cursor() as cursor:
try:
cursor.execute(sql)
frame = cursor.fetch_dataframe()
except Exception as e:
# Rollback on error to avoid transaction issues
# for subsequent queries
self.conn.rollback()
raise e
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Use raise instead of raise e to re-raise the exception. This is more Pythonic and preserves the full traceback. The bare raise statement re-raises the current exception without modification.

Copilot uses AI. Check for mistakes.

var_name = var_name or "conn"
return frame, (
f"with {var_name}.cursor() as cursor:\n"
f" cursor.execute({sql!r})\n"
f" {table.name} = cursor.fetch_dataframe()"
)

def _query(self, sql: str) -> list[dict[str, Any]]:
cursor = self.conn.cursor()
try:
cursor.execute(sql)
rows = cursor.fetchall()
description = cursor.description or []
columns = [col[0] for col in description]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
# Rollback on error to avoid transaction issues
# for subsequent queries
self.conn.rollback()
raise e
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Use raise instead of raise e to re-raise the exception. This is more Pythonic and preserves the full traceback. The bare raise statement re-raises the current exception without modification.

Copilot uses AI. Check for mistakes.
finally:
cursor.close()

def _qualify(self, identifier: str) -> str:
escaped = identifier.replace('"', '""')
return f'"{escaped}"'

def _make_code(self) -> str:
return (
"# Requires redshift-connector package\n"
"# Authentication steps may be incomplete, adjust as needed.\n"
"import redshift_connector\n"
"con = redshift_connector.connect(\n"
f" iam = True,\n"
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary f-string prefix. This line doesn't contain any interpolated values, so it should be a regular string: " iam = True,\n" instead of f" iam = True,\n". This is consistent with the pattern used in other connection implementations (e.g., DatabricksConnection at lines 1198-1199, 1202-1203).

Suggested change
f" iam = True,\n"
" iam = True,\n"

Copilot uses AI. Check for mistakes.
f" host = '{self.host}',\n"
")\n"
"%connection_show con\n"
)
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
except ImportError:
HAS_DATABRICKS = False


try:
import redshift_connector

HAS_REDSHIFT = "REDSHIFT_HOST" in os.environ
except ImportError:
HAS_REDSHIFT = False

from positron.access_keys import encode_access_key
from positron.connections import ConnectionsService

Expand Down Expand Up @@ -706,3 +714,135 @@ def _view_in_connections_pane(self, variables_comm: DummyComm, path):
assert variables_comm.messages == [json_rpc_response({})]
variables_comm.messages.clear()
return tuple(encoded_paths)


@pytest.mark.skipif(not HAS_REDSHIFT, reason="Redshift not available")
class TestRedshiftConnectionsService:
REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST")
REDSHIFT_PROFILE = os.environ.get("REDSHIFT_PROFILE", "default")
REDSHIFT_DATABASE = "dev"
REDSHIFT_SCHEMA = "public"
REDSHIFT_TABLE = "airlines"

def _connect(self):
return redshift_connector.connect(
iam=True,
host=self.REDSHIFT_HOST,
database=self.REDSHIFT_DATABASE,
profile=self.REDSHIFT_PROFILE,
)

def _open_comm(self, connections_service: ConnectionsService):
con = self._connect()
comm_id = connections_service.register_connection(con)
dummy_comm = DummyComm(TARGET_NAME, comm_id=comm_id)
connections_service.on_comm_open(dummy_comm)
dummy_comm.messages.clear()
return dummy_comm, comm_id

def _database_path(self):
return [{"kind": "database", "name": self.REDSHIFT_DATABASE}]

def _schema_path(self):
return [*self._database_path(), {"kind": "schema", "name": self.REDSHIFT_SCHEMA}]

def _table_path(self):
return [*self._schema_path(), {"kind": "table", "name": self.REDSHIFT_TABLE}]

def _resolve_path(self, kind: str):
if kind == "root":
return []
if kind == "database":
return self._database_path()
if kind == "schema":
return self._schema_path()
if kind == "table":
return self._table_path()
raise ValueError(f"Unknown path kind: {kind}")

def test_register_connection(self, connections_service: ConnectionsService):
con = self._connect()
comm_id = connections_service.register_connection(con)
assert comm_id in connections_service.comms

@pytest.mark.parametrize(
("path_kind"),
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @pytest.mark.parametrize decorator has incorrect syntax. For a single parameter, use "path_kind" instead of ("path_kind"). The parentheses without a trailing comma create a string, not a tuple. While this may still work, it's inconsistent with the pattern used elsewhere (e.g., line 382 in the Databricks test) and could lead to confusion.

Suggested change
("path_kind"),
"path_kind",

Copilot uses AI. Check for mistakes.
[
pytest.param("root", id="root"),
pytest.param("database", id="database"),
pytest.param("schema", id="schema"),
pytest.param("table", id="table"),
],
)
def test_contains_data(self, connections_service: ConnectionsService, path_kind: str):
dummy_comm, comm_id = self._open_comm(connections_service)
path = self._resolve_path(path_kind)

msg = _make_msg(params={"path": path}, method="contains_data", comm_id=comm_id)
dummy_comm.handle_msg(msg)
result = dummy_comm.messages[0]["data"]["result"]
assert result == (path_kind == "table")
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Inconsistent boolean assertion. This test uses assert result == (path_kind == "table") while the equivalent Databricks test at line 397 uses assert result is (path_kind == "table"). For consistency with the existing codebase, consider using is instead of == when comparing boolean values.

Suggested change
assert result == (path_kind == "table")
assert result is (path_kind == "table")

Copilot uses AI. Check for mistakes.

@pytest.mark.parametrize(
("path_kind", "expected"),
[
pytest.param("root", "data:image", id="root"),
pytest.param("database", "", id="database"),
pytest.param("schema", "", id="schema"),
pytest.param("table", "", id="table"),
],
)
def test_get_icon(self, connections_service: ConnectionsService, path_kind: str, expected: str):
dummy_comm, comm_id = self._open_comm(connections_service)
path = self._resolve_path(path_kind)

msg = _make_msg(params={"path": path}, method="get_icon", comm_id=comm_id)
dummy_comm.handle_msg(msg)
result = dummy_comm.messages[0]["data"]["result"]
if expected:
assert expected in result
else:
assert result == ""

@pytest.mark.parametrize(
"path_kind",
[
pytest.param("root", id="databases"),
pytest.param("database", id="schemas"),
pytest.param("schema", id="tables"),
],
)
def test_list_objects(self, connections_service: ConnectionsService, path_kind: str):
dummy_comm, comm_id = self._open_comm(connections_service)
path = self._resolve_path(path_kind)
expected = {
"root": self.REDSHIFT_DATABASE,
"database": self.REDSHIFT_SCHEMA,
"schema": self.REDSHIFT_TABLE,
}[path_kind]

msg = _make_msg(params={"path": path}, method="list_objects", comm_id=comm_id)
dummy_comm.handle_msg(msg)
result = dummy_comm.messages[0]["data"]["result"]
names = [item["name"] for item in result]
assert expected in names

def test_list_fields(self, connections_service: ConnectionsService):
dummy_comm, comm_id = self._open_comm(connections_service)
path = self._table_path()

msg = _make_msg(params={"path": path}, method="list_fields", comm_id=comm_id)
dummy_comm.handle_msg(msg)
result = dummy_comm.messages[0]["data"]["result"]
field_names = {field["name"].lower() for field in result}
assert {"carrier", "name"}.issubset(field_names)

def test_preview_object(self, connections_service: ConnectionsService):
dummy_comm, comm_id = self._open_comm(connections_service)
path = self._table_path()

msg = _make_msg(params={"path": path}, method="preview_object", comm_id=comm_id)
dummy_comm.handle_msg(msg)
connections_service._kernel.data_explorer_service.shutdown() # noqa: SLF001
result = dummy_comm.messages[0]["data"]["result"]
assert result is None
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pytest
pytest-asyncio
pytest-mock
syrupy
redshift_connector
torch
scipy
snowflake-connector-python; python_version < '3.14'
Expand Down
Loading