55import requests
66from unittest .mock import Mock , patch
77import torch
8+ from lxml import etree
89import numpy as np
910import pandas as pd
1011from kabr_tools import (
@@ -97,12 +98,13 @@ def test_run(self):
9798 @patch ('kabr_tools.miniscene2behavior.process_cv2_inputs' )
9899 @patch ('kabr_tools.miniscene2behavior.cv2.VideoCapture' )
99100 def test_matching_tracks (self , video_capture , process_cv2_inputs ):
100-
101- # Create fake model that always returns a prediction of 1
101+ # create fake model that weights class 98
102102 mock_model = Mock ()
103- mock_model .return_value = torch .tensor ([1 ])
103+ prob = torch .zeros (99 )
104+ prob [- 1 ] = 1
105+ mock_model .return_value = prob
104106
105- # Create fake cfg
107+ # create fake cfg
106108 mock_config = Mock (
107109 DATA = Mock (NUM_FRAMES = 16 ,
108110 SAMPLING_RATE = 5 ,
@@ -111,25 +113,36 @@ def test_matching_tracks(self, video_capture, process_cv2_inputs):
111113 OUTPUT_DIR = ''
112114 )
113115
114- # Create fake video capture
116+ # create fake video capture
115117 vc = video_capture .return_value
116118 vc .read .return_value = True , np .zeros ((8 , 8 , 3 ), np .uint8 )
117- vc .get .return_value = 1
119+ vc .get .return_value = 21
118120
119121 self .output = '/tmp/annotation_data.csv'
122+ miniscene_dir = os .path .join (EXAMPLESDIR , "MINISCENE1" )
123+ video_name = "DJI"
120124
121125 annotate_miniscene (cfg = mock_config ,
122126 model = mock_model ,
123- miniscene_path = os .path .join (
124- EXAMPLESDIR , "MINISCENE1" ),
125- video = 'DJI' ,
127+ miniscene_path = miniscene_dir ,
128+ video = video_name ,
126129 output_path = self .output )
127130
128- # Read in output CSV and make sure we have the expected columns and at least one row
131+ # check output CSV
129132 df = pd .read_csv (self .output , sep = ' ' )
130133 self .assertEqual (list (df .columns ), [
131134 "video" , "track" , "frame" , "label" ])
132- self .assertGreater (len (df .index ), 0 )
135+ row_ct = 0
136+
137+ root = etree .parse (
138+ f"{ miniscene_dir } /metadata/DJI_tracks.xml" ).getroot ()
139+ for track in root .iterfind ("track" ):
140+ track_id = int (track .get ("id" ))
141+ for box in track .iterfind ("box" ):
142+ row_val = [video_name , track_id , int (box .get ("frame" )), 98 ]
143+ self .assertEqual (list (df .loc [row_ct ]), row_val )
144+ row_ct += 1
145+ self .assertEqual (len (df .index ), row_ct )
133146
134147 @patch ('kabr_tools.miniscene2behavior.process_cv2_inputs' )
135148 @patch ('kabr_tools.miniscene2behavior.cv2.VideoCapture' )
@@ -151,9 +164,11 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs):
151164 # Create fake video capture
152165 vc = video_capture .return_value
153166 vc .read .return_value = True , np .zeros ((8 , 8 , 3 ), np .uint8 )
154- vc .get .return_value = 1
167+ vc .get .return_value = 21
155168
156169 self .output = '/tmp/annotation_data.csv'
170+ miniscene_dir = os .path .join (EXAMPLESDIR , "MINISCENE2" )
171+ video_name = "DJI"
157172
158173 annotate_miniscene (cfg = mock_config ,
159174 model = mock_model ,
@@ -162,11 +177,22 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs):
162177 video = 'DJI' ,
163178 output_path = self .output )
164179
165- # Read in output CSV and make sure we have the expected columns and at least one row
180+ # check output CSV
166181 df = pd .read_csv (self .output , sep = ' ' )
167182 self .assertEqual (list (df .columns ), [
168183 "video" , "track" , "frame" , "label" ])
169- self .assertGreater (len (df .index ), 0 )
184+ row_ct = 0
185+
186+ root = etree .parse (
187+ f"{ miniscene_dir } /metadata/DJI_tracks.xml" ).getroot ()
188+ for track in root .iterfind ("track" ):
189+ track_id = int (track .get ("id" ))
190+ for box in track .iterfind ("box" ):
191+ row_val = [video_name , track_id , int (box .get ("frame" )), 0 ]
192+ self .assertEqual (list (df .loc [row_ct ]), row_val )
193+ row_ct += 1
194+ self .assertEqual (len (df .index ), row_ct )
195+
170196
171197 def test_parse_arg_min (self ):
172198 # parse arguments
0 commit comments