diff --git a/src/firebase_functions/alerts_fn.py b/src/firebase_functions/alerts_fn.py index 9fc09e0..ed68673 100644 --- a/src/firebase_functions/alerts_fn.py +++ b/src/firebase_functions/alerts_fn.py @@ -24,7 +24,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import T, CloudEvent as _CloudEvent +from firebase_functions.core import T, CloudEvent as _CloudEvent, _with_init from firebase_functions.options import FirebaseAlertOptions # Explicitly import AlertType to make it available in the public API. @@ -95,7 +95,7 @@ def on_alert_published_inner_decorator(func: OnAlertPublishedCallable): @_functools.wraps(func) def on_alert_published_wrapped(raw: _ce.CloudEvent): from firebase_functions.private._alerts_fn import alerts_event_from_ce - func(alerts_event_from_ce(raw)) + _with_init(func)(alerts_event_from_ce(raw)) _util.set_func_endpoint_attr( on_alert_published_wrapped, diff --git a/src/firebase_functions/core.py b/src/firebase_functions/core.py index 26a3582..a12e688 100644 --- a/src/firebase_functions/core.py +++ b/src/firebase_functions/core.py @@ -18,6 +18,8 @@ import datetime as _datetime import typing as _typing +from . import logger as _logger + T = _typing.TypeVar("T") @@ -80,3 +82,46 @@ class Change(_typing.Generic[T]): """ The state of data after the change. """ + + +_did_init = False +_init_callback: _typing.Callable[[], _typing.Any] | None = None + + +def init(callback: _typing.Callable[[], _typing.Any]) -> None: + """ + Registers a function that should be run when in a production environment + before executing any functions code. + Calling this decorator more than once leads to undefined behavior. + """ + + global _did_init + global _init_callback + + if _did_init: + _logger.warn( + "Setting init callback more than once. Only the most recent callback will be called" + ) + + _init_callback = callback + _did_init = False + + +def _with_init( + fn: _typing.Callable[..., + _typing.Any]) -> _typing.Callable[..., _typing.Any]: + """ + A decorator that runs the init callback before running the decorated function. + """ + + def wrapper(*args, **kwargs): + global _did_init + + if not _did_init: + if _init_callback is not None: + _init_callback() + _did_init = True + + return fn(*args, **kwargs) + + return wrapper diff --git a/src/firebase_functions/db_fn.py b/src/firebase_functions/db_fn.py index 1681862..7298e99 100644 --- a/src/firebase_functions/db_fn.py +++ b/src/firebase_functions/db_fn.py @@ -119,7 +119,7 @@ def _db_endpoint_handler( subject=event_attributes["subject"], params=params, ) - func(database_event) + _core._with_init(func)(database_event) @_util.copy_func_kwargs(DatabaseOptions) diff --git a/src/firebase_functions/eventarc_fn.py b/src/firebase_functions/eventarc_fn.py index d3b62f3..d76772c 100644 --- a/src/firebase_functions/eventarc_fn.py +++ b/src/firebase_functions/eventarc_fn.py @@ -21,7 +21,7 @@ import firebase_functions.options as _options import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init @_util.copy_func_kwargs(_options.EventarcTriggerOptions) @@ -73,7 +73,7 @@ def on_custom_event_published_wrapped(raw: _ce.CloudEvent): ), type=event_dict["type"], ) - func(event) + _with_init(func)(event) _util.set_func_endpoint_attr( on_custom_event_published_wrapped, diff --git a/src/firebase_functions/firestore_fn.py b/src/firebase_functions/firestore_fn.py index c75c288..11156b7 100644 --- a/src/firebase_functions/firestore_fn.py +++ b/src/firebase_functions/firestore_fn.py @@ -205,6 +205,8 @@ def _firestore_endpoint_handler( params=params, ) + func = _core._with_init(func) + if event_type.endswith(".withAuthContext"): database_event_with_auth_context = AuthEvent(**vars(database_event), auth_type=event_auth_type, diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index fed3973..61ff13f 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -403,7 +403,7 @@ def _on_call_handler(func: _C2, instance_id_token=request.headers.get( "Firebase-Instance-ID-Token"), ) - result = func(context) + result = _core._with_init(func)(context) return _jsonify(result=result) # Disable broad exceptions lint since we want to handle all exceptions here # and wrap as an HttpsError. @@ -447,7 +447,7 @@ def on_request_wrapped(request: Request) -> Response: methods=options.cors.cors_methods, origins=options.cors.cors_origins, )(func)(request) - return func(request) + return _core._with_init(func)(request) _util.set_func_endpoint_attr( on_request_wrapped, diff --git a/src/firebase_functions/private/_identity_fn.py b/src/firebase_functions/private/_identity_fn.py index 89706a5..2a8f516 100644 --- a/src/firebase_functions/private/_identity_fn.py +++ b/src/firebase_functions/private/_identity_fn.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Cloud functions to handle Eventarc events.""" - # pylint: disable=protected-access import typing as _typing import datetime as _dt import time as _time import json as _json + +from firebase_functions.core import _with_init from firebase_functions.https_fn import HttpsError, FunctionsErrorCode import firebase_functions.private.util as _util @@ -351,8 +352,8 @@ def before_operation_handler( jwt_token = request.json["data"]["jwt"] decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token) event = _auth_blocking_event_from_token_data(decoded_token) - auth_response: BeforeCreateResponse | BeforeSignInResponse | None = func( - event) + auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init( + func)(event) if not auth_response: return _jsonify({}) auth_response_dict = _validate_auth_response(event_type, auth_response) @@ -362,7 +363,7 @@ def before_operation_handler( # pylint: disable=broad-except except Exception as exception: if not isinstance(exception, HttpsError): - _logging.error("Unhandled error", exception) + _logging.error("Unhandled error %s", exception) exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") status = exception._http_error_code.status return _make_response(_jsonify(error=exception._as_dict()), status) diff --git a/src/firebase_functions/pubsub_fn.py b/src/firebase_functions/pubsub_fn.py index 6d833d3..4647a1b 100644 --- a/src/firebase_functions/pubsub_fn.py +++ b/src/firebase_functions/pubsub_fn.py @@ -25,7 +25,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent, T +from firebase_functions.core import CloudEvent, T, _with_init from firebase_functions.options import PubSubOptions @@ -151,7 +151,7 @@ def _message_handler( type=event_dict["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(PubSubOptions) diff --git a/src/firebase_functions/remote_config_fn.py b/src/firebase_functions/remote_config_fn.py index 402fd98..c48436d 100644 --- a/src/firebase_functions/remote_config_fn.py +++ b/src/firebase_functions/remote_config_fn.py @@ -24,7 +24,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import EventHandlerOptions @@ -189,7 +189,7 @@ def _config_handler(func: _C1, raw: _ce.CloudEvent) -> None: type=event_dict["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(EventHandlerOptions) diff --git a/src/firebase_functions/scheduler_fn.py b/src/firebase_functions/scheduler_fn.py index 4d3bcd7..c5a92c9 100644 --- a/src/firebase_functions/scheduler_fn.py +++ b/src/firebase_functions/scheduler_fn.py @@ -27,6 +27,7 @@ make_response as _make_response, ) +from firebase_functions.core import _with_init # Export for user convenience. # pylint: disable=unused-import from firebase_functions.options import Timezone @@ -108,7 +109,7 @@ def on_schedule_wrapped(request: _Request) -> _Response: schedule_time=schedule_time, ) try: - func(event) + _with_init(func)(event) return _make_response() # Disable broad exceptions lint since we want to handle all exceptions. # pylint: disable=broad-except diff --git a/src/firebase_functions/storage_fn.py b/src/firebase_functions/storage_fn.py index fec61b4..342d257 100644 --- a/src/firebase_functions/storage_fn.py +++ b/src/firebase_functions/storage_fn.py @@ -22,7 +22,7 @@ import cloudevents.http as _ce import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import StorageOptions _event_type_archived = "google.cloud.storage.object.v1.archived" @@ -255,7 +255,7 @@ def _message_handler( type=event_attributes["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(StorageOptions) diff --git a/src/firebase_functions/test_lab_fn.py b/src/firebase_functions/test_lab_fn.py index 15eb5ab..7aede95 100644 --- a/src/firebase_functions/test_lab_fn.py +++ b/src/firebase_functions/test_lab_fn.py @@ -24,7 +24,7 @@ import firebase_functions.private.util as _util -from firebase_functions.core import CloudEvent +from firebase_functions.core import CloudEvent, _with_init from firebase_functions.options import EventHandlerOptions @@ -246,7 +246,7 @@ def _event_handler(func: _C1, raw: _ce.CloudEvent) -> None: type=event_dict["type"], ) - func(event) + _with_init(func)(event) @_util.copy_func_kwargs(EventHandlerOptions) diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..4e8b487 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,43 @@ +""" +Tests for the db module. +""" + +import unittest +from unittest import mock +from cloudevents.http import CloudEvent +from firebase_functions import core, db_fn + + +class TestDb(unittest.TestCase): + """ + Tests for the db module. + """ + + def test_calls_init_function(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = mock.Mock(__name__="example_func") + decorated_func = db_fn.on_value_created(reference="path")(func) + + event = CloudEvent(attributes={ + "specversion": "1.0", + "id": "id", + "source": "source", + "subject": "subject", + "type": "type", + "time": "2024-04-10T12:00:00.000Z", + "instance": "instance", + "ref": "ref", + "firebasedatabasehost": "firebasedatabasehost", + "location": "location", + }, + data={"delta": "delta"}) + + decorated_func(event) + + self.assertEqual(hello, "world") diff --git a/tests/test_eventarc_fn.py b/tests/test_eventarc_fn.py index 882a6d1..730812a 100644 --- a/tests/test_eventarc_fn.py +++ b/tests/test_eventarc_fn.py @@ -14,7 +14,10 @@ """Eventarc trigger function tests.""" import unittest from unittest.mock import Mock + from cloudevents.http import CloudEvent as _CloudEvent + +from firebase_functions import core from firebase_functions.core import CloudEvent from firebase_functions.eventarc_fn import on_custom_event_published @@ -83,3 +86,34 @@ def test_on_custom_event_published_wrapped(self): event_arg.type, "firebase.extensions.storage-resize-images.v1.complete", ) + + def test_calls_init_function(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + raw_event = _CloudEvent( + attributes={ + "specversion": "1.0", + "type": "firebase.extensions.storage-resize-images.v1.complete", + "source": "https://example.com/testevent", + "id": "1234567890", + "subject": "test_subject", + "time": "2023-03-11T13:25:37.403Z", + }, + data={ + "some_key": "some_value", + }, + ) + + decorated_func = on_custom_event_published( + event_type="firebase.extensions.storage-resize-images.v1.complete", + )(func) + + decorated_func(raw_event) + + self.assertEqual(hello, "world") diff --git a/tests/test_firestore_fn.py b/tests/test_firestore_fn.py index d6a65a5..fa4ed15 100644 --- a/tests/test_firestore_fn.py +++ b/tests/test_firestore_fn.py @@ -70,3 +70,54 @@ def test_firestore_endpoint_handler_calls_function_with_correct_args(self): self.assertIsInstance(event, AuthEvent) self.assertEqual(event.auth_type, "unauthenticated") self.assertEqual(event.auth_id, "foo") + + def test_calls_init_function(self): + with patch.dict("sys.modules", mocked_modules): + from firebase_functions import firestore_fn, core + from cloudevents.http import CloudEvent + + func = Mock(__name__="example_func") + + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + attributes = { + "specversion": + "1.0", + # pylint: disable=protected-access + "type": + firestore_fn._event_type_created, + "source": + "https://example.com/testevent", + "time": + "2023-03-11T13:25:37.403Z", + "subject": + "test_subject", + "datacontenttype": + "application/json", + "location": + "projects/project-id/databases/(default)/documents/foo/{bar}", + "project": + "project-id", + "namespace": + "(default)", + "document": + "foo/{bar}", + "database": + "projects/project-id/databases/(default)", + "authtype": + "unauthenticated", + "authid": + "foo" + } + raw_event = CloudEvent(attributes=attributes, data=json.dumps({})) + decorated_func = firestore_fn.on_document_created( + document="/foo/{bar}")(func) + + decorated_func(raw_event) + + self.assertEqual(hello, "world") diff --git a/tests/test_https_fn.py b/tests/test_https_fn.py new file mode 100644 index 0000000..e128b39 --- /dev/null +++ b/tests/test_https_fn.py @@ -0,0 +1,72 @@ +""" +Tests for the https module. +""" + +import unittest +from unittest.mock import Mock +from flask import Flask, Request +from werkzeug.test import EnvironBuilder + +from firebase_functions import core, https_fn + + +class TestHttps(unittest.TestCase): + """ + Tests for the http module. + """ + + def test_on_request_calls_init_function(self): + app = Flask(__name__) + + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = https_fn.on_request()(func) + + decorated_func(request) + + self.assertEqual(hello, "world") + + def test_on_call_calls_init_function(self): + app = Flask(__name__) + + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = https_fn.on_call()(func) + + decorated_func(request) + + self.assertEqual("world", hello) diff --git a/tests/test_identity_fn.py b/tests/test_identity_fn.py new file mode 100644 index 0000000..b71414b --- /dev/null +++ b/tests/test_identity_fn.py @@ -0,0 +1,64 @@ +""" +Identity function tests. +""" + +import unittest +from unittest.mock import Mock, patch, MagicMock +from flask import Flask, Request +from werkzeug.test import EnvironBuilder + +from firebase_functions import core, identity_fn + +token_verifier_mock = MagicMock() +token_verifier_mock.verify_auth_blocking_token = Mock( + return_value={ + "user_record": { + "uid": "uid", + "metadata": { + "creation_time": 0 + }, + "provider_data": [] + }, + "event_id": "event_id", + "ip_address": "ip_address", + "user_agent": "user_agent", + "iat": 0 + }) +mocked_modules = { + "firebase_functions.private.token_verifier": token_verifier_mock, +} + + +class TestIdentity(unittest.TestCase): + """ + Identity function tests. + """ + + def test_calls_init_function(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + with patch.dict("sys.modules", mocked_modules): + app = Flask(__name__) + + func = Mock(__name__="example_func", + return_value=identity_fn.BeforeSignInResponse()) + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "jwt": "jwt" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = identity_fn.before_user_signed_in()(func) + decorated_func(request) + + self.assertEqual("world", hello) diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 0000000..07ee924 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,23 @@ +""" +Test the init decorator. +""" + +import unittest +from firebase_functions import core + + +class TestInit(unittest.TestCase): + """ + Test the init decorator. + """ + + def test_init_is_initialized(self): + + @core.init + def fn(): + pass + + # pylint: disable=protected-access + self.assertIsNotNone(core._init_callback) + # pylint: disable=protected-access + self.assertFalse(core._did_init) diff --git a/tests/test_pubsub_fn.py b/tests/test_pubsub_fn.py index 06777d8..c0b9665 100644 --- a/tests/test_pubsub_fn.py +++ b/tests/test_pubsub_fn.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock from cloudevents.http import CloudEvent as _CloudEvent +from firebase_functions import core from firebase_functions.pubsub_fn import ( Message, MessagePublishedData, @@ -94,3 +95,37 @@ def test_message_handler(self): "eyJ0ZXN0IjogInZhbHVlIn0=") self.assertIsNone(event_arg.data.message.ordering_key) self.assertEqual(event_arg.data.subscription, "my-subscription") + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = MagicMock() + raw_event = _CloudEvent( + attributes={ + "id": "test-message", + "source": "https://example.com/pubsub", + "specversion": "1.0", + "time": "2023-03-11T13:25:37.403Z", + "type": "com.example.pubsub.message", + }, + data={ + "message": { + "attributes": { + "key": "value" + }, + "data": "eyJ0ZXN0IjogInZhbHVlIn0=", + "message_id": "message-id-123", + "publish_time": "2023-03-11T13:25:37.403Z", + }, + "subscription": "my-subscription", + }, + ) + + _message_handler(func, raw_event) + + self.assertEqual("world", hello) diff --git a/tests/test_scheduler_fn.py b/tests/test_scheduler_fn.py index eb68219..56853c6 100644 --- a/tests/test_scheduler_fn.py +++ b/tests/test_scheduler_fn.py @@ -17,7 +17,7 @@ from datetime import datetime from flask import Request, Flask from werkzeug.test import EnvironBuilder -from firebase_functions import scheduler_fn +from firebase_functions import scheduler_fn, core class TestScheduler(unittest.TestCase): @@ -118,3 +118,21 @@ def test_on_schedule_call_with_exception(self): self.assertEqual(response.status_code, 500) self.assertEqual(response.data, b"Test exception") + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + with Flask(__name__).test_request_context("/"): + environ = EnvironBuilder().get_environ() + mock_request = Request(environ) + example_func = Mock(__name__="example_func") + decorated_func = scheduler_fn.on_schedule( + schedule="* * * * *")(example_func) + decorated_func(mock_request) + + self.assertEqual("world", hello) diff --git a/tests/test_storage_fn.py b/tests/test_storage_fn.py new file mode 100644 index 0000000..ec55ca8 --- /dev/null +++ b/tests/test_storage_fn.py @@ -0,0 +1,43 @@ +""" +Tests for the storage function. +""" + +import unittest +from unittest.mock import Mock + +from firebase_functions import core, storage_fn +from cloudevents.http import CloudEvent + + +class TestStorage(unittest.TestCase): + """ + Storage function tests. + """ + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + event = CloudEvent(attributes={ + "source": "source", + "type": "type" + }, + data={ + "bucket": "bucket", + "generation": "generation", + "id": "id", + "metageneration": "metageneration", + "name": "name", + "size": "size", + "storageClass": "storageClass", + }) + + decorated_func = storage_fn.on_object_archived(bucket="bucket")(func) + decorated_func(event) + + self.assertEqual("world", hello) diff --git a/tests/test_tasks_fn.py b/tests/test_tasks_fn.py index b52bd0f..531594c 100644 --- a/tests/test_tasks_fn.py +++ b/tests/test_tasks_fn.py @@ -14,9 +14,11 @@ """Task Queue function tests.""" import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from flask import Flask, Request from werkzeug.test import EnvironBuilder + +from firebase_functions import core from firebase_functions.tasks_fn import on_task_dispatched, CallableRequest @@ -103,3 +105,30 @@ def example(request: CallableRequest[object]) -> str: request = Request(environ) response = example(request) self.assertEqual(response.status_code, 200) + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + app = Flask(__name__) + + func = Mock(__name__="example_func") + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + decorated_func = on_task_dispatched()(func) + decorated_func(request) + + self.assertEqual("world", hello) diff --git a/tests/test_test_lab_fn.py b/tests/test_test_lab_fn.py index faa15dd..afa9836 100644 --- a/tests/test_test_lab_fn.py +++ b/tests/test_test_lab_fn.py @@ -13,9 +13,10 @@ # limitations under the License. """Test Lab function tests.""" import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from cloudevents.http import CloudEvent as _CloudEvent +from firebase_functions import core from firebase_functions.test_lab_fn import ( CloudEvent, TestMatrixCompletedData, @@ -94,3 +95,48 @@ def test_event_handler(self): self.assertEqual(event_arg.data.state, TestState.FINISHED) self.assertEqual(event_arg.data.outcome_summary, OutcomeSummary.SUCCESS) self.assertEqual(event_arg.data.test_matrix_id, "testmatrix-123") + + def test_calls_init(self): + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func") + raw_event = _CloudEvent( + attributes={ + "specversion": "1.0", + "type": "com.example.someevent", + "source": "https://example.com/someevent", + "id": "A234-1234-1234", + "time": "2023-03-11T13:25:37.403Z", + }, + data={ + "createTime": "2023-03-11T13:25:37.403Z", + "state": "FINISHED", + "invalidMatrixDetails": "Some details", + "outcomeSummary": "SUCCESS", + "resultStorage": { + "toolResultsHistory": + "projects/123/histories/456", + "resultsUri": + "https://example.com/results", + "gcsPath": + "gs://bucket/path/to/somewhere", + "toolResultsExecution": + "projects/123/histories/456/executions/789", + }, + "clientInfo": { + "client": "gcloud", + }, + "testMatrixId": "testmatrix-123", + }) + + decorated_func = on_test_matrix_completed()(func) + decorated_func(raw_event) + + func.assert_called_once() + + self.assertEqual("world", hello)