diff --git a/visualizer_drag.py b/visualizer_drag.py index 9120906..b016bcc 100644 --- a/visualizer_drag.py +++ b/visualizer_drag.py @@ -177,15 +177,16 @@ def draw_frame(self): if self.result.init_net: self.drag_widget.reset_point() - if self.check_update_mask(**self.args): - h, w, _ = self.result.image.shape - self.drag_widget.init_mask(w, h) - # Display. max_w = self.content_width - self.pane_w max_h = self.content_height pos = np.array([self.pane_w + max_w / 2, max_h / 2]) if 'image' in self.result: + # Reset mask after loading a new pickle or changing seed. + if self.check_update_mask(**self.args): + h, w, _ = self.result.image.shape + self.drag_widget.init_mask(w, h) + if self._tex_img is not self.result.image: self._tex_img = self.result.image if self._tex_obj is None or not self._tex_obj.is_compatible(image=self._tex_img): diff --git a/viz/capture_widget.py b/viz/capture_widget.py index 48e1373..72bf3cf 100644 --- a/viz/capture_widget.py +++ b/viz/capture_widget.py @@ -31,7 +31,7 @@ def dump_png(self, image): viz = self.viz try: _height, _width, channels = image.shape - assert channels in [1, 3] + print(viz.result) assert image.dtype == np.uint8 os.makedirs(self.path, exist_ok=True) file_id = 0 @@ -43,8 +43,9 @@ def dump_png(self, image): if channels == 1: pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') else: - pil_image = PIL.Image.fromarray(image, 'RGB') + pil_image = PIL.Image.fromarray(image[:, :, :3], 'RGB') pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) + np.save(os.path.join(self.path, f'{file_id:05d}.npy'), viz.result.w) except: viz.result.error = renderer.CapturedException() diff --git a/viz/renderer.py b/viz/renderer.py index ed81697..aa43bf9 100644 --- a/viz/renderer.py +++ b/viz/renderer.py @@ -382,5 +382,6 @@ def _render_drag_impl(self, res, img = img.cpu().numpy() img = Image.fromarray(img) res.image = img + res.w = ws.detach().cpu().numpy() #----------------------------------------------------------------------------