From 777c5d3e64b4a9032255f5579f60feb94b090e82 Mon Sep 17 00:00:00 2001 From: exaby73 Date: Mon, 8 Apr 2024 05:09:57 +0530 Subject: [PATCH 1/6] feat: add core.init function, and add test --- src/firebase_functions/core.py | 47 ++++++++++++++++++++++++++++++++++ tests/test_init.py | 12 +++++++++ 2 files changed, 59 insertions(+) create mode 100644 tests/test_init.py diff --git a/src/firebase_functions/core.py b/src/firebase_functions/core.py index 26a3582..a71da0b 100644 --- a/src/firebase_functions/core.py +++ b/src/firebase_functions/core.py @@ -18,6 +18,8 @@ import datetime as _datetime import typing as _typing +from . import logger as _logger + T = _typing.TypeVar("T") @@ -80,3 +82,48 @@ class Change(_typing.Generic[T]): """ The state of data after the change. """ + + +_didInit = False +_initCallback: _typing.Callable[[], _typing.Any] | None = None + + +def init(callback: _typing.Callable[[], _typing.Any]) -> None: + """ + Registers a function that should be run when in a production environment + before executing any functions code. + Calling this decorator more than once leads to undefined behavior. + """ + + global _didInit + global _initCallback + + if _didInit: + raise ValueError("Firebase Functions SDK already initialized") + + _initCallback = callback + + if _didInit: + _logger.warn("Setting init callback more than once. Only the most recent callback will be called") + + _initCallback = callback + _didInit = False + + +def _with_init(fn: _typing.Callable[..., _typing.Any]) -> _typing.Callable[..., _typing.Any]: + """ + A decorator that runs the init callback before running the decorated function. + """ + + def wrapper(*args, **kwargs): + global _didInit + global _initCallback + + if not _didInit: + if _initCallback is not None: + _initCallback() + _didInit = True + + return fn(*args, **kwargs) + + return wrapper diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 0000000..3508327 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,12 @@ +import unittest +from firebase_functions import core + + +class TestInit(unittest.TestCase): + def test_init_is_initialized(self): + @core.init + def fn(): + pass + + self.assertIsNotNone(core._initCallback) + self.assertFalse(core._didInit) From 3558e0098b96c26eabbe4f5c9c0c4f806fa77314 Mon Sep 17 00:00:00 2001 From: exaby73 Date: Mon, 8 Apr 2024 05:22:10 +0530 Subject: [PATCH 2/6] feat: add _with_init to all triggers --- src/firebase_functions/alerts_fn.py | 4 ++-- src/firebase_functions/db_fn.py | 2 +- src/firebase_functions/eventarc_fn.py | 4 ++-- src/firebase_functions/firestore_fn.py | 2 +- src/firebase_functions/https_fn.py | 4 ++-- src/firebase_functions/private/_identity_fn.py | 4 +++- src/firebase_functions/pubsub_fn.py | 4 ++-- src/firebase_functions/remote_config_fn.py | 4 ++-- src/firebase_functions/scheduler_fn.py | 3 ++- src/firebase_functions/storage_fn.py | 4 ++-- src/firebase_functions/test_lab_fn.py | 4 ++-- 11 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/firebase_functions/alerts_fn.py b/src/firebase_functions/alerts_fn.py index 9fc09e0..ed68673 100644 --- a/src/firebase_functions/alerts_fn.py +++ b/src/firebase_functions/alerts_fn.py @@ -24,7 +24,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import T, CloudEvent as _CloudEvent +from firebase_functions.core import T, CloudEvent as _CloudEvent, _with_init from firebase_functions.options import FirebaseAlertOptions # Explicitly import AlertType to make it available in the public API. @@ -95,7 +95,7 @@ def on_alert_published_inner_decorator(func: OnAlertPublishedCallable): @_functools.wraps(func) def on_alert_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import alerts_event_from_ce - func(alerts_event_from_ce(raw)) + _with_init(func)(alerts_event_from_ce(raw)) _util.set_func_endpoint_attr( on_alert_published_wrapped, diff --git a/src/firebase_functions/db_fn.py b/src/firebase_functions/db_fn.py index 1681862..7298e99 100644 --- a/src/firebase_functions/db_fn.py +++ b/src/firebase_functions/db_fn.py @@ -119,7 +119,7 @@ def _db_endpoint_handler( subject=event_attributes["subject"], params=params, ) - func(database_event) + _core._with_init(func)(database_event) @_util.copy_func_kwargs(DatabaseOptions) diff --git a/src/firebase_functions/eventarc_fn.py b/src/firebase_functions/eventarc_fn.py index d3b62f3..d76772c 100644 --- a/src/firebase_functions/eventarc_fn.py +++ b/src/firebase_functions/eventarc_fn.py @@ -21,7 +21,7 @@ import firebase_functions.options as _options import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init @_util.copy_func_kwargs(_options.EventarcTriggerOptions) @@ -73,7 +73,7 @@ def on_custom_event_published_wrapped(raw: _ce.CloudEvent): ), type=event_dict["type"], ) - func(event) + _with_init(func)(event) _util.set_func_endpoint_attr( on_custom_event_published_wrapped, diff --git a/src/firebase_functions/firestore_fn.py b/src/firebase_functions/firestore_fn.py index 4b66f15..2fe1ed6 100644 --- a/src/firebase_functions/firestore_fn.py +++ b/src/firebase_functions/firestore_fn.py @@ -173,7 +173,7 @@ def _firestore_endpoint_handler( subject=event_attributes["subject"], params=params, ) - func(database_event) + _core._with_init(func)(database_event) @_util.copy_func_kwargs(FirestoreOptions) diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index fed3973..61ff13f 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -403,7 +403,7 @@ def _on_call_handler(func: _C2, instance_id_token=request.headers.get( "Firebase-Instance-ID-Token"), ) - result = func(context) + result = _core._with_init(func)(context) return _jsonify(result=result) # Disable broad exceptions lint since we want to handle all exceptions here # and wrap as an HttpsError. @@ -447,7 +447,7 @@ def on_request_wrapped(request: Request) -> Response: methods=options.cors.cors_methods, origins=options.cors.cors_origins, )(func)(request) - return func(request) + return _core._with_init(func)(request) _util.set_func_endpoint_attr( on_request_wrapped, diff --git a/src/firebase_functions/private/_identity_fn.py b/src/firebase_functions/private/_identity_fn.py index 89706a5..b066192 100644 --- a/src/firebase_functions/private/_identity_fn.py +++ b/src/firebase_functions/private/_identity_fn.py @@ -18,6 +18,8 @@ import datetime as _dt import time as _time import json as _json + +from firebase_functions.core import _with_init from firebase_functions.https_fn import HttpsError, FunctionsErrorCode import firebase_functions.private.util as _util @@ -351,7 +353,7 @@ def before_operation_handler( jwt_token = request.json["data"]["jwt"] decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token) event = _auth_blocking_event_from_token_data(decoded_token) - auth_response: BeforeCreateResponse | BeforeSignInResponse | None = func( + auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(func)( event) if not auth_response: return _jsonify({}) diff --git a/src/firebase_functions/pubsub_fn.py b/src/firebase_functions/pubsub_fn.py index 6d833d3..4647a1b 100644 --- a/src/firebase_functions/pubsub_fn.py +++ b/src/firebase_functions/pubsub_fn.py @@ -25,7 +25,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent, T +from firebase_functions.core import CloudEvent, T, _with_init from firebase_functions.options import PubSubOptions @@ -151,7 +151,7 @@ def _message_handler( type=event_dict["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(PubSubOptions) diff --git a/src/firebase_functions/remote_config_fn.py b/src/firebase_functions/remote_config_fn.py index 402fd98..c48436d 100644 --- a/src/firebase_functions/remote_config_fn.py +++ b/src/firebase_functions/remote_config_fn.py @@ -24,7 +24,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import EventHandlerOptions @@ -189,7 +189,7 @@ def _config_handler(func: _C1, raw: _ce.CloudEvent) -> None: type=event_dict["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(EventHandlerOptions) diff --git a/src/firebase_functions/scheduler_fn.py b/src/firebase_functions/scheduler_fn.py index 4d3bcd7..c5a92c9 100644 --- a/src/firebase_functions/scheduler_fn.py +++ b/src/firebase_functions/scheduler_fn.py @@ -27,6 +27,7 @@ make_response as _make_response, ) +from firebase_functions.core import _with_init # Export for user convenience. # pylint: disable=unused-import from firebase_functions.options import Timezone @@ -108,7 +109,7 @@ def on_schedule_wrapped(request: _Request) -> _Response: schedule_time=schedule_time, ) try: - func(event) + _with_init(func)(event) return _make_response() # Disable broad exceptions lint since we want to handle all exceptions. # pylint: disable=broad-except diff --git a/src/firebase_functions/storage_fn.py b/src/firebase_functions/storage_fn.py index fec61b4..342d257 100644 --- a/src/firebase_functions/storage_fn.py +++ b/src/firebase_functions/storage_fn.py @@ -22,7 +22,7 @@ import cloudevents.http as _ce import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import StorageOptions _event_type_archived = "google.cloud.storage.object.v1.archived" @@ -255,7 +255,7 @@ def _message_handler( type=event_attributes["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(StorageOptions) diff --git a/src/firebase_functions/test_lab_fn.py b/src/firebase_functions/test_lab_fn.py index 15eb5ab..7aede95 100644 --- a/src/firebase_functions/test_lab_fn.py +++ b/src/firebase_functions/test_lab_fn.py @@ -24,7 +24,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import EventHandlerOptions @@ -246,7 +246,7 @@ def _event_handler(func: _C1, raw: _ce.CloudEvent) -> None: type=event_dict["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(EventHandlerOptions) From 7ba735a799f91d75e42846042684b18cc959f6c4 Mon Sep 17 00:00:00 2001 From: exaby73 Date: Wed, 10 Apr 2024 07:11:20 +0530 Subject: [PATCH 3/6] fix: remove error raised in init --- src/firebase_functions/core.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/firebase_functions/core.py b/src/firebase_functions/core.py index a71da0b..425550f 100644 --- a/src/firebase_functions/core.py +++ b/src/firebase_functions/core.py @@ -98,9 +98,6 @@ def init(callback: _typing.Callable[[], _typing.Any]) -> None: global _didInit global _initCallback - if _didInit: - raise ValueError("Firebase Functions SDK already initialized") - _initCallback = callback if _didInit: From 7ab6f1a607f3d8d9e2584248eb37b237d6c0b126 Mon Sep 17 00:00:00 2001 From: exaby73 Date: Wed, 10 Apr 2024 09:33:21 +0530 Subject: [PATCH 4/6] feat: add tests --- src/firebase_functions/core.py | 35 ++++----- src/firebase_functions/firestore_fn.py | 11 ++- .../private/_identity_fn.py | 7 +- tests/test_db.py | 43 +++++++++++ tests/test_eventarc_fn.py | 34 +++++++++ tests/test_firestore_fn.py | 51 +++++++++++++ tests/test_https_fn.py | 72 +++++++++++++++++++ tests/test_identity_fn.py | 64 +++++++++++++++++ tests/test_init.py | 15 +++- tests/test_pubsub_fn.py | 36 ++++++++++ tests/test_scheduler_fn.py | 20 +++++- tests/test_storage_fn.py | 43 +++++++++++ tests/test_tasks_fn.py | 31 +++++++- tests/test_test_lab_fn.py | 48 ++++++++++++- 14 files changed, 479 insertions(+), 31 deletions(-) create mode 100644 tests/test_db.py create mode 100644 tests/test_https_fn.py create mode 100644 tests/test_identity_fn.py create mode 100644 tests/test_storage_fn.py diff --git a/src/firebase_functions/core.py b/src/firebase_functions/core.py index 425550f..491e467 100644 --- a/src/firebase_functions/core.py +++ b/src/firebase_functions/core.py @@ -84,8 +84,8 @@ class Change(_typing.Generic[T]): """ -_didInit = False -_initCallback: _typing.Callable[[], _typing.Any] | None = None +_did_init = False +_init_callback: _typing.Callable[[], _typing.Any] | None = None def init(callback: _typing.Callable[[], _typing.Any]) -> None: @@ -95,31 +95,34 @@ def init(callback: _typing.Callable[[], _typing.Any]) -> None: Calling this decorator more than once leads to undefined behavior. """ - global _didInit - global _initCallback + global _did_init + global _init_callback - _initCallback = callback + _init_callback = callback - if _didInit: - _logger.warn("Setting init callback more than once. Only the most recent callback will be called") + if _did_init: + _logger.warn( + "Setting init callback more than once. Only the most recent callback will be called" + ) - _initCallback = callback - _didInit = False + _init_callback = callback + _did_init = False -def _with_init(fn: _typing.Callable[..., _typing.Any]) -> _typing.Callable[..., _typing.Any]: +def _with_init( + fn: _typing.Callable[..., + _typing.Any]) -> _typing.Callable[..., _typing.Any]: """ A decorator that runs the init callback before running the decorated function. """ def wrapper(*args, **kwargs): - global _didInit - global _initCallback + global _did_init - if not _didInit: - if _initCallback is not None: - _initCallback() - _didInit = True + if not _did_init: + if _init_callback is not None: + _init_callback() + _did_init = True return fn(*args, **kwargs) diff --git a/src/firebase_functions/firestore_fn.py b/src/firebase_functions/firestore_fn.py index e517423..11156b7 100644 --- a/src/firebase_functions/firestore_fn.py +++ b/src/firebase_functions/firestore_fn.py @@ -145,8 +145,7 @@ def _firestore_endpoint_handler( app = get_app() firestore_client = _firestore_v1.Client(project=app.project_id, database=event_database) - firestore_ref: DocumentReference = firestore_client.document( - event_document) + firestore_ref: DocumentReference = firestore_client.document(event_document) value_snapshot: DocumentSnapshot | None = None old_value_snapshot: DocumentSnapshot | None = None @@ -268,7 +267,7 @@ def on_document_written_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) def on_document_written_with_auth_context(**kwargs - ) -> _typing.Callable[[_C1], _C1]: + ) -> _typing.Callable[[_C1], _C1]: """ Event handler that triggers when a document is created, updated, or deleted in Firestore. This trigger will also provide the authentication context of the principal who triggered @@ -367,7 +366,7 @@ def on_document_updated_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) def on_document_updated_with_auth_context(**kwargs - ) -> _typing.Callable[[_C1], _C1]: + ) -> _typing.Callable[[_C1], _C1]: """ Event handler that triggers when a document is updated in Firestore. This trigger will also provide the authentication context of the principal who triggered @@ -466,7 +465,7 @@ def on_document_created_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) def on_document_created_with_auth_context(**kwargs - ) -> _typing.Callable[[_C2], _C2]: + ) -> _typing.Callable[[_C2], _C2]: """ Event handler that triggers when a document is created in Firestore. This trigger will also provide the authentication context of the principal who triggered @@ -565,7 +564,7 @@ def on_document_deleted_wrapped(raw: _ce.CloudEvent): @_util.copy_func_kwargs(FirestoreOptions) def on_document_deleted_with_auth_context(**kwargs - ) -> _typing.Callable[[_C2], _C2]: + ) -> _typing.Callable[[_C2], _C2]: """ Event handler that triggers when a document is deleted in Firestore. This trigger will also provide the authentication context of the principal who triggered diff --git a/src/firebase_functions/private/_identity_fn.py b/src/firebase_functions/private/_identity_fn.py index b066192..2a8f516 100644 --- a/src/firebase_functions/private/_identity_fn.py +++ b/src/firebase_functions/private/_identity_fn.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Cloud functions to handle Eventarc events.""" - # pylint: disable=protected-access import typing as _typing import datetime as _dt @@ -353,8 +352,8 @@ def before_operation_handler( jwt_token = request.json["data"]["jwt"] decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token) event = _auth_blocking_event_from_token_data(decoded_token) - auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(func)( - event) + auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init( + func)(event) if not auth_response: return _jsonify({}) auth_response_dict = _validate_auth_response(event_type, auth_response) @@ -364,7 +363,7 @@ def before_operation_handler( # pylint: disable=broad-except except Exception as exception: if not isinstance(exception, HttpsError): - _logging.error("Unhandled error", exception) + _logging.error("Unhandled error %s", exception) exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") status = exception._http_error_code.status return _make_response(_jsonify(error=exception._as_dict()), status) diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..4e8b487 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,43 @@ +""" +Tests for the db module. +""" + +import unittest +from unittest import mock +from cloudevents.http import CloudEvent +from firebase_functions import core, db_fn + + +class TestDb(unittest.TestCase): + """ + Tests for the db module. + """ + + def test_calls_init_function(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = mock.Mock(__name__="example_func") + decorated_func = db_fn.on_value_created(reference="path")(func) + + event = CloudEvent(attributes={ + "specversion": "1.0", + "id": "id", + "source": "source", + "subject": "subject", + "type": "type", + "time": "2024-04-10T12:00:00.000Z", + "instance": "instance", + "ref": "ref", + "firebasedatabasehost": "firebasedatabasehost", + "location": "location", + }, + data={"delta": "delta"}) + + decorated_func(event) + + self.assertEqual(hello, "world") diff --git a/tests/test_eventarc_fn.py b/tests/test_eventarc_fn.py index 882a6d1..730812a 100644 --- a/tests/test_eventarc_fn.py +++ b/tests/test_eventarc_fn.py @@ -14,7 +14,10 @@ """Eventarc trigger function tests.""" import unittest from unittest.mock import Mock + from cloudevents.http import CloudEvent as _CloudEvent + +from firebase_functions import core from firebase_functions.core import CloudEvent from firebase_functions.eventarc_fn import on_custom_event_published @@ -83,3 +86,34 @@ def test_on_custom_event_published_wrapped(self): event_arg.type, "firebase.extensions.storage-resize-images.v1.complete", ) + + def test_calls_init_function(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + raw_event = _CloudEvent( + attributes={ + "specversion": "1.0", + "type": "firebase.extensions.storage-resize-images.v1.complete", + "source": "https://example.com/testevent", + "id": "1234567890", + "subject": "test_subject", + "time": "2023-03-11T13:25:37.403Z", + }, + data={ + "some_key": "some_value", + }, + ) + + decorated_func = on_custom_event_published( + event_type="firebase.extensions.storage-resize-images.v1.complete", + )(func) + + decorated_func(raw_event) + + self.assertEqual(hello, "world") diff --git a/tests/test_firestore_fn.py b/tests/test_firestore_fn.py index d6a65a5..fa4ed15 100644 --- a/tests/test_firestore_fn.py +++ b/tests/test_firestore_fn.py @@ -70,3 +70,54 @@ def test_firestore_endpoint_handler_calls_function_with_correct_args(self): self.assertIsInstance(event, AuthEvent) self.assertEqual(event.auth_type, "unauthenticated") self.assertEqual(event.auth_id, "foo") + + def test_calls_init_function(self): + with patch.dict("sys.modules", mocked_modules): + from firebase_functions import firestore_fn, core + from cloudevents.http import CloudEvent + + func = Mock(__name__="example_func") + + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + attributes = { + "specversion": + "1.0", + # pylint: disable=protected-access + "type": + firestore_fn._event_type_created, + "source": + "https://example.com/testevent", + "time": + "2023-03-11T13:25:37.403Z", + "subject": + "test_subject", + "datacontenttype": + "application/json", + "location": + "projects/project-id/databases/(default)/documents/foo/{bar}", + "project": + "project-id", + "namespace": + "(default)", + "document": + "foo/{bar}", + "database": + "projects/project-id/databases/(default)", + "authtype": + "unauthenticated", + "authid": + "foo" + } + raw_event = CloudEvent(attributes=attributes, data=json.dumps({})) + decorated_func = firestore_fn.on_document_created( + document="/foo/{bar}")(func) + + decorated_func(raw_event) + + self.assertEqual(hello, "world") diff --git a/tests/test_https_fn.py b/tests/test_https_fn.py new file mode 100644 index 0000000..e128b39 --- /dev/null +++ b/tests/test_https_fn.py @@ -0,0 +1,72 @@ +""" +Tests for the https module. +""" + +import unittest +from unittest.mock import Mock +from flask import Flask, Request +from werkzeug.test import EnvironBuilder + +from firebase_functions import core, https_fn + + +class TestHttps(unittest.TestCase): + """ + Tests for the http module. + """ + + def test_on_request_calls_init_function(self): + app = Flask(__name__) + + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = https_fn.on_request()(func) + + decorated_func(request) + + self.assertEqual(hello, "world") + + def test_on_call_calls_init_function(self): + app = Flask(__name__) + + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = https_fn.on_call()(func) + + decorated_func(request) + + self.assertEqual("world", hello) diff --git a/tests/test_identity_fn.py b/tests/test_identity_fn.py new file mode 100644 index 0000000..c2f1b92 --- /dev/null +++ b/tests/test_identity_fn.py @@ -0,0 +1,64 @@ +""" +Identity function tests. +""" + +import unittest +from unittest.mock import Mock, patch, MagicMock +from flask import Flask, Request +from werkzeug.test import EnvironBuilder + +from firebase_functions import core, identity_fn + +token_verifier_mock = MagicMock() +token_verifier_mock.verify_auth_blocking_token = Mock( + return_value={ + "user_record": { + "uid": "uid", + "metadata": { + "creation_time": 0 + }, + "provider_data": [] + }, + "event_id": "event_id", + "ip_address": "ip_address", + "user_agent": "user_agent", + "iat": 0 + }) +mocked_modules = { + "firebase_functions.private.token_verifier": token_verifier_mock, +} + + +class TestIdentity(unittest.TestCase): + """ + Identity function tests. + """ + + def test_calls_init_function(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + with patch.dict("sys.modules", mocked_modules): + app = Flask(__name__) + + func = Mock(__name__="example_func", + return_value=identity_fn.BeforeSignInResponse()) + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "jwt": "jwt" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = identity_fn.before_user_signed_in()(func) + print(decorated_func(request)) + + self.assertEqual("world", hello) diff --git a/tests/test_init.py b/tests/test_init.py index 3508327..07ee924 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,12 +1,23 @@ +""" +Test the init decorator. +""" + import unittest from firebase_functions import core class TestInit(unittest.TestCase): + """ + Test the init decorator. + """ + def test_init_is_initialized(self): + @core.init def fn(): pass - self.assertIsNotNone(core._initCallback) - self.assertFalse(core._didInit) + # pylint: disable=protected-access + self.assertIsNotNone(core._init_callback) + # pylint: disable=protected-access + self.assertFalse(core._did_init) diff --git a/tests/test_pubsub_fn.py b/tests/test_pubsub_fn.py index 06777d8..9abe007 100644 --- a/tests/test_pubsub_fn.py +++ b/tests/test_pubsub_fn.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock from cloudevents.http import CloudEvent as _CloudEvent +from firebase_functions import core from firebase_functions.pubsub_fn import ( Message, MessagePublishedData, @@ -94,3 +95,38 @@ def test_message_handler(self): "eyJ0ZXN0IjogInZhbHVlIn0=") self.assertIsNone(event_arg.data.message.ordering_key) self.assertEqual(event_arg.data.subscription, "my-subscription") + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = MagicMock() + raw_event = _CloudEvent( + attributes={ + "id": "test-message", + "source": "https://example.com/pubsub", + "specversion": "1.0", + "time": "2023-03-11T13:25:37.403Z", + "type": "com.example.pubsub.message", + }, + data={ + "message": { + "attributes": { + "key": "value" + }, + # {"test": "value"} + "data": "eyJ0ZXN0IjogInZhbHVlIn0=", + "message_id": "message-id-123", + "publish_time": "2023-03-11T13:25:37.403Z", + }, + "subscription": "my-subscription", + }, + ) + + _message_handler(func, raw_event) + + self.assertEqual("world", hello) diff --git a/tests/test_scheduler_fn.py b/tests/test_scheduler_fn.py index eb68219..56853c6 100644 --- a/tests/test_scheduler_fn.py +++ b/tests/test_scheduler_fn.py @@ -17,7 +17,7 @@ from datetime import datetime from flask import Request, Flask from werkzeug.test import EnvironBuilder -from firebase_functions import scheduler_fn +from firebase_functions import scheduler_fn, core class TestScheduler(unittest.TestCase): @@ -118,3 +118,21 @@ def test_on_schedule_call_with_exception(self): self.assertEqual(response.status_code, 500) self.assertEqual(response.data, b"Test exception") + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + with Flask(__name__).test_request_context("/"): + environ = EnvironBuilder().get_environ() + mock_request = Request(environ) + example_func = Mock(__name__="example_func") + decorated_func = scheduler_fn.on_schedule( + schedule="* * * * *")(example_func) + decorated_func(mock_request) + + self.assertEqual("world", hello) diff --git a/tests/test_storage_fn.py b/tests/test_storage_fn.py new file mode 100644 index 0000000..ec55ca8 --- /dev/null +++ b/tests/test_storage_fn.py @@ -0,0 +1,43 @@ +""" +Tests for the storage function. +""" + +import unittest +from unittest.mock import Mock + +from firebase_functions import core, storage_fn +from cloudevents.http import CloudEvent + + +class TestStorage(unittest.TestCase): + """ + Storage function tests. + """ + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + event = CloudEvent(attributes={ + "source": "source", + "type": "type" + }, + data={ + "bucket": "bucket", + "generation": "generation", + "id": "id", + "metageneration": "metageneration", + "name": "name", + "size": "size", + "storageClass": "storageClass", + }) + + decorated_func = storage_fn.on_object_archived(bucket="bucket")(func) + decorated_func(event) + + self.assertEqual("world", hello) diff --git a/tests/test_tasks_fn.py b/tests/test_tasks_fn.py index b52bd0f..531594c 100644 --- a/tests/test_tasks_fn.py +++ b/tests/test_tasks_fn.py @@ -14,9 +14,11 @@ """Task Queue function tests.""" import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from flask import Flask, Request from werkzeug.test import EnvironBuilder + +from firebase_functions import core from firebase_functions.tasks_fn import on_task_dispatched, CallableRequest @@ -103,3 +105,30 @@ def example(request: CallableRequest[object]) -> str: request = Request(environ) response = example(request) self.assertEqual(response.status_code, 200) + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + app = Flask(__name__) + + func = Mock(__name__="example_func") + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = on_task_dispatched()(func) + decorated_func(request) + + self.assertEqual("world", hello) diff --git a/tests/test_test_lab_fn.py b/tests/test_test_lab_fn.py index faa15dd..afa9836 100644 --- a/tests/test_test_lab_fn.py +++ b/tests/test_test_lab_fn.py @@ -13,9 +13,10 @@ # limitations under the License. """Test Lab function tests.""" import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from cloudevents.http import CloudEvent as _CloudEvent +from firebase_functions import core from firebase_functions.test_lab_fn import ( CloudEvent, TestMatrixCompletedData, @@ -94,3 +95,48 @@ def test_event_handler(self): self.assertEqual(event_arg.data.state, TestState.FINISHED) self.assertEqual(event_arg.data.outcome_summary, OutcomeSummary.SUCCESS) self.assertEqual(event_arg.data.test_matrix_id, "testmatrix-123") + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + raw_event = _CloudEvent( + attributes={ + "specversion": "1.0", + "type": "com.example.someevent", + "source": "https://example.com/someevent", + "id": "A234-1234-1234", + "time": "2023-03-11T13:25:37.403Z", + }, + data={ + "createTime": "2023-03-11T13:25:37.403Z", + "state": "FINISHED", + "invalidMatrixDetails": "Some details", + "outcomeSummary": "SUCCESS", + "resultStorage": { + "toolResultsHistory": + "projects/123/histories/456", + "resultsUri": + "https://example.com/results", + "gcsPath": + "gs://bucket/path/to/somewhere", + "toolResultsExecution": + "projects/123/histories/456/executions/789", + }, + "clientInfo": { + "client": "gcloud", + }, + "testMatrixId": "testmatrix-123", + }) + + decorated_func = on_test_matrix_completed()(func) + decorated_func(raw_event) + + func.assert_called_once() + + self.assertEqual("world", hello) From 7f8f41f802229979569cf91bc06abca03ddfdfd5 Mon Sep 17 00:00:00 2001 From: exaby73 Date: Wed, 10 Apr 2024 09:43:06 +0530 Subject: [PATCH 5/6] fix: remove commented code --- tests/test_pubsub_fn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_pubsub_fn.py b/tests/test_pubsub_fn.py index 9abe007..c0b9665 100644 --- a/tests/test_pubsub_fn.py +++ b/tests/test_pubsub_fn.py @@ -118,7 +118,6 @@ def init(): "attributes": { "key": "value" }, - # {"test": "value"} "data": "eyJ0ZXN0IjogInZhbHVlIn0=", "message_id": "message-id-123", "publish_time": "2023-03-11T13:25:37.403Z", From 731d7ebafd46b75b13a36eebd43ca08735f8c68c Mon Sep 17 00:00:00 2001 From: exaby73 Date: Fri, 12 Apr 2024 07:52:00 +0530 Subject: [PATCH 6/6] fix: nits --- src/firebase_functions/core.py | 2 -- tests/test_identity_fn.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/firebase_functions/core.py b/src/firebase_functions/core.py index 491e467..a12e688 100644 --- a/src/firebase_functions/core.py +++ b/src/firebase_functions/core.py @@ -98,8 +98,6 @@ def init(callback: _typing.Callable[[], _typing.Any]) -> None: global _did_init global _init_callback - _init_callback = callback - if _did_init: _logger.warn( "Setting init callback more than once. Only the most recent callback will be called" diff --git a/tests/test_identity_fn.py b/tests/test_identity_fn.py index c2f1b92..b71414b 100644 --- a/tests/test_identity_fn.py +++ b/tests/test_identity_fn.py @@ -59,6 +59,6 @@ def init(): ).get_environ() request = Request(environ) decorated_func = identity_fn.before_user_signed_in()(func) - print(decorated_func(request)) + decorated_func(request) self.assertEqual("world", hello)