forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFunctional.cs
392 lines (343 loc) · 15.3 KB
/
Functional.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Engine
{
/// <summary>
/// A `Functional` model is a `Model` defined as a directed graph of layers.
/// </summary>
public partial class Functional : Model
{
List<ILayer> _output_layers;
List<ILayer> _input_layers;
List<KerasHistory> _input_coordinates;
List<KerasHistory> _output_coordinates;
public string[] NetworkNodes { get; set; }
Dictionary<long, int> tensor_usage_count;
/// <summary>
/// Dictionary of layer dependencies to be included in the checkpoint.
/// </summary>
public IDictionary<string, ILayer> LayerCheckpointDependencies
{
get
{
int weight_layer_index = 0;
Dictionary<string, ILayer> dependencies = new();
for(int i = 0; i < Layers.Count; i++)
{
var layer = Layers[i];
var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList();
if(weights.Count > 0)
{
dependencies[$"layer_with_weights-{weight_layer_index}"] = layer;
weight_layer_index++;
}
dependencies[$"layer-{i}"] = layer;
}
return dependencies;
}
}
public Functional(Tensors inputs, Tensors outputs, string name = null)
: base(new ModelArgs
{
Name = name,
Inputs = inputs,
Outputs = outputs
})
{
Initialize(inputs, outputs, name);
}
internal void Initialize(Tensors inputs, Tensors outputs, string name = null)
{
_input_layers = new List<ILayer>();
_output_layers = new List<ILayer>();
_input_coordinates = new List<KerasHistory>();
_output_coordinates = new List<KerasHistory>();
tensor_usage_count = new Dictionary<long, int>();
if (this is Sequential)
return;
_init_graph_network(inputs, outputs);
}
protected void _init_graph_network(Tensors inputs, Tensors outputs)
{
_is_graph_network = true;
this.inputs = inputs;
this.outputs = outputs;
built = true;
if(inputs.Length > 0)
{
_buildInputShape = inputs.shape;
}
else
{
_buildInputShape = new TensorShapeConfig();
}
if (outputs.Any(x => x.KerasHistory == null))
base_layer_utils.create_keras_history(outputs);
// Build self._output_layers:
foreach (var x in outputs)
{
var (layer, node_index, tensor_index) = x.KerasHistory;
_output_layers.append(layer);
_output_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
}
// Build self._input_layers:
foreach (var x in inputs)
{
var (layer, node_index, tensor_index) = x.KerasHistory;
_input_layers.append(layer);
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
}
// Keep track of the network's nodes and layers.
(NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs);
// Build self.input_names and self.output_names.
_set_output_names();
ComputeTensorUsageCount();
}
/// <summary>
/// Assigns unique names to the Network's outputs.
/// </summary>
void _set_output_names()
{
var uniquified = new List<string>();
var output_names = new List<string>();
var prefix_count = new Dictionary<string, int>();
foreach (var layer in _output_layers)
{
var proposal = layer.Name;
while (output_names.Contains(proposal))
{
var existing_count = prefix_count.Get(layer.Name, 1);
proposal = $"{layer.Name}_{existing_count}";
prefix_count[layer.Name] = existing_count + 1;
}
output_names.add(proposal);
uniquified.append(proposal);
}
this.output_names = uniquified.ToArray();
}
void ComputeTensorUsageCount()
{
var available_tensors = inputs.Select(x => x.Id).ToList();
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray();
foreach (var depth in depth_keys)
{
foreach (var node in NodesByDepth[depth])
{
var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray();
if (input_tensors.issubset(available_tensors))
{
foreach (var tensor in node.KerasInputs)
{
if (!tensor_usage_count.ContainsKey(tensor.Id))
tensor_usage_count[tensor.Id] = 0;
tensor_usage_count[tensor.Id] += 1;
}
foreach (var output_tensor in node.Outputs)
available_tensors.Add(output_tensor.Id);
}
}
}
foreach (var tensor in outputs)
{
if (!tensor_usage_count.ContainsKey(tensor.Id))
tensor_usage_count[tensor.Id] = 0;
tensor_usage_count[tensor.Id] += 1;
}
}
/// <summary>
/// Validates a network's topology and gather its layers and nodes.
/// </summary>
/// <param name="inputs"></param>
/// <param name="outputs"></param>
(string[], Dictionary<int, List<INode>>, List<ILayer>, Dictionary<int, List<ILayer>>) MapGraphNetwork(Tensors inputs, Tensors outputs)
{
var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs);
var network_nodes = nodes_in_decreasing_depth
.Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node)))
.ToList();
var nodes_depths = new Dictionary<INode, int>();
var layers_depths = new Dictionary<ILayer, int>();
nodes_in_decreasing_depth.Reverse();
foreach (var node in nodes_in_decreasing_depth)
{
// If the depth is not set, the node has no outbound nodes (depth 0).
int depth = nodes_depths.SetDefault(node, 0);
// Update the depth of the corresponding layer
int previous_depth = layers_depths.Get(node.Layer, 0);
// If we've seen this layer before at a higher depth,
// we should use that depth instead of the node depth.
// This is necessary for shared layers that have inputs at different
// depth levels in the graph.
depth = Math.Max(depth, previous_depth);
layers_depths[node.Layer] = depth;
nodes_depths[node] = depth;
// Update the depth of inbound nodes.
// The "depth" of a node is the max of the depths
// of all nodes it is connected to + 1.
foreach (var node_dep in node.ParentNodes)
{
previous_depth = nodes_depths.Get(node_dep, 0);
nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth);
}
}
// Handle inputs that are not connected to outputs.
// We do not error out here because the inputs may be used to compute losses
// and metrics.
foreach (var input_t in inputs)
{
var (input_layer, _, _) = input_t.KerasHistory;
if (!layers_depths.ContainsKey(input_layer))
{
layers_depths[input_layer] = 0;
layer_indices[input_layer] = -1;
nodes_depths[input_layer.InboundNodes[0]] = 0;
network_nodes.Add(MakeNodeKey(input_layer.Name, 0));
}
}
// Build a dict {depth: list of nodes with this depth}
var nodes_by_depth = new Dictionary<int, List<INode>>();
foreach (var (node, depth) in enumerate(nodes_depths))
{
if (!nodes_by_depth.ContainsKey(depth))
nodes_by_depth[depth] = new List<INode>();
nodes_by_depth[depth].Add(node);
}
var layers_by_depth = new Dictionary<int, List<ILayer>>();
foreach (var (layer, depth) in enumerate(layers_depths))
{
if (!layers_by_depth.ContainsKey(depth))
layers_by_depth[depth] = new List<ILayer>();
layers_by_depth[depth].Add(layer);
}
// Get sorted list of layer depths.
var depth_keys = layers_by_depth.Keys.OrderBy(x => x).Reverse();
// Set self.layers ordered by depth.
var layers = new List<ILayer>();
foreach (var depth in depth_keys)
{
var layers_for_depth = layers_by_depth[depth];
// Network.layers needs to have a deterministic order:
// here we order them by traversal order.
layers_for_depth = layers_for_depth.OrderBy(x => layer_indices[x]).ToList();
layers.AddRange(layers_for_depth);
}
// Get sorted list of node depths.
depth_keys = nodes_by_depth.Keys.OrderBy(x => x).Reverse();
return (network_nodes.ToArray(), nodes_by_depth, layers, layers_by_depth);
}
string MakeNodeKey(string layer_name, int node_index)
=> $"{layer_name}_ib-{node_index}";
/// <summary>
/// This method topologically sorts nodes in order from inputs to outputs.
/// </summary>
/// <param name="outputs"></param>
(List<INode>, Dictionary<ILayer, int>) BuildMap(Tensors outputs)
{
var finished_nodes = new List<INode>();
var nodes_in_progress = new List<INode>();
var nodes_in_decreasing_depth = new List<INode>();
var layer_indices = new Dictionary<ILayer, int>();
foreach (var output in outputs)
BuildMapHelper(output,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
layer_indices);
return (nodes_in_decreasing_depth, layer_indices);
}
void BuildMapHelper(Tensor tensor,
List<INode> finished_nodes,
List<INode> nodes_in_progress,
List<INode> nodes_in_decreasing_depth,
Dictionary<ILayer, int> layer_indices)
{
var (layer, node_index, _) = tensor.KerasHistory;
var node = layer.InboundNodes[node_index] as Node;
// Don't repeat work for shared subgraphs
if (finished_nodes.Contains(node))
return;
// Prevent cycles.
if (nodes_in_progress.Contains(node))
throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle.");
// Store the traversal order for layer sorting.
if (!layer_indices.ContainsKey(layer))
layer_indices[layer] = layer_indices.Count;
// Propagate to all previous tensors connected to this node.
nodes_in_progress.Add(node);
if (!node.is_input)
{
foreach (var k_tensor in node.KerasInputs)
{
BuildMapHelper(k_tensor,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
layer_indices);
}
}
finished_nodes.Add(node);
nodes_in_progress.Remove(node);
nodes_in_decreasing_depth.append(node);
}
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var tensor_dict = new Dictionary<long, Queue<Tensor>>();
// map input values
foreach (var (x, y) in zip(this.inputs, inputs))
{
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y));
}
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
foreach (var depth in depth_keys)
{
var nodes = NodesByDepth[depth];
foreach (Node node in nodes)
{
// Input tensors already exist.
if (node.is_input)
continue;
var layer_inputs = node.MapArguments(tensor_dict);
tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}");
var outputs = node.Layer.Apply(layer_inputs, training: training ?? false);
foreach (var output in outputs.Where(x => x != null))
tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}");
// Update tensor_dict for next or later input
foreach (var (x_id, y) in zip(node.Outputs.Select(x => x.Id), outputs))
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));
}
}
var output_tensors = new Tensors();
foreach (var x in outputs)
output_tensors.Add(tensor_dict[x.Id].Dequeue());
return output_tensors;
}
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
.ToDictionary(x => x.Key, x => x.Value);
}
protected override void _init_set_name(string name, bool zero_based = true)
{
if (string.IsNullOrEmpty(name))
{
string class_name = GetType().Name;
if (this.GetType() == typeof(Functional))
{
class_name = "Model";
}
this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based);
}
else
{
this.name = name;
}
}
}
}