4545import math
4646import datetime
4747import string
48+ from uuid import uuid4
4849
4950from impacket import ntlm , uuid , LOG
5051from impacket .structure import Structure
@@ -469,7 +470,7 @@ class TDS_COLMETADATA(Structure):
469470 )
470471
471472class MSSQL :
472- def __init__ (self , address , port = 1433 , remoteName = '' , rowsPrinter = DummyPrint ()):
473+ def __init__ (self , address , port = 1433 , remoteName = '' , workstation_id : str = "" , application_name : str = "" , rowsPrinter = DummyPrint ()):
473474 #self.packetSize = 32764
474475 self .packetSize = 32763
475476 self .server = address
@@ -487,6 +488,9 @@ def __init__(self, address, port=1433, remoteName = '', rowsPrinter=DummyPrint()
487488 self .__rowsPrinter = rowsPrinter
488489 self .mssql_version = ""
489490
491+ self ._workstation_id = workstation_id or f"DESKTOP-{ uuid4 ().hex [:8 ].upper ()} "
492+ self ._application_name = application_name or "Microsoft SQL Server Management Studio - Query"
493+
490494 # With Kerberos we need to know to which MSSQL instance we are going to connect (to compute the SPN)
491495 # As such we need to be able to list these instances which is what this code does
492496 def getInstances (self , timeout = 5 ):
@@ -550,9 +554,10 @@ def preLogin(self):
550554 def encryptPassword (self , password ):
551555 return bytes (bytearray ([((x & 0x0f ) << 4 ) + ((x & 0xf0 ) >> 4 ) ^ 0xa5 for x in bytearray (password )]))
552556
553- def connect (self ):
557+ def connect (self , timeout = 30 ):
554558 af , socktype , proto , canonname , sa = socket .getaddrinfo (self .server , self .port , 0 , socket .SOCK_STREAM )[0 ]
555559 sock = socket .socket (af , socktype , proto )
560+ sock .settimeout (timeout )
556561
557562 try :
558563 sock .connect (sa )
@@ -808,8 +813,8 @@ def kerberosLogin(self, database, username, password='', domain='', hashes=None,
808813 self .version ["ProductMajorVersion" ], self .version ["ProductMinorVersion" ], self .version ["ProductBuild" ] = 10 , 0 , 20348
809814
810815 login = TDS_LOGIN ()
811- login ['HostName' ] = ( '' . join ([ random . choice ( string . ascii_letters ) for _ in range ( 8 )])) .encode ('utf-16le' )
812- login ['AppName' ] = ( '' . join ([ random . choice ( string . ascii_letters ) for _ in range ( 8 )])) .encode ('utf-16le' )
816+ login ['HostName' ] = self . workstation_id .encode ('utf-16le' )
817+ login ['AppName' ] = self . application_name .encode ('utf-16le' )
813818 login ['ServerName' ] = self .remoteName .encode ('utf-16le' )
814819 login ['CltIntName' ] = login ['AppName' ]
815820 login ['ClientPID' ] = random .randint (0 ,1024 )
@@ -1013,8 +1018,8 @@ def login(self, database, username, password='', domain='', hashes = None, useWi
10131018 self .version ["ProductMajorVersion" ], self .version ["ProductMinorVersion" ], self .version ["ProductBuild" ] = 10 , 0 , 20348
10141019
10151020 login = TDS_LOGIN ()
1016- login ['HostName' ] = ( '' . join ([ random . choice ( string . ascii_letters ) for i in range ( 8 )])) .encode ('utf-16le' )
1017- login ['AppName' ] = ( '' . join ([ random . choice ( string . ascii_letters ) for i in range ( 8 )])) .encode ('utf-16le' )
1021+ login ['HostName' ] = self . workstation_id .encode ('utf-16le' )
1022+ login ['AppName' ] = self . application_name .encode ('utf-16le' )
10181023 login ['ServerName' ] = self .remoteName .encode ('utf-16le' )
10191024 login ['CltIntName' ] = login ['AppName' ]
10201025 login ['ClientPID' ] = random .randint (0 ,1024 )
@@ -1714,4 +1719,13 @@ def RunSQLStatement(self,db,sql_query,wait=True,**kwArgs):
17141719 self .RunSQLQuery (db ,sql_query ,wait = wait )
17151720 if self .lastError :
17161721 raise self .lastError
1717- return True
1722+ return True
1723+
1724+ # Properties
1725+ @property
1726+ def workstation_id (self ):
1727+ return self ._workstation_id
1728+
1729+ @property
1730+ def application_name (self ):
1731+ return self ._application_name
0 commit comments