Skip to content

Commit babcc3b

Browse files
zhong-alegrace479
andauthored
Find frames per track (#65)
* Find frames per track * Account for track extraction * Add check to miniscene2behavior * Use index because track frames can be noncontiguous * Apply suggestions from code review Co-authored-by: Elizabeth Campolongo <[email protected]> * Add tests --------- Co-authored-by: Elizabeth Campolongo <[email protected]>
1 parent 8b2b110 commit babcc3b

File tree

2 files changed

+59
-24
lines changed

2 files changed

+59
-24
lines changed

src/kabr_tools/miniscene2behavior.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@ def get_input_clip(cap: cv2.VideoCapture, cfg: CfgNode, keyframe_idx: int) -> li
1919
# https://github.com/facebookresearch/SlowFast/blob/bac7b672f40d44166a84e8c51d1a5ba367ace816/slowfast/visualization/ava_demo_precomputed_boxes.py
2020
seq_length = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE
2121
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
22+
assert keyframe_idx < total_frames, f"keyframe_idx: {keyframe_idx}" \
23+
f" >= total_frames: {total_frames}"
2224
seq = get_sequence(
2325
keyframe_idx,
2426
seq_length // 2,
2527
cfg.DATA.SAMPLING_RATE,
2628
total_frames,
2729
)
30+
2831
clip = []
2932
for frame_idx in seq:
3033
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
@@ -124,29 +127,34 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module,
124127

125128
# find all tracks
126129
tracks = []
130+
frames = {}
127131
for track in root.iterfind("track"):
128132
track_id = track.attrib["id"]
129133
tracks.append(track_id)
134+
frames[track_id] = []
130135

131-
# find all frames
132-
# TODO: rewrite - some tracks may have different frames
133-
assert len(tracks) > 0, "No tracks found in track file"
134-
frames = []
135-
for box in track.iterfind("box"):
136-
frames.append(int(box.attrib["frame"]))
136+
# find all frames
137+
for box in track.iterfind("box"):
138+
frames[track_id].append(int(box.attrib["frame"]))
137139

138140
# run model on miniscene
139141
for track in tracks:
140142
video_file = f"{miniscene_path}/{track}.mp4"
141143
cap = cv2.VideoCapture(video_file)
142-
for frame in tqdm(frames, desc=f"{track} frames"):
143-
inputs = get_input_clip(cap, cfg, frame)
144+
index = 0
145+
for frame in tqdm(frames[track], desc=f"{track} frames"):
146+
try:
147+
inputs = get_input_clip(cap, cfg, index)
148+
except AssertionError as e:
149+
print(e)
150+
break
151+
index += 1
144152

145153
if cfg.NUM_GPUS:
146154
# transfer the data to the current GPU device.
147155
if isinstance(inputs, (list,)):
148-
for i in range(len(inputs)):
149-
inputs[i] = inputs[i].cuda(non_blocking=True)
156+
for i, input_clip in enumerate(inputs):
157+
inputs[i] = input_clip.cuda(non_blocking=True)
150158
else:
151159
inputs = inputs.cuda(non_blocking=True)
152160

@@ -163,6 +171,7 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module,
163171
if frame % 20 == 0:
164172
pd.DataFrame(label_data).to_csv(
165173
output_path, sep=" ", index=False)
174+
cap.release()
166175
pd.DataFrame(label_data).to_csv(output_path, sep=" ", index=False)
167176

168177

tests/test_miniscene2behavior.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import requests
66
from unittest.mock import Mock, patch
77
import torch
8+
from lxml import etree
89
import numpy as np
910
import pandas as pd
1011
from kabr_tools import (
@@ -97,12 +98,13 @@ def test_run(self):
9798
@patch('kabr_tools.miniscene2behavior.process_cv2_inputs')
9899
@patch('kabr_tools.miniscene2behavior.cv2.VideoCapture')
99100
def test_matching_tracks(self, video_capture, process_cv2_inputs):
100-
101-
# Create fake model that always returns a prediction of 1
101+
# create fake model that weights class 98
102102
mock_model = Mock()
103-
mock_model.return_value = torch.tensor([1])
103+
prob = torch.zeros(99)
104+
prob[-1] = 1
105+
mock_model.return_value = prob
104106

105-
# Create fake cfg
107+
# create fake cfg
106108
mock_config = Mock(
107109
DATA=Mock(NUM_FRAMES=16,
108110
SAMPLING_RATE=5,
@@ -111,25 +113,36 @@ def test_matching_tracks(self, video_capture, process_cv2_inputs):
111113
OUTPUT_DIR=''
112114
)
113115

114-
# Create fake video capture
116+
# create fake video capture
115117
vc = video_capture.return_value
116118
vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8)
117-
vc.get.return_value = 1
119+
vc.get.return_value = 21
118120

119121
self.output = '/tmp/annotation_data.csv'
122+
miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE1")
123+
video_name = "DJI"
120124

121125
annotate_miniscene(cfg=mock_config,
122126
model=mock_model,
123-
miniscene_path=os.path.join(
124-
EXAMPLESDIR, "MINISCENE1"),
125-
video='DJI',
127+
miniscene_path=miniscene_dir,
128+
video=video_name,
126129
output_path=self.output)
127130

128-
# Read in output CSV and make sure we have the expected columns and at least one row
131+
# check output CSV
129132
df = pd.read_csv(self.output, sep=' ')
130133
self.assertEqual(list(df.columns), [
131134
"video", "track", "frame", "label"])
132-
self.assertGreater(len(df.index), 0)
135+
row_ct = 0
136+
137+
root = etree.parse(
138+
f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot()
139+
for track in root.iterfind("track"):
140+
track_id = int(track.get("id"))
141+
for box in track.iterfind("box"):
142+
row_val = [video_name, track_id, int(box.get("frame")), 98]
143+
self.assertEqual(list(df.loc[row_ct]), row_val)
144+
row_ct += 1
145+
self.assertEqual(len(df.index), row_ct)
133146

134147
@patch('kabr_tools.miniscene2behavior.process_cv2_inputs')
135148
@patch('kabr_tools.miniscene2behavior.cv2.VideoCapture')
@@ -151,9 +164,11 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs):
151164
# Create fake video capture
152165
vc = video_capture.return_value
153166
vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8)
154-
vc.get.return_value = 1
167+
vc.get.return_value = 21
155168

156169
self.output = '/tmp/annotation_data.csv'
170+
miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE2")
171+
video_name = "DJI"
157172

158173
annotate_miniscene(cfg=mock_config,
159174
model=mock_model,
@@ -162,11 +177,22 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs):
162177
video='DJI',
163178
output_path=self.output)
164179

165-
# Read in output CSV and make sure we have the expected columns and at least one row
180+
# check output CSV
166181
df = pd.read_csv(self.output, sep=' ')
167182
self.assertEqual(list(df.columns), [
168183
"video", "track", "frame", "label"])
169-
self.assertGreater(len(df.index), 0)
184+
row_ct = 0
185+
186+
root = etree.parse(
187+
f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot()
188+
for track in root.iterfind("track"):
189+
track_id = int(track.get("id"))
190+
for box in track.iterfind("box"):
191+
row_val = [video_name, track_id, int(box.get("frame")), 0]
192+
self.assertEqual(list(df.loc[row_ct]), row_val)
193+
row_ct += 1
194+
self.assertEqual(len(df.index), row_ct)
195+
170196

171197
def test_parse_arg_min(self):
172198
# parse arguments

0 commit comments

Comments
 (0)