Skip to content

Commit 357278d

Browse files
Added type annotations for public API + flake8 fixes
1 parent 5c21b81 commit 357278d

File tree

8 files changed

+49
-34
lines changed

8 files changed

+49
-34
lines changed

firebase_admin/__init__.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import os
1919
import threading
20+
from typing import Any, Callable, Dict, Optional
2021

2122
from firebase_admin import credentials
2223
from firebase_admin.__about__ import __version__
@@ -31,7 +32,8 @@
3132
_CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId',
3233
'storageBucket']
3334

34-
def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):
35+
36+
def initialize_app(credential: Optional[credentials.Base] = None, options: Optional[Dict[str, Any]] = None, name: str = _DEFAULT_APP_NAME) -> "App":
3537
"""Initializes and returns a new App instance.
3638
3739
Creates a new App instance using the specified options
@@ -83,7 +85,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):
8385
'you call initialize_app().').format(name))
8486

8587

86-
def delete_app(app):
88+
def delete_app(app: "App"):
8789
"""Gracefully deletes an App instance.
8890
8991
Args:
@@ -98,7 +100,7 @@ def delete_app(app):
98100
with _apps_lock:
99101
if _apps.get(app.name) is app:
100102
del _apps[app.name]
101-
app._cleanup() # pylint: disable=protected-access
103+
app._cleanup() # pylint: disable=protected-access
102104
return
103105
if app.name == _DEFAULT_APP_NAME:
104106
raise ValueError(
@@ -111,7 +113,7 @@ def delete_app(app):
111113
'second argument.').format(app.name))
112114

113115

114-
def get_app(name=_DEFAULT_APP_NAME):
116+
def get_app(name: str = _DEFAULT_APP_NAME) -> "App":
115117
"""Retrieves an App instance by name.
116118
117119
Args:
@@ -190,7 +192,7 @@ class App:
190192
common to all Firebase APIs.
191193
"""
192194

193-
def __init__(self, name, credential, options):
195+
def __init__(self, name: str, credential: credentials.Base, options: Optional[Dict[str, Any]]):
194196
"""Constructs a new App using the provided name and options.
195197
196198
Args:
@@ -265,7 +267,7 @@ def _lookup_project_id(self):
265267
App._validate_project_id(self._options.get('projectId'))
266268
return project_id
267269

268-
def _get_service(self, name, initializer):
270+
def _get_service(self, name: str, initializer: Callable):
269271
"""Returns the service instance identified by the given name.
270272
271273
Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each

firebase_admin/_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Internal utilities common to all modules."""
1616

1717
import json
18+
from typing import Callable, Optional
1819

1920
import google.auth
2021
import requests
@@ -76,7 +77,7 @@
7677
}
7778

7879

79-
def _get_initialized_app(app):
80+
def _get_initialized_app(app: Optional[firebase_admin.App]):
8081
"""Returns a reference to an initialized App instance."""
8182
if app is None:
8283
return firebase_admin.get_app()
@@ -92,10 +93,9 @@ def _get_initialized_app(app):
9293
' firebase_admin.App, but given "{0}".'.format(type(app)))
9394

9495

95-
96-
def get_app_service(app, name, initializer):
96+
def get_app_service(app: Optional[firebase_admin.App], name: str, initializer: Callable):
9797
app = _get_initialized_app(app)
98-
return app._get_service(name, initializer) # pylint: disable=protected-access
98+
return app._get_service(name, initializer) # pylint: disable=protected-access
9999

100100

101101
def handle_platform_error_from_requests(error, handle_func=None):

firebase_admin/credentials.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import collections
1717
import json
1818
import pathlib
19+
from typing import Any, Dict, Union
1920

2021
import google.auth
2122
from google.auth.transport import requests
2223
from google.oauth2 import credentials
2324
from google.oauth2 import service_account
25+
import google.auth.credentials
2426

2527

2628
_request = requests.Request()
@@ -44,7 +46,7 @@
4446
class Base:
4547
"""Provides OAuth2 access tokens for accessing Firebase services."""
4648

47-
def get_access_token(self):
49+
def get_access_token(self) -> AccessTokenInfo:
4850
"""Fetches a Google OAuth2 access token using this credential instance.
4951
5052
Returns:
@@ -54,7 +56,7 @@ def get_access_token(self):
5456
google_cred.refresh(_request)
5557
return AccessTokenInfo(google_cred.token, google_cred.expiry)
5658

57-
def get_credential(self):
59+
def get_credential(self) -> google.auth.credentials.Credentials:
5860
"""Returns the Google credential instance used for authentication."""
5961
raise NotImplementedError
6062

@@ -64,7 +66,7 @@ class Certificate(Base):
6466

6567
_CREDENTIAL_TYPE = 'service_account'
6668

67-
def __init__(self, cert):
69+
def __init__(self, cert: Union[str, Dict[str, Any]]):
6870
"""Initializes a credential from a Google service account certificate.
6971
7072
Service account certificates can be downloaded as JSON files from the Firebase console.
@@ -158,6 +160,7 @@ def _load_credential(self):
158160
if not self._g_credential:
159161
self._g_credential, self._project_id = google.auth.default(scopes=_scopes)
160162

163+
161164
class RefreshToken(Base):
162165
"""A credential initialized from an existing refresh token."""
163166

firebase_admin/firestore.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
try:
22-
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
22+
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
2323
existing = globals().keys()
2424
for key, value in firestore.__dict__.items():
2525
if not key.startswith('_') and key not in existing:
@@ -28,13 +28,15 @@
2828
raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure '
2929
'to install the "google-cloud-firestore" module.')
3030

31-
from firebase_admin import _utils
31+
from firebase_admin import _utils, App
32+
import google.auth.credentials
33+
from typing import Optional
3234

3335

3436
_FIRESTORE_ATTRIBUTE = '_firestore'
3537

3638

37-
def client(app=None):
39+
def client(app: Optional[App] = None) -> firestore.Client:
3840
"""Returns a client that can be used to interact with Google Cloud Firestore.
3941
4042
Args:
@@ -57,14 +59,14 @@ def client(app=None):
5759
class _FirestoreClient:
5860
"""Holds a Google Cloud Firestore client instance."""
5961

60-
def __init__(self, credentials, project):
62+
def __init__(self, credentials: google.auth.credentials.Credentials, project: str):
6163
self._client = firestore.Client(credentials=credentials, project=project)
6264

63-
def get(self):
65+
def get(self) -> firestore.Client:
6466
return self._client
6567

6668
@classmethod
67-
def from_app(cls, app):
69+
def from_app(cls, app: App):
6870
"""Creates a new _FirestoreClient for the specified app."""
6971
credentials = app.credential.get_credential()
7072
project = app.project_id

firebase_admin/messaging.py

+5
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
def _get_messaging_service(app):
9696
return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService)
9797

98+
9899
def send(message, dry_run=False, app=None):
99100
"""Sends the given message via Firebase Cloud Messaging (FCM).
100101
@@ -115,6 +116,7 @@ def send(message, dry_run=False, app=None):
115116
"""
116117
return _get_messaging_service(app).send(message, dry_run)
117118

119+
118120
def send_all(messages, dry_run=False, app=None):
119121
"""Sends the given list of messages via Firebase Cloud Messaging as a single batch.
120122
@@ -135,6 +137,7 @@ def send_all(messages, dry_run=False, app=None):
135137
"""
136138
return _get_messaging_service(app).send_all(messages, dry_run)
137139

140+
138141
def send_multicast(multicast_message, dry_run=False, app=None):
139142
"""Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM).
140143
@@ -166,6 +169,7 @@ def send_multicast(multicast_message, dry_run=False, app=None):
166169
) for token in multicast_message.tokens]
167170
return _get_messaging_service(app).send_all(messages, dry_run)
168171

172+
169173
def subscribe_to_topic(tokens, topic, app=None):
170174
"""Subscribes a list of registration tokens to an FCM topic.
171175
@@ -185,6 +189,7 @@ def subscribe_to_topic(tokens, topic, app=None):
185189
return _get_messaging_service(app).make_topic_management_request(
186190
tokens, topic, 'iid/v1:batchAdd')
187191

192+
188193
def unsubscribe_from_topic(tokens, topic, app=None):
189194
"""Unsubscribes a list of registration tokens from an FCM topic.
190195

firebase_admin/ml.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,13 @@ def from_dict(cls, data, app=None):
211211
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
212212
model = Model(model_format=tflite_format)
213213
model._data = data_copy # pylint: disable=protected-access
214-
model._app = app # pylint: disable=protected-access
214+
model._app = app # pylint: disable=protected-access
215215
return model
216216

217217
def _update_from_dict(self, data):
218218
copy = Model.from_dict(data)
219219
self.model_format = copy.model_format
220-
self._data = copy._data # pylint: disable=protected-access
220+
self._data = copy._data # pylint: disable=protected-access
221221

222222
def __eq__(self, other):
223223
if isinstance(other, self.__class__):
@@ -334,7 +334,7 @@ def model_format(self):
334334
def model_format(self, model_format):
335335
if model_format is not None:
336336
_validate_model_format(model_format)
337-
self._model_format = model_format #Can be None
337+
self._model_format = model_format # Can be None
338338
return self
339339

340340
def as_dict(self, for_upload=False):
@@ -370,7 +370,7 @@ def from_dict(cls, data):
370370
"""Create an instance of the object from a dict."""
371371
data_copy = dict(data)
372372
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
373-
tflite_format._data = data_copy # pylint: disable=protected-access
373+
tflite_format._data = data_copy # pylint: disable=protected-access
374374
return tflite_format
375375

376376
def __eq__(self, other):
@@ -405,7 +405,7 @@ def model_source(self, model_source):
405405
if model_source is not None:
406406
if not isinstance(model_source, TFLiteModelSource):
407407
raise TypeError('Model source must be a TFLiteModelSource object.')
408-
self._model_source = model_source # Can be None
408+
self._model_source = model_source # Can be None
409409

410410
@property
411411
def size_bytes(self):
@@ -485,7 +485,7 @@ def __init__(self, gcs_tflite_uri, app=None):
485485

486486
def __eq__(self, other):
487487
if isinstance(other, self.__class__):
488-
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
488+
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
489489
return False
490490

491491
def __ne__(self, other):
@@ -775,7 +775,7 @@ def _validate_display_name(display_name):
775775

776776
def _validate_tags(tags):
777777
if not isinstance(tags, list) or not \
778-
all(isinstance(tag, str) for tag in tags):
778+
all(isinstance(tag, str) for tag in tags):
779779
raise TypeError('Tags must be a list of strings.')
780780
if not all(_TAG_PATTERN.match(tag) for tag in tags):
781781
raise ValueError('Tag format is invalid.')
@@ -789,6 +789,7 @@ def _validate_gcs_tflite_uri(uri):
789789
raise ValueError('GCS TFLite URI format is invalid.')
790790
return uri
791791

792+
792793
def _validate_auto_ml_model(model):
793794
if not _AUTO_ML_MODEL_PATTERN.match(model):
794795
raise ValueError('Model resource name format is invalid.')
@@ -809,7 +810,7 @@ def _validate_list_filter(list_filter):
809810

810811
def _validate_page_size(page_size):
811812
if page_size is not None:
812-
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
813+
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
813814
# Specifically type() to disallow boolean which is a subtype of int
814815
raise TypeError('Page size must be a number or None.')
815816
if page_size < 1 or page_size > _MAX_PAGE_SIZE:
@@ -864,7 +865,7 @@ def _exponential_backoff(self, current_attempt, stop_time):
864865

865866
if stop_time is not None:
866867
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
867-
if max_seconds_left < 1: # allow a bit of time for rpc
868+
if max_seconds_left < 1: # allow a bit of time for rpc
868869
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
869870
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
870871
time.sleep(wait_time_seconds)
@@ -925,7 +926,6 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
925926
# If the operation is not complete or timed out, return a (locked) model instead
926927
return get_model(model_id).as_dict()
927928

928-
929929
def create_model(self, model):
930930
_validate_model(model)
931931
try:

firebase_admin/storage.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
2626
'to install the "google-cloud-storage" module.')
2727

28-
from firebase_admin import _utils
28+
from firebase_admin import _utils, App
29+
from typing import Optional
2930

3031

3132
_STORAGE_ATTRIBUTE = '_storage'
3233

33-
def bucket(name=None, app=None) -> storage.Bucket:
34+
35+
def bucket(name: Optional[str] = None, app: Optional[App] = None) -> storage.Bucket:
3436
"""Returns a handle to a Google Cloud Storage bucket.
3537
3638
If the name argument is not provided, uses the 'storageBucket' option specified when
@@ -67,7 +69,7 @@ def from_app(cls, app):
6769
# significantly speeds up the initialization of the storage client.
6870
return _StorageClient(credentials, app.project_id, default_bucket)
6971

70-
def bucket(self, name=None):
72+
def bucket(self, name: Optional[str] = None):
7173
"""Returns a handle to the specified Cloud Storage Bucket."""
7274
bucket_name = name if name is not None else self._default_bucket
7375
if bucket_name is None:

firebase_admin/tenant_mgt.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non
183183
FirebaseError: If an error occurs while retrieving the user accounts.
184184
"""
185185
tenant_mgt_service = _get_tenant_mgt_service(app)
186+
186187
def download(page_token, max_results):
187188
return tenant_mgt_service.list_tenants(page_token, max_results)
188189
return ListTenantsPage(download, page_token, max_results)
@@ -206,7 +207,7 @@ class Tenant:
206207
def __init__(self, data):
207208
if not isinstance(data, dict):
208209
raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data))
209-
if not 'name' in data:
210+
if 'name' not in data:
210211
raise ValueError('Tenant response missing required keys.')
211212

212213
self._data = data
@@ -256,7 +257,7 @@ def auth_for_tenant(self, tenant_id):
256257

257258
client = auth.Client(self.app, tenant_id=tenant_id)
258259
self.tenant_clients[tenant_id] = client
259-
return client
260+
return client
260261

261262
def get_tenant(self, tenant_id):
262263
"""Gets the tenant corresponding to the given ``tenant_id``."""

0 commit comments

Comments
 (0)