Skip to content

Commit 34d7745

Browse files
Internal change
PiperOrigin-RevId: 840778041
1 parent 2504a29 commit 34d7745

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Exercises proto_api.h. Most of the code is in proto_api_test.cc."""
2+
3+
from google.protobuf.internal import proto_api_test_ext
4+
from google3.testing.pybase import unittest
5+
from google.protobuf import unittest_pb2
6+
7+
8+
class ProtoApiTest(unittest.TestCase):
9+
10+
def test_get_const_message(self):
11+
msg = unittest_pb2.TestAllTypes(optional_int32=123, optional_string='hello')
12+
result = proto_api_test_ext.get_const_message(msg)
13+
self.assertEqual(result, (123, 'hello'))
14+
15+
def test_set_message_field_with_mutator(self):
16+
msg = unittest_pb2.TestAllTypes(optional_int32=123, optional_string='hello')
17+
proto_api_test_ext.set_message_field_with_mutator(msg, 456)
18+
# GetClearedMessageMutator clears message first, then update it.
19+
# On destruction, it copies content back to python message.
20+
self.assertEqual(456, msg.optional_int32)
21+
self.assertFalse(msg.HasField('optional_string'))
22+
23+
def test_dynamic_message(self):
24+
result = proto_api_test_ext.repr_dynamic_message(789)
25+
self.assertEqual(result, 'optional_int32: 789\n')
26+
27+
28+
if __name__ == '__main__':
29+
unittest.main()
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// An extension module to test proto_api.h.
2+
3+
#include <memory>
4+
#include <stdexcept>
5+
#include <string>
6+
7+
#include "google/protobuf/descriptor.pb.h"
8+
#include "google/protobuf/descriptor.h"
9+
#include "google/protobuf/dynamic_message.h"
10+
#include "google/protobuf/message.h"
11+
#include "google/protobuf/unittest.pb.h"
12+
#include "google/protobuf/proto_api.h"
13+
#include "third_party/pybind11/include/pybind11/eval.h"
14+
#include "third_party/pybind11/include/pybind11/pybind11.h"
15+
#include "third_party/pybind11/include/pybind11/stl.h"
16+
17+
namespace google {
18+
namespace protobuf {
19+
namespace python {
20+
21+
namespace py = pybind11;
22+
using ::google_protobuf_unittest::TestAllTypes;
23+
24+
const PyProto_API* GetProtoApi() {
25+
py::module_::import("google.protobuf.pyext._message");
26+
const PyProto_API* py_proto_api = static_cast<const PyProto_API*>(
27+
PyCapsule_Import(PyProtoAPICapsuleName(), 0));
28+
if (!py_proto_api) {
29+
throw py::error_already_set();
30+
}
31+
return py_proto_api;
32+
}
33+
34+
// Test for GetConstMessagePointer
35+
auto GetConstMessage(py::handle py_msg) {
36+
const PyProto_API* api = GetProtoApi();
37+
auto msg_ptr = api->GetConstMessagePointer(py_msg.ptr());
38+
if (!msg_ptr.ok()) {
39+
throw std::runtime_error(msg_ptr.status().ToString());
40+
}
41+
const auto* msg = DynamicCastMessage<TestAllTypes>(&msg_ptr->get());
42+
if (!msg) {
43+
throw std::runtime_error("Invalid message type");
44+
}
45+
return py::make_tuple(msg->optional_int32(), msg->optional_string());
46+
}
47+
48+
// Test for GetClearedMessageMutator
49+
auto SetMessageFieldWithMutator(py::handle py_msg, int value) {
50+
const PyProto_API* api = GetProtoApi();
51+
auto status_or_mutator = api->GetClearedMessageMutator(py_msg.ptr());
52+
if (!status_or_mutator.ok()) {
53+
throw std::runtime_error(status_or_mutator.status().ToString());
54+
}
55+
TestAllTypes* msg = DownCastMessage<TestAllTypes>(status_or_mutator->get());
56+
msg->set_optional_int32(value);
57+
// On destruction, the mutator will copy content back to python message.
58+
}
59+
60+
// Test for DescriptorPool_FromPool and NewMessageOwnedExternally
61+
auto ReprDynamicMessage(int value) {
62+
const PyProto_API* api = GetProtoApi();
63+
64+
// Create a descriptor pool which copies everything from the linked protos.
65+
DescriptorPool pool(DescriptorPool::internal_generated_database());
66+
// FileDescriptorProto file_descriptor;
67+
// TestAllTypes::descriptor()->file()->CopyTo(&file_descriptor);
68+
// if (!pool.BuildFile(file_descriptor)) {
69+
// throw std::runtime_error("Failed to build file descriptor");
70+
// }
71+
const Descriptor* descriptor =
72+
pool.FindMessageTypeByName("google_protobuf_unittest.TestAllTypes");
73+
if (!descriptor) {
74+
throw std::runtime_error("Failed to find file descriptor");
75+
}
76+
DynamicMessageFactory factory(&pool);
77+
const Message* prototype = factory.GetPrototype(descriptor);
78+
if (!prototype) {
79+
throw std::runtime_error("Failed to get prototype for descriptor");
80+
}
81+
std::unique_ptr<Message> msg(prototype->New());
82+
if (!msg) {
83+
throw std::runtime_error("Failed to create message");
84+
}
85+
msg->GetReflection()->SetInt32(
86+
msg.get(), descriptor->FindFieldByName("optional_int32"), value);
87+
88+
// These calls to NewMessage fail because the descriptor pool is not
89+
// known to Python yet.
90+
{
91+
auto py_msg =
92+
py::reinterpret_steal<py::object>(api->NewMessage(descriptor, nullptr));
93+
if (py_msg) {
94+
throw std::runtime_error("NewMessage succeeded unexpectedly");
95+
}
96+
py_msg = py::reinterpret_steal<py::object>(
97+
api->NewMessageOwnedExternally(msg.get(), nullptr));
98+
if (py_msg) {
99+
throw std::runtime_error("NewMessage succeeded unexpectedly");
100+
}
101+
}
102+
103+
// Create the Python DescriptorPool...
104+
auto py_pool =
105+
py::reinterpret_steal<py::object>(api->DescriptorPool_FromPool(&pool));
106+
if (!py_pool) {
107+
throw py::error_already_set();
108+
}
109+
110+
// ... And now the API Can use it to create the messages.
111+
std::string result_string;
112+
{
113+
auto py_msg =
114+
py::reinterpret_steal<py::object>(api->NewMessage(descriptor, nullptr));
115+
if (!py_msg) {
116+
throw py::error_already_set();
117+
}
118+
119+
py_msg = py::reinterpret_steal<py::object>(
120+
api->NewMessageOwnedExternally(msg.get(), nullptr));
121+
if (!py_msg) {
122+
throw py::error_already_set();
123+
}
124+
result_string = py::repr(py_msg);
125+
}
126+
127+
// The code above is dangerous! It relies on the C++ DescriptorPool being
128+
// alive for whole duration of the test.
129+
// At this point, there are no external references to the Python Message
130+
// classes, but they always form a reference cycle with their Python
131+
// MessageFactory.
132+
// So it is necessary to run the garbage collector.
133+
py::exec("import gc; gc.collect()");
134+
// Now the Python MessageFactory has been deleted, and it is safe to destroy
135+
// the C++ DescriptorPool.
136+
137+
return result_string;
138+
}
139+
140+
PYBIND11_MODULE(proto_api_test_ext, m) {
141+
m.def("get_const_message", &GetConstMessage);
142+
m.def("set_message_field_with_mutator", &SetMessageFieldWithMutator);
143+
m.def("repr_dynamic_message", &ReprDynamicMessage);
144+
}
145+
146+
} // namespace python
147+
} // namespace protobuf
148+
} // namespace google

0 commit comments

Comments
 (0)