Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-ssh] Update to Pythonic resources #15180

Merged
merged 5 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 108 additions & 67 deletions python_modules/libraries/dagster-ssh/dagster_ssh/resources.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import getpass
import logging
import os
from io import StringIO
from typing import Optional

import paramiko
from dagster import (
BoolSource,
Field,
Field as DagsterField,
IntSource,
StringSource,
_check as check,
resource,
)
from dagster._config.pythonic_config import ConfigurableResource
from dagster._core.definitions.resource_definition import dagster_maintained_resource
from dagster._core.execution.context.init import InitResourceContext
from dagster._utils import mkdir_p
from dagster._utils.merger import merge_dicts
from paramiko.client import SSHClient
from paramiko.config import SSH_PORT
from pydantic import (
Field,
PrivateAttr,
)
from sshtunnel import SSHTunnelForwarder


Expand All @@ -29,70 +37,99 @@ def key_from_str(key_str):
return result


class SSHResource:
class SSHResource(ConfigurableResource):
"""Resource for ssh remote execution using Paramiko.

ref: https://github.com/paramiko/paramiko
"""

def __init__(
self,
remote_host,
remote_port,
username=None,
password=None,
key_file=None,
key_string=None,
timeout=10,
keepalive_interval=30,
compress=True,
no_host_key_check=True,
allow_host_key_change=False,
logger=None,
):
self.remote_host = check.str_param(remote_host, "remote_host")
self.remote_port = check.opt_int_param(remote_port, "remote_port")
self.username = check.opt_str_param(username, "username")
self.password = check.opt_str_param(password, "password")
self.key_file = check.opt_str_param(key_file, "key_file")
self.timeout = check.opt_int_param(timeout, "timeout")
self.keepalive_interval = check.opt_int_param(keepalive_interval, "keepalive_interval")
self.compress = check.opt_bool_param(compress, "compress")
self.no_host_key_check = check.opt_bool_param(no_host_key_check, "no_host_key_check")
self.log = logger

self.host_proxy = None
remote_host: str = Field(description="Remote host to connect to")
remote_port: Optional[int] = Field(default=None, description="Port of remote host to connect")
username: Optional[str] = Field(default=None, description="Username to connect to remote host")
password: Optional[str] = Field(
default=None, description="Password of the username to connect to remote host"
)
key_file: Optional[str] = Field(
default=None, description="Key file to use to connect to remote host"
)
key_string: Optional[str] = Field(
default=None, description="Key string to use to connect to remote host"
)
timeout: int = Field(
default=10, description="Timeout for the attempt to connect to remote host"
)
keepalive_interval: int = Field(
default=30,
description="Send a keepalive packet to remote host every keepalive_interval seconds",
)
compress: bool = Field(default=True, description="Compress the transport stream")
no_host_key_check: bool = Field(
default=True,
description=(
"If True, the host key will not be verified. This is unsafe and not recommended"
),
)
allow_host_key_change: bool = Field(
default=False,
description="If True, allow connecting to hosts whose host key has changed",
)

_logger: Optional[logging.Logger] = PrivateAttr(default=None)
_host_proxy: Optional[paramiko.ProxyCommand] = PrivateAttr(default=None)
_key_obj: Optional[paramiko.RSAKey] = PrivateAttr(default=None)

def set_logger(self, logger: logging.Logger) -> None:
self._logger = logger

def setup_for_execution(self, context: InitResourceContext) -> None:
self._logger = context.log
self._host_proxy = None

# Create RSAKey object from private key string
self.key_obj = key_from_str(key_string) if key_string is not None else None
self._key_obj = key_from_str(self.key_string) if self.key_string is not None else None

# Auto detecting username values from system
if not self.username:
logger.debug(
"username to ssh to host: %s is not specified. Using system's default provided by"
" getpass.getuser()" % self.remote_host
)
if self._logger:
self._logger.debug(
"username to ssh to host: %s is not specified. Using system's default provided"
" by getpass.getuser()" % self.remote_host
)
self.username = getpass.getuser()

user_ssh_config_filename = os.path.expanduser("~/.ssh/config")
if os.path.isfile(user_ssh_config_filename):
ssh_conf = paramiko.SSHConfig()
ssh_conf.parse(open(user_ssh_config_filename, encoding="utf8"))
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get("proxycommand"):
self.host_proxy = paramiko.ProxyCommand(host_info.get("proxycommand"))

proxy_command = host_info.get("proxycommand")
if host_info and proxy_command:
self._host_proxy = paramiko.ProxyCommand(proxy_command)

if not (self.password or self.key_file):
if host_info and host_info.get("identityfile"):
self.key_file = host_info.get("identityfile")[0]
identify_file = host_info.get("identityfile")
if host_info and identify_file:
self.key_file = identify_file[0]

@property
def log(self) -> logging.Logger:
return check.not_none(self._logger)

def get_connection(self):
def get_connection(self) -> SSHClient:
"""Opens a SSH connection to the remote host.

:rtype: paramiko.client.SSHClient
"""
client = paramiko.SSHClient()
client.load_system_host_keys()

if not self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. This won't protect against "
"Man-In-The-Middle attacks"
)
client.load_system_host_keys()
if self.no_host_key_check:
self.log.warning(
"No Host Key Verification. This won't protect against Man-In-The-Middle attacks"
Expand All @@ -106,31 +143,33 @@ def get_connection(self):
username=self.username,
password=self.password,
key_filename=self.key_file,
pkey=self.key_obj,
pkey=self._key_obj,
timeout=self.timeout,
compress=self.compress,
port=self.remote_port,
sock=self.host_proxy,
port=self.remote_port, # type: ignore
sock=self._host_proxy, # type: ignore
look_for_keys=False,
)
else:
client.connect(
hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
pkey=self.key_obj,
pkey=self._key_obj,
timeout=self.timeout,
compress=self.compress,
port=self.remote_port,
sock=self.host_proxy,
port=self.remote_port, # type: ignore
sock=self._host_proxy, # type: ignore
)

if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)
client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore

return client

def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
def get_tunnel(
self, remote_port, remote_host="localhost", local_port=None
) -> SSHTunnelForwarder:
check.int_param(remote_port, "remote_port")
check.str_param(remote_host, "remote_host")
check.opt_int_param(local_port, "local_port")
Expand All @@ -141,7 +180,11 @@ def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
local_bind_address = ("localhost",)

# Will prefer key string if specified, otherwise use the key file
pkey = self.key_obj if self.key_obj else self.key_file
if self._key_obj and self.key_file:
self.log.warning(
"SSHResource: key_string and key_file both specified as config. Using key_string."
)
pkey = self._key_obj if self._key_obj else self.key_file

if self.password and self.password.strip():
client = SSHTunnelForwarder(
Expand All @@ -150,22 +193,22 @@ def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
ssh_username=self.username,
ssh_password=self.password,
ssh_pkey=pkey,
ssh_proxy=self.host_proxy,
ssh_proxy=self._host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
logger=self.log,
logger=self._logger,
)
else:
client = SSHTunnelForwarder(
self.remote_host,
ssh_port=self.remote_port,
ssh_username=self.username,
ssh_pkey=pkey,
ssh_proxy=self.host_proxy,
ssh_proxy=self._host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
host_pkey_directories=[],
logger=self.log,
logger=self._logger,
)

return client
Expand Down Expand Up @@ -203,53 +246,51 @@ def sftp_put(self, remote_filepath, local_filepath, confirm=True):
@dagster_maintained_resource
@resource(
config_schema={
"remote_host": Field(
"remote_host": DagsterField(
StringSource, description="remote host to connect to", is_required=True
),
"remote_port": Field(
"remote_port": DagsterField(
IntSource,
description="port of remote host to connect (Default is paramiko SSH_PORT)",
is_required=False,
default_value=SSH_PORT,
),
"username": Field(
"username": DagsterField(
StringSource, description="username to connect to the remote_host", is_required=False
),
"password": Field(
"password": DagsterField(
StringSource,
description="password of the username to connect to the remote_host",
is_required=False,
),
"key_file": Field(
"key_file": DagsterField(
StringSource,
description="key file to use to connect to the remote_host.",
is_required=False,
),
"key_string": Field(
"key_string": DagsterField(
StringSource,
description="key string to use to connect to remote_host",
is_required=False,
),
"timeout": Field(
"timeout": DagsterField(
IntSource,
description="timeout for the attempt to connect to the remote_host.",
is_required=False,
default_value=10,
),
"keepalive_interval": Field(
"keepalive_interval": DagsterField(
IntSource,
description="send a keepalive packet to remote host every keepalive_interval seconds",
is_required=False,
default_value=30,
),
"compress": Field(BoolSource, is_required=False, default_value=True),
"no_host_key_check": Field(BoolSource, is_required=False, default_value=True),
"allow_host_key_change": Field(
"compress": DagsterField(BoolSource, is_required=False, default_value=True),
"no_host_key_check": DagsterField(BoolSource, is_required=False, default_value=True),
"allow_host_key_change": DagsterField(
BoolSource, description="[Deprecated]", is_required=False, default_value=False
),
}
)
def ssh_resource(init_context):
args = init_context.resource_config
args = merge_dicts(init_context.resource_config, {"logger": init_context.log})
return SSHResource(**args)
return SSHResource.from_resource_context(init_context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the config for the function resource different from the ConfigurableResource so we cant do SSHResource.to_config_schema()

Loading