Skip to content

Commit 5650e40

Browse files
committed
allow methods as process-based routing functions
1 parent 1dc94c2 commit 5650e40

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

CHANGES.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
History
22
-------
33

4+
+ **3.1.3** (2024-04-08)**
5+
+ Allows class methods as generator functions for process-based routing.
6+
47
+ **3.1.2 (2024-04-08)**
58
+ Fix bug when using Mixture distribution.
69

ciw/import_params.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def create_network_from_dictionary(params_input):
115115
)
116116
)
117117
for clss_name in params['customer_class_names']:
118-
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
118+
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
119119
classes[clss_name] = CustomerClass(
120120
params['arrival_distributions'][clss_name],
121121
params['service_distributions'][clss_name],
@@ -140,7 +140,7 @@ def create_network_from_dictionary(params_input):
140140
class_change_time_distributions[clss_name],
141141
)
142142
n = Network(nodes, classes)
143-
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
143+
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
144144
n.process_based = True
145145
else:
146146
n.process_based = False
@@ -220,7 +220,7 @@ def validify_dictionary(params):
220220
Raises errors if there is something wrong with the
221221
parameters dictionary.
222222
"""
223-
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
223+
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
224224
consistant_num_classes = (
225225
params["number_of_classes"]
226226
== len(params["arrival_distributions"])
@@ -241,7 +241,7 @@ def validify_dictionary(params):
241241
)
242242
if not consistant_num_classes:
243243
raise ValueError("Ensure consistant number of classes is used throughout.")
244-
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
244+
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
245245
consistant_class_names = (
246246
set(params["arrival_distributions"])
247247
== set(params["service_distributions"])
@@ -266,7 +266,7 @@ def validify_dictionary(params):
266266
)
267267
if not consistant_class_names:
268268
raise ValueError("Ensure consistant names for customer classes.")
269-
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
269+
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
270270
num_nodes_count = (
271271
[params["number_of_nodes"]]
272272
+ [len(obs) for obs in params["arrival_distributions"].values()]
@@ -296,7 +296,7 @@ def validify_dictionary(params):
296296
)
297297
if len(set(num_nodes_count)) != 1:
298298
raise ValueError("Ensure consistant number of nodes is used throughout.")
299-
if not all(isinstance(f, types.FunctionType) for f in params["routing"]):
299+
if not all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
300300
for clss in params["routing"].values():
301301
for row in clss:
302302
if sum(row) > 1.0 or min(row) < 0.0 or max(row) > 1.0:

ciw/tests/test_process_based.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def generator_function_8(ind):
4545
return [1]
4646
return [1, 1, 1]
4747

48+
class ClassForProcessBasedMethod:
49+
def __init__(self, n):
50+
self.n = n
51+
def generator_method(self, ind):
52+
return [1, 1, 1]
53+
4854

4955
class TestProcessBased(unittest.TestCase):
5056
def test_network_takes_routing_function(self):
@@ -294,3 +300,19 @@ def test_customer_class_based_routing(self):
294300
inds = Q.nodes[-1].all_individuals
295301
routes_counter = set([tuple([ind.customer_class, tuple(dr.node for dr in ind.data_records)]) for ind in inds])
296302
self.assertEqual(routes_counter, {('Class 1', (1, 1, 1)), ('Class 0', (1,))})
303+
304+
def test_process_based_takes_methods(self):
305+
import types
306+
G = ClassForProcessBasedMethod(5)
307+
self.assertTrue(isinstance(G.generator_method, types.MethodType))
308+
N = ciw.create_network(
309+
arrival_distributions=[ciw.dists.Deterministic(1)],
310+
service_distributions=[ciw.dists.Deterministic(1000)],
311+
number_of_servers=[1],
312+
routing=[G.generator_method],
313+
)
314+
Q = ciw.Simulation(N)
315+
Q.simulate_until_max_time(4.5)
316+
inds = Q.nodes[1].all_individuals
317+
for ind in inds:
318+
self.assertEqual(ind.route, [1, 1, 1])

ciw/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.1.2"
1+
__version__ = "3.1.3"

0 commit comments

Comments
 (0)