|
1 | | -# Copyright (c) 2019 Ole-Christoffer Granmo |
| 1 | +# Copyright (c) 2021 Ole-Christoffer Granmo |
2 | 2 |
|
3 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy |
4 | 4 | # of this software and associated documentation files (the "Software"), to deal |
@@ -117,6 +117,19 @@ def __init__(self, number_of_clauses, T, s, patch_dim, boost_true_positive_feedb |
117 | 117 | else: |
118 | 118 | self.s_range = s |
119 | 119 |
|
| 120 | + def __getstate__(self): |
| 121 | + state = self.__dict__.copy() |
| 122 | + state['mc_ctm_state'] = self.get_state() |
| 123 | + del state['mc_ctm'] |
| 124 | + if 'encoded_X' in state: |
| 125 | + del state['encoded_X'] |
| 126 | + return state |
| 127 | + |
| 128 | + def __setstate__(self, state): |
| 129 | + self.__dict__.update(state) |
| 130 | + self.mc_ctm = _lib.CreateMultiClassTsetlinMachine(self.number_of_classes, self.number_of_clauses, self.number_of_features, self.number_of_patches, self.number_of_ta_chunks, self.number_of_state_bits, self.T, self.s, self.s_range, self.boost_true_positive_feedback, self.weighted_clauses, self.clause_drop_p, self.literal_drop_p) |
| 131 | + self.set_state(state['mc_ctm_state']) |
| 132 | + |
120 | 133 | def __del__(self): |
121 | 134 | if self.mc_ctm != None: |
122 | 135 | _lib.mc_tm_destroy(self.mc_ctm) |
@@ -237,6 +250,19 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb |
237 | 250 | else: |
238 | 251 | self.s_range = s |
239 | 252 |
|
| 253 | + def __getstate__(self): |
| 254 | + state = self.__dict__.copy() |
| 255 | + state['mc_tm_state'] = self.get_state() |
| 256 | + del state['mc_tm'] |
| 257 | + if 'encoded_X' in state: |
| 258 | + del state['encoded_X'] |
| 259 | + return state |
| 260 | + |
| 261 | + def __setstate__(self, state): |
| 262 | + self.__dict__.update(state) |
| 263 | + self.mc_tm = _lib.CreateMultiClassTsetlinMachine(self.number_of_classes, self.number_of_clauses, self.number_of_features, 1, self.number_of_ta_chunks, self.number_of_state_bits, self.T, self.s, self.s_range, self.boost_true_positive_feedback, self.weighted_clauses, self.clause_drop_p, self.literal_drop_p) |
| 264 | + self.set_state(state['mc_tm_state']) |
| 265 | + |
240 | 266 | def __del__(self): |
241 | 267 | if self.mc_tm != None: |
242 | 268 | _lib.mc_tm_destroy(self.mc_tm) |
@@ -348,6 +374,19 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb |
348 | 374 | else: |
349 | 375 | self.s_range = s |
350 | 376 |
|
| 377 | + def __getstate__(self): |
| 378 | + state = self.__dict__.copy() |
| 379 | + state['rtm_state'] = self.get_state() |
| 380 | + del state['rtm'] |
| 381 | + if 'encoded_X' in state: |
| 382 | + del state['encoded_X'] |
| 383 | + return state |
| 384 | + |
| 385 | + def __setstate__(self, state): |
| 386 | + self.__dict__.update(state) |
| 387 | + self.rtm = _lib.CreateTsetlinMachine(self.number_of_clauses, self.number_of_features, 1, self.number_of_ta_chunks, self.number_of_state_bits, self.T, self.s, self.s_range, self.boost_true_positive_feedback, self.weighted_clauses) |
| 388 | + self.set_state(state['rtm_state']) |
| 389 | + |
351 | 390 | def __del__(self): |
352 | 391 | if self.rtm != None: |
353 | 392 | _lib.tm_destroy(self.rtm) |
|
0 commit comments