-
Notifications
You must be signed in to change notification settings - Fork 252
Digit Classifier Migration to Superthin Template #2690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MihirGore23
wants to merge
8
commits into
humble-devel
Choose a base branch
from
dig_class_migrate
base: humble-devel
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
52853fe
Add Digit Classifier Newmanager Exercise and Template
MihirGore23 8be41d5
Implement model quantization and graph optimization tools in Dig_Clas…
MihirGore23 ba07468
Remove launch
MihirGore23 45e053b
Remove Console
MihirGore23 fd2635b
Rename React components
MihirGore23 d254ec3
Update React component name
MihirGore23 c425d09
Add start_console
MihirGore23 44778b1
Update react-component names
MihirGore23 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
13 changes: 13 additions & 0 deletions
13
exercises/static/exercises/dl_digit_classifier_newmanager/entry_point/exercise.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import os.path | ||
from typing import Callable | ||
|
||
from src.manager.libs.applications.compatibility.exercise_wrapper_ros2 import CompatibilityExerciseWrapperRos2 | ||
|
||
|
||
class Exercise(CompatibilityExerciseWrapperRos2): | ||
def __init__(self, circuit: str, update_callback: Callable): | ||
current_path = os.path.dirname(__file__) | ||
|
||
super(Exercise, self).__init__(exercise_command=f"{current_path}/../../python_template/ros2_humble/exercise.py 0.0.0.0", | ||
gui_command=f"{current_path}/../../python_template/ros2_humble/gui.py 0.0.0.0 {circuit}", | ||
update_callback=update_callback) |
191 changes: 191 additions & 0 deletions
191
exercises/static/exercises/dl_digit_classifier_newmanager/python_template/ros2_humble/GUI.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import json | ||
import os | ||
import rclpy | ||
import cv2 | ||
import sys | ||
import base64 | ||
import threading | ||
import time | ||
import numpy as np | ||
from datetime import datetime | ||
import websocket | ||
import subprocess | ||
import logging | ||
|
||
from hal_interfaces.general.odometry import OdometryNode | ||
from console_interfaces.general.console import start_console | ||
|
||
|
||
# Graphical User Interface Class | ||
class GUI: | ||
# Initialization function | ||
# The actual initialization | ||
def __init__(self, host): | ||
|
||
self.payload = {'image': '', 'shape': []} | ||
|
||
# ROS2 init | ||
if not rclpy.ok(): | ||
rclpy.init(args=None) | ||
|
||
|
||
# Image variables | ||
self.image_to_be_shown = None | ||
self.image_to_be_shown_updated = False | ||
self.image_show_lock = threading.Lock() | ||
self.host = host | ||
self.client = None | ||
|
||
|
||
|
||
self.ack = False | ||
self.ack_lock = threading.Lock() | ||
|
||
# Create the lap object | ||
# TODO: maybe move this to HAL and have it be hybrid | ||
|
||
|
||
self.client_thread = threading.Thread(target=self.run_websocket) | ||
self.client_thread.start() | ||
|
||
def run_websocket(self): | ||
while True: | ||
print("GUI WEBSOCKET CONNECTED") | ||
self.client = websocket.WebSocketApp(self.host, on_message=self.on_message) | ||
self.client.run_forever(ping_timeout=None, ping_interval=0) | ||
|
||
# Function to prepare image payload | ||
# Encodes the image as a JSON string and sends through the WS | ||
def payloadImage(self): | ||
with self.image_show_lock: | ||
image_to_be_shown_updated = self.image_to_be_shown_updated | ||
image_to_be_shown = self.image_to_be_shown | ||
|
||
image = image_to_be_shown | ||
payload = {'image': '', 'shape': ''} | ||
|
||
if not image_to_be_shown_updated: | ||
return payload | ||
|
||
shape = image.shape | ||
frame = cv2.imencode('.JPEG', image)[1] | ||
encoded_image = base64.b64encode(frame) | ||
|
||
payload['image'] = encoded_image.decode('utf-8') | ||
payload['shape'] = shape | ||
with self.image_show_lock: | ||
self.image_to_be_shown_updated = False | ||
|
||
return payload | ||
|
||
# Function for student to call | ||
def showImage(self, image): | ||
with self.image_show_lock: | ||
self.image_to_be_shown = image | ||
self.image_to_be_shown_updated = True | ||
|
||
# Update the gui | ||
def update_gui(self): | ||
# print("GUI update") | ||
# Payload Image Message | ||
payload = self.payloadImage() | ||
self.payload["image"] = json.dumps(payload) | ||
|
||
|
||
message = json.dumps(self.payload) | ||
if self.client: | ||
try: | ||
self.client.send(message) | ||
# print(message) | ||
except Exception as e: | ||
print(f"Error sending message: {e}") | ||
|
||
def on_message(self, ws, message): | ||
"""Handles incoming messages from the websocket client.""" | ||
if message.startswith("#ack"): | ||
# print("on message" + str(message)) | ||
self.set_acknowledge(True) | ||
|
||
def get_acknowledge(self): | ||
"""Gets the acknowledge status.""" | ||
with self.ack_lock: | ||
ack = self.ack | ||
|
||
return ack | ||
|
||
def set_acknowledge(self, value): | ||
"""Sets the acknowledge status.""" | ||
with self.ack_lock: | ||
self.ack = value | ||
|
||
|
||
class ThreadGUI: | ||
"""Class to manage GUI updates and frequency measurements in separate threads.""" | ||
|
||
def __init__(self, gui): | ||
"""Initializes the ThreadGUI with a reference to the GUI instance.""" | ||
self.gui = gui | ||
self.ideal_cycle = 80 | ||
self.real_time_factor = 0 | ||
self.frequency_message = {'brain': '', 'gui': ''} | ||
self.iteration_counter = 0 | ||
self.running = True | ||
|
||
def start(self): | ||
"""Starts the GUI, frequency measurement, and real-time factor threads.""" | ||
self.frequency_thread = threading.Thread(target=self.measure_and_send_frequency) | ||
self.gui_thread = threading.Thread(target=self.run) | ||
self.frequency_thread.start() | ||
self.gui_thread.start() | ||
print("GUI Thread Started!") | ||
|
||
def measure_and_send_frequency(self): | ||
"""Measures and sends the frequency of GUI updates and brain cycles.""" | ||
previous_time = datetime.now() | ||
while self.running: | ||
time.sleep(2) | ||
|
||
current_time = datetime.now() | ||
dt = current_time - previous_time | ||
ms = (dt.days * 24 * 60 * 60 + dt.seconds) * 1000 + dt.microseconds / 1000.0 | ||
previous_time = current_time | ||
measured_cycle = ms / self.iteration_counter if self.iteration_counter > 0 else 0 | ||
self.iteration_counter = 0 | ||
brain_frequency = round(1000 / measured_cycle, 1) if measured_cycle != 0 else 0 | ||
gui_frequency = round(1000 / self.ideal_cycle, 1) | ||
self.frequency_message = {'brain': brain_frequency, 'gui': gui_frequency} | ||
message = json.dumps(self.frequency_message) | ||
if self.gui.client: | ||
try: | ||
self.gui.client.send(message) | ||
except Exception as e: | ||
print(f"Error sending frequency message: {e}") | ||
|
||
def run(self): | ||
"""Main loop to update the GUI at regular intervals.""" | ||
while self.running: | ||
start_time = datetime.now() | ||
|
||
self.gui.update_gui() | ||
self.iteration_counter += 1 | ||
finish_time = datetime.now() | ||
|
||
dt = finish_time - start_time | ||
ms = (dt.days * 24 * 60 * 60 + dt.seconds) * 1000 + dt.microseconds / 1000.0 | ||
sleep_time = max(0, (50 - ms) / 1000.0) | ||
time.sleep(sleep_time) | ||
|
||
|
||
# Create a GUI interface | ||
host = "ws://127.0.0.1:2303" | ||
gui_interface = GUI(host) | ||
|
||
start_console() | ||
|
||
# Spin a thread to keep the interface updated | ||
thread_gui = ThreadGUI(gui_interface) | ||
thread_gui.start() | ||
|
||
def showImage(image): | ||
gui_interface.showImage(image) | ||
|
41 changes: 41 additions & 0 deletions
41
exercises/static/exercises/dl_digit_classifier_newmanager/python_template/ros2_humble/HAL.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import rclpy | ||
from rclpy.node import Node | ||
from sensor_msgs.msg import Image | ||
from cv_bridge import CvBridge | ||
import threading | ||
import cv2 | ||
|
||
current_frame = None # Global variable to store the frame | ||
|
||
class WebcamSubscriber(Node): | ||
def __init__(self): | ||
super().__init__('webcam_subscriber') | ||
self.subscription = self.create_subscription( | ||
Image, | ||
'/image_raw', | ||
self.listener_callback, | ||
10) | ||
self.subscription # prevent unused variable warning | ||
self.bridge = CvBridge() | ||
|
||
def listener_callback(self, msg): | ||
global current_frame | ||
self.get_logger().info('Receiving video frame') | ||
current_frame = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8') | ||
|
||
def run_webcam_node(): | ||
|
||
webcam_subscriber = WebcamSubscriber() | ||
|
||
rclpy.spin(webcam_subscriber) | ||
webcam_subscriber.destroy_node() | ||
|
||
|
||
# Start the ROS2 node in a separate thread | ||
thread = threading.Thread(target=run_webcam_node) | ||
thread.start() | ||
|
||
def getImage(): | ||
global current_frame | ||
return current_frame | ||
|
1 change: 1 addition & 0 deletions
1
.../exercises/dl_digit_classifier_newmanager/python_template/ros2_humble/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[Exercise Documentation Website](https://jderobot.github.io/RoboticsAcademy/exercises/ComputerVision/dl_digit_classifier) |
63 changes: 63 additions & 0 deletions
63
...exercises/dl_digit_classifier_newmanager/python_template/ros2_humble/demo_code/academy.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import GUI | ||
import HAL | ||
import base64 | ||
from datetime import datetime | ||
import json | ||
import sys | ||
import time | ||
import cv2 | ||
import numpy as np | ||
import onnxruntime | ||
|
||
roi_scale = 0.75 | ||
input_size = (28, 28) | ||
|
||
# Receive model | ||
raw_dl_model = '/workspace/code/demo_model/mnist_cnn.onnx' | ||
|
||
# Load ONNX model | ||
try: | ||
ort_session = onnxruntime.InferenceSession(raw_dl_model) | ||
except Exception: | ||
exc_type, exc_value, exc_traceback = sys.exc_info() | ||
print(str(exc_value)) | ||
print("ERROR: Model couldn't be loaded") | ||
|
||
previous_pred = 0 | ||
previous_established_pred = "-" | ||
count_same_digit = 0 | ||
|
||
while True: | ||
|
||
# Get input webcam image | ||
image = HAL.getImage() | ||
if image is not None: | ||
input_image_gray = np.mean(image, axis=2).astype(np.uint8) | ||
|
||
# Get original image and ROI dimensions | ||
h_in, w_in = image.shape[:2] | ||
min_dim_in = min(h_in, w_in) | ||
h_roi, w_roi = (int(min_dim_in * roi_scale), int(min_dim_in * roi_scale)) | ||
h_border, w_border = (int((h_in - h_roi) / 2.), int((w_in - w_roi) / 2.)) | ||
|
||
# Extract ROI and convert to tensor format required by the model | ||
roi = input_image_gray[h_border:h_border + h_roi, w_border:w_border + w_roi] | ||
roi_norm = (roi - np.mean(roi)) / np.std(roi) | ||
roi_resized = cv2.resize(roi_norm, input_size) | ||
input_tensor = roi_resized.reshape((1, 1, input_size[0], input_size[1])).astype(np.float32) | ||
|
||
# Inference | ||
ort_inputs = {ort_session.get_inputs()[0].name: input_tensor} | ||
output = ort_session.run(None, ort_inputs)[0] | ||
pred = int(np.argmax(output, axis=1)) # get the index of the max log-probability | ||
|
||
# Show region used as ROI | ||
cv2.rectangle(image, pt2=(w_border, h_border), pt1=(w_border + w_roi, h_border + h_roi), color=(255, 0, 0), thickness=3) | ||
|
||
# Show FPS count | ||
cv2.putText(image, "Pred: {}".format(int(pred)), (7, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | ||
|
||
# Send result | ||
GUI.showImage(image) | ||
|
||
|
Binary file added
BIN
+1.65 MB
...ises/dl_digit_classifier_newmanager/python_template/ros2_humble/demo_model/mnist_cnn.onnx
Binary file not shown.
14 changes: 14 additions & 0 deletions
14
...ses/static/exercises/dl_digit_classifier_newmanager/react-components/DigitClassifierRR.js
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import * as React from "react"; | ||
import {Fragment} from "react"; | ||
|
||
import "./css/DigitClassifierRR.css"; | ||
|
||
const DigitClassifierRR = (props) => { | ||
return ( | ||
<Fragment> | ||
{props.children} | ||
</Fragment> | ||
); | ||
}; | ||
|
||
export default DigitClassifierRR; |
53 changes: 53 additions & 0 deletions
53
exercises/static/exercises/dl_digit_classifier_newmanager/react-components/DisplayFeed.js
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import * as React from "react"; | ||
import { Box } from "@mui/material"; | ||
import "./css/GUICanvas.css"; | ||
import { drawImage } from "./helpers/showImages"; | ||
|
||
|
||
const DisplayFeed = (props) => { | ||
const [image, setImage] = React.useState(null) | ||
const canvasRef = React.useRef(null) | ||
|
||
React.useEffect(() => { | ||
console.log("TestShowScreen subscribing to ['update'] events"); | ||
const callback = (message) => { | ||
if(message.data.update.image){ | ||
console.log('image') | ||
const image = JSON.parse(message.data.update.image) | ||
if(image.image){ | ||
drawImage(message.data.update) | ||
} | ||
} | ||
}; | ||
|
||
window.RoboticsExerciseComponents.commsManager.subscribe( | ||
[window.RoboticsExerciseComponents.commsManager.events.UPDATE], | ||
callback | ||
); | ||
|
||
return () => { | ||
console.log("TestShowScreen unsubscribing from ['state-changed'] events"); | ||
window.RoboticsExerciseComponents.commsManager.unsubscribe( | ||
[window.RoboticsExerciseComponents.commsManager.events.UPDATE], | ||
callback | ||
); | ||
}; | ||
}, []); | ||
|
||
return ( | ||
<Box sx={{ height: "100%" }}> | ||
<canvas | ||
ref={canvasRef} | ||
className={"exercise-canvas"} | ||
id="canvas" | ||
></canvas> | ||
</Box> | ||
); | ||
}; | ||
|
||
DisplayFeed.defaultProps = { | ||
width: 800, | ||
height: 600, | ||
}; | ||
|
||
export default DisplayFeed |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.