-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.py
153 lines (122 loc) · 6.07 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Utility functions for dealing with trees
from collections import defaultdict
import numpy as np
import torch
from node import Node
from queue import Queue
import json
class Graph():
def __init__(self, nodes, root_id):
self.root_id = root_id
self.nodes = {root_id: nodes.pop(root_id), **nodes}
def __str__(self):
return f"Graph: root={self.root_id} num_nodes={len(self.nodes)}"
def __len__(self):
return len(self.nodes)
def __getitem__(self, node_name):
return self.nodes[node_name]
def compute_reachable_leaves(self):
leaves = [node for node in self.nodes.values() if len(node.children) == 0]
for leaf in leaves:
leaf.reachable_leaves = set([leaf])
self.nodes[self.root_id].set_reachable_leaves()
def compute_height(self):
self.nodes[self.root_id].set_height()
def compute_depth(self):
leaves = [node for node in self.nodes.values() if len(node.children) == 0]
for leaf in leaves:
leaf.set_depth()
def finalize(self):
self.compute_reachable_leaves()
self.compute_height()
self.compute_depth()
for node in self.nodes.values():
if node.depth is None:
print(node.name, node.depth)
print([(l.name, l.depth, len(l.children)) for l in node.reachable_leaves])
assert node.reachable_leaves is not None
assert node.height is not None
assert node.depth is not None
def serialize(self, include_values=False):
output = []
node_ids = {node.name: i for i, node in enumerate(self.nodes.values())}
for i, (node_name, node) in enumerate(self.nodes.items()):
json_object = {'id': node_ids[node.name], 'name': node.name}
if include_values:
json_object['values'] = node.values
json_object['parents'] = [node_ids[parent.name] for parent in node.parents]
output.append(json_object)
return output
# def specify_tree(self, model_input, model, device, labels):
# """Specifies the confidence tree for a specific input.
# Args:
# model_input (Tensor): the model_input used to set the confidence.
# model (Pytorch Model): the model that takes in the model_input.
# device (Pytoch Device): the device the model is on.
# labels (list): a list of class names corresponding the model's
# outputs.
# allow_multiparents (bool, default=False): if True, confidence of a
# non-leaf node is the sum of the confidence of its reachable
# leaves. If False, multiple parents are not allowed, so
# confidence can only travel one path from a leaf to any other
# node.
# Returns: nothing. Updates the nodes in the tree with their confidence
# values.
# """
# model.eval()
# model_inputs = model_input.unsqueeze(0)
# with torch.no_grad():
# model_inputs = model_inputs.to(device)
# # Compute model predictions
# output = model(model_inputs)
# confidences = torch.nn.functional.softmax(output, dim=1).squeeze(0).detach().cpu().numpy()
# # Propogate confidences up the tree
# self.set_and_propogate_confidences(confidences, labels)
# def set_and_propogate_confidences(self, confidences, labels):
# """Sets the leaf confidences and propogates them through the tree."""
# self.clear_confidence()
# for i, confidence in enumerate(confidences):
# node_name = labels[i]
# node = self.nodes[node_name]
# node.value = confidence
# self.propogate_confidence()
# def clear_confidence(self):
# """Clears the confidence for every node in the tree."""
# for node in self.nodes.values():
# node.value = None
# def propogate_confidence(self, confidence_threshold=1e-4):
# """Propograte confidence values from leaf to root in a tree. Confidence
# at the leaf node is the model's predicted confidence. Confidence at an
# internal node is the sum of the confidence from leaf nodes the internal
# node can reach.
# Args:
# allow_multiparents (bool, default=False): if True, confidence of a
# non-leaf node is the sum of the confidence of its reachable
# leaves. If False, multiple parents are not allowed, so
# confidence can only travel one path from a leaf to any other
# node. At each confident node, it allows confidence to travel to
# the first parent only.
# confidence_threshold (float, default=1e-4): A node is considered
# confident if its confidence is above the confidence_threshold.
# Only matters when allow_multiparents=False and confident nodes
# can only have one parent.
# Returns: nothing. Nodes in the tree are modified with updated confidence
# values.
# """
# for node in self.nodes.values():
# value = np.sum([reachable_leaf.value for reachable_leaf in node.reachable_leaves])
# node.value = value
# def to_json(self, filename=None):
# result = []
# name_to_id = {node.name: i for i, node in enumerate(self.nodes.values())}
# for i, (code, node) in enumerate(self.nodes.items()):
# node_result = {}
# node_result['name'] = node.name
# node_result['id'] = name_to_id[node.name]
# if node.parent is not None:
# node_result['parent'] = node.parent.name
# result.append(node_result)
# if filename is not None:
# with open(filename, 'w') as f:
# json.dump(result, f, indent=4)
# return result