Skip to content

Commit bd3a229

Browse files
committed
Support pickling Record-s
Closes #451
1 parent 7df9812 commit bd3a229

File tree

5 files changed

+153
-35
lines changed

5 files changed

+153
-35
lines changed

asyncpg/protocol/protocol.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -1024,4 +1024,4 @@ def _create_record(object mapping, tuple elems):
10241024
return rec
10251025

10261026

1027-
Record = <object>record.ApgRecord_InitTypes()
1027+
Record, RecordDescriptor = record.ApgRecord_InitTypes()

asyncpg/protocol/record/__init__.pxd

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ cimport cpython
1010

1111
cdef extern from "record/recordobj.h":
1212

13-
cpython.PyTypeObject *ApgRecord_InitTypes() except NULL
13+
tuple ApgRecord_InitTypes()
1414

1515
int ApgRecord_CheckExact(object)
1616
object ApgRecord_New(type, object, int)

asyncpg/protocol/record/recordobj.c

+142-27
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ static PyObject * record_new_items_iter(PyObject *);
2020
static ApgRecordObject *free_list[ApgRecord_MAXSAVESIZE];
2121
static int numfree[ApgRecord_MAXSAVESIZE];
2222

23+
static PyObject *record_reconstruct_obj;
24+
static PyObject *record_desc_reconstruct_obj;
25+
2326
static size_t MAX_RECORD_SIZE = (
2427
((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) - sizeof(PyObject *))
2528
/ sizeof(PyObject *)
@@ -575,14 +578,14 @@ record_repr(ApgRecordObject *v)
575578

576579

577580
static PyObject *
578-
record_values(PyObject *o, PyObject *args)
581+
record_values(PyObject *o, PyObject *Py_UNUSED(unused))
579582
{
580583
return record_iter(o);
581584
}
582585

583586

584587
static PyObject *
585-
record_keys(PyObject *o, PyObject *args)
588+
record_keys(PyObject *o, PyObject *Py_UNUSED(unused))
586589
{
587590
if (!ApgRecord_Check(o)) {
588591
PyErr_BadInternalCall();
@@ -594,7 +597,7 @@ record_keys(PyObject *o, PyObject *args)
594597

595598

596599
static PyObject *
597-
record_items(PyObject *o, PyObject *args)
600+
record_items(PyObject *o, PyObject *Py_UNUSED(unused))
598601
{
599602
if (!ApgRecord_Check(o)) {
600603
PyErr_BadInternalCall();
@@ -658,11 +661,69 @@ static PyMappingMethods record_as_mapping = {
658661
};
659662

660663

664+
static PyObject *
665+
record_reduce(ApgRecordObject *o, PyObject *Py_UNUSED(unused))
666+
{
667+
PyObject *value = PyTuple_New(2);
668+
if (value == NULL) {
669+
return NULL;
670+
}
671+
Py_ssize_t len = Py_SIZE(o);
672+
PyObject *state = PyTuple_New(1 + len);
673+
if (state == NULL) {
674+
Py_DECREF(value);
675+
return NULL;
676+
}
677+
PyTuple_SET_ITEM(value, 0, record_reconstruct_obj);
678+
Py_INCREF(record_reconstruct_obj);
679+
PyTuple_SET_ITEM(value, 1, state);
680+
PyTuple_SET_ITEM(state, 0, (PyObject *)o->desc);
681+
Py_INCREF(o->desc);
682+
for (Py_ssize_t i = 0; i < len; i++) {
683+
PyObject *item = ApgRecord_GET_ITEM(o, i);
684+
PyTuple_SET_ITEM(state, i + 1, item);
685+
Py_INCREF(item);
686+
}
687+
return value;
688+
}
689+
690+
static PyObject *
691+
record_reconstruct(PyObject *Py_UNUSED(unused), PyObject *args)
692+
{
693+
if (!PyTuple_CheckExact(args)) {
694+
return NULL;
695+
}
696+
Py_ssize_t len = PyTuple_GET_SIZE(args);
697+
if (len < 2) {
698+
return NULL;
699+
}
700+
len--;
701+
ApgRecordDescObject *desc = (ApgRecordDescObject *)PyTuple_GET_ITEM(args, 0);
702+
if (!ApgRecordDesc_CheckExact(desc)) {
703+
return NULL;
704+
}
705+
if (PyObject_Length(desc->mapping) != len) {
706+
return NULL;
707+
}
708+
PyObject *record = ApgRecord_New(&ApgRecord_Type, (PyObject *)desc, len);
709+
if (record == NULL) {
710+
return NULL;
711+
}
712+
for (Py_ssize_t i = 0; i < len; i++) {
713+
PyObject *item = PyTuple_GET_ITEM(args, i + 1);
714+
ApgRecord_SET_ITEM(record, i, item);
715+
Py_INCREF(item);
716+
}
717+
return record;
718+
}
719+
661720
static PyMethodDef record_methods[] = {
662721
{"values", (PyCFunction)record_values, METH_NOARGS},
663722
{"keys", (PyCFunction)record_keys, METH_NOARGS},
664723
{"items", (PyCFunction)record_items, METH_NOARGS},
665724
{"get", (PyCFunction)record_get, METH_VARARGS},
725+
{"__reduce__", (PyCFunction)record_reduce, METH_NOARGS},
726+
{"__reconstruct__", (PyCFunction)record_reconstruct, METH_VARARGS | METH_STATIC},
666727
{NULL, NULL} /* sentinel */
667728
};
668729

@@ -942,29 +1003,6 @@ record_new_items_iter(PyObject *seq)
9421003
}
9431004

9441005

945-
PyTypeObject *
946-
ApgRecord_InitTypes(void)
947-
{
948-
if (PyType_Ready(&ApgRecord_Type) < 0) {
949-
return NULL;
950-
}
951-
952-
if (PyType_Ready(&ApgRecordDesc_Type) < 0) {
953-
return NULL;
954-
}
955-
956-
if (PyType_Ready(&ApgRecordIter_Type) < 0) {
957-
return NULL;
958-
}
959-
960-
if (PyType_Ready(&ApgRecordItems_Type) < 0) {
961-
return NULL;
962-
}
963-
964-
return &ApgRecord_Type;
965-
}
966-
967-
9681006
/* ----------------- */
9691007

9701008

@@ -987,15 +1025,54 @@ record_desc_traverse(ApgRecordDescObject *o, visitproc visit, void *arg)
9871025
}
9881026

9891027

1028+
static PyObject *record_desc_reduce(ApgRecordDescObject *o, PyObject *Py_UNUSED(unused))
1029+
{
1030+
PyObject *value = PyTuple_New(2);
1031+
if (value == NULL) {
1032+
return NULL;
1033+
}
1034+
PyObject *state = PyTuple_New(2);
1035+
if (state == NULL) {
1036+
Py_DECREF(value);
1037+
return NULL;
1038+
}
1039+
PyTuple_SET_ITEM(value, 0, record_desc_reconstruct_obj);
1040+
Py_INCREF(record_desc_reconstruct_obj);
1041+
PyTuple_SET_ITEM(value, 1, state);
1042+
PyTuple_SET_ITEM(state, 0, o->mapping);
1043+
Py_INCREF(o->mapping);
1044+
PyTuple_SET_ITEM(state, 1, o->keys);
1045+
Py_INCREF(o->keys);
1046+
return value;
1047+
}
1048+
1049+
1050+
static PyObject *record_desc_reconstruct(PyObject *Py_UNUSED(unused), PyObject *args)
1051+
{
1052+
if (PyTuple_GET_SIZE(args) != 2) {
1053+
return NULL;
1054+
}
1055+
return ApgRecordDesc_New(PyTuple_GET_ITEM(args, 0), PyTuple_GET_ITEM(args, 1));
1056+
}
1057+
1058+
1059+
static PyMethodDef record_desc_methods[] = {
1060+
{"__reduce__", (PyCFunction)record_desc_reduce, METH_NOARGS},
1061+
{"__reconstruct__", (PyCFunction)record_desc_reconstruct, METH_VARARGS | METH_STATIC},
1062+
{NULL, NULL} /* sentinel */
1063+
};
1064+
1065+
9901066
PyTypeObject ApgRecordDesc_Type = {
9911067
PyVarObject_HEAD_INIT(NULL, 0)
992-
.tp_name = "RecordDescriptor",
1068+
.tp_name = "asyncpg.protocol.protocol.RecordDescriptor",
9931069
.tp_basicsize = sizeof(ApgRecordDescObject),
9941070
.tp_dealloc = (destructor)record_desc_dealloc,
9951071
.tp_getattro = PyObject_GenericGetAttr,
9961072
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
9971073
.tp_traverse = (traverseproc)record_desc_traverse,
9981074
.tp_iter = PyObject_SelfIter,
1075+
.tp_methods = record_desc_methods,
9991076
};
10001077

10011078

@@ -1023,3 +1100,41 @@ ApgRecordDesc_New(PyObject *mapping, PyObject *keys)
10231100
PyObject_GC_Track(o);
10241101
return (PyObject *) o;
10251102
}
1103+
1104+
1105+
PyObject *
1106+
ApgRecord_InitTypes(void)
1107+
{
1108+
if (PyType_Ready(&ApgRecord_Type) < 0) {
1109+
return NULL;
1110+
}
1111+
1112+
if (PyType_Ready(&ApgRecordDesc_Type) < 0) {
1113+
return NULL;
1114+
}
1115+
1116+
if (PyType_Ready(&ApgRecordIter_Type) < 0) {
1117+
return NULL;
1118+
}
1119+
1120+
if (PyType_Ready(&ApgRecordItems_Type) < 0) {
1121+
return NULL;
1122+
}
1123+
1124+
record_reconstruct_obj = PyCFunction_New(
1125+
&record_methods[5], (PyObject *)&ApgRecord_Type
1126+
);
1127+
record_desc_reconstruct_obj = PyCFunction_New(
1128+
&record_desc_methods[1], (PyObject *)&ApgRecordDesc_Type
1129+
);
1130+
1131+
PyObject *types = PyTuple_New(2);
1132+
if (types == NULL) {
1133+
return NULL;
1134+
}
1135+
PyTuple_SET_ITEM(types, 0, (PyObject *)&ApgRecord_Type);
1136+
Py_INCREF(&ApgRecord_Type);
1137+
PyTuple_SET_ITEM(types, 1, (PyObject *)&ApgRecordDesc_Type);
1138+
Py_INCREF(&ApgRecordDesc_Type);
1139+
return types;
1140+
}

asyncpg/protocol/record/recordobj.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ extern PyTypeObject ApgRecordDesc_Type;
4646
#define ApgRecord_GET_ITEM(op, i) \
4747
(((ApgRecordObject *)(op))->ob_item[i])
4848

49-
PyTypeObject *ApgRecord_InitTypes(void);
49+
PyObject *ApgRecord_InitTypes(void);
5050
PyObject *ApgRecord_New(PyTypeObject *, PyObject *, Py_ssize_t);
5151
PyObject *ApgRecordDesc_New(PyObject *, PyObject *);
5252

tests/test_record.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,6 @@ def test_record_get(self):
287287
self.assertEqual(r.get('nonexistent'), None)
288288
self.assertEqual(r.get('nonexistent', 'default'), 'default')
289289

290-
def test_record_not_pickleable(self):
291-
r = Record(R_A, (42,))
292-
with self.assertRaises(Exception):
293-
pickle.dumps(r)
294-
295290
def test_record_empty(self):
296291
r = Record(None, ())
297292
self.assertEqual(r, ())
@@ -575,3 +570,11 @@ class MyRecordBad:
575570
'record_class is expected to be a subclass of asyncpg.Record',
576571
):
577572
await self.connect(record_class=MyRecordBad)
573+
574+
def test_record_pickle(self):
575+
r = pickle.loads(pickle.dumps(Record(R_AB, (42, 43))))
576+
self.assertEqual(len(r), 2)
577+
self.assertEqual(r[0], 42)
578+
self.assertEqual(r[1], 43)
579+
self.assertEqual(r['a'], 42)
580+
self.assertEqual(r['b'], 43)

0 commit comments

Comments
 (0)