Skip to content

Commit 777c5d3

Browse files
committed
feat: add core.init function, and add test
1 parent 37c1fc0 commit 777c5d3

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

src/firebase_functions/core.py

+47
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import datetime as _datetime
1919
import typing as _typing
2020

21+
from . import logger as _logger
22+
2123
T = _typing.TypeVar("T")
2224

2325

@@ -80,3 +82,48 @@ class Change(_typing.Generic[T]):
8082
"""
8183
The state of data after the change.
8284
"""
85+
86+
87+
_didInit = False
88+
_initCallback: _typing.Callable[[], _typing.Any] | None = None
89+
90+
91+
def init(callback: _typing.Callable[[], _typing.Any]) -> None:
92+
"""
93+
Registers a function that should be run when in a production environment
94+
before executing any functions code.
95+
Calling this decorator more than once leads to undefined behavior.
96+
"""
97+
98+
global _didInit
99+
global _initCallback
100+
101+
if _didInit:
102+
raise ValueError("Firebase Functions SDK already initialized")
103+
104+
_initCallback = callback
105+
106+
if _didInit:
107+
_logger.warn("Setting init callback more than once. Only the most recent callback will be called")
108+
109+
_initCallback = callback
110+
_didInit = False
111+
112+
113+
def _with_init(fn: _typing.Callable[..., _typing.Any]) -> _typing.Callable[..., _typing.Any]:
114+
"""
115+
A decorator that runs the init callback before running the decorated function.
116+
"""
117+
118+
def wrapper(*args, **kwargs):
119+
global _didInit
120+
global _initCallback
121+
122+
if not _didInit:
123+
if _initCallback is not None:
124+
_initCallback()
125+
_didInit = True
126+
127+
return fn(*args, **kwargs)
128+
129+
return wrapper

tests/test_init.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import unittest
2+
from firebase_functions import core
3+
4+
5+
class TestInit(unittest.TestCase):
6+
def test_init_is_initialized(self):
7+
@core.init
8+
def fn():
9+
pass
10+
11+
self.assertIsNotNone(core._initCallback)
12+
self.assertFalse(core._didInit)

0 commit comments

Comments
 (0)