Skip to content

Commit 7ab6f1a

Browse files
committed
feat: add tests
1 parent 7ba735a commit 7ab6f1a

14 files changed

+479
-31
lines changed

src/firebase_functions/core.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ class Change(_typing.Generic[T]):
8484
"""
8585

8686

87-
_didInit = False
88-
_initCallback: _typing.Callable[[], _typing.Any] | None = None
87+
_did_init = False
88+
_init_callback: _typing.Callable[[], _typing.Any] | None = None
8989

9090

9191
def init(callback: _typing.Callable[[], _typing.Any]) -> None:
@@ -95,31 +95,34 @@ def init(callback: _typing.Callable[[], _typing.Any]) -> None:
9595
Calling this decorator more than once leads to undefined behavior.
9696
"""
9797

98-
global _didInit
99-
global _initCallback
98+
global _did_init
99+
global _init_callback
100100

101-
_initCallback = callback
101+
_init_callback = callback
102102

103-
if _didInit:
104-
_logger.warn("Setting init callback more than once. Only the most recent callback will be called")
103+
if _did_init:
104+
_logger.warn(
105+
"Setting init callback more than once. Only the most recent callback will be called"
106+
)
105107

106-
_initCallback = callback
107-
_didInit = False
108+
_init_callback = callback
109+
_did_init = False
108110

109111

110-
def _with_init(fn: _typing.Callable[..., _typing.Any]) -> _typing.Callable[..., _typing.Any]:
112+
def _with_init(
113+
fn: _typing.Callable[...,
114+
_typing.Any]) -> _typing.Callable[..., _typing.Any]:
111115
"""
112116
A decorator that runs the init callback before running the decorated function.
113117
"""
114118

115119
def wrapper(*args, **kwargs):
116-
global _didInit
117-
global _initCallback
120+
global _did_init
118121

119-
if not _didInit:
120-
if _initCallback is not None:
121-
_initCallback()
122-
_didInit = True
122+
if not _did_init:
123+
if _init_callback is not None:
124+
_init_callback()
125+
_did_init = True
123126

124127
return fn(*args, **kwargs)
125128

src/firebase_functions/firestore_fn.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ def _firestore_endpoint_handler(
145145
app = get_app()
146146
firestore_client = _firestore_v1.Client(project=app.project_id,
147147
database=event_database)
148-
firestore_ref: DocumentReference = firestore_client.document(
149-
event_document)
148+
firestore_ref: DocumentReference = firestore_client.document(event_document)
150149
value_snapshot: DocumentSnapshot | None = None
151150
old_value_snapshot: DocumentSnapshot | None = None
152151

@@ -268,7 +267,7 @@ def on_document_written_wrapped(raw: _ce.CloudEvent):
268267

269268
@_util.copy_func_kwargs(FirestoreOptions)
270269
def on_document_written_with_auth_context(**kwargs
271-
) -> _typing.Callable[[_C1], _C1]:
270+
) -> _typing.Callable[[_C1], _C1]:
272271
"""
273272
Event handler that triggers when a document is created, updated, or deleted in Firestore.
274273
This trigger will also provide the authentication context of the principal who triggered
@@ -367,7 +366,7 @@ def on_document_updated_wrapped(raw: _ce.CloudEvent):
367366

368367
@_util.copy_func_kwargs(FirestoreOptions)
369368
def on_document_updated_with_auth_context(**kwargs
370-
) -> _typing.Callable[[_C1], _C1]:
369+
) -> _typing.Callable[[_C1], _C1]:
371370
"""
372371
Event handler that triggers when a document is updated in Firestore.
373372
This trigger will also provide the authentication context of the principal who triggered
@@ -466,7 +465,7 @@ def on_document_created_wrapped(raw: _ce.CloudEvent):
466465

467466
@_util.copy_func_kwargs(FirestoreOptions)
468467
def on_document_created_with_auth_context(**kwargs
469-
) -> _typing.Callable[[_C2], _C2]:
468+
) -> _typing.Callable[[_C2], _C2]:
470469
"""
471470
Event handler that triggers when a document is created in Firestore.
472471
This trigger will also provide the authentication context of the principal who triggered
@@ -565,7 +564,7 @@ def on_document_deleted_wrapped(raw: _ce.CloudEvent):
565564

566565
@_util.copy_func_kwargs(FirestoreOptions)
567566
def on_document_deleted_with_auth_context(**kwargs
568-
) -> _typing.Callable[[_C2], _C2]:
567+
) -> _typing.Callable[[_C2], _C2]:
569568
"""
570569
Event handler that triggers when a document is deleted in Firestore.
571570
This trigger will also provide the authentication context of the principal who triggered

src/firebase_functions/private/_identity_fn.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Cloud functions to handle Eventarc events."""
15-
1615
# pylint: disable=protected-access
1716
import typing as _typing
1817
import datetime as _dt
@@ -353,8 +352,8 @@ def before_operation_handler(
353352
jwt_token = request.json["data"]["jwt"]
354353
decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token)
355354
event = _auth_blocking_event_from_token_data(decoded_token)
356-
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(func)(
357-
event)
355+
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(
356+
func)(event)
358357
if not auth_response:
359358
return _jsonify({})
360359
auth_response_dict = _validate_auth_response(event_type, auth_response)
@@ -364,7 +363,7 @@ def before_operation_handler(
364363
# pylint: disable=broad-except
365364
except Exception as exception:
366365
if not isinstance(exception, HttpsError):
367-
_logging.error("Unhandled error", exception)
366+
_logging.error("Unhandled error %s", exception)
368367
exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
369368
status = exception._http_error_code.status
370369
return _make_response(_jsonify(error=exception._as_dict()), status)

tests/test_db.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Tests for the db module.
3+
"""
4+
5+
import unittest
6+
from unittest import mock
7+
from cloudevents.http import CloudEvent
8+
from firebase_functions import core, db_fn
9+
10+
11+
class TestDb(unittest.TestCase):
12+
"""
13+
Tests for the db module.
14+
"""
15+
16+
def test_calls_init_function(self):
17+
hello = None
18+
19+
@core.init
20+
def init():
21+
nonlocal hello
22+
hello = "world"
23+
24+
func = mock.Mock(__name__="example_func")
25+
decorated_func = db_fn.on_value_created(reference="path")(func)
26+
27+
event = CloudEvent(attributes={
28+
"specversion": "1.0",
29+
"id": "id",
30+
"source": "source",
31+
"subject": "subject",
32+
"type": "type",
33+
"time": "2024-04-10T12:00:00.000Z",
34+
"instance": "instance",
35+
"ref": "ref",
36+
"firebasedatabasehost": "firebasedatabasehost",
37+
"location": "location",
38+
},
39+
data={"delta": "delta"})
40+
41+
decorated_func(event)
42+
43+
self.assertEqual(hello, "world")

tests/test_eventarc_fn.py

+34
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
"""Eventarc trigger function tests."""
1515
import unittest
1616
from unittest.mock import Mock
17+
1718
from cloudevents.http import CloudEvent as _CloudEvent
19+
20+
from firebase_functions import core
1821
from firebase_functions.core import CloudEvent
1922
from firebase_functions.eventarc_fn import on_custom_event_published
2023

@@ -83,3 +86,34 @@ def test_on_custom_event_published_wrapped(self):
8386
event_arg.type,
8487
"firebase.extensions.storage-resize-images.v1.complete",
8588
)
89+
90+
def test_calls_init_function(self):
91+
hello = None
92+
93+
@core.init
94+
def init():
95+
nonlocal hello
96+
hello = "world"
97+
98+
func = Mock(__name__="example_func")
99+
raw_event = _CloudEvent(
100+
attributes={
101+
"specversion": "1.0",
102+
"type": "firebase.extensions.storage-resize-images.v1.complete",
103+
"source": "https://example.com/testevent",
104+
"id": "1234567890",
105+
"subject": "test_subject",
106+
"time": "2023-03-11T13:25:37.403Z",
107+
},
108+
data={
109+
"some_key": "some_value",
110+
},
111+
)
112+
113+
decorated_func = on_custom_event_published(
114+
event_type="firebase.extensions.storage-resize-images.v1.complete",
115+
)(func)
116+
117+
decorated_func(raw_event)
118+
119+
self.assertEqual(hello, "world")

tests/test_firestore_fn.py

+51
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,54 @@ def test_firestore_endpoint_handler_calls_function_with_correct_args(self):
7070
self.assertIsInstance(event, AuthEvent)
7171
self.assertEqual(event.auth_type, "unauthenticated")
7272
self.assertEqual(event.auth_id, "foo")
73+
74+
def test_calls_init_function(self):
75+
with patch.dict("sys.modules", mocked_modules):
76+
from firebase_functions import firestore_fn, core
77+
from cloudevents.http import CloudEvent
78+
79+
func = Mock(__name__="example_func")
80+
81+
hello = None
82+
83+
@core.init
84+
def init():
85+
nonlocal hello
86+
hello = "world"
87+
88+
attributes = {
89+
"specversion":
90+
"1.0",
91+
# pylint: disable=protected-access
92+
"type":
93+
firestore_fn._event_type_created,
94+
"source":
95+
"https://example.com/testevent",
96+
"time":
97+
"2023-03-11T13:25:37.403Z",
98+
"subject":
99+
"test_subject",
100+
"datacontenttype":
101+
"application/json",
102+
"location":
103+
"projects/project-id/databases/(default)/documents/foo/{bar}",
104+
"project":
105+
"project-id",
106+
"namespace":
107+
"(default)",
108+
"document":
109+
"foo/{bar}",
110+
"database":
111+
"projects/project-id/databases/(default)",
112+
"authtype":
113+
"unauthenticated",
114+
"authid":
115+
"foo"
116+
}
117+
raw_event = CloudEvent(attributes=attributes, data=json.dumps({}))
118+
decorated_func = firestore_fn.on_document_created(
119+
document="/foo/{bar}")(func)
120+
121+
decorated_func(raw_event)
122+
123+
self.assertEqual(hello, "world")

tests/test_https_fn.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
Tests for the https module.
3+
"""
4+
5+
import unittest
6+
from unittest.mock import Mock
7+
from flask import Flask, Request
8+
from werkzeug.test import EnvironBuilder
9+
10+
from firebase_functions import core, https_fn
11+
12+
13+
class TestHttps(unittest.TestCase):
14+
"""
15+
Tests for the http module.
16+
"""
17+
18+
def test_on_request_calls_init_function(self):
19+
app = Flask(__name__)
20+
21+
hello = None
22+
23+
@core.init
24+
def init():
25+
nonlocal hello
26+
hello = "world"
27+
28+
func = Mock(__name__="example_func")
29+
30+
with app.test_request_context("/"):
31+
environ = EnvironBuilder(
32+
method="POST",
33+
json={
34+
"data": {
35+
"test": "value"
36+
},
37+
},
38+
).get_environ()
39+
request = Request(environ)
40+
decorated_func = https_fn.on_request()(func)
41+
42+
decorated_func(request)
43+
44+
self.assertEqual(hello, "world")
45+
46+
def test_on_call_calls_init_function(self):
47+
app = Flask(__name__)
48+
49+
hello = None
50+
51+
@core.init
52+
def init():
53+
nonlocal hello
54+
hello = "world"
55+
56+
func = Mock(__name__="example_func")
57+
58+
with app.test_request_context("/"):
59+
environ = EnvironBuilder(
60+
method="POST",
61+
json={
62+
"data": {
63+
"test": "value"
64+
},
65+
},
66+
).get_environ()
67+
request = Request(environ)
68+
decorated_func = https_fn.on_call()(func)
69+
70+
decorated_func(request)
71+
72+
self.assertEqual("world", hello)

0 commit comments

Comments
 (0)