-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEventQueue.py
More file actions
109 lines (90 loc) · 4.78 KB
/
EventQueue.py
File metadata and controls
109 lines (90 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from collections import defaultdict
from typing import Callable
from itertools import chain
class Event:
def __init__(self, op, dst, msg_idx):
self.op = op
self.dst = dst
self.msg_idx = msg_idx
class EventQueue:
def __init__(self):
self.event_buffer_monotonic = []
self.event_buffer_accumulative = []
self.event_buffer_user = []
self.message_buffer = []
def push_monotonic_event(self, op, dst, msg):
msg_idx = self.push_message(msg)
self.event_buffer_monotonic.append(Event(op, dst, msg_idx))
def push_accumulative_event(self, op, dst, msg):
msg_idx = self.push_message(msg)
self.event_buffer_accumulative.append(Event(op, dst, msg_idx))
def push_user_event(self, op, dst, msg):
msg_idx = self.push_message(msg)
self.event_buffer_user.append(Event(op, dst, msg_idx))
def push_message(self, message: torch.Tensor):
idx = len(self.message_buffer)
self.message_buffer.append(message)
return idx
def bulky_push(self, old_out_neighbors:list, new_out_neighbors:list,
old_message:torch.Tensor, new_message:torch.Tensor, aggregator:str="min"):
if aggregator in ["min", "max"]:
old_message_idx = self.push_message(old_message.clone())
new_message_idx = self.push_message(new_message.clone())
for out_neighbor in old_out_neighbors:
self.event_buffer_monotonic.append(Event('remove', out_neighbor, old_message_idx))
for out_neighbor in new_out_neighbors:
self.event_buffer_monotonic.append(Event('insert', out_neighbor, new_message_idx))
elif aggregator in ["mean", "add"]:
shared_neighbours = set(old_out_neighbors) & set(new_out_neighbors)
if shared_neighbours:
delta_message_idx = self.push_message(new_message.clone() - old_message.clone())
for shared_neighbour in shared_neighbours:
self.event_buffer_accumulative.append(Event('update', shared_neighbour, delta_message_idx))
removed_neighbours = set(old_out_neighbors) - shared_neighbours
if removed_neighbours:
message_idx = self.push_message((-old_message).clone())
for removed_neighbour in removed_neighbours:
self.event_buffer_accumulative.append(Event('update', removed_neighbour, message_idx))
added_neighbours = set(new_out_neighbors) - shared_neighbours
if added_neighbours:
message_idx = self.push_message(new_message.clone())
for added_neighbour in added_neighbours:
self.event_buffer_accumulative.append(Event('update', added_neighbour, message_idx))
else:
self.bulky_push_user(old_out_neighbors, new_out_neighbors, old_message, new_message, aggregator)
def bulky_push_user(self, old_out_neighbors:list, new_out_neighbors:list,
old_message:torch.Tensor, new_message:torch.Tensor, aggregator:str="min"):
raise NotImplementedError("User-defined events are not supported yet")
def empty(self):
self.event_buffer_monotonic.clear()
self.event_buffer_accumulative.clear()
self.event_buffer_user.clear()
self.message_buffer.clear()
def default_factory(self):
return defaultdict(list)
def reduce(self, monotonic_aggregator:Callable,
accumulative_aggregator:Callable,
user_defined_reduce_function:Callable) -> dict :
# First, group tasks by 'dst' and 'op'
task_dict = defaultdict(self.default_factory)
for task in chain(self.event_buffer_monotonic, self.event_buffer_accumulative, self.event_buffer_user):
message = self.message_buffer[task.msg_idx]
task_dict[task.dst][task.op].append(message)
# Then, reduce messages of each group
for dst, ops in task_dict.items() :
for op, messages in ops.items() :
if op == 'remove' or op == 'insert' :
task_dict[dst][op] = monotonic_aggregator(messages)
elif op == 'update':
task_dict[dst][op] = accumulative_aggregator(messages)
else: # user-defined events
task_dict[dst][op] = user_defined_reduce_function(messages)
return task_dict
def print(self):
print("Events fot monotonic operations")
for event in self.event_buffer_monotonic:
print(f"<{event.op}, {event.dst}, {self.message_buffer[event.msg_idx].shape}>")
print("Events fot accumulative operations")
for event in self.event_buffer_accumulative:
print(f"<{event.op}, {event.dst}, {self.message_buffer[event.msg_idx].shape}>")