Skip to content

Commit e9781a9

Browse files
authored
feat: add init function support (#188)
1 parent b7c0b43 commit e9781a9

23 files changed

+530
-23
lines changed

src/firebase_functions/alerts_fn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import firebase_functions.private.util as _util
2626

27-
from firebase_functions.core import T, CloudEvent as _CloudEvent
27+
from firebase_functions.core import T, CloudEvent as _CloudEvent, _with_init
2828
from firebase_functions.options import FirebaseAlertOptions
2929

3030
# Explicitly import AlertType to make it available in the public API.
@@ -95,7 +95,7 @@ def on_alert_published_inner_decorator(func: OnAlertPublishedCallable):
9595
@_functools.wraps(func)
9696
def on_alert_published_wrapped(raw: _ce.CloudEvent):
9797
from firebase_functions.private._alerts_fn import alerts_event_from_ce
98-
func(alerts_event_from_ce(raw))
98+
_with_init(func)(alerts_event_from_ce(raw))
9999

100100
_util.set_func_endpoint_attr(
101101
on_alert_published_wrapped,

src/firebase_functions/core.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import datetime as _datetime
1919
import typing as _typing
2020

21+
from . import logger as _logger
22+
2123
T = _typing.TypeVar("T")
2224

2325

@@ -80,3 +82,46 @@ class Change(_typing.Generic[T]):
8082
"""
8183
The state of data after the change.
8284
"""
85+
86+
87+
_did_init = False
88+
_init_callback: _typing.Callable[[], _typing.Any] | None = None
89+
90+
91+
def init(callback: _typing.Callable[[], _typing.Any]) -> None:
92+
"""
93+
Registers a function that should be run when in a production environment
94+
before executing any functions code.
95+
Calling this decorator more than once leads to undefined behavior.
96+
"""
97+
98+
global _did_init
99+
global _init_callback
100+
101+
if _did_init:
102+
_logger.warn(
103+
"Setting init callback more than once. Only the most recent callback will be called"
104+
)
105+
106+
_init_callback = callback
107+
_did_init = False
108+
109+
110+
def _with_init(
111+
fn: _typing.Callable[...,
112+
_typing.Any]) -> _typing.Callable[..., _typing.Any]:
113+
"""
114+
A decorator that runs the init callback before running the decorated function.
115+
"""
116+
117+
def wrapper(*args, **kwargs):
118+
global _did_init
119+
120+
if not _did_init:
121+
if _init_callback is not None:
122+
_init_callback()
123+
_did_init = True
124+
125+
return fn(*args, **kwargs)
126+
127+
return wrapper

src/firebase_functions/db_fn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _db_endpoint_handler(
119119
subject=event_attributes["subject"],
120120
params=params,
121121
)
122-
func(database_event)
122+
_core._with_init(func)(database_event)
123123

124124

125125
@_util.copy_func_kwargs(DatabaseOptions)

src/firebase_functions/eventarc_fn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import firebase_functions.options as _options
2323
import firebase_functions.private.util as _util
24-
from firebase_functions.core import CloudEvent
24+
from firebase_functions.core import CloudEvent, _with_init
2525

2626

2727
@_util.copy_func_kwargs(_options.EventarcTriggerOptions)
@@ -73,7 +73,7 @@ def on_custom_event_published_wrapped(raw: _ce.CloudEvent):
7373
),
7474
type=event_dict["type"],
7575
)
76-
func(event)
76+
_with_init(func)(event)
7777

7878
_util.set_func_endpoint_attr(
7979
on_custom_event_published_wrapped,

src/firebase_functions/firestore_fn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def _firestore_endpoint_handler(
205205
params=params,
206206
)
207207

208+
func = _core._with_init(func)
209+
208210
if event_type.endswith(".withAuthContext"):
209211
database_event_with_auth_context = AuthEvent(**vars(database_event),
210212
auth_type=event_auth_type,

src/firebase_functions/https_fn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def _on_call_handler(func: _C2,
403403
instance_id_token=request.headers.get(
404404
"Firebase-Instance-ID-Token"),
405405
)
406-
result = func(context)
406+
result = _core._with_init(func)(context)
407407
return _jsonify(result=result)
408408
# Disable broad exceptions lint since we want to handle all exceptions here
409409
# and wrap as an HttpsError.
@@ -447,7 +447,7 @@ def on_request_wrapped(request: Request) -> Response:
447447
methods=options.cors.cors_methods,
448448
origins=options.cors.cors_origins,
449449
)(func)(request)
450-
return func(request)
450+
return _core._with_init(func)(request)
451451

452452
_util.set_func_endpoint_attr(
453453
on_request_wrapped,

src/firebase_functions/private/_identity_fn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Cloud functions to handle Eventarc events."""
15-
1615
# pylint: disable=protected-access
1716
import typing as _typing
1817
import datetime as _dt
1918
import time as _time
2019
import json as _json
20+
21+
from firebase_functions.core import _with_init
2122
from firebase_functions.https_fn import HttpsError, FunctionsErrorCode
2223

2324
import firebase_functions.private.util as _util
@@ -351,8 +352,8 @@ def before_operation_handler(
351352
jwt_token = request.json["data"]["jwt"]
352353
decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token)
353354
event = _auth_blocking_event_from_token_data(decoded_token)
354-
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = func(
355-
event)
355+
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(
356+
func)(event)
356357
if not auth_response:
357358
return _jsonify({})
358359
auth_response_dict = _validate_auth_response(event_type, auth_response)
@@ -362,7 +363,7 @@ def before_operation_handler(
362363
# pylint: disable=broad-except
363364
except Exception as exception:
364365
if not isinstance(exception, HttpsError):
365-
_logging.error("Unhandled error", exception)
366+
_logging.error("Unhandled error %s", exception)
366367
exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
367368
status = exception._http_error_code.status
368369
return _make_response(_jsonify(error=exception._as_dict()), status)

src/firebase_functions/pubsub_fn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import firebase_functions.private.util as _util
2727

28-
from firebase_functions.core import CloudEvent, T
28+
from firebase_functions.core import CloudEvent, T, _with_init
2929
from firebase_functions.options import PubSubOptions
3030

3131

@@ -151,7 +151,7 @@ def _message_handler(
151151
type=event_dict["type"],
152152
)
153153

154-
func(event)
154+
_with_init(func)(event)
155155

156156

157157
@_util.copy_func_kwargs(PubSubOptions)

src/firebase_functions/remote_config_fn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import firebase_functions.private.util as _util
2626

27-
from firebase_functions.core import CloudEvent
27+
from firebase_functions.core import CloudEvent, _with_init
2828
from firebase_functions.options import EventHandlerOptions
2929

3030

@@ -189,7 +189,7 @@ def _config_handler(func: _C1, raw: _ce.CloudEvent) -> None:
189189
type=event_dict["type"],
190190
)
191191

192-
func(event)
192+
_with_init(func)(event)
193193

194194

195195
@_util.copy_func_kwargs(EventHandlerOptions)

src/firebase_functions/scheduler_fn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
make_response as _make_response,
2828
)
2929

30+
from firebase_functions.core import _with_init
3031
# Export for user convenience.
3132
# pylint: disable=unused-import
3233
from firebase_functions.options import Timezone
@@ -108,7 +109,7 @@ def on_schedule_wrapped(request: _Request) -> _Response:
108109
schedule_time=schedule_time,
109110
)
110111
try:
111-
func(event)
112+
_with_init(func)(event)
112113
return _make_response()
113114
# Disable broad exceptions lint since we want to handle all exceptions.
114115
# pylint: disable=broad-except

0 commit comments

Comments
 (0)