|
1 | 1 | #import "@preview/cetz:0.3.1": canvas, draw
|
2 | 2 |
|
3 |
| -#set page(width: auto, height: auto, margin: 5pt) |
4 |
| - |
5 |
| -#let layer-sep = 2.5 |
6 |
| -#let node-sep = 1.5 |
7 |
| -#let node-radius = 0.3 |
8 |
| -#let cross-size = node-radius * 1.2 |
9 |
| - |
10 |
| -// Helper function to draw a neural network layer |
11 |
| -#let draw-layer(x-pos, n-nodes, disabled: (), center: false) = { |
12 |
| - let y-coords = () |
13 |
| - |
14 |
| - for idx in range(n-nodes) { |
15 |
| - // Center vertically if requested (for output layer) |
16 |
| - let y-offset = if center { (5 - n-nodes) * node-sep / 2 } else { 0 } |
17 |
| - let y-pos = idx * node-sep + y-offset |
18 |
| - |
19 |
| - // Draw node as filled circle with black outline |
20 |
| - draw.circle( |
21 |
| - (x-pos, y-pos), |
22 |
| - radius: node-radius, |
23 |
| - stroke: black + 1pt, |
24 |
| - fill: white, |
25 |
| - ) |
26 |
| - |
27 |
| - // Add X if node is disabled |
28 |
| - if idx + 1 in disabled { |
29 |
| - // Draw red X with thicker lines |
30 |
| - draw.line( |
31 |
| - (x-pos - cross-size, y-pos - cross-size), |
32 |
| - (x-pos + cross-size, y-pos + cross-size), |
33 |
| - stroke: red + 2pt, |
34 |
| - ) |
35 |
| - draw.line( |
36 |
| - (x-pos - cross-size, y-pos + cross-size), |
37 |
| - (x-pos + cross-size, y-pos - cross-size), |
38 |
| - stroke: red + 2pt, |
| 3 | +#set page(width: auto, height: auto, margin: 8pt) |
| 4 | + |
| 5 | +#canvas({ |
| 6 | + import draw: line, circle, content |
| 7 | + |
| 8 | + let node-style = (stroke: black + 1pt, fill: white) |
| 9 | + |
| 10 | + let layer-sep = 2.5 // Horizontal separation between layers |
| 11 | + let node-sep = 1.5 // Vertical separation between nodes |
| 12 | + let arrow-style = (stroke: black + 1pt, mark: (end: "stealth"), fill: black) |
| 13 | + |
| 14 | + // Helper function to draw a layer of nodes |
| 15 | + let draw-layer(x, nodes, prefix: "") = { |
| 16 | + for ii in range(nodes) { |
| 17 | + circle( |
| 18 | + (x, node-sep * (ii + 1)), |
| 19 | + radius: 0.3, |
| 20 | + name: prefix + str(ii + 1), |
| 21 | + ..node-style, |
39 | 22 | )
|
40 |
| - } else { |
41 |
| - y-coords.push(y-pos) |
42 | 23 | }
|
43 | 24 | }
|
44 |
| - return y-coords |
45 |
| -} |
46 |
| - |
47 |
| -// Helper to draw connections between layers |
48 |
| -#let connect-layers(x1, y1s, x2, y2s) = { |
49 |
| - for y1 in y1s { |
50 |
| - for y2 in y2s { |
51 |
| - draw.line( |
52 |
| - (float(x1), float(y1)), |
53 |
| - (float(x2), float(y2)), |
54 |
| - mark: (end: "stealth", fill: black), |
55 |
| - stroke: black + 1pt, |
56 |
| - ) |
| 25 | + |
| 26 | + // Helper to connect all nodes between layers |
| 27 | + let connect-layers(from-prefix, to-prefix, from-nodes, to-nodes) = { |
| 28 | + for ii in range(from-nodes) { |
| 29 | + for jj in range(to-nodes) { |
| 30 | + line( |
| 31 | + (from-prefix + str(ii + 1)), |
| 32 | + (to-prefix + str(jj + 1)), |
| 33 | + ..arrow-style, |
| 34 | + ) |
| 35 | + } |
57 | 36 | }
|
58 | 37 | }
|
59 |
| -} |
60 | 38 |
|
61 |
| -#canvas({ |
62 |
| - // Left network (before dropout) |
63 |
| - let x = 0.0 |
64 |
| - let y1s = draw-layer(x, 5) |
65 |
| - |
66 |
| - x += layer-sep |
67 |
| - let y2s = draw-layer(x, 5) |
68 |
| - |
69 |
| - x += layer-sep |
70 |
| - let y3s = draw-layer(x, 5) |
71 |
| - |
72 |
| - x += layer-sep |
73 |
| - let y4s = draw-layer(x, 2, disabled: (), center: true) // vertically center output layer |
74 |
| - |
75 |
| - // Connect all nodes in adjacent layers |
76 |
| - connect-layers(0.0, y1s, layer-sep, y2s) |
77 |
| - connect-layers(layer-sep, y2s, 2 * layer-sep, y3s) |
78 |
| - connect-layers(2 * layer-sep, y3s, 3 * layer-sep, y4s) |
79 |
| - |
80 |
| - // Dropout arrow |
81 |
| - let arrow-x = 3.5 * layer-sep |
82 |
| - draw.line( |
83 |
| - (arrow-x, 2 * node-sep), |
84 |
| - (arrow-x + 2, 2 * node-sep), |
85 |
| - mark: (end: "stealth", fill: black), |
86 |
| - stroke: black + 1pt, |
87 |
| - ) |
88 |
| - draw.content( |
89 |
| - (arrow-x + 1, 2 * node-sep + 0.3), |
90 |
| - text(weight: "bold", size: 1.2em, "dropout"), |
91 |
| - ) |
| 39 | + // Left network (fully connected) |
| 40 | + // Draw all layers |
| 41 | + draw-layer(0, 5, prefix: "i") // Input layer |
| 42 | + draw-layer(layer-sep, 5, prefix: "h1") // First hidden layer |
| 43 | + draw-layer(2 * layer-sep, 5, prefix: "h2") // Second hidden layer |
| 44 | + |
| 45 | + // Draw output nodes |
| 46 | + circle((3 * layer-sep, 2 * node-sep), radius: 0.3, name: "o1", ..node-style) |
| 47 | + circle((3 * layer-sep, 4 * node-sep), radius: 0.3, name: "o2", ..node-style) |
92 | 48 |
|
93 |
| - // Right network (after dropout) |
94 |
| - x = arrow-x + 3 |
95 |
| - let dy1s = draw-layer(x, 5, disabled: (1, 3)) |
| 49 | + // Connect all layers |
| 50 | + connect-layers("i", "h1", 5, 5) |
| 51 | + connect-layers("h1", "h2", 5, 5) |
96 | 52 |
|
97 |
| - x += layer-sep |
98 |
| - let dy2s = draw-layer(x, 5, disabled: (1, 3, 4)) |
| 53 | + // Connect to output nodes |
| 54 | + for ii in range(5) { |
| 55 | + line(("h2" + str(ii + 1)), "o1", ..arrow-style) |
| 56 | + line(("h2" + str(ii + 1)), "o2", ..arrow-style) |
| 57 | + } |
99 | 58 |
|
100 |
| - x += layer-sep |
101 |
| - let dy3s = draw-layer(x, 5, disabled: (2, 4)) |
| 59 | + // Draw dropout arrow |
| 60 | + let mid-x = 4 * layer-sep |
| 61 | + line( |
| 62 | + (3.5 * layer-sep, 3 * node-sep), |
| 63 | + (4.5 * layer-sep, 3 * node-sep), |
| 64 | + ..arrow-style, |
| 65 | + name: "dropout-arrow", |
| 66 | + ) |
| 67 | + content( |
| 68 | + "dropout-arrow.mid", |
| 69 | + text(weight: "bold", size: 1.2em)[dropout], |
| 70 | + anchor: "south", |
| 71 | + padding: 3pt, |
| 72 | + ) |
102 | 73 |
|
103 |
| - x += layer-sep |
104 |
| - let dy4s = draw-layer(x, 2, center: true) // vertically center output layer |
| 74 | + // Right network (with dropout) |
| 75 | + // Draw all layers |
| 76 | + draw-layer(mid-x + layer-sep, 5, prefix: "di") |
| 77 | + draw-layer(mid-x + 2 * layer-sep, 5, prefix: "dh1") |
| 78 | + draw-layer(mid-x + 3 * layer-sep, 5, prefix: "dh2") |
| 79 | + |
| 80 | + // Draw output nodes |
| 81 | + circle((mid-x + 4 * layer-sep, 2 * node-sep), radius: 0.3, name: "do1", ..node-style) |
| 82 | + circle((mid-x + 4 * layer-sep, 4 * node-sep), radius: 0.3, name: "do2", ..node-style) |
| 83 | + |
| 84 | + // Add dropout X marks |
| 85 | + let x-style = (fill: red, weight: "bold", size: 4em, baseline: -4pt) |
| 86 | + content("di1", text(..x-style)[×]) |
| 87 | + content("di3", text(..x-style)[×]) |
| 88 | + content("dh11", text(..x-style)[×]) |
| 89 | + content("dh13", text(..x-style)[×]) |
| 90 | + content("dh14", text(..x-style)[×]) |
| 91 | + content("dh22", text(..x-style)[×]) |
| 92 | + content("dh24", text(..x-style)[×]) |
| 93 | + |
| 94 | + // Connect remaining nodes (after dropout) |
| 95 | + for ii in (2, 4, 5) { |
| 96 | + for jj in (2, 5) { |
| 97 | + line(("di" + str(ii)), ("dh1" + str(jj)), ..arrow-style) |
| 98 | + } |
| 99 | + } |
105 | 100 |
|
106 |
| - // Connect only enabled nodes |
107 |
| - connect-layers(x - 3 * layer-sep, dy1s, x - 2 * layer-sep, dy2s) |
108 |
| - connect-layers(x - 2 * layer-sep, dy2s, x - layer-sep, dy3s) |
109 |
| - connect-layers(x - layer-sep, dy3s, x, dy4s) |
| 101 | + for ii in (2, 5) { |
| 102 | + for jj in (1, 3, 5) { |
| 103 | + line(("dh1" + str(ii)), ("dh2" + str(jj)), ..arrow-style) |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + for ii in (1, 3, 5) { |
| 108 | + line(("dh2" + str(ii)), "do1", ..arrow-style) |
| 109 | + line(("dh2" + str(ii)), "do2", ..arrow-style) |
| 110 | + } |
110 | 111 | })
|
0 commit comments