|
| 1 | +<!DOCTYPE html> |
| 2 | +<html> |
| 3 | +<head> |
| 4 | + <title>Generative Models as Priors for Inverse Problems</title> |
| 5 | +<style> |
| 6 | + body { |
| 7 | + height: 500px; |
| 8 | + width: 500px; |
| 9 | +} |
| 10 | +/* #animation { |
| 11 | + position: absolute; |
| 12 | + top: 0px; |
| 13 | + left: 0px; |
| 14 | + background: #000; |
| 15 | +} body { |
| 16 | + text-align: center; |
| 17 | + } |
| 18 | +
|
| 19 | + #mynetwork { |
| 20 | + height: 500px; |
| 21 | + } */ |
| 22 | +</style> |
| 23 | + |
| 24 | + <!-- Import TensorFlow.js --> |
| 25 | + <script src=" https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js" ></script> |
| 26 | + <script src="https://d3js.org/d3.v5.min.js"></script> |
| 27 | +</head> |
| 28 | + |
| 29 | +<body> |
| 30 | + <canvas id="animation" height="500" width="500"></canvas> |
| 31 | + <script> |
| 32 | + |
| 33 | + (async function animation() { |
| 34 | + // This function is closely modeled on http://bl.ocks.org/newby-jay/767c5ffdbbe43b65902f |
| 35 | + const model = await tf.loadGraphModel('models/js/export3/model.json'); |
| 36 | + const grads = tf.grad(x => model.predict(x)); |
| 37 | + const vfunc = (x,y) => { |
| 38 | + c = tf.concat([tf.reshape(x, [-1,1] ), tf.reshape(y, [-1,1])], axis=1); |
| 39 | + gr = grads(c); |
| 40 | + gn = tf.sum(tf.mul(gr,gr), axis=1, keepDims=true); |
| 41 | + gr = tf.mul(tf.div(gr, gn), tf.clipByValue(gn,0,100)); |
| 42 | + |
| 43 | + return tf.split(gr, 2, axis=1); |
| 44 | + }; |
| 45 | + |
| 46 | + |
| 47 | + // vector field data |
| 48 | + var dt = 0.003, |
| 49 | + X0 = [], Y0 = [], // to store initial starting locations |
| 50 | + X = [], Y = [], // to store current point for each curve |
| 51 | + Xd=[], Yd=[], |
| 52 | + xb = 4, yb = 3; |
| 53 | + var sigma=0.2; |
| 54 | + var X1=0.5, Y1=0.5, X2=0, Y2=0; |
| 55 | + var XC=1, YC=1; |
| 56 | + var width = 500, height = 500; |
| 57 | + |
| 58 | + // First draw the modelled density in the background |
| 59 | + var N = 128 |
| 60 | + var xd = d3.range(N).map( |
| 61 | + function (i) { |
| 62 | + return -1.5 + xb*i/N; |
| 63 | + }), |
| 64 | + yd = d3.range(N).map( |
| 65 | + function (i) { |
| 66 | + return -1 + yb*i/N; |
| 67 | + }); |
| 68 | + // array of starting positions for each curve on a uniform grid |
| 69 | + for (var i = 0; i < N; i++) { |
| 70 | + for (var j = 0; j < N; j++) { |
| 71 | + Xd.push(xd[j]), Yd.push(yd[i]); |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + // Compute the density field in this input resolution and rescale it to output res |
| 76 | + const logp = tf.tidy(() => { |
| 77 | + const c = tf.concat([tf.reshape(Xd, [-1,1] ), tf.reshape(Yd, [-1,1])], axis=1); |
| 78 | + const out = model.predict(c); |
| 79 | + const out_resized = tf.exp(tf.image.resizeBilinear(tf.reshape(out, [N,N,1]), [width, height])); |
| 80 | + return out_resized.dataSync(); |
| 81 | + }); |
| 82 | + |
| 83 | + // Store this array as image data |
| 84 | + var g = d3.select("#animation").node().getContext("2d"); |
| 85 | + var imagedata = g.createImageData(width, height); |
| 86 | + for (var x=0; x<width; x++) { |
| 87 | + for (var y=0; y<height; y++) { |
| 88 | + var pixelindex = (y * width + x) * 4; |
| 89 | + // Generate a xor pattern with some random noise |
| 90 | + var po = logp[((height -1 - y) * width + x)]*0.5; |
| 91 | + if(isNaN(po)){ po = 0; } |
| 92 | + c = d3.rgb(d3.interpolateInferno(po)); |
| 93 | + // Set the pixel data |
| 94 | + imagedata.data[pixelindex] = c.r; // Red |
| 95 | + imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green |
| 96 | + imagedata.data[pixelindex+2] = c.b; // Blue |
| 97 | + imagedata.data[pixelindex+3] = 255; // Alpha |
| 98 | + } |
| 99 | + } |
| 100 | + g.putImageData(imagedata,0,0); |
| 101 | + for (var x=0; x<width; x++) { |
| 102 | + for (var y=0; y<height; y++) { |
| 103 | + var pixelindex = (y * width + x) * 4; |
| 104 | + // Generate a xor pattern with some random noise |
| 105 | + var po = logp[((height -1 - y) * width + x)]*0.5; |
| 106 | + if(isNaN(po)){ po = 0; } |
| 107 | + c = d3.rgb(d3.interpolateInferno(po)); |
| 108 | + // Set the pixel data |
| 109 | + imagedata.data[pixelindex] = c.r; // Red |
| 110 | + imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green |
| 111 | + imagedata.data[pixelindex+2] = c.b; // Blue |
| 112 | + imagedata.data[pixelindex+3] = 25; // Alpha |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + var N = 50; |
| 117 | + var xp = d3.range(N).map( |
| 118 | + function (i) { |
| 119 | + return -1.5 + xb*i/N; |
| 120 | + }), |
| 121 | + yp = d3.range(N).map( |
| 122 | + function (i) { |
| 123 | + return -1 + yb*i/N; |
| 124 | + }); |
| 125 | + // array of starting positions for each curve on a uniform grid |
| 126 | + for (var i = 0; i < N; i++) { |
| 127 | + for (var j = 0; j < N; j++) { |
| 128 | + X.push(xp[j]), Y.push(yp[i]); |
| 129 | + X0.push(xp[j]), Y0.push(yp[i]); |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + // // vfield |
| 134 | + function F(x, y) { |
| 135 | + const [px, py] = tf.tidy(() => { |
| 136 | + const [predx, predy] = vfunc(x, y); |
| 137 | + return [predx.dataSync(), predy.dataSync()]; |
| 138 | + }); |
| 139 | + return [px, py]; |
| 140 | + } |
| 141 | + |
| 142 | + //// frame setup |
| 143 | + var mw = 0; |
| 144 | + |
| 145 | + g.lineWidth = 0.8; |
| 146 | + g.strokeStyle = "#FF8000"; // html color code |
| 147 | + |
| 148 | + //// mapping from vfield coords to web page coords |
| 149 | + var xMap = d3.scaleLinear() |
| 150 | + .domain([-1.5, 2.5]) |
| 151 | + .range([mw, width - mw]), |
| 152 | + yMap = d3.scaleLinear() |
| 153 | + .domain([-1, 2.]) |
| 154 | + .range([height - mw, mw]); |
| 155 | + //// animation setup |
| 156 | + var animAge = 0, |
| 157 | + frameRate = 30, // ms per timestep (yeah I know it's not really a rate) |
| 158 | + M = X.length, |
| 159 | + thr=200, |
| 160 | + MaxAge = 100, // # timesteps before restart |
| 161 | + age = []; |
| 162 | + |
| 163 | + for (var i=0; i<M; i++) {age.push(randage());} |
| 164 | + var drawFlag = false; |
| 165 | + |
| 166 | + d3.timer(function () {if (drawFlag) {draw();}}, frameRate); |
| 167 | + d3.select("#animation") |
| 168 | + .on("click", function() { |
| 169 | + var mouse = d3.mouse(this); |
| 170 | + XC = xMap.invert(mouse[0]); |
| 171 | + YC = yMap.invert(mouse[1]); |
| 172 | + }) |
| 173 | + |
| 174 | + d3.select("body").on("keypress", function() { |
| 175 | + if(d3.event.keyCode === 32 || d3.event.keyCode === 13){ |
| 176 | + drawFlag = (drawFlag) ? false : true; |
| 177 | + } |
| 178 | + if(d3.event.keyCode === 61 ){ |
| 179 | + sigma = sigma*2.; |
| 180 | + } |
| 181 | + if(d3.event.keyCode === 45 ){ |
| 182 | + sigma /= 2.; |
| 183 | + } |
| 184 | + }) |
| 185 | + function randage() { |
| 186 | + // to randomize starting ages for each curve |
| 187 | + return Math.round(Math.random()*100); |
| 188 | + } |
| 189 | + |
| 190 | + var overlayCanvas = document.createElement("canvas"); |
| 191 | + overlayCanvas.width = width; |
| 192 | + overlayCanvas.height = height; |
| 193 | + overlayCanvas.getContext("2d").putImageData(imagedata, 0, 0); |
| 194 | + g.imageSmoothingEnabled = false; |
| 195 | + |
| 196 | + // for info on the global canvas operations see |
| 197 | + // http://bucephalus.org/text/CanvasHandbook/CanvasHandbook.html#globalcompositeoperation |
| 198 | + g.globalCompositeOperation = "source-over"; |
| 199 | + function draw() { |
| 200 | + var s = (xMap(sigma) - xMap(0)); |
| 201 | + //g.fillRect(0, 0, width, height); // fades all existing curves by a set amount determined by fillStyle (above), which sets opacity using rgba |
| 202 | + //g.putImageData(imagedata,0,0); |
| 203 | + g.drawImage(overlayCanvas,0,0); |
| 204 | + // Compute dr for all points |
| 205 | + g.lineWidth = 1.5; |
| 206 | + g.strokeStyle = "#FF8000"; // html color code |
| 207 | + var [dx, dy] = F(X, Y); |
| 208 | + for (var i=0; i<M; i++) { // draw a single timestep for every curve |
| 209 | + // if dx dy is larger than our threshold, we don't need to move this point |
| 210 | + if((dx[i]**2 + dy[i]**2) < thr){ |
| 211 | + g.beginPath(); |
| 212 | + g.moveTo(xMap(X[i]), yMap(Y[i])); // the start point of the path |
| 213 | + g.lineTo(xMap(X[i]+=dx[i]*dt), yMap(Y[i]+=dy[i]*dt)); // the end point |
| 214 | + g.stroke(); // final draw command |
| 215 | + }; |
| 216 | + if (age[i]++ > MaxAge) { |
| 217 | + // incriment age of each curve, restart if MaxAge is reached |
| 218 | + age[i] = randage(); |
| 219 | + X[i] = X0[i], Y[i] = Y0[i]; |
| 220 | + } |
| 221 | + } |
| 222 | + // Computes gradients of the solution |
| 223 | + var [dx, dy] = F([X1, X2], [Y1, Y2]); |
| 224 | + dx[0]+= 0.5*(XC - (X1+X2))/sigma/sigma; |
| 225 | + dx[1]+= 0.5*(XC - (X1+X2))/sigma/sigma; |
| 226 | + dy[0]+= 0.5*(YC - (Y1+Y2))/sigma/sigma; |
| 227 | + dy[1]+= 0.5*(YC - (Y1+Y2))/sigma/sigma; |
| 228 | + |
| 229 | + // Draw solution points |
| 230 | + g.lineWidth = 14; |
| 231 | + g.strokeStyle = g.fillStyle = "#ADFF2F"; // html color code |
| 232 | + XS=X1+X2; YS=Y1+Y2; |
| 233 | + g.beginPath(); |
| 234 | + g.moveTo(xMap(X1), yMap(Y1)); |
| 235 | + g.lineTo(xMap(X1+=dx[0]*dt), yMap(Y1+=dy[0]*dt)); |
| 236 | + g.stroke(); |
| 237 | + g.beginPath(); |
| 238 | + g.arc(xMap(X1), yMap(Y1), 7, 0, 2 * Math.PI); |
| 239 | + g.fill(); |
| 240 | + |
| 241 | + g.strokeStyle = g.fillStyle = "#96CDFF"; // html color code |
| 242 | + g.beginPath(); |
| 243 | + g.moveTo(xMap(X2), yMap(Y2)); |
| 244 | + g.lineTo(xMap(X2+=dx[1]*dt), yMap(Y2+=dy[1]*dt)); |
| 245 | + g.stroke(); |
| 246 | + g.beginPath(); |
| 247 | + g.arc(xMap(X2), yMap(Y2), 7, 0, 2 * Math.PI); |
| 248 | + g.fill(); |
| 249 | + |
| 250 | + g.strokeStyle = g.fillStyle = "#E32E52";//#896ED1"; // html color code |
| 251 | + g.beginPath(); |
| 252 | + g.moveTo(xMap(XS), yMap(YS)); |
| 253 | + XS=X1+X2; YS=Y1+Y2; |
| 254 | + g.lineTo(xMap(XS), yMap(YS)); |
| 255 | + g.stroke(); |
| 256 | + g.beginPath(); |
| 257 | + g.arc(xMap(XS), yMap(YS), 7, 0, 2 * Math.PI); |
| 258 | + g.fill(); |
| 259 | + |
| 260 | + g.beginPath(); |
| 261 | + g.strokeStyle = "#C94277"; |
| 262 | + g.lineWidth = 1.5; |
| 263 | + g.arc(xMap(XC), yMap(YC), s, 0, 2 * Math.PI); |
| 264 | + g.stroke(); |
| 265 | + |
| 266 | + } |
| 267 | + })() |
| 268 | + |
| 269 | + </script> |
| 270 | +</body> |
| 271 | + |
| 272 | +</html> |
0 commit comments