Skip to content

Training data collector #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions parallax/dialogs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from PyQt5.QtWidgets import QPushButton, QLabel, QRadioButton, QSpinBox
from PyQt5.QtWidgets import QPushButton, QLabel, QSpinBox
from PyQt5.QtWidgets import QGridLayout
from PyQt5.QtWidgets import QDialog, QLineEdit, QDialogButtonBox
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QDoubleValidator

import pyqtgraph as pg
import numpy as np
import time
import datetime

from .toggle_switch import ToggleSwitch
from .helper import FONT_BOLD
Expand Down Expand Up @@ -343,3 +341,61 @@ def get_params(self):
y = float(self.yedit.text())
z = float(self.zedit.text())
return x,y,z


class TrainingDataDialog(QDialog):

def __init__(self, model):
QDialog.__init__(self)
self.model = model

self.setWindowTitle('Training Data Generator')

self.stage_label = QLabel('Select a Stage:')
self.stage_label.setAlignment(Qt.AlignCenter)
self.stage_label.setFont(FONT_BOLD)

self.stage_dropdown = StageDropdown(self.model)
self.stage_dropdown.activated.connect(self.update_status)

self.img_count_label = QLabel('Image Count:')
self.img_count_label.setAlignment(Qt.AlignCenter)
self.img_count_box = QSpinBox()
self.img_count_box.setMinimum(1)
self.img_count_box.setValue(100)

self.extent_label = QLabel('Extent:')
self.extent_label.setAlignment(Qt.AlignCenter)
self.extent_spin = pg.SpinBox(value=4e-3, suffix='m', siPrefix=True, bounds=[0.1e-3, 20e-3], dec=True, step=0.5, minStep=1e-6, compactHeight=False)

self.go_button = QPushButton('Start Data Collection')
self.go_button.setEnabled(False)
self.go_button.clicked.connect(self.go)

layout = QGridLayout()
layout.addWidget(self.stage_label, 0,0, 1,1)
layout.addWidget(self.stage_dropdown, 0,1, 1,1)
layout.addWidget(self.img_count_label, 1,0, 1,1)
layout.addWidget(self.img_count_box, 1,1, 1,1)
layout.addWidget(self.extent_label, 2,0, 1,1)
layout.addWidget(self.extent_spin, 2,1, 1,1)
layout.addWidget(self.go_button, 4,0, 1,2)
self.setLayout(layout)

self.setMinimumWidth(300)

def get_stage(self):
return self.stage_dropdown.current_stage()

def get_img_count(self):
return self.img_count_box.value()

def get_extent(self):
return self.extent_spin.value() * 1e6

def go(self):
self.accept()

def update_status(self):
if self.stage_dropdown.is_selected():
self.go_button.setEnabled(True)
14 changes: 13 additions & 1 deletion parallax/main_window.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QMainWindow, QAction
from PyQt5.QtWidgets import QApplication, QWidget, QHBoxLayout, QVBoxLayout, QGridLayout, QMainWindow, QAction, QSplitter
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QIcon
import pyqtgraph.console
Expand All @@ -10,6 +10,7 @@
from .dialogs import AboutDialog
from .rigid_body_transform_tool import RigidBodyTransformTool
from .stage_manager import StageManager
from .training_data import TrainingDataCollector


class MainWindow(QMainWindow):
Expand All @@ -18,6 +19,11 @@ def __init__(self, model):
QMainWindow.__init__(self)
self.model = model

# allow main window to be accessed globally
model.main_window = self

self.data_collector = None

self.widget = MainWidget(model)
self.setCentralWidget(self.widget)

Expand All @@ -37,6 +43,8 @@ def __init__(self, model):
self.rbt_action.triggered.connect(self.launch_rbt)
self.console_action = QAction("Python Console")
self.console_action.triggered.connect(self.show_console)
self.training_data_action = QAction("Collect Training Data")
self.training_data_action.triggered.connect(self.collect_training_data)
self.about_action = QAction("About")
self.about_action.triggered.connect(self.launch_about)

Expand All @@ -56,6 +64,7 @@ def __init__(self, model):
self.tools_menu = self.menuBar().addMenu("Tools")
self.tools_menu.addAction(self.rbt_action)
self.tools_menu.addAction(self.console_action)
self.tools_menu.addAction(self.training_data_action)

self.help_menu = self.menuBar().addMenu("Help")
self.help_menu.addAction(self.about_action)
Expand Down Expand Up @@ -106,6 +115,9 @@ def refresh_focus_controllers(self):
for screen in self.screens():
screen.update_focus_control_menu()

def collect_training_data(self):
self.data_collector = TrainingDataCollector(self.model)
self.data_collector.start()

class MainWidget(QWidget):

Expand Down
78 changes: 78 additions & 0 deletions parallax/training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import threading, pickle, os
import numpy as np
from PyQt5 import QtWidgets, QtCore
from .dialogs import TrainingDataDialog


class TrainingDataCollector(QtCore.QObject):
def __init__(self, model):
QtCore.QObject.__init__(self)
self.model = model

def start(self):
dlg = TrainingDataDialog(self.model)
dlg.exec_()
if dlg.result() != dlg.Accepted:
return

self.stage = dlg.get_stage()
self.img_count = dlg.get_img_count()
self.extent = dlg.get_extent()
self.path = QtWidgets.QFileDialog.getExistingDirectory(parent=None, caption="Select Storage Directory")
if self.path == '':
return

self.start_pos = self.stage.get_position()
self.stage_cal = list(self.model.calibrations.values())[0]

self.thread = threading.Thread(target=self.thread_run, daemon=True)
self.thread.start()

def thread_run(self):
meta_file = os.path.join(self.path, 'meta.pkl')
if os.path.exists(meta_file):
# todo: just append
raise Exception("Already data in this folder!")
trials = []
meta = {
'calibration': self.stage_cal,
'stage': self.stage.get_name(),
'trials': trials,
}

# move electrode out of fov for background images
pos = self.start_pos.coordinates.copy()
pos[2] += 10000
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this line should be changed to move the electrode out of the FOV

self.stage.move_to_target_3d(*pos, block=True)
imgs = self.save_images('background')
meta['background'] = imgs

for i in range(self.img_count):

# first image in random location
rnd = np.random.uniform(-self.extent/2, self.extent/2, size=3)
pos1 = self.start_pos.coordinates + rnd
self.stage.move_to_target_3d(*pos1, block=True)
images1 = self.save_images(f'{i:04d}-a')

# take a second image slightly shifted
pos2 = pos1.copy()
pos2[2] += 10
self.stage.move_to_target_3d(*pos2, block=True)
images2 = self.save_images(f'{i:04d}-b')

trials.append([
{'pos': pos1, 'images': images1},
{'pos': pos2, 'images': images2},
])

with open(meta_file, 'wb') as fh:
pickle.dump(meta, fh)

def save_images(self, name):
images = []
for camera in self.model.cameras:
filename = f'{name}-{camera.name()}.png'
camera.save_last_image(os.path.join(self.path, filename))
images.append({'camera': camera.name(), 'image': filename})
return images
128 changes: 128 additions & 0 deletions tools/annotate_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import pyqtgraph as pg
import json


class MainWindow(pg.GraphicsView):
def __init__(self, meta_file, img_files):
super().__init__()
self.img_files = img_files

self.view = pg.ViewBox()
self.view.invertY()
self.setCentralItem(self.view)

self.img_item = pg.QtWidgets.QGraphicsPixmapItem()
self.view.addItem(self.img_item)

self.line_item = pg.QtWidgets.QGraphicsLineItem()
self.line_item.setPen(pg.mkPen('r'))
self.circle_item = pg.QtWidgets.QGraphicsEllipseItem()
self.circle_item.setPen(pg.mkPen('r'))
self.view.addItem(self.line_item)
self.view.addItem(self.circle_item)

self.next_click = 0
self.attached_pt = None
self.loaded_file = None

self.meta_file = meta_file
if os.path.exists(meta_file):
self.meta = json.load(open(meta_file, 'r'))
else:
self.meta = {}

self.load_image(0)

def keyPressEvent(self, ev):
if ev.key() == pg.QtCore.Qt.Key_Left:
self.load_image(self.current_index - 1)
elif ev.key() == pg.QtCore.Qt.Key_Right:
self.load_image(self.current_index + 1)
else:
print(ev.key())

def mousePressEvent(self, ev):
# print('press', ev)
if ev.button() == pg.QtCore.Qt.LeftButton:
self.attached_pt = self.next_click
self.update_pos(ev.pos())
ev.accept()
return
# return super().mousePressEvent(ev)

def mouseReleaseEvent(self, ev):
# print('release', ev)
self.attached_pt = None
self.next_click = (self.next_click + 1) % 2

def mouseMoveEvent(self, ev):
# print('move', ev)
self.update_pos(ev.pos())
ev.accept()

def update_pos(self, pos):
pos = self.view.mapDeviceToView(pos)
if self.attached_pt == 0:
self.set_pts(pos, None)
elif self.attached_pt == 1:
self.set_pts(None, pos)
else:
return
self.update_meta()

def set_pts(self, pt1, pt2):
line = self.line_item.line()
if pt1 is not None:
line.setP1(pt1)
self.circle_item.setRect(pt1.x()-10, pt1.y()-10, 20, 20)
self.circle_item.setVisible(True)
if pt2 is not None:
line.setP2(pt2)
self.line_item.setVisible(True)
self.line_item.setLine(line)

def hide_line(self):
self.line_item.setVisible(False)
self.circle_item.setVisible(False)

def update_meta(self):
line = self.line_item.line()
self.meta[self.loaded_file] = {
'pt1': (line.x1(), line.y1()),
'pt2': (line.x2(), line.y2()),
}
json.dump(self.meta, open(self.meta_file, 'w'))

def load_image(self, index):
filename = self.img_files[index]
pxm = pg.QtGui.QPixmap()
pxm.load(filename)
self.img_item.setPixmap(pxm)
self.img_item.pxm = pxm
self.view.autoRange(padding=0)
self.current_index = index
self.setWindowTitle(filename)
self.loaded_file = filename

meta = self.meta.get(filename, {})
pt1 = meta.get('pt1', None)
pt2 = meta.get('pt2', None)
if None in (pt1, pt2):
self.hide_line()
else:
self.set_pts(pg.QtCore.QPointF(*pt1), pg.QtCore.QPointF(*pt2))


if __name__ == '__main__':
import os, sys

app = pg.mkQApp()

meta_file = sys.argv[1]
img_files = sys.argv[2:]
win = MainWindow(meta_file, img_files)
win.resize(1000, 800)
win.show()

if sys.flags.interactive == 0:
app.exec_()