-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrelabel_episodes.py
More file actions
99 lines (77 loc) · 3.82 KB
/
relabel_episodes.py
File metadata and controls
99 lines (77 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Relabel slip episodes to separate normal steps from actual slip events
Uses feature thresholds to automatically detect when slip occurs within an episode
"""
import pandas as pd
import argparse
import numpy as np
def relabel_slip_episodes(input_file, output_file, rms_threshold=0.1, variance_threshold=0.02):
"""
Relabel slip episodes: normal steps before slip -> label 0, slip event -> label 1
Args:
input_file: Input CSV file
output_file: Output CSV file
rms_threshold: RMS value threshold to detect slip (feature_1)
variance_threshold: Variance threshold to detect slip (feature_2 or feature_3)
"""
df = pd.read_csv(input_file)
print(f"Relabeling episodes in {input_file}")
print(f"Thresholds: RMS > {rms_threshold}, Variance > {variance_threshold}")
print("="*60)
# Get slip episodes
slip_episodes = df[(df['episode_type'] == 'slip') & (df['label'].notna())].copy()
if len(slip_episodes) == 0:
print("No slip episodes found!")
return
episodes = slip_episodes['episode_id'].unique()
print(f"Found {len(episodes)} slip episodes")
relabeled_count = 0
for ep_id in episodes:
ep_samples = slip_episodes[slip_episodes['episode_id'] == ep_id].sort_values('timestamp')
# Find where slip actually occurs based on feature thresholds
# Slip typically has higher RMS and variance
rms_values = ep_samples['feature_1'].values
variance_values = ep_samples['feature_2'].values
# Find first sample that exceeds thresholds
slip_start_idx = None
for i, (rms, var) in enumerate(zip(rms_values, variance_values)):
if rms > rms_threshold or var > variance_threshold:
slip_start_idx = i
break
# If no clear slip detected, use middle of episode as split point
if slip_start_idx is None:
slip_start_idx = len(ep_samples) // 2
print(f" Episode {ep_id}: No clear slip threshold, using middle point")
# Relabel: samples before slip_start_idx -> normal (0), after -> slip (1)
before_slip = ep_samples.iloc[:slip_start_idx].index
after_slip = ep_samples.iloc[slip_start_idx:].index
# Update labels in original dataframe
df.loc[before_slip, 'label'] = 0.0
df.loc[after_slip, 'label'] = 1.0
relabeled_count += len(before_slip)
print(f" Episode {ep_id}: {len(before_slip)} samples -> normal (0), {len(after_slip)} samples -> slip (1)")
# Save
df.to_csv(output_file, index=False)
# Statistics
labeled = df[df['label'].notna()]
normal_count = len(labeled[labeled['label'] == 0])
slip_count = len(labeled[labeled['label'] == 1])
print(f"\n{'='*60}")
print(f"Relabeling Summary:")
print(f" Samples relabeled from slip->normal: {relabeled_count}")
print(f" Final counts:")
print(f" Normal (0): {normal_count}")
print(f" Slip (1): {slip_count}")
print(f"\nSaved to: {output_file}")
def main():
parser = argparse.ArgumentParser(description='Relabel slip episodes to separate normal steps from slip')
parser.add_argument('--input', type=str, required=True, help='Input CSV file')
parser.add_argument('--output', type=str, required=True, help='Output CSV file')
parser.add_argument('--rms-threshold', type=float, default=0.1,
help='RMS threshold for slip detection (default: 0.1)')
parser.add_argument('--variance-threshold', type=float, default=0.02,
help='Variance threshold for slip detection (default: 0.02)')
args = parser.parse_args()
relabel_slip_episodes(args.input, args.output, args.rms_threshold, args.variance_threshold)
if __name__ == '__main__':
main()