Skip to content

Commit 8190a28

Browse files
committed
full implementation of ssh-agent authentication
1 parent db69cdc commit 8190a28

File tree

8 files changed

+130
-72
lines changed

8 files changed

+130
-72
lines changed

ssh_utilities/abstract/_connection.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class ConnectionABC(ABC):
3434

3535
__name__: str
3636
__abstractmethods__: FrozenSet[str]
37+
password: Optional[str]
38+
address: Optional[str]
39+
username: str
40+
pkey_file: Optional[Union[str, "Path"]]
41+
allow_agent: Optional[bool]
3742

3843
@abstractmethod
3944
def __str__(self):
@@ -135,11 +140,11 @@ def to_dict(self):
135140
@staticmethod
136141
def _to_dict(connection_name: str, host_name: str, address: Optional[str],
137142
user_name: str, ssh_key: Optional[Union[Path, str]],
138-
thread_safe: bool
143+
thread_safe: bool, allow_agent: bool
139144
) -> Dict[str, Optional[Union[str, bool, int]]]:
140145

141146
if ssh_key is None:
142-
key_path = ssh_key
147+
key_path = None
143148
else:
144149
key_path = str(Path(ssh_key).resolve())
145150

@@ -149,12 +154,14 @@ def _to_dict(connection_name: str, host_name: str, address: Optional[str],
149154
"user_name": user_name,
150155
"ssh_key": key_path,
151156
"address": address,
152-
"thread_safe": thread_safe
157+
"thread_safe": thread_safe,
158+
"allow_agent": allow_agent,
153159
}
154160

155161
def _to_str(self, connection_name: str, host_name: str,
156162
address: Optional[str], user_name: str,
157-
ssh_key: Optional[Union[Path, str]], thread_safe: bool) -> str:
163+
ssh_key: Optional[Union[Path, str]], thread_safe: bool,
164+
allow_agent: bool) -> str:
158165
"""Aims to ease persistance, returns string representation of instance.
159166
160167
With this method all data needed to initialize class are saved to sting
@@ -177,6 +184,9 @@ def _to_str(self, connection_name: str, host_name: str,
177184
make connection object thread safe so it can be safely accessed
178185
from any number of threads, it is disabled by default to avoid
179186
performance penalty of threading locks
187+
allow_agent: bool
188+
allows use of ssh agent for connection authentication, when this is
189+
`True` key for the host does not have to be available.
180190
181191
Returns
182192
-------
@@ -188,7 +198,8 @@ def _to_str(self, connection_name: str, host_name: str,
188198
:class:`ssh_utilities.conncection.Connection`
189199
"""
190200
return dumps(self._to_dict(connection_name, host_name, address,
191-
user_name, ssh_key, thread_safe))
201+
user_name, ssh_key, thread_safe,
202+
allow_agent))
192203

193204
def __del__(self):
194205
self.close(quiet=True)
@@ -206,10 +217,11 @@ def __setstate__(self, state: dict):
206217
self.__init__(state["address"], state["user_name"], # type: ignore
207218
pkey_file=state["ssh_key"],
208219
server_name=state["server_name"],
209-
quiet=True, thread_safe=state["thread_safe"])
220+
quiet=True, thread_safe=state["thread_safe"],
221+
allow_agent=state["allow_agent"])
210222

211223
def __enter__(self: "CONN_TYPE") -> "CONN_TYPE":
212224
return self
213225

214226
def __exit__(self, exc_type, exc_value, exc_traceback):
215-
self.close(quiet=True)
227+
self.close(quiet=True)

ssh_utilities/connection.py

Lines changed: 78 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
__all__ = ["Connection"]
3636

37-
logging.getLogger(__name__)
37+
log = logging.getLogger(__name__)
3838

3939
# guard for when readthedocs is building documentation or travis
4040
# is running CI build
@@ -121,21 +121,22 @@ class Connection(metaclass=_ConnectionMeta):
121121

122122
@overload
123123
def __new__(cls, ssh_server: str, local: Literal[False], quiet: bool,
124-
thread_safe: bool) -> SSHConnection:
124+
thread_safe: bool, allow_agent: bool) -> SSHConnection:
125125
...
126126

127127
@overload
128128
def __new__(cls, ssh_server: str, local: Literal[True], quiet: bool,
129-
thread_safe: bool) -> LocalConnection:
129+
thread_safe: bool, allow_agent: bool) -> LocalConnection:
130130
...
131131

132132
@overload
133133
def __new__(cls, ssh_server: str, local: bool, quiet: bool,
134-
thread_safe: bool) -> Union[SSHConnection, LocalConnection]:
134+
thread_safe: bool, allow_agent: bool
135+
) -> Union[SSHConnection, LocalConnection]:
135136
...
136137

137138
def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
138-
thread_safe: bool = False):
139+
thread_safe: bool = False, allow_agent: bool = True):
139140
"""Get Connection based on one of names defined in .ssh/config file.
140141
141142
If name of local PC is passed initilize LocalConnection
@@ -152,11 +153,14 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
152153
make connection object thread safe so it can be safely accessed
153154
from any number of threads, it is disabled by default to avoid
154155
performance penalty of threading locks
156+
allow_agent: bool
157+
allows use of ssh agent for connection authentication, when this is
158+
`True` key for the host does not have to be available.
155159
156160
Raises
157161
------
158162
KeyError
159-
if server name is not in config file
163+
if server name is not in config file and allow agent is false
160164
161165
Returns
162166
-------
@@ -173,14 +177,36 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
173177
raise KeyError(f"couldn't find login credentials for {ssh_server}:"
174178
f" {e}")
175179
else:
180+
# get username and address
176181
try:
177-
return cls.open(credentials["user"], credentials["hostname"],
178-
credentials["identityfile"][0],
179-
server_name=ssh_server, quiet=quiet,
180-
thread_safe=thread_safe)
182+
user = credentials["user"]
183+
hostname = credentials["hostname"]
181184
except KeyError as e:
182-
raise KeyError(f"{RED}missing key in config dictionary for "
183-
f"{ssh_server}: {R}{e}")
185+
raise KeyError(
186+
"Cannot find username or hostname for specified host"
187+
)
188+
189+
# get key or use agent
190+
if allow_agent:
191+
log.info(f"no private key supplied for {hostname}, will try "
192+
f"to authenticate through ssh-agent")
193+
pkey_file = None
194+
else:
195+
log.info(f"private key found for host: {hostname}")
196+
try:
197+
pkey_file = credentials["identityfile"][0]
198+
except (KeyError, IndexError) as e:
199+
raise KeyError(f"No private key found for specified host")
200+
201+
return cls.open(
202+
user,
203+
hostname,
204+
ssh_key_file=pkey_file,
205+
allow_agent=allow_agent,
206+
server_name=ssh_server,
207+
quiet=quiet,
208+
thread_safe=thread_safe
209+
)
184210

185211
@classmethod
186212
def get_available_hosts(cls) -> List[str]:
@@ -212,7 +238,8 @@ def get(cls, *args, **kwargs):
212238
get_connection = get
213239

214240
@classmethod
215-
def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]):
241+
def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]],
242+
allow_agent: Union[bool, List[bool]]):
216243
"""Add or override availbale host read fron ssh config file.
217244
218245
You can use supplied config parser to parse some externaf ssh config
@@ -223,15 +250,22 @@ def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]):
223250
hosts : Union[_HOSTS, List[_HOSTS]]
224251
dictionary or a list of dictionaries containing keys: `user`,
225252
`hostname` and `identityfile`
253+
allow_agent: Union[bool, List[bool]]
254+
bool or a list of bools with corresponding length to list of hosts.
255+
if only one bool is passed in, it will be used for all host entries
226256
227257
See also
228258
--------
229259
:func:ssh_utilities.config_parser
230260
"""
231261
if not isinstance(hosts, list):
232262
hosts = [hosts]
263+
if not isinstance(allow_agent, list):
264+
allow_agent = [allow_agent] * len(hosts)
233265

234-
for h in hosts:
266+
for h, a in zip(hosts, allow_agent):
267+
if a:
268+
h["identityfile"][0] = None
235269
if not isinstance(h["identityfile"], list):
236270
h["identityfile"] = [h["identityfile"]]
237271
h["identityfile"][0] = os.path.abspath(
@@ -300,7 +334,7 @@ def open(ssh_username: str, ssh_server: None = None,
300334
ssh_password: Optional[str] = None,
301335
server_name: Optional[str] = None, quiet: bool = False,
302336
thread_safe: bool = False,
303-
ssh_allow_agent: bool = False) -> LocalConnection:
337+
allow_agent: bool = False) -> LocalConnection:
304338
...
305339

306340
@overload
@@ -310,7 +344,7 @@ def open(ssh_username: str, ssh_server: str,
310344
ssh_password: Optional[str] = None,
311345
server_name: Optional[str] = None, quiet: bool = False,
312346
thread_safe: bool = False,
313-
ssh_allow_agent: bool = False) -> SSHConnection:
347+
allow_agent: bool = False) -> SSHConnection:
314348
...
315349

316350
@staticmethod
@@ -319,7 +353,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
319353
ssh_password: Optional[str] = None,
320354
server_name: Optional[str] = None, quiet: bool = False,
321355
thread_safe: bool = False,
322-
ssh_allow_agent: bool = False):
356+
allow_agent: bool = False):
323357
"""Initialize SSH or local connection.
324358
325359
Local connection is only a wrapper around os and shutil module methods
@@ -346,7 +380,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
346380
make connection object thread safe so it can be safely accessed
347381
from any number of threads, it is disabled by default to avoid
348382
performance penalty of threading locks
349-
ssh_allow_agent: bool
383+
allow_agent: bool
350384
allow the use of the ssh-agent to connect. Will disable ssh_key_file.
351385
352386
Warnings
@@ -355,27 +389,29 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
355389
risk!
356390
"""
357391
if not ssh_server:
358-
return LocalConnection(ssh_server, ssh_username,
359-
pkey_file=ssh_key_file,
360-
server_name=server_name, quiet=quiet)
361-
else:
362-
if ssh_allow_agent:
363-
c = SSHConnection(ssh_server, ssh_username,
364-
allow_agent=ssh_allow_agent, line_rewrite=True,
365-
server_name=server_name, quiet=quiet,
366-
thread_safe=thread_safe)
367-
elif ssh_key_file:
368-
c = SSHConnection(ssh_server, ssh_username,
369-
pkey_file=ssh_key_file, line_rewrite=True,
370-
server_name=server_name, quiet=quiet,
371-
thread_safe=thread_safe)
372-
else:
373-
if not ssh_password:
374-
ssh_password = getpass.getpass(prompt="Enter password: ")
375-
376-
c = SSHConnection(ssh_server, ssh_username,
377-
password=ssh_password, line_rewrite=True,
378-
server_name=server_name, quiet=quiet,
379-
thread_safe=thread_safe)
380-
381-
return c
392+
return LocalConnection(
393+
ssh_server,
394+
ssh_username,
395+
pkey_file=ssh_key_file,
396+
server_name=server_name,
397+
quiet=quiet
398+
)
399+
elif allow_agent:
400+
ssh_key_file = None
401+
ssh_password = None
402+
elif ssh_key_file:
403+
ssh_password = None
404+
elif not ssh_password:
405+
ssh_password = getpass.getpass(prompt="Enter password: ")
406+
407+
return SSHConnection(
408+
ssh_server,
409+
ssh_username,
410+
allow_agent=allow_agent,
411+
pkey_file=ssh_key_file,
412+
password=ssh_password,
413+
line_rewrite=True,
414+
server_name=server_name,
415+
quiet=quiet,
416+
thread_safe=thread_safe
417+
)

ssh_utilities/local/local.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ def __init__(self, address: Optional[str], username: str,
3030
password: Optional[str] = None,
3131
pkey_file: Optional[Union[str, "Path"]] = None,
3232
line_rewrite: bool = True, server_name: Optional[str] = None,
33-
quiet: bool = False, thread_safe: bool = False) -> None:
33+
quiet: bool = False, thread_safe: bool = False,
34+
allow_agent: Optional[bool] = False) -> None:
3435

3536
# set login credentials
3637
self.password = password
3738
self.address = address
3839
self.username = username
3940
self.pkey_file = pkey_file
41+
self.allow_agent = allow_agent
4042

4143
self.server_name = server_name if server_name else gethostname()
4244
self.server_name = self.server_name.upper()
@@ -92,11 +94,11 @@ def subprocess(self) -> "_SUBPROCESS_LOCAL":
9294

9395
def __str__(self) -> str:
9496
return self._to_str("LocalConnection", self.server_name, None,
95-
self.username, None, True)
97+
self.username, None, True, False)
9698

9799
def to_dict(self) -> Dict[str, Optional[Union[str, bool, int]]]:
98100
return self._to_dict("LocalConnection", self.server_name, None,
99-
self.username, None, True)
101+
self.username, None, True, False)
100102

101103
@staticmethod
102104
def close(*, quiet: bool = True):

ssh_utilities/multi_connection/_persistence.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def __getstate__(self):
5555

5656
def __setstate__(self, state: dict):
5757
"""Initializes the object after load from pickle."""
58-
ssh_servers, local, thread_safe = (
58+
ssh_servers, local, thread_safe, allow_agent = (
5959
self._parse_persistence_dict(state)
6060
)
6161

6262
self.__init__(ssh_servers, local, quiet=True, # type: ignore
63-
thread_safe=thread_safe)
63+
thread_safe=thread_safe, allow_agent=allow_agent)
6464

6565
def to_dict(self) -> Dict[int, Dict[str, Optional[Union[str, bool,
6666
int, None]]]]:
@@ -96,12 +96,14 @@ def _parse_persistence_dict(d: dict) -> Tuple[List[str], List[int],
9696
ssh_servers = []
9797
local = []
9898
thread_safe = []
99+
allow_agent = []
99100
for j in d.values():
100101
ssh_servers.append(j.pop("server_name"))
101102
thread_safe.append(j.pop("thread_safe"))
102103
local.append(not bool(j.pop("address")))
104+
allow_agent.append(j.pop("allow_agent"))
103105

104-
return ssh_servers, local, thread_safe
106+
return ssh_servers, local, thread_safe, allow_agent
105107

106108
@classmethod
107109
def from_dict(cls, json: dict, quiet: bool = False
@@ -129,12 +131,12 @@ def from_dict(cls, json: dict, quiet: bool = False
129131
KeyError
130132
if required key is missing from string
131133
"""
132-
ssh_servers, local, thread_safe = (
134+
ssh_servers, local, thread_safe, allow_agent = (
133135
cls._parse_persistence_dict(json)
134136
)
135137

136138
return cls(ssh_servers, local, quiet=quiet, # type: ignore
137-
thread_safe=thread_safe)
139+
thread_safe=thread_safe, allow_agent=allow_agent)
138140

139141
@classmethod
140142
def from_str(cls, string: str, quiet: bool = False

ssh_utilities/multi_connection/multi_connection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ class MultiConnection(DictInterface, Pesistence, ConnectionABC):
9696

9797
def __init__(self, ssh_servers: Union[List[str], str],
9898
local: Union[List[bool], bool] = False, quiet: bool = False,
99-
thread_safe: Union[List[bool], bool] = False) -> None:
99+
thread_safe: Union[List[bool], bool] = False,
100+
allow_agent: Union[List[bool], bool] = True) -> None:
100101

101102
# TODO somehow adjust number of workers if connection are deleted or
102103
# TODO added
@@ -108,6 +109,8 @@ def __init__(self, ssh_servers: Union[List[str], str],
108109
local = [local] * len(ssh_servers)
109110
if not isinstance(thread_safe, list):
110111
thread_safe = [thread_safe] * len(ssh_servers)
112+
if not isinstance(allow_agent, list):
113+
allow_agent = [allow_agent] * len(ssh_servers)
111114

112115
self._connections = defaultdict(deque)
113116
for ss, l, ts in zip(ssh_servers, local, thread_safe):

0 commit comments

Comments
 (0)