Skip to content

Commit b1a85a4

Browse files
committed
Add example scripts for image, video, and webcam detection using YOLO-NAS
1 parent 648761c commit b1a85a4

6 files changed

Lines changed: 324 additions & 9 deletions

File tree

examples/detect_image.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Example: Run YOLO-NAS inference on a single image.
2+
3+
Usage:
4+
python examples/detect_image.py path/to/image.jpg
5+
python examples/detect_image.py path/to/image.jpg --model yolo_nas_l --device cpu
6+
"""
7+
8+
import argparse
9+
10+
from modern_yolonas.inference.detect import Detector
11+
12+
13+
def main():
14+
parser = argparse.ArgumentParser(description="YOLO-NAS image detection")
15+
parser.add_argument("image", help="Path to input image")
16+
parser.add_argument("--model", default="yolo_nas_s", choices=["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"])
17+
parser.add_argument("--device", default="cuda")
18+
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
19+
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
20+
parser.add_argument("--output", default="output.jpg", help="Output image path")
21+
args = parser.parse_args()
22+
23+
# Create detector (downloads pretrained weights on first run)
24+
det = Detector(args.model, device=args.device, conf_threshold=args.conf, iou_threshold=args.iou)
25+
26+
# Run detection
27+
result = det(args.image)
28+
29+
# Print results
30+
print(f"Found {len(result.boxes)} objects:")
31+
from modern_yolonas.inference.visualize import COCO_NAMES
32+
33+
for box, score, cls_id in zip(result.boxes, result.scores, result.class_ids):
34+
name = COCO_NAMES[int(cls_id)]
35+
x1, y1, x2, y2 = box
36+
print(f" {name}: {score:.2f} [{x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f}]")
37+
38+
# Save annotated image
39+
result.save(args.output)
40+
print(f"Saved to {args.output}")
41+
42+
43+
if __name__ == "__main__":
44+
main()

examples/detect_video.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Example: Run YOLO-NAS inference on a video file.
2+
3+
Usage:
4+
python examples/detect_video.py path/to/video.mp4
5+
python examples/detect_video.py path/to/video.mp4 --output output.mp4 --model yolo_nas_l
6+
"""
7+
8+
import argparse
9+
from pathlib import Path
10+
11+
from modern_yolonas.inference.detect import Detector
12+
13+
14+
def main():
15+
parser = argparse.ArgumentParser(description="YOLO-NAS video detection")
16+
parser.add_argument("video", help="Path to input video")
17+
parser.add_argument("--model", default="yolo_nas_s", choices=["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"])
18+
parser.add_argument("--device", default="cuda")
19+
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
20+
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
21+
parser.add_argument("--output", default=None, help="Output video path (default: <input>_detect.<ext>)")
22+
parser.add_argument("--skip-frames", type=int, default=0, help="Process every N-th frame (0 = all)")
23+
parser.add_argument("--codec", default="mp4v", help="Output video codec")
24+
args = parser.parse_args()
25+
26+
# Default output path
27+
if args.output is None:
28+
src = Path(args.video)
29+
args.output = str(src.parent / f"{src.stem}_detect{src.suffix}")
30+
31+
# Create detector
32+
det = Detector(args.model, device=args.device, conf_threshold=args.conf, iou_threshold=args.iou)
33+
34+
# --- Option 1: Write annotated video directly ---
35+
print(f"Processing {args.video} ...")
36+
stats = det.detect_video_to_file(
37+
source=args.video,
38+
output=args.output,
39+
codec=args.codec,
40+
skip_frames=args.skip_frames,
41+
)
42+
print(
43+
f"Done! {stats['total_frames']} frames, "
44+
f"{stats['processed_frames']} processed, "
45+
f"{stats['total_detections']} total detections"
46+
)
47+
print(f"Saved to {args.output}")
48+
49+
# --- Option 2: Iterate frames with a generator (commented out) ---
50+
# This is useful when you need custom per-frame logic:
51+
#
52+
# for frame_idx, result in det.detect_video(args.video):
53+
# print(f"Frame {frame_idx}: {len(result.boxes)} detections")
54+
# # Access result.boxes, result.scores, result.class_ids
55+
# # Or get annotated frame: annotated = result.visualize()
56+
57+
58+
if __name__ == "__main__":
59+
main()

examples/detect_webcam.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Example: Run YOLO-NAS live on webcam feed.
2+
3+
Usage:
4+
python examples/detect_webcam.py
5+
python examples/detect_webcam.py --model yolo_nas_m --device cuda
6+
"""
7+
8+
import argparse
9+
10+
import cv2
11+
12+
from modern_yolonas.inference.detect import Detector
13+
14+
15+
def main():
16+
parser = argparse.ArgumentParser(description="YOLO-NAS webcam detection")
17+
parser.add_argument("--model", default="yolo_nas_s", choices=["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"])
18+
parser.add_argument("--device", default="cuda")
19+
parser.add_argument("--conf", type=float, default=0.25)
20+
parser.add_argument("--camera", type=int, default=0, help="Camera index")
21+
args = parser.parse_args()
22+
23+
det = Detector(args.model, device=args.device, conf_threshold=args.conf)
24+
25+
print("Press 'q' to quit")
26+
for frame_idx, result in det.detect_video(source=args.camera):
27+
annotated = result.visualize()
28+
cv2.imshow("YOLO-NAS", annotated)
29+
if cv2.waitKey(1) & 0xFF == ord("q"):
30+
break
31+
32+
cv2.destroyAllWindows()
33+
34+
35+
if __name__ == "__main__":
36+
main()

src/modern_yolonas/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from modern_yolonas._version import __version__
44
from modern_yolonas.model import YoloNAS
55
from modern_yolonas.weights import load_pretrained
6+
from modern_yolonas.inference.detect import Detector, Detection
67

78

89
def yolo_nas_s(pretrained: bool = False, num_classes: int = 80) -> YoloNAS:
@@ -36,4 +37,6 @@ def yolo_nas_l(pretrained: bool = False, num_classes: int = 80) -> YoloNAS:
3637
"yolo_nas_m",
3738
"yolo_nas_l",
3839
"load_pretrained",
40+
"Detector",
41+
"Detection",
3942
]

src/modern_yolonas/cli/detect_cmd.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import click
88

9+
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
10+
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"}
11+
912

1013
@click.command()
1114
@click.option("--model", default="yolo_nas_s", type=click.Choice(["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"]))
@@ -15,7 +18,19 @@
1518
@click.option("--device", default="cuda", help="Device (cuda or cpu).")
1619
@click.option("--output", default="results", help="Output directory.")
1720
@click.option("--input-size", default=640, help="Model input size.")
18-
def detect(model: str, source: str, conf: float, iou: float, device: str, output: str, input_size: int):
21+
@click.option("--skip-frames", default=0, help="Process every N-th frame for video (0 = every frame).")
22+
@click.option("--codec", default="mp4v", help="Video output codec (e.g. mp4v, XVID, avc1).")
23+
def detect(
24+
model: str,
25+
source: str,
26+
conf: float,
27+
iou: float,
28+
device: str,
29+
output: str,
30+
input_size: int,
31+
skip_frames: int,
32+
codec: str,
33+
):
1934
"""Run object detection on images or video."""
2035
from rich.console import Console
2136

@@ -31,16 +46,51 @@ def detect(model: str, source: str, conf: float, iou: float, device: str, output
3146
source_path = Path(source)
3247

3348
if source_path.is_dir():
49+
# Directory of images
3450
files = sorted(source_path.glob("*.*"))
35-
files = [f for f in files if f.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp")]
51+
files = [f for f in files if f.suffix.lower() in IMAGE_EXTENSIONS]
52+
_detect_images(det, files, out_dir, console)
53+
54+
elif source_path.suffix.lower() in VIDEO_EXTENSIONS:
55+
# Video file
56+
_detect_video(det, source_path, out_dir, console, skip_frames, codec)
57+
58+
elif source_path.suffix.lower() in IMAGE_EXTENSIONS:
59+
# Single image
60+
_detect_images(det, [source_path], out_dir, console)
61+
3662
else:
37-
files = [source_path]
63+
console.print(f"[red]Unknown source type: {source_path.suffix}[/red]")
64+
raise click.Abort()
65+
3866

67+
def _detect_images(det, files: list[Path], out_dir: Path, console):
68+
"""Run detection on a list of image files."""
3969
for f in files:
4070
console.print(f"Processing {f.name}...")
4171
result = det(str(f))
4272
out_path = out_dir / f.name
4373
result.save(out_path)
44-
console.print(f" {len(result.boxes)} detections → {out_path}")
74+
console.print(f" {len(result.boxes)} detections -> {out_path}")
75+
76+
console.print(f"[green]Done! {len(files)} images saved to {out_dir}[/green]")
77+
78+
79+
def _detect_video(det, source_path: Path, out_dir: Path, console, skip_frames: int, codec: str):
80+
"""Run detection on a video file."""
81+
out_path = out_dir / source_path.name
82+
console.print(f"Processing video {source_path.name}...")
83+
84+
stats = det.detect_video_to_file(
85+
source=str(source_path),
86+
output=str(out_path),
87+
codec=codec,
88+
skip_frames=skip_frames,
89+
)
4590

46-
console.print(f"[green]Done! Results saved to {out_dir}[/green]")
91+
console.print(
92+
f" {stats['total_frames']} frames, "
93+
f"{stats['processed_frames']} processed, "
94+
f"{stats['total_detections']} total detections"
95+
)
96+
console.print(f"[green]Done! Video saved to {out_path}[/green]")

0 commit comments

Comments
 (0)