66
77import click
88
9+ IMAGE_EXTENSIONS = {".jpg" , ".jpeg" , ".png" , ".bmp" , ".tiff" , ".webp" }
10+ VIDEO_EXTENSIONS = {".mp4" , ".avi" , ".mov" , ".mkv" , ".webm" , ".flv" , ".wmv" , ".m4v" }
11+
912
1013@click .command ()
1114@click .option ("--model" , default = "yolo_nas_s" , type = click .Choice (["yolo_nas_s" , "yolo_nas_m" , "yolo_nas_l" ]))
1518@click .option ("--device" , default = "cuda" , help = "Device (cuda or cpu)." )
1619@click .option ("--output" , default = "results" , help = "Output directory." )
1720@click .option ("--input-size" , default = 640 , help = "Model input size." )
18- def detect (model : str , source : str , conf : float , iou : float , device : str , output : str , input_size : int ):
21+ @click .option ("--skip-frames" , default = 0 , help = "Process every N-th frame for video (0 = every frame)." )
22+ @click .option ("--codec" , default = "mp4v" , help = "Video output codec (e.g. mp4v, XVID, avc1)." )
23+ def detect (
24+ model : str ,
25+ source : str ,
26+ conf : float ,
27+ iou : float ,
28+ device : str ,
29+ output : str ,
30+ input_size : int ,
31+ skip_frames : int ,
32+ codec : str ,
33+ ):
1934 """Run object detection on images or video."""
2035 from rich .console import Console
2136
@@ -31,16 +46,51 @@ def detect(model: str, source: str, conf: float, iou: float, device: str, output
3146 source_path = Path (source )
3247
3348 if source_path .is_dir ():
49+ # Directory of images
3450 files = sorted (source_path .glob ("*.*" ))
35- files = [f for f in files if f .suffix .lower () in (".jpg" , ".jpeg" , ".png" , ".bmp" )]
51+ files = [f for f in files if f .suffix .lower () in IMAGE_EXTENSIONS ]
52+ _detect_images (det , files , out_dir , console )
53+
54+ elif source_path .suffix .lower () in VIDEO_EXTENSIONS :
55+ # Video file
56+ _detect_video (det , source_path , out_dir , console , skip_frames , codec )
57+
58+ elif source_path .suffix .lower () in IMAGE_EXTENSIONS :
59+ # Single image
60+ _detect_images (det , [source_path ], out_dir , console )
61+
3662 else :
37- files = [source_path ]
63+ console .print (f"[red]Unknown source type: { source_path .suffix } [/red]" )
64+ raise click .Abort ()
65+
3866
67+ def _detect_images (det , files : list [Path ], out_dir : Path , console ):
68+ """Run detection on a list of image files."""
3969 for f in files :
4070 console .print (f"Processing { f .name } ..." )
4171 result = det (str (f ))
4272 out_path = out_dir / f .name
4373 result .save (out_path )
44- console .print (f" { len (result .boxes )} detections → { out_path } " )
74+ console .print (f" { len (result .boxes )} detections -> { out_path } " )
75+
76+ console .print (f"[green]Done! { len (files )} images saved to { out_dir } [/green]" )
77+
78+
79+ def _detect_video (det , source_path : Path , out_dir : Path , console , skip_frames : int , codec : str ):
80+ """Run detection on a video file."""
81+ out_path = out_dir / source_path .name
82+ console .print (f"Processing video { source_path .name } ..." )
83+
84+ stats = det .detect_video_to_file (
85+ source = str (source_path ),
86+ output = str (out_path ),
87+ codec = codec ,
88+ skip_frames = skip_frames ,
89+ )
4590
46- console .print (f"[green]Done! Results saved to { out_dir } [/green]" )
91+ console .print (
92+ f" { stats ['total_frames' ]} frames, "
93+ f"{ stats ['processed_frames' ]} processed, "
94+ f"{ stats ['total_detections' ]} total detections"
95+ )
96+ console .print (f"[green]Done! Video saved to { out_path } [/green]" )
0 commit comments