Skip to content

Commit 3a9ead8

Browse files
authored
Merge pull request #8 from EiffL/Seattle2019
Adds missing files
2 parents 77ec989 + d682153 commit 3a9ead8

File tree

3 files changed

+273
-0
lines changed

3 files changed

+273
-0
lines changed

Seattle2019/dgm_prior.html

+272
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<head>
4+
<title>Generative Models as Priors for Inverse Problems</title>
5+
<style>
6+
body {
7+
height: 500px;
8+
width: 500px;
9+
}
10+
/* #animation {
11+
position: absolute;
12+
top: 0px;
13+
left: 0px;
14+
background: #000;
15+
} body {
16+
text-align: center;
17+
}
18+
19+
#mynetwork {
20+
height: 500px;
21+
} */
22+
</style>
23+
24+
<!-- Import TensorFlow.js -->
25+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
26+
<script src="https://d3js.org/d3.v5.min.js"></script>
27+
</head>
28+
29+
<body>
30+
<canvas id="animation" height="500" width="500"></canvas>
31+
<script>
32+
33+
(async function animation() {
34+
// This function is closely modeled on http://bl.ocks.org/newby-jay/767c5ffdbbe43b65902f
35+
const model = await tf.loadGraphModel('models/js/export3/model.json');
36+
const grads = tf.grad(x => model.predict(x));
37+
const vfunc = (x,y) => {
38+
c = tf.concat([tf.reshape(x, [-1,1] ), tf.reshape(y, [-1,1])], axis=1);
39+
gr = grads(c);
40+
gn = tf.sum(tf.mul(gr,gr), axis=1, keepDims=true);
41+
gr = tf.mul(tf.div(gr, gn), tf.clipByValue(gn,0,100));
42+
43+
return tf.split(gr, 2, axis=1);
44+
};
45+
46+
47+
// vector field data
48+
var dt = 0.003,
49+
X0 = [], Y0 = [], // to store initial starting locations
50+
X = [], Y = [], // to store current point for each curve
51+
Xd=[], Yd=[],
52+
xb = 4, yb = 3;
53+
var sigma=0.2;
54+
var X1=0.5, Y1=0.5, X2=0, Y2=0;
55+
var XC=1, YC=1;
56+
var width = 500, height = 500;
57+
58+
// First draw the modelled density in the background
59+
var N = 128
60+
var xd = d3.range(N).map(
61+
function (i) {
62+
return -1.5 + xb*i/N;
63+
}),
64+
yd = d3.range(N).map(
65+
function (i) {
66+
return -1 + yb*i/N;
67+
});
68+
// array of starting positions for each curve on a uniform grid
69+
for (var i = 0; i < N; i++) {
70+
for (var j = 0; j < N; j++) {
71+
Xd.push(xd[j]), Yd.push(yd[i]);
72+
}
73+
}
74+
75+
// Compute the density field in this input resolution and rescale it to output res
76+
const logp = tf.tidy(() => {
77+
const c = tf.concat([tf.reshape(Xd, [-1,1] ), tf.reshape(Yd, [-1,1])], axis=1);
78+
const out = model.predict(c);
79+
const out_resized = tf.exp(tf.image.resizeBilinear(tf.reshape(out, [N,N,1]), [width, height]));
80+
return out_resized.dataSync();
81+
});
82+
83+
// Store this array as image data
84+
var g = d3.select("#animation").node().getContext("2d");
85+
var imagedata = g.createImageData(width, height);
86+
for (var x=0; x<width; x++) {
87+
for (var y=0; y<height; y++) {
88+
var pixelindex = (y * width + x) * 4;
89+
// Generate a xor pattern with some random noise
90+
var po = logp[((height -1 - y) * width + x)]*0.5;
91+
if(isNaN(po)){ po = 0; }
92+
c = d3.rgb(d3.interpolateInferno(po));
93+
// Set the pixel data
94+
imagedata.data[pixelindex] = c.r; // Red
95+
imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green
96+
imagedata.data[pixelindex+2] = c.b; // Blue
97+
imagedata.data[pixelindex+3] = 255; // Alpha
98+
}
99+
}
100+
g.putImageData(imagedata,0,0);
101+
for (var x=0; x<width; x++) {
102+
for (var y=0; y<height; y++) {
103+
var pixelindex = (y * width + x) * 4;
104+
// Generate a xor pattern with some random noise
105+
var po = logp[((height -1 - y) * width + x)]*0.5;
106+
if(isNaN(po)){ po = 0; }
107+
c = d3.rgb(d3.interpolateInferno(po));
108+
// Set the pixel data
109+
imagedata.data[pixelindex] = c.r; // Red
110+
imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green
111+
imagedata.data[pixelindex+2] = c.b; // Blue
112+
imagedata.data[pixelindex+3] = 25; // Alpha
113+
}
114+
}
115+
116+
var N = 50;
117+
var xp = d3.range(N).map(
118+
function (i) {
119+
return -1.5 + xb*i/N;
120+
}),
121+
yp = d3.range(N).map(
122+
function (i) {
123+
return -1 + yb*i/N;
124+
});
125+
// array of starting positions for each curve on a uniform grid
126+
for (var i = 0; i < N; i++) {
127+
for (var j = 0; j < N; j++) {
128+
X.push(xp[j]), Y.push(yp[i]);
129+
X0.push(xp[j]), Y0.push(yp[i]);
130+
}
131+
}
132+
133+
// // vfield
134+
function F(x, y) {
135+
const [px, py] = tf.tidy(() => {
136+
const [predx, predy] = vfunc(x, y);
137+
return [predx.dataSync(), predy.dataSync()];
138+
});
139+
return [px, py];
140+
}
141+
142+
//// frame setup
143+
var mw = 0;
144+
145+
g.lineWidth = 0.8;
146+
g.strokeStyle = "#FF8000"; // html color code
147+
148+
//// mapping from vfield coords to web page coords
149+
var xMap = d3.scaleLinear()
150+
.domain([-1.5, 2.5])
151+
.range([mw, width - mw]),
152+
yMap = d3.scaleLinear()
153+
.domain([-1, 2.])
154+
.range([height - mw, mw]);
155+
//// animation setup
156+
var animAge = 0,
157+
frameRate = 30, // ms per timestep (yeah I know it's not really a rate)
158+
M = X.length,
159+
thr=200,
160+
MaxAge = 100, // # timesteps before restart
161+
age = [];
162+
163+
for (var i=0; i<M; i++) {age.push(randage());}
164+
var drawFlag = false;
165+
166+
d3.timer(function () {if (drawFlag) {draw();}}, frameRate);
167+
d3.select("#animation")
168+
.on("click", function() {
169+
var mouse = d3.mouse(this);
170+
XC = xMap.invert(mouse[0]);
171+
YC = yMap.invert(mouse[1]);
172+
})
173+
174+
d3.select("body").on("keypress", function() {
175+
if(d3.event.keyCode === 32 || d3.event.keyCode === 13){
176+
drawFlag = (drawFlag) ? false : true;
177+
}
178+
if(d3.event.keyCode === 61 ){
179+
sigma = sigma*2.;
180+
}
181+
if(d3.event.keyCode === 45 ){
182+
sigma /= 2.;
183+
}
184+
})
185+
function randage() {
186+
// to randomize starting ages for each curve
187+
return Math.round(Math.random()*100);
188+
}
189+
190+
var overlayCanvas = document.createElement("canvas");
191+
overlayCanvas.width = width;
192+
overlayCanvas.height = height;
193+
overlayCanvas.getContext("2d").putImageData(imagedata, 0, 0);
194+
g.imageSmoothingEnabled = false;
195+
196+
// for info on the global canvas operations see
197+
// http://bucephalus.org/text/CanvasHandbook/CanvasHandbook.html#globalcompositeoperation
198+
g.globalCompositeOperation = "source-over";
199+
function draw() {
200+
var s = (xMap(sigma) - xMap(0));
201+
//g.fillRect(0, 0, width, height); // fades all existing curves by a set amount determined by fillStyle (above), which sets opacity using rgba
202+
//g.putImageData(imagedata,0,0);
203+
g.drawImage(overlayCanvas,0,0);
204+
// Compute dr for all points
205+
g.lineWidth = 1.5;
206+
g.strokeStyle = "#FF8000"; // html color code
207+
var [dx, dy] = F(X, Y);
208+
for (var i=0; i<M; i++) { // draw a single timestep for every curve
209+
// if dx dy is larger than our threshold, we don't need to move this point
210+
if((dx[i]**2 + dy[i]**2) < thr){
211+
g.beginPath();
212+
g.moveTo(xMap(X[i]), yMap(Y[i])); // the start point of the path
213+
g.lineTo(xMap(X[i]+=dx[i]*dt), yMap(Y[i]+=dy[i]*dt)); // the end point
214+
g.stroke(); // final draw command
215+
};
216+
if (age[i]++ > MaxAge) {
217+
// incriment age of each curve, restart if MaxAge is reached
218+
age[i] = randage();
219+
X[i] = X0[i], Y[i] = Y0[i];
220+
}
221+
}
222+
// Computes gradients of the solution
223+
var [dx, dy] = F([X1, X2], [Y1, Y2]);
224+
dx[0]+= 0.5*(XC - (X1+X2))/sigma/sigma;
225+
dx[1]+= 0.5*(XC - (X1+X2))/sigma/sigma;
226+
dy[0]+= 0.5*(YC - (Y1+Y2))/sigma/sigma;
227+
dy[1]+= 0.5*(YC - (Y1+Y2))/sigma/sigma;
228+
229+
// Draw solution points
230+
g.lineWidth = 14;
231+
g.strokeStyle = g.fillStyle = "#ADFF2F"; // html color code
232+
XS=X1+X2; YS=Y1+Y2;
233+
g.beginPath();
234+
g.moveTo(xMap(X1), yMap(Y1));
235+
g.lineTo(xMap(X1+=dx[0]*dt), yMap(Y1+=dy[0]*dt));
236+
g.stroke();
237+
g.beginPath();
238+
g.arc(xMap(X1), yMap(Y1), 7, 0, 2 * Math.PI);
239+
g.fill();
240+
241+
g.strokeStyle = g.fillStyle = "#96CDFF"; // html color code
242+
g.beginPath();
243+
g.moveTo(xMap(X2), yMap(Y2));
244+
g.lineTo(xMap(X2+=dx[1]*dt), yMap(Y2+=dy[1]*dt));
245+
g.stroke();
246+
g.beginPath();
247+
g.arc(xMap(X2), yMap(Y2), 7, 0, 2 * Math.PI);
248+
g.fill();
249+
250+
g.strokeStyle = g.fillStyle = "#E32E52";//#896ED1"; // html color code
251+
g.beginPath();
252+
g.moveTo(xMap(XS), yMap(YS));
253+
XS=X1+X2; YS=Y1+Y2;
254+
g.lineTo(xMap(XS), yMap(YS));
255+
g.stroke();
256+
g.beginPath();
257+
g.arc(xMap(XS), yMap(YS), 7, 0, 2 * Math.PI);
258+
g.fill();
259+
260+
g.beginPath();
261+
g.strokeStyle = "#C94277";
262+
g.lineWidth = 1.5;
263+
g.arc(xMap(XC), yMap(YC), s, 0, 2 * Math.PI);
264+
g.stroke();
265+
266+
}
267+
})()
268+
269+
</script>
270+
</body>
271+
272+
</html>
Binary file not shown.

Seattle2019/models/js/export3/model.json

+1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)