88
99import asyncio
1010from concurrent .futures import ThreadPoolExecutor
11- from typing import Any , Dict , Optional , Type
11+ from typing import Any , Dict , Optional
1212
13+ from azure .core .credentials import TokenCredential
1314from azure .core .exceptions import AzureError , ClientAuthenticationError
14- from azure .identity import DefaultAzureCredential
1515
1616from application_sdk .clients import ClientInterface
1717from application_sdk .clients .azure .azure_auth import AzureAuthProvider
1818from application_sdk .clients .azure .azure_services import AzureStorageClient
19- from application_sdk .common .error_codes import ClientError , CommonError
2019from application_sdk .common .credential_utils import resolve_credentials
20+ from application_sdk .common .error_codes import ClientError
2121from application_sdk .observability .logger_adaptor import get_logger
2222
2323logger = get_logger (__name__ )
2626class AzureClient (ClientInterface ):
2727 """
2828 Main Azure client for the application-sdk framework.
29-
29+
3030 This client provides a unified interface for connecting to and interacting
3131 with Azure Storage services. It supports Service Principal authentication
3232 and provides service-specific subclients for different Azure services.
33-
33+
3434 Attributes:
3535 credentials (Dict[str, Any]): Azure connection credentials
3636 resolved_credentials (Dict[str, Any]): Resolved credentials after processing
@@ -58,7 +58,7 @@ def __init__(
5858 """
5959 self .credentials = credentials or {}
6060 self .resolved_credentials : Dict [str , Any ] = {}
61- self .credential : Optional [DefaultAzureCredential ] = None
61+ self .credential : Optional [TokenCredential ] = None
6262 self .auth_provider = AzureAuthProvider ()
6363 self ._services : Dict [str , Any ] = {}
6464 self ._executor = ThreadPoolExecutor (max_workers = max_workers )
@@ -82,22 +82,21 @@ async def load(self, credentials: Optional[Dict[str, Any]] = None) -> None:
8282
8383 try :
8484 logger .info ("Loading Azure client..." )
85-
85+
8686 # Resolve credentials using framework's credential resolution
8787 self .resolved_credentials = await resolve_credentials (self .credentials )
88-
88+
8989 # Create Azure credential using Service Principal authentication
9090 self .credential = await self .auth_provider .create_credential (
91- auth_type = "service_principal" ,
92- credentials = self .resolved_credentials
91+ auth_type = "service_principal" , credentials = self .resolved_credentials
9392 )
94-
93+
9594 # Test the connection
9695 await self ._test_connection ()
97-
96+
9897 self ._connection_health = True
9998 logger .info ("Azure client loaded successfully" )
100-
99+
101100 except ClientAuthenticationError as e :
102101 logger .error (f"Azure authentication failed: { str (e )} " )
103102 raise ClientError (f"{ ClientError .CLIENT_AUTH_ERROR } : { str (e )} " )
@@ -112,28 +111,28 @@ async def close(self) -> None:
112111 """Close Azure connections and clean up resources."""
113112 try :
114113 logger .info ("Closing Azure client..." )
115-
114+
116115 # Close all service clients
117116 for service_name , service_client in self ._services .items ():
118117 try :
119- if hasattr (service_client , ' close' ):
118+ if hasattr (service_client , " close" ):
120119 await service_client .close ()
121- elif hasattr (service_client , ' disconnect' ):
120+ elif hasattr (service_client , " disconnect" ):
122121 await service_client .disconnect ()
123122 except Exception as e :
124123 logger .warning (f"Error closing { service_name } client: { str (e )} " )
125-
124+
126125 # Clear service cache
127126 self ._services .clear ()
128-
127+
129128 # Shutdown executor
130129 self ._executor .shutdown (wait = True )
131-
130+
132131 # Reset connection health
133132 self ._connection_health = False
134-
133+
135134 logger .info ("Azure client closed successfully" )
136-
135+
137136 except Exception as e :
138137 logger .error (f"Error closing Azure client: { str (e )} " )
139138
@@ -149,18 +148,17 @@ async def get_storage_client(self) -> AzureStorageClient:
149148 """
150149 if not self ._connection_health :
151150 raise ClientError (f"{ ClientError .CLIENT_AUTH_ERROR } : Client not loaded" )
152-
151+
153152 if "storage" not in self ._services :
154153 try :
155154 self ._services ["storage" ] = AzureStorageClient (
156- credential = self .credential ,
157- ** self ._kwargs
155+ credential = self .credential , ** self ._kwargs
158156 )
159157 await self ._services ["storage" ].load (self .resolved_credentials )
160158 except Exception as e :
161159 logger .error (f"Failed to create storage client: { str (e )} " )
162160 raise ClientError (f"{ ClientError .CLIENT_AUTH_ERROR } : { str (e )} " )
163-
161+
164162 return self ._services ["storage" ]
165163
166164 async def get_service_client (self , service_type : str ) -> Any :
@@ -181,10 +179,10 @@ async def get_service_client(self, service_type: str) -> Any:
181179 service_mapping = {
182180 "storage" : self .get_storage_client ,
183181 }
184-
182+
185183 if service_type not in service_mapping :
186184 raise ValueError (f"Unsupported service type: { service_type } " )
187-
185+
188186 return await service_mapping [service_type ]()
189187
190188 async def health_check (self ) -> Dict [str , Any ]:
@@ -197,33 +195,32 @@ async def health_check(self) -> Dict[str, Any]:
197195 health_status = {
198196 "connection_health" : self ._connection_health ,
199197 "services" : {},
200- "overall_health" : False
198+ "overall_health" : False ,
201199 }
202-
200+
203201 if not self ._connection_health :
204202 return health_status
205-
203+
206204 # Check each service
207205 for service_name , service_client in self ._services .items ():
208206 try :
209- if hasattr (service_client , ' health_check' ):
207+ if hasattr (service_client , " health_check" ):
210208 service_health = await service_client .health_check ()
211209 else :
212210 service_health = {"status" : "unknown" }
213-
211+
214212 health_status ["services" ][service_name ] = service_health
215213 except Exception as e :
216214 health_status ["services" ][service_name ] = {
217215 "status" : "error" ,
218- "error" : str (e )
216+ "error" : str (e ),
219217 }
220-
218+
221219 # Overall health is True if connection is healthy and at least one service is available
222220 health_status ["overall_health" ] = (
223- self ._connection_health and
224- len (health_status ["services" ]) > 0
221+ self ._connection_health and len (health_status ["services" ]) > 0
225222 )
226-
223+
227224 return health_status
228225
229226 async def _test_connection (self ) -> None :
@@ -233,12 +230,17 @@ async def _test_connection(self) -> None:
233230 Raises:
234231 ClientAuthenticationError: If connection test fails.
235232 """
233+ if not self .credential :
234+ raise ClientAuthenticationError (
235+ "No credential available for connection test"
236+ )
237+
236238 try :
237239 # Test the credential by getting a token
238240 await asyncio .get_event_loop ().run_in_executor (
239241 self ._executor ,
240242 self .credential .get_token ,
241- "https://management.azure.com/.default"
243+ "https://management.azure.com/.default" ,
242244 )
243245 except Exception as e :
244246 raise ClientAuthenticationError (f"Connection test failed: { str (e )} " )
@@ -249,4 +251,26 @@ def __enter__(self):
249251
250252 def __exit__ (self , exc_type , exc_val , exc_tb ):
251253 """Context manager exit."""
252- asyncio .create_task (self .close ())
254+ # Note: This is a synchronous context manager.
255+ # For proper async cleanup, use the async context manager instead.
256+ # This method is kept for backward compatibility but doesn't guarantee cleanup.
257+ logger .warning (
258+ "Using synchronous context manager. For proper async cleanup, "
259+ "use 'async with AzureClient() as client:' instead."
260+ )
261+ # Schedule cleanup but don't wait for it
262+ try :
263+ loop = asyncio .get_event_loop ()
264+ if loop .is_running ():
265+ loop .create_task (self .close ())
266+ except RuntimeError :
267+ # No event loop running, can't schedule async cleanup
268+ logger .warning ("No event loop running, async cleanup not possible" )
269+
270+ async def __aenter__ (self ):
271+ """Async context manager entry."""
272+ return self
273+
274+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
275+ """Async context manager exit."""
276+ await self .close ()
0 commit comments