Skip to content

Commit e9781a9

Browse files
authored
feat: add init function support (#188)
1 parent b7c0b43 commit e9781a9

23 files changed

+530
-23
lines changed

src/firebase_functions/alerts_fn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import firebase_functions.private.util as _util
2626

27-
from firebase_functions.core import T, CloudEvent as _CloudEvent
27+
from firebase_functions.core import T, CloudEvent as _CloudEvent, _with_init
2828
from firebase_functions.options import FirebaseAlertOptions
2929

3030
# Explicitly import AlertType to make it available in the public API.
@@ -95,7 +95,7 @@ def on_alert_published_inner_decorator(func: OnAlertPublishedCallable):
9595
@_functools.wraps(func)
9696
def on_alert_published_wrapped(raw: _ce.CloudEvent):
9797
from firebase_functions.private._alerts_fn import alerts_event_from_ce
98-
func(alerts_event_from_ce(raw))
98+
_with_init(func)(alerts_event_from_ce(raw))
9999

100100
_util.set_func_endpoint_attr(
101101
on_alert_published_wrapped,

src/firebase_functions/core.py

+45
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import datetime as _datetime
1919
import typing as _typing
2020

21+
from . import logger as _logger
22+
2123
T = _typing.TypeVar("T")
2224

2325

@@ -80,3 +82,46 @@ class Change(_typing.Generic[T]):
8082
"""
8183
The state of data after the change.
8284
"""
85+
86+
87+
_did_init = False
88+
_init_callback: _typing.Callable[[], _typing.Any] | None = None
89+
90+
91+
def init(callback: _typing.Callable[[], _typing.Any]) -> None:
92+
"""
93+
Registers a function that should be run when in a production environment
94+
before executing any functions code.
95+
Calling this decorator more than once leads to undefined behavior.
96+
"""
97+
98+
global _did_init
99+
global _init_callback
100+
101+
if _did_init:
102+
_logger.warn(
103+
"Setting init callback more than once. Only the most recent callback will be called"
104+
)
105+
106+
_init_callback = callback
107+
_did_init = False
108+
109+
110+
def _with_init(
111+
fn: _typing.Callable[...,
112+
_typing.Any]) -> _typing.Callable[..., _typing.Any]:
113+
"""
114+
A decorator that runs the init callback before running the decorated function.
115+
"""
116+
117+
def wrapper(*args, **kwargs):
118+
global _did_init
119+
120+
if not _did_init:
121+
if _init_callback is not None:
122+
_init_callback()
123+
_did_init = True
124+
125+
return fn(*args, **kwargs)
126+
127+
return wrapper

src/firebase_functions/db_fn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _db_endpoint_handler(
119119
subject=event_attributes["subject"],
120120
params=params,
121121
)
122-
func(database_event)
122+
_core._with_init(func)(database_event)
123123

124124

125125
@_util.copy_func_kwargs(DatabaseOptions)

src/firebase_functions/eventarc_fn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import firebase_functions.options as _options
2323
import firebase_functions.private.util as _util
24-
from firebase_functions.core import CloudEvent
24+
from firebase_functions.core import CloudEvent, _with_init
2525

2626

2727
@_util.copy_func_kwargs(_options.EventarcTriggerOptions)
@@ -73,7 +73,7 @@ def on_custom_event_published_wrapped(raw: _ce.CloudEvent):
7373
),
7474
type=event_dict["type"],
7575
)
76-
func(event)
76+
_with_init(func)(event)
7777

7878
_util.set_func_endpoint_attr(
7979
on_custom_event_published_wrapped,

src/firebase_functions/firestore_fn.py

+2
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def _firestore_endpoint_handler(
205205
params=params,
206206
)
207207

208+
func = _core._with_init(func)
209+
208210
if event_type.endswith(".withAuthContext"):
209211
database_event_with_auth_context = AuthEvent(**vars(database_event),
210212
auth_type=event_auth_type,

src/firebase_functions/https_fn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def _on_call_handler(func: _C2,
403403
instance_id_token=request.headers.get(
404404
"Firebase-Instance-ID-Token"),
405405
)
406-
result = func(context)
406+
result = _core._with_init(func)(context)
407407
return _jsonify(result=result)
408408
# Disable broad exceptions lint since we want to handle all exceptions here
409409
# and wrap as an HttpsError.
@@ -447,7 +447,7 @@ def on_request_wrapped(request: Request) -> Response:
447447
methods=options.cors.cors_methods,
448448
origins=options.cors.cors_origins,
449449
)(func)(request)
450-
return func(request)
450+
return _core._with_init(func)(request)
451451

452452
_util.set_func_endpoint_attr(
453453
on_request_wrapped,

src/firebase_functions/private/_identity_fn.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Cloud functions to handle Eventarc events."""
15-
1615
# pylint: disable=protected-access
1716
import typing as _typing
1817
import datetime as _dt
1918
import time as _time
2019
import json as _json
20+
21+
from firebase_functions.core import _with_init
2122
from firebase_functions.https_fn import HttpsError, FunctionsErrorCode
2223

2324
import firebase_functions.private.util as _util
@@ -351,8 +352,8 @@ def before_operation_handler(
351352
jwt_token = request.json["data"]["jwt"]
352353
decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token)
353354
event = _auth_blocking_event_from_token_data(decoded_token)
354-
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = func(
355-
event)
355+
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(
356+
func)(event)
356357
if not auth_response:
357358
return _jsonify({})
358359
auth_response_dict = _validate_auth_response(event_type, auth_response)
@@ -362,7 +363,7 @@ def before_operation_handler(
362363
# pylint: disable=broad-except
363364
except Exception as exception:
364365
if not isinstance(exception, HttpsError):
365-
_logging.error("Unhandled error", exception)
366+
_logging.error("Unhandled error %s", exception)
366367
exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
367368
status = exception._http_error_code.status
368369
return _make_response(_jsonify(error=exception._as_dict()), status)

src/firebase_functions/pubsub_fn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import firebase_functions.private.util as _util
2727

28-
from firebase_functions.core import CloudEvent, T
28+
from firebase_functions.core import CloudEvent, T, _with_init
2929
from firebase_functions.options import PubSubOptions
3030

3131

@@ -151,7 +151,7 @@ def _message_handler(
151151
type=event_dict["type"],
152152
)
153153

154-
func(event)
154+
_with_init(func)(event)
155155

156156

157157
@_util.copy_func_kwargs(PubSubOptions)

src/firebase_functions/remote_config_fn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import firebase_functions.private.util as _util
2626

27-
from firebase_functions.core import CloudEvent
27+
from firebase_functions.core import CloudEvent, _with_init
2828
from firebase_functions.options import EventHandlerOptions
2929

3030

@@ -189,7 +189,7 @@ def _config_handler(func: _C1, raw: _ce.CloudEvent) -> None:
189189
type=event_dict["type"],
190190
)
191191

192-
func(event)
192+
_with_init(func)(event)
193193

194194

195195
@_util.copy_func_kwargs(EventHandlerOptions)

src/firebase_functions/scheduler_fn.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
make_response as _make_response,
2828
)
2929

30+
from firebase_functions.core import _with_init
3031
# Export for user convenience.
3132
# pylint: disable=unused-import
3233
from firebase_functions.options import Timezone
@@ -108,7 +109,7 @@ def on_schedule_wrapped(request: _Request) -> _Response:
108109
schedule_time=schedule_time,
109110
)
110111
try:
111-
func(event)
112+
_with_init(func)(event)
112113
return _make_response()
113114
# Disable broad exceptions lint since we want to handle all exceptions.
114115
# pylint: disable=broad-except

src/firebase_functions/storage_fn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import cloudevents.http as _ce
2323

2424
import firebase_functions.private.util as _util
25-
from firebase_functions.core import CloudEvent
25+
from firebase_functions.core import CloudEvent, _with_init
2626
from firebase_functions.options import StorageOptions
2727

2828
_event_type_archived = "google.cloud.storage.object.v1.archived"
@@ -255,7 +255,7 @@ def _message_handler(
255255
type=event_attributes["type"],
256256
)
257257

258-
func(event)
258+
_with_init(func)(event)
259259

260260

261261
@_util.copy_func_kwargs(StorageOptions)

src/firebase_functions/test_lab_fn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import firebase_functions.private.util as _util
2626

27-
from firebase_functions.core import CloudEvent
27+
from firebase_functions.core import CloudEvent, _with_init
2828
from firebase_functions.options import EventHandlerOptions
2929

3030

@@ -246,7 +246,7 @@ def _event_handler(func: _C1, raw: _ce.CloudEvent) -> None:
246246
type=event_dict["type"],
247247
)
248248

249-
func(event)
249+
_with_init(func)(event)
250250

251251

252252
@_util.copy_func_kwargs(EventHandlerOptions)

tests/test_db.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Tests for the db module.
3+
"""
4+
5+
import unittest
6+
from unittest import mock
7+
from cloudevents.http import CloudEvent
8+
from firebase_functions import core, db_fn
9+
10+
11+
class TestDb(unittest.TestCase):
12+
"""
13+
Tests for the db module.
14+
"""
15+
16+
def test_calls_init_function(self):
17+
hello = None
18+
19+
@core.init
20+
def init():
21+
nonlocal hello
22+
hello = "world"
23+
24+
func = mock.Mock(__name__="example_func")
25+
decorated_func = db_fn.on_value_created(reference="path")(func)
26+
27+
event = CloudEvent(attributes={
28+
"specversion": "1.0",
29+
"id": "id",
30+
"source": "source",
31+
"subject": "subject",
32+
"type": "type",
33+
"time": "2024-04-10T12:00:00.000Z",
34+
"instance": "instance",
35+
"ref": "ref",
36+
"firebasedatabasehost": "firebasedatabasehost",
37+
"location": "location",
38+
},
39+
data={"delta": "delta"})
40+
41+
decorated_func(event)
42+
43+
self.assertEqual(hello, "world")

tests/test_eventarc_fn.py

+34
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
"""Eventarc trigger function tests."""
1515
import unittest
1616
from unittest.mock import Mock
17+
1718
from cloudevents.http import CloudEvent as _CloudEvent
19+
20+
from firebase_functions import core
1821
from firebase_functions.core import CloudEvent
1922
from firebase_functions.eventarc_fn import on_custom_event_published
2023

@@ -83,3 +86,34 @@ def test_on_custom_event_published_wrapped(self):
8386
event_arg.type,
8487
"firebase.extensions.storage-resize-images.v1.complete",
8588
)
89+
90+
def test_calls_init_function(self):
91+
hello = None
92+
93+
@core.init
94+
def init():
95+
nonlocal hello
96+
hello = "world"
97+
98+
func = Mock(__name__="example_func")
99+
raw_event = _CloudEvent(
100+
attributes={
101+
"specversion": "1.0",
102+
"type": "firebase.extensions.storage-resize-images.v1.complete",
103+
"source": "https://example.com/testevent",
104+
"id": "1234567890",
105+
"subject": "test_subject",
106+
"time": "2023-03-11T13:25:37.403Z",
107+
},
108+
data={
109+
"some_key": "some_value",
110+
},
111+
)
112+
113+
decorated_func = on_custom_event_published(
114+
event_type="firebase.extensions.storage-resize-images.v1.complete",
115+
)(func)
116+
117+
decorated_func(raw_event)
118+
119+
self.assertEqual(hello, "world")

tests/test_firestore_fn.py

+51
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,54 @@ def test_firestore_endpoint_handler_calls_function_with_correct_args(self):
7070
self.assertIsInstance(event, AuthEvent)
7171
self.assertEqual(event.auth_type, "unauthenticated")
7272
self.assertEqual(event.auth_id, "foo")
73+
74+
def test_calls_init_function(self):
75+
with patch.dict("sys.modules", mocked_modules):
76+
from firebase_functions import firestore_fn, core
77+
from cloudevents.http import CloudEvent
78+
79+
func = Mock(__name__="example_func")
80+
81+
hello = None
82+
83+
@core.init
84+
def init():
85+
nonlocal hello
86+
hello = "world"
87+
88+
attributes = {
89+
"specversion":
90+
"1.0",
91+
# pylint: disable=protected-access
92+
"type":
93+
firestore_fn._event_type_created,
94+
"source":
95+
"https://example.com/testevent",
96+
"time":
97+
"2023-03-11T13:25:37.403Z",
98+
"subject":
99+
"test_subject",
100+
"datacontenttype":
101+
"application/json",
102+
"location":
103+
"projects/project-id/databases/(default)/documents/foo/{bar}",
104+
"project":
105+
"project-id",
106+
"namespace":
107+
"(default)",
108+
"document":
109+
"foo/{bar}",
110+
"database":
111+
"projects/project-id/databases/(default)",
112+
"authtype":
113+
"unauthenticated",
114+
"authid":
115+
"foo"
116+
}
117+
raw_event = CloudEvent(attributes=attributes, data=json.dumps({}))
118+
decorated_func = firestore_fn.on_document_created(
119+
document="/foo/{bar}")(func)
120+
121+
decorated_func(raw_event)
122+
123+
self.assertEqual(hello, "world")

0 commit comments

Comments
 (0)