Skip to content

Commit 9f12bf1

Browse files
committed
Adding Pickle Support
1 parent 0cf7e1c commit 9f12bf1

File tree

1 file changed

+40
-1
lines changed
  • pyTsetlinMachineParallel

1 file changed

+40
-1
lines changed

pyTsetlinMachineParallel/tm.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019 Ole-Christoffer Granmo
1+
# Copyright (c) 2021 Ole-Christoffer Granmo
22

33
# Permission is hereby granted, free of charge, to any person obtaining a copy
44
# of this software and associated documentation files (the "Software"), to deal
@@ -117,6 +117,19 @@ def __init__(self, number_of_clauses, T, s, patch_dim, boost_true_positive_feedb
117117
else:
118118
self.s_range = s
119119

120+
def __getstate__(self):
121+
state = self.__dict__.copy()
122+
state['mc_ctm_state'] = self.get_state()
123+
del state['mc_ctm']
124+
if 'encoded_X' in state:
125+
del state['encoded_X']
126+
return state
127+
128+
def __setstate__(self, state):
129+
self.__dict__.update(state)
130+
self.mc_ctm = _lib.CreateMultiClassTsetlinMachine(self.number_of_classes, self.number_of_clauses, self.number_of_features, self.number_of_patches, self.number_of_ta_chunks, self.number_of_state_bits, self.T, self.s, self.s_range, self.boost_true_positive_feedback, self.weighted_clauses, self.clause_drop_p, self.literal_drop_p)
131+
self.set_state(state['mc_ctm_state'])
132+
120133
def __del__(self):
121134
if self.mc_ctm != None:
122135
_lib.mc_tm_destroy(self.mc_ctm)
@@ -237,6 +250,19 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb
237250
else:
238251
self.s_range = s
239252

253+
def __getstate__(self):
254+
state = self.__dict__.copy()
255+
state['mc_tm_state'] = self.get_state()
256+
del state['mc_tm']
257+
if 'encoded_X' in state:
258+
del state['encoded_X']
259+
return state
260+
261+
def __setstate__(self, state):
262+
self.__dict__.update(state)
263+
self.mc_tm = _lib.CreateMultiClassTsetlinMachine(self.number_of_classes, self.number_of_clauses, self.number_of_features, 1, self.number_of_ta_chunks, self.number_of_state_bits, self.T, self.s, self.s_range, self.boost_true_positive_feedback, self.weighted_clauses, self.clause_drop_p, self.literal_drop_p)
264+
self.set_state(state['mc_tm_state'])
265+
240266
def __del__(self):
241267
if self.mc_tm != None:
242268
_lib.mc_tm_destroy(self.mc_tm)
@@ -348,6 +374,19 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb
348374
else:
349375
self.s_range = s
350376

377+
def __getstate__(self):
378+
state = self.__dict__.copy()
379+
state['rtm_state'] = self.get_state()
380+
del state['rtm']
381+
if 'encoded_X' in state:
382+
del state['encoded_X']
383+
return state
384+
385+
def __setstate__(self, state):
386+
self.__dict__.update(state)
387+
self.rtm = _lib.CreateTsetlinMachine(self.number_of_clauses, self.number_of_features, 1, self.number_of_ta_chunks, self.number_of_state_bits, self.T, self.s, self.s_range, self.boost_true_positive_feedback, self.weighted_clauses)
388+
self.set_state(state['rtm_state'])
389+
351390
def __del__(self):
352391
if self.rtm != None:
353392
_lib.tm_destroy(self.rtm)

0 commit comments

Comments
 (0)