Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add init function support #188

Merged
merged 7 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/firebase_functions/alerts_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import T, CloudEvent as _CloudEvent
from firebase_functions.core import T, CloudEvent as _CloudEvent, _with_init
from firebase_functions.options import FirebaseAlertOptions

# Explicitly import AlertType to make it available in the public API.
Expand Down Expand Up @@ -95,7 +95,7 @@ def on_alert_published_inner_decorator(func: OnAlertPublishedCallable):
@_functools.wraps(func)
def on_alert_published_wrapped(raw: _ce.CloudEvent):
from firebase_functions.private._alerts_fn import alerts_event_from_ce
func(alerts_event_from_ce(raw))
_with_init(func)(alerts_event_from_ce(raw))

_util.set_func_endpoint_attr(
on_alert_published_wrapped,
Expand Down
45 changes: 45 additions & 0 deletions src/firebase_functions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import datetime as _datetime
import typing as _typing

from . import logger as _logger

T = _typing.TypeVar("T")


Expand Down Expand Up @@ -80,3 +82,46 @@ class Change(_typing.Generic[T]):
"""
The state of data after the change.
"""


_did_init = False
_init_callback: _typing.Callable[[], _typing.Any] | None = None


def init(callback: _typing.Callable[[], _typing.Any]) -> None:
"""
Registers a function that should be run when in a production environment
before executing any functions code.
Calling this decorator more than once leads to undefined behavior.
"""

global _did_init
global _init_callback

if _did_init:
_logger.warn(
"Setting init callback more than once. Only the most recent callback will be called"
)

_init_callback = callback
_did_init = False


def _with_init(
fn: _typing.Callable[...,
_typing.Any]) -> _typing.Callable[..., _typing.Any]:
"""
A decorator that runs the init callback before running the decorated function.
"""

def wrapper(*args, **kwargs):
global _did_init

if not _did_init:
if _init_callback is not None:
_init_callback()
_did_init = True

return fn(*args, **kwargs)

return wrapper
2 changes: 1 addition & 1 deletion src/firebase_functions/db_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _db_endpoint_handler(
subject=event_attributes["subject"],
params=params,
)
func(database_event)
_core._with_init(func)(database_event)


@_util.copy_func_kwargs(DatabaseOptions)
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/eventarc_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import firebase_functions.options as _options
import firebase_functions.private.util as _util
from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init


@_util.copy_func_kwargs(_options.EventarcTriggerOptions)
Expand Down Expand Up @@ -73,7 +73,7 @@ def on_custom_event_published_wrapped(raw: _ce.CloudEvent):
),
type=event_dict["type"],
)
func(event)
_with_init(func)(event)

_util.set_func_endpoint_attr(
on_custom_event_published_wrapped,
Expand Down
2 changes: 2 additions & 0 deletions src/firebase_functions/firestore_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def _firestore_endpoint_handler(
params=params,
)

func = _core._with_init(func)

if event_type.endswith(".withAuthContext"):
database_event_with_auth_context = AuthEvent(**vars(database_event),
auth_type=event_auth_type,
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/https_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _on_call_handler(func: _C2,
instance_id_token=request.headers.get(
"Firebase-Instance-ID-Token"),
)
result = func(context)
result = _core._with_init(func)(context)
return _jsonify(result=result)
# Disable broad exceptions lint since we want to handle all exceptions here
# and wrap as an HttpsError.
Expand Down Expand Up @@ -447,7 +447,7 @@ def on_request_wrapped(request: Request) -> Response:
methods=options.cors.cors_methods,
origins=options.cors.cors_origins,
)(func)(request)
return func(request)
return _core._with_init(func)(request)

_util.set_func_endpoint_attr(
on_request_wrapped,
Expand Down
9 changes: 5 additions & 4 deletions src/firebase_functions/private/_identity_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud functions to handle Eventarc events."""

# pylint: disable=protected-access
import typing as _typing
import datetime as _dt
import time as _time
import json as _json

from firebase_functions.core import _with_init
from firebase_functions.https_fn import HttpsError, FunctionsErrorCode

import firebase_functions.private.util as _util
Expand Down Expand Up @@ -351,8 +352,8 @@ def before_operation_handler(
jwt_token = request.json["data"]["jwt"]
decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token)
event = _auth_blocking_event_from_token_data(decoded_token)
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = func(
event)
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(
func)(event)
if not auth_response:
return _jsonify({})
auth_response_dict = _validate_auth_response(event_type, auth_response)
Expand All @@ -362,7 +363,7 @@ def before_operation_handler(
# pylint: disable=broad-except
except Exception as exception:
if not isinstance(exception, HttpsError):
_logging.error("Unhandled error", exception)
_logging.error("Unhandled error %s", exception)
exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
status = exception._http_error_code.status
return _make_response(_jsonify(error=exception._as_dict()), status)
4 changes: 2 additions & 2 deletions src/firebase_functions/pubsub_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import CloudEvent, T
from firebase_functions.core import CloudEvent, T, _with_init
from firebase_functions.options import PubSubOptions


Expand Down Expand Up @@ -151,7 +151,7 @@ def _message_handler(
type=event_dict["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(PubSubOptions)
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/remote_config_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init
from firebase_functions.options import EventHandlerOptions


Expand Down Expand Up @@ -189,7 +189,7 @@ def _config_handler(func: _C1, raw: _ce.CloudEvent) -> None:
type=event_dict["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(EventHandlerOptions)
Expand Down
3 changes: 2 additions & 1 deletion src/firebase_functions/scheduler_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
make_response as _make_response,
)

from firebase_functions.core import _with_init
# Export for user convenience.
# pylint: disable=unused-import
from firebase_functions.options import Timezone
Expand Down Expand Up @@ -108,7 +109,7 @@ def on_schedule_wrapped(request: _Request) -> _Response:
schedule_time=schedule_time,
)
try:
func(event)
_with_init(func)(event)
return _make_response()
# Disable broad exceptions lint since we want to handle all exceptions.
# pylint: disable=broad-except
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/storage_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import cloudevents.http as _ce

import firebase_functions.private.util as _util
from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init
from firebase_functions.options import StorageOptions

_event_type_archived = "google.cloud.storage.object.v1.archived"
Expand Down Expand Up @@ -255,7 +255,7 @@ def _message_handler(
type=event_attributes["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(StorageOptions)
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/test_lab_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init
from firebase_functions.options import EventHandlerOptions


Expand Down Expand Up @@ -246,7 +246,7 @@ def _event_handler(func: _C1, raw: _ce.CloudEvent) -> None:
type=event_dict["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(EventHandlerOptions)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Tests for the db module.
"""

import unittest
from unittest import mock
from cloudevents.http import CloudEvent
from firebase_functions import core, db_fn


class TestDb(unittest.TestCase):
"""
Tests for the db module.
"""

def test_calls_init_function(self):
hello = None

@core.init
def init():
nonlocal hello
hello = "world"

func = mock.Mock(__name__="example_func")
decorated_func = db_fn.on_value_created(reference="path")(func)

event = CloudEvent(attributes={
"specversion": "1.0",
"id": "id",
"source": "source",
"subject": "subject",
"type": "type",
"time": "2024-04-10T12:00:00.000Z",
"instance": "instance",
"ref": "ref",
"firebasedatabasehost": "firebasedatabasehost",
"location": "location",
},
data={"delta": "delta"})

decorated_func(event)

self.assertEqual(hello, "world")
34 changes: 34 additions & 0 deletions tests/test_eventarc_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
"""Eventarc trigger function tests."""
import unittest
from unittest.mock import Mock

from cloudevents.http import CloudEvent as _CloudEvent

from firebase_functions import core
from firebase_functions.core import CloudEvent
from firebase_functions.eventarc_fn import on_custom_event_published

Expand Down Expand Up @@ -83,3 +86,34 @@ def test_on_custom_event_published_wrapped(self):
event_arg.type,
"firebase.extensions.storage-resize-images.v1.complete",
)

def test_calls_init_function(self):
hello = None

@core.init
def init():
nonlocal hello
hello = "world"

func = Mock(__name__="example_func")
raw_event = _CloudEvent(
attributes={
"specversion": "1.0",
"type": "firebase.extensions.storage-resize-images.v1.complete",
"source": "https://example.com/testevent",
"id": "1234567890",
"subject": "test_subject",
"time": "2023-03-11T13:25:37.403Z",
},
data={
"some_key": "some_value",
},
)

decorated_func = on_custom_event_published(
event_type="firebase.extensions.storage-resize-images.v1.complete",
)(func)

decorated_func(raw_event)

self.assertEqual(hello, "world")
51 changes: 51 additions & 0 deletions tests/test_firestore_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,54 @@ def test_firestore_endpoint_handler_calls_function_with_correct_args(self):
self.assertIsInstance(event, AuthEvent)
self.assertEqual(event.auth_type, "unauthenticated")
self.assertEqual(event.auth_id, "foo")

def test_calls_init_function(self):
with patch.dict("sys.modules", mocked_modules):
from firebase_functions import firestore_fn, core
from cloudevents.http import CloudEvent

func = Mock(__name__="example_func")

hello = None

@core.init
def init():
nonlocal hello
hello = "world"

attributes = {
"specversion":
"1.0",
# pylint: disable=protected-access
"type":
firestore_fn._event_type_created,
"source":
"https://example.com/testevent",
"time":
"2023-03-11T13:25:37.403Z",
"subject":
"test_subject",
"datacontenttype":
"application/json",
"location":
"projects/project-id/databases/(default)/documents/foo/{bar}",
"project":
"project-id",
"namespace":
"(default)",
"document":
"foo/{bar}",
"database":
"projects/project-id/databases/(default)",
"authtype":
"unauthenticated",
"authid":
"foo"
}
raw_event = CloudEvent(attributes=attributes, data=json.dumps({}))
decorated_func = firestore_fn.on_document_created(
document="/foo/{bar}")(func)

decorated_func(raw_event)

self.assertEqual(hello, "world")
Loading
Loading