34
34
35
35
__all__ = ["Connection" ]
36
36
37
- logging .getLogger (__name__ )
37
+ log = logging .getLogger (__name__ )
38
38
39
39
# guard for when readthedocs is building documentation or travis
40
40
# is running CI build
@@ -121,21 +121,22 @@ class Connection(metaclass=_ConnectionMeta):
121
121
122
122
@overload
123
123
def __new__ (cls , ssh_server : str , local : Literal [False ], quiet : bool ,
124
- thread_safe : bool ) -> SSHConnection :
124
+ thread_safe : bool , allow_agent : bool ) -> SSHConnection :
125
125
...
126
126
127
127
@overload
128
128
def __new__ (cls , ssh_server : str , local : Literal [True ], quiet : bool ,
129
- thread_safe : bool ) -> LocalConnection :
129
+ thread_safe : bool , allow_agent : bool ) -> LocalConnection :
130
130
...
131
131
132
132
@overload
133
133
def __new__ (cls , ssh_server : str , local : bool , quiet : bool ,
134
- thread_safe : bool ) -> Union [SSHConnection , LocalConnection ]:
134
+ thread_safe : bool , allow_agent : bool
135
+ ) -> Union [SSHConnection , LocalConnection ]:
135
136
...
136
137
137
138
def __new__ (cls , ssh_server : str , local : bool = False , quiet : bool = False ,
138
- thread_safe : bool = False ):
139
+ thread_safe : bool = False , allow_agent : bool = True ):
139
140
"""Get Connection based on one of names defined in .ssh/config file.
140
141
141
142
If name of local PC is passed initilize LocalConnection
@@ -152,11 +153,14 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
152
153
make connection object thread safe so it can be safely accessed
153
154
from any number of threads, it is disabled by default to avoid
154
155
performance penalty of threading locks
156
+ allow_agent: bool
157
+ allows use of ssh agent for connection authentication, when this is
158
+ `True` key for the host does not have to be available.
155
159
156
160
Raises
157
161
------
158
162
KeyError
159
- if server name is not in config file
163
+ if server name is not in config file and allow agent is false
160
164
161
165
Returns
162
166
-------
@@ -173,14 +177,36 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
173
177
raise KeyError (f"couldn't find login credentials for { ssh_server } :"
174
178
f" { e } " )
175
179
else :
180
+ # get username and address
176
181
try :
177
- return cls .open (credentials ["user" ], credentials ["hostname" ],
178
- credentials ["identityfile" ][0 ],
179
- server_name = ssh_server , quiet = quiet ,
180
- thread_safe = thread_safe )
182
+ user = credentials ["user" ]
183
+ hostname = credentials ["hostname" ]
181
184
except KeyError as e :
182
- raise KeyError (f"{ RED } missing key in config dictionary for "
183
- f"{ ssh_server } : { R } { e } " )
185
+ raise KeyError (
186
+ "Cannot find username or hostname for specified host"
187
+ )
188
+
189
+ # get key or use agent
190
+ if allow_agent :
191
+ log .info (f"no private key supplied for { hostname } , will try "
192
+ f"to authenticate through ssh-agent" )
193
+ pkey_file = None
194
+ else :
195
+ log .info (f"private key found for host: { hostname } " )
196
+ try :
197
+ pkey_file = credentials ["identityfile" ][0 ]
198
+ except (KeyError , IndexError ) as e :
199
+ raise KeyError (f"No private key found for specified host" )
200
+
201
+ return cls .open (
202
+ user ,
203
+ hostname ,
204
+ ssh_key_file = pkey_file ,
205
+ allow_agent = allow_agent ,
206
+ server_name = ssh_server ,
207
+ quiet = quiet ,
208
+ thread_safe = thread_safe
209
+ )
184
210
185
211
@classmethod
186
212
def get_available_hosts (cls ) -> List [str ]:
@@ -212,7 +238,8 @@ def get(cls, *args, **kwargs):
212
238
get_connection = get
213
239
214
240
@classmethod
215
- def add_hosts (cls , hosts : Union ["_HOSTS" , List ["_HOSTS" ]]):
241
+ def add_hosts (cls , hosts : Union ["_HOSTS" , List ["_HOSTS" ]],
242
+ allow_agent : Union [bool , List [bool ]]):
216
243
"""Add or override availbale host read fron ssh config file.
217
244
218
245
You can use supplied config parser to parse some externaf ssh config
@@ -223,15 +250,22 @@ def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]):
223
250
hosts : Union[_HOSTS, List[_HOSTS]]
224
251
dictionary or a list of dictionaries containing keys: `user`,
225
252
`hostname` and `identityfile`
253
+ allow_agent: Union[bool, List[bool]]
254
+ bool or a list of bools with corresponding length to list of hosts.
255
+ if only one bool is passed in, it will be used for all host entries
226
256
227
257
See also
228
258
--------
229
259
:func:ssh_utilities.config_parser
230
260
"""
231
261
if not isinstance (hosts , list ):
232
262
hosts = [hosts ]
263
+ if not isinstance (allow_agent , list ):
264
+ allow_agent = [allow_agent ] * len (hosts )
233
265
234
- for h in hosts :
266
+ for h , a in zip (hosts , allow_agent ):
267
+ if a :
268
+ h ["identityfile" ][0 ] = None
235
269
if not isinstance (h ["identityfile" ], list ):
236
270
h ["identityfile" ] = [h ["identityfile" ]]
237
271
h ["identityfile" ][0 ] = os .path .abspath (
@@ -300,7 +334,7 @@ def open(ssh_username: str, ssh_server: None = None,
300
334
ssh_password : Optional [str ] = None ,
301
335
server_name : Optional [str ] = None , quiet : bool = False ,
302
336
thread_safe : bool = False ,
303
- ssh_allow_agent : bool = False ) -> LocalConnection :
337
+ allow_agent : bool = False ) -> LocalConnection :
304
338
...
305
339
306
340
@overload
@@ -310,7 +344,7 @@ def open(ssh_username: str, ssh_server: str,
310
344
ssh_password : Optional [str ] = None ,
311
345
server_name : Optional [str ] = None , quiet : bool = False ,
312
346
thread_safe : bool = False ,
313
- ssh_allow_agent : bool = False ) -> SSHConnection :
347
+ allow_agent : bool = False ) -> SSHConnection :
314
348
...
315
349
316
350
@staticmethod
@@ -319,7 +353,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
319
353
ssh_password : Optional [str ] = None ,
320
354
server_name : Optional [str ] = None , quiet : bool = False ,
321
355
thread_safe : bool = False ,
322
- ssh_allow_agent : bool = False ):
356
+ allow_agent : bool = False ):
323
357
"""Initialize SSH or local connection.
324
358
325
359
Local connection is only a wrapper around os and shutil module methods
@@ -346,7 +380,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
346
380
make connection object thread safe so it can be safely accessed
347
381
from any number of threads, it is disabled by default to avoid
348
382
performance penalty of threading locks
349
- ssh_allow_agent : bool
383
+ allow_agent : bool
350
384
allow the use of the ssh-agent to connect. Will disable ssh_key_file.
351
385
352
386
Warnings
@@ -355,27 +389,29 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
355
389
risk!
356
390
"""
357
391
if not ssh_server :
358
- return LocalConnection (ssh_server , ssh_username ,
359
- pkey_file = ssh_key_file ,
360
- server_name = server_name , quiet = quiet )
361
- else :
362
- if ssh_allow_agent :
363
- c = SSHConnection (ssh_server , ssh_username ,
364
- allow_agent = ssh_allow_agent , line_rewrite = True ,
365
- server_name = server_name , quiet = quiet ,
366
- thread_safe = thread_safe )
367
- elif ssh_key_file :
368
- c = SSHConnection (ssh_server , ssh_username ,
369
- pkey_file = ssh_key_file , line_rewrite = True ,
370
- server_name = server_name , quiet = quiet ,
371
- thread_safe = thread_safe )
372
- else :
373
- if not ssh_password :
374
- ssh_password = getpass .getpass (prompt = "Enter password: " )
375
-
376
- c = SSHConnection (ssh_server , ssh_username ,
377
- password = ssh_password , line_rewrite = True ,
378
- server_name = server_name , quiet = quiet ,
379
- thread_safe = thread_safe )
380
-
381
- return c
392
+ return LocalConnection (
393
+ ssh_server ,
394
+ ssh_username ,
395
+ pkey_file = ssh_key_file ,
396
+ server_name = server_name ,
397
+ quiet = quiet
398
+ )
399
+ elif allow_agent :
400
+ ssh_key_file = None
401
+ ssh_password = None
402
+ elif ssh_key_file :
403
+ ssh_password = None
404
+ elif not ssh_password :
405
+ ssh_password = getpass .getpass (prompt = "Enter password: " )
406
+
407
+ return SSHConnection (
408
+ ssh_server ,
409
+ ssh_username ,
410
+ allow_agent = allow_agent ,
411
+ pkey_file = ssh_key_file ,
412
+ password = ssh_password ,
413
+ line_rewrite = True ,
414
+ server_name = server_name ,
415
+ quiet = quiet ,
416
+ thread_safe = thread_safe
417
+ )
0 commit comments