Skip to content

Commit 225b7f5

Browse files
authored
Include fixes to SRTP (#619)
* srtp: Fix roll over count calculation This brings in fixes for ROC over the last couple of years from PION * srtp: Fix packet length validation Ported from PION
1 parent ea8fb77 commit 225b7f5

File tree

9 files changed

+331
-101
lines changed

9 files changed

+331
-101
lines changed

srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs

+3
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ impl Cipher for CipherAesCmHmacSha1 {
183183
}
184184

185185
let tail_offset = encrypted_len - (self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE);
186+
if tail_offset < 8 {
187+
return Err(Error::ErrTooShortRtcp);
188+
}
186189

187190
let mut writer = Vec::with_capacity(tail_offset);
188191

srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs

+3
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ impl Cipher for CipherAesCmHmacSha1 {
217217
}
218218

219219
let tail_offset = encrypted_len - (self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE);
220+
if tail_offset < 8 {
221+
return Err(Error::ErrTooShortRtcp);
222+
}
220223

221224
let mut writer = Vec::with_capacity(tail_offset);
222225

srtp/src/context/context_test.rs

+114-32
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ fn test_valid_packet_counter() -> Result<()> {
113113
0xcf, 0x90, 0x1e, 0xa5, 0xda, 0xd3, 0x2c, 0x15, 0x00, 0xa2, 0x24, 0xae, 0xae, 0xaf, 0x00,
114114
0x00,
115115
];
116-
let counter = generate_counter(32846, s.rollover_counter, s.ssrc, &srtp_session_salt);
116+
let counter = generate_counter(32846, (s.index >> 16) as _, s.ssrc, &srtp_session_salt);
117117
assert_eq!(
118118
counter, expected_counter,
119119
"Session Key {counter:?} does not match expected {expected_counter:?}",
@@ -124,15 +124,13 @@ fn test_valid_packet_counter() -> Result<()> {
124124

125125
#[test]
126126
fn test_rollover_count() -> Result<()> {
127-
let mut s = SrtpSsrcState {
128-
ssrc: DEFAULT_SSRC,
129-
..Default::default()
130-
};
127+
let mut s = SrtpSsrcState::default();
131128

132129
// Set initial seqnum
133-
let roc = s.next_rollover_count(65530);
130+
let (roc, diff, ovf) = s.next_rollover_count(65530);
134131
assert_eq!(roc, 0, "Initial rolloverCounter must be 0");
135-
s.update_rollover_count(65530);
132+
assert!(!ovf, "Should not overflow");
133+
s.update_rollover_count(65530, diff);
136134

137135
// Invalid packets never update ROC
138136
s.next_rollover_count(0);
@@ -142,64 +140,148 @@ fn test_rollover_count() -> Result<()> {
142140
s.next_rollover_count(0);
143141

144142
// We rolled over to 0
145-
let roc = s.next_rollover_count(0);
143+
let (roc, diff, ovf) = s.next_rollover_count(0);
146144
assert_eq!(roc, 1, "rolloverCounter was not updated after it crossed 0");
147-
s.update_rollover_count(0);
145+
assert!(!ovf, "Should not overflow");
146+
s.update_rollover_count(0, diff);
148147

149-
let roc = s.next_rollover_count(65530);
148+
let (roc, diff, ovf) = s.next_rollover_count(65530);
150149
assert_eq!(
151150
roc, 0,
152151
"rolloverCounter was not updated when it rolled back, failed to handle out of order"
153152
);
154-
s.update_rollover_count(65530);
153+
assert!(!ovf, "Should not overflow");
154+
s.update_rollover_count(65530, diff);
155155

156-
let roc = s.next_rollover_count(5);
156+
let (roc, diff, ovf) = s.next_rollover_count(5);
157157
assert_eq!(
158158
roc, 1,
159159
"rolloverCounter was not updated when it rolled over initial, to handle out of order"
160160
);
161-
s.update_rollover_count(5);
162-
163-
s.next_rollover_count(6);
164-
s.update_rollover_count(6);
165-
166-
s.next_rollover_count(7);
167-
s.update_rollover_count(7);
168-
169-
let roc = s.next_rollover_count(8);
161+
assert!(!ovf, "Should not overflow");
162+
s.update_rollover_count(5, diff);
163+
164+
let (_, diff, _) = s.next_rollover_count(6);
165+
s.update_rollover_count(6, diff);
166+
let (_, diff, _) = s.next_rollover_count(7);
167+
s.update_rollover_count(7, diff);
168+
let (roc, diff, _) = s.next_rollover_count(8);
170169
assert_eq!(
171170
roc, 1,
172171
"rolloverCounter was improperly updated for non-significant packets"
173172
);
174-
s.update_rollover_count(8);
173+
s.update_rollover_count(8, diff);
175174

176175
// valid packets never update ROC
177-
let roc = s.next_rollover_count(0x4000);
176+
let (roc, diff, ovf) = s.next_rollover_count(0x4000);
178177
assert_eq!(
179178
roc, 1,
180179
"rolloverCounter was improperly updated for non-significant packets"
181180
);
182-
s.update_rollover_count(0x4000);
183-
184-
let roc = s.next_rollover_count(0x8000);
181+
assert!(!ovf, "Should not overflow");
182+
s.update_rollover_count(0x4000, diff);
183+
let (roc, diff, ovf) = s.next_rollover_count(0x8000);
185184
assert_eq!(
186185
roc, 1,
187186
"rolloverCounter was improperly updated for non-significant packets"
188187
);
189-
s.update_rollover_count(0x8000);
190-
191-
let roc = s.next_rollover_count(0xFFFF);
188+
assert!(!ovf, "Should not overflow");
189+
s.update_rollover_count(0x8000, diff);
190+
let (roc, diff, ovf) = s.next_rollover_count(0xFFFF);
192191
assert_eq!(
193192
roc, 1,
194193
"rolloverCounter was improperly updated for non-significant packets"
195194
);
196-
s.update_rollover_count(0xFFFF);
197-
198-
let roc = s.next_rollover_count(0);
195+
assert!(!ovf, "Should not overflow");
196+
s.update_rollover_count(0xFFFF, diff);
197+
let (roc, _, ovf) = s.next_rollover_count(0);
199198
assert_eq!(
200199
roc, 2,
201200
"rolloverCounter must be incremented after wrapping, got {roc}"
202201
);
202+
assert!(!ovf, "Should not overflow");
203+
204+
Ok(())
205+
}
206+
207+
#[test]
208+
fn test_rollover_count_overflow() -> Result<()> {
209+
let mut s = SrtpSsrcState {
210+
index: (MAX_ROC as u64) << 16,
211+
..Default::default()
212+
};
213+
s.update_rollover_count(0xFFFF, 0);
214+
let (_, _, ovf) = s.next_rollover_count(0);
215+
assert!(ovf, "Should overflow");
216+
217+
Ok(())
218+
}
219+
220+
#[test]
221+
fn test_rollover_count_2() -> Result<()> {
222+
let mut s = SrtpSsrcState::default();
223+
224+
let (roc, diff, ovf) = s.next_rollover_count(30123);
225+
assert_eq!(roc, 0, "Initial rolloverCounter must be 0");
226+
assert!(!ovf, "Should not overflow");
227+
s.update_rollover_count(30123, diff);
228+
229+
// 62892 = 30123 + (1 << 15) + 1
230+
let (roc, diff, ovf) = s.next_rollover_count(62892);
231+
assert_eq!(roc, 0, "Initial rolloverCounter must be 0");
232+
assert!(!ovf, "Should not overflow");
233+
s.update_rollover_count(62892, diff);
234+
let (roc, diff, ovf) = s.next_rollover_count(204);
235+
assert_eq!(roc, 1, "rolloverCounter was not updated after it crossed 0");
236+
assert!(!ovf, "Should not overflow");
237+
s.update_rollover_count(62892, diff);
238+
let (roc, diff, ovf) = s.next_rollover_count(64535);
239+
assert_eq!(
240+
roc, 0,
241+
"rolloverCounter was not updated when it rolled back, failed to handle out of order"
242+
);
243+
assert!(!ovf, "Should not overflow");
244+
s.update_rollover_count(64535, diff);
245+
let (roc, diff, ovf) = s.next_rollover_count(205);
246+
assert_eq!(
247+
roc, 1,
248+
"rolloverCounter was improperly updated for non-significant packets"
249+
);
250+
assert!(!ovf, "Should not overflow");
251+
s.update_rollover_count(205, diff);
252+
let (roc, diff, ovf) = s.next_rollover_count(1);
253+
assert_eq!(
254+
roc, 1,
255+
"rolloverCounter was improperly updated for non-significant packets"
256+
);
257+
assert!(!ovf, "Should not overflow");
258+
s.update_rollover_count(1, diff);
259+
260+
let (roc, diff, ovf) = s.next_rollover_count(64532);
261+
assert_eq!(
262+
roc, 0,
263+
"rolloverCounter was improperly updated for non-significant packets"
264+
);
265+
assert!(!ovf, "Should not overflow");
266+
s.update_rollover_count(64532, diff);
267+
let (roc, diff, ovf) = s.next_rollover_count(64534);
268+
assert_eq!(
269+
roc, 0,
270+
"index was improperly updated for non-significant packets"
271+
);
272+
assert!(!ovf, "Should not overflow");
273+
s.update_rollover_count(64534, diff);
274+
let (roc, diff, ovf) = s.next_rollover_count(64532);
275+
assert_eq!(
276+
roc, 0,
277+
"index was improperly updated for non-significant packets"
278+
);
279+
assert!(!ovf, "Should not overflow");
280+
s.update_rollover_count(64532, diff);
281+
let (roc, diff, ovf) = s.next_rollover_count(205);
282+
assert_eq!(roc, 1, "index was not updated after it crossed 0");
283+
assert!(!ovf, "Should not overflow");
284+
s.update_rollover_count(205, diff);
203285

204286
Ok(())
205287
}

srtp/src/context/mod.rs

+46-53
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@ use crate::protection_profile::*;
1919
pub mod srtcp;
2020
pub mod srtp;
2121

22-
const MAX_ROC_DISORDER: u16 = 100;
22+
const MAX_ROC: u32 = u32::MAX;
23+
const SEQ_NUM_MEDIAN: u16 = 1 << 15;
24+
const SEQ_NUM_MAX: u16 = u16::MAX;
2325

2426
/// Encrypt/Decrypt state for a single SRTP SSRC
2527
#[derive(Default)]
2628
pub(crate) struct SrtpSsrcState {
2729
ssrc: u32,
28-
rollover_counter: u32,
30+
index: u64,
2931
rollover_has_processed: bool,
30-
last_sequence_number: u16,
3132
replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
3233
}
3334

@@ -40,61 +41,49 @@ pub(crate) struct SrtcpSsrcState {
4041
}
4142

4243
impl SrtpSsrcState {
43-
pub fn next_rollover_count(&self, sequence_number: u16) -> u32 {
44-
let mut roc = self.rollover_counter;
45-
46-
if !self.rollover_has_processed {
47-
} else if sequence_number == 0 {
48-
// We exactly hit the rollover count
49-
50-
// Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
51-
// otherwise we already incremented for disorder
52-
if self.last_sequence_number > MAX_ROC_DISORDER {
53-
roc += 1;
44+
pub fn next_rollover_count(&self, sequence_number: u16) -> (u32, i32, bool) {
45+
let local_roc = (self.index >> 16) as u32;
46+
let local_seq = self.index as u16;
47+
48+
let mut guess_roc = local_roc;
49+
50+
let diff = if self.rollover_has_processed {
51+
let seq = (sequence_number as i32).wrapping_sub(local_seq as i32);
52+
// When local_roc is equal to 0, and entering seq-local_seq > SEQ_NUM_MEDIAN
53+
// judgment, it will cause guess_roc calculation error
54+
if self.index > SEQ_NUM_MEDIAN as _ {
55+
if local_seq < SEQ_NUM_MEDIAN {
56+
if seq > SEQ_NUM_MEDIAN as i32 {
57+
guess_roc = local_roc.wrapping_sub(1);
58+
seq.wrapping_sub(SEQ_NUM_MAX as i32 + 1)
59+
} else {
60+
seq
61+
}
62+
} else if local_seq - SEQ_NUM_MEDIAN > sequence_number {
63+
guess_roc = local_roc.wrapping_add(1);
64+
seq.wrapping_add(SEQ_NUM_MAX as i32 + 1)
65+
} else {
66+
seq
67+
}
68+
} else {
69+
// local_roc is equal to 0
70+
seq
5471
}
55-
} else if self.last_sequence_number < MAX_ROC_DISORDER
56-
&& sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
57-
{
58-
// Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
59-
// So we fell behind, drop to account for jitter
60-
roc -= 1;
61-
} else if sequence_number < MAX_ROC_DISORDER
62-
&& self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
63-
{
64-
// our current is within a MAX_ROCDISORDER of 0
65-
// and our last sequence number was a high sequence number, increment to account for jitter
66-
roc += 1;
67-
}
72+
} else {
73+
0i32
74+
};
6875

69-
roc
76+
(guess_roc, diff, (guess_roc == 0 && local_roc == MAX_ROC))
7077
}
7178

7279
/// https://tools.ietf.org/html/rfc3550#appendix-A.1
73-
pub fn update_rollover_count(&mut self, sequence_number: u16) {
80+
pub fn update_rollover_count(&mut self, sequence_number: u16, diff: i32) {
7481
if !self.rollover_has_processed {
82+
self.index |= sequence_number as u64;
7583
self.rollover_has_processed = true;
76-
} else if sequence_number == 0 {
77-
// We exactly hit the rollover count
78-
79-
// Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
80-
// otherwise we already incremented for disorder
81-
if self.last_sequence_number > MAX_ROC_DISORDER {
82-
self.rollover_counter += 1;
83-
}
84-
} else if self.last_sequence_number < MAX_ROC_DISORDER
85-
&& sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
86-
{
87-
// Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
88-
// So we fell behind, drop to account for jitter
89-
self.rollover_counter -= 1;
90-
} else if sequence_number < MAX_ROC_DISORDER
91-
&& self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
92-
{
93-
// our current is within a MAX_ROCDISORDER of 0
94-
// and our last sequence number was a high sequence number, increment to account for jitter
95-
self.rollover_counter += 1;
84+
} else {
85+
self.index = self.index.wrapping_add(diff as _);
9686
}
97-
self.last_sequence_number = sequence_number;
9887
}
9988
}
10089

@@ -181,12 +170,16 @@ impl Context {
181170

182171
/// roc returns SRTP rollover counter value of specified SSRC.
183172
fn get_roc(&self, ssrc: u32) -> Option<u32> {
184-
self.srtp_ssrc_states.get(&ssrc).map(|s| s.rollover_counter)
173+
self.srtp_ssrc_states
174+
.get(&ssrc)
175+
.map(|s| (s.index >> 16) as _)
185176
}
186177

187178
/// set_roc sets SRTP rollover counter value of specified SSRC.
188179
fn set_roc(&mut self, ssrc: u32, roc: u32) {
189-
self.get_srtp_ssrc_state(ssrc).rollover_counter = roc;
180+
let state = self.get_srtp_ssrc_state(ssrc);
181+
state.index = (roc as u64) << 16;
182+
state.rollover_has_processed = false;
190183
}
191184

192185
/// index returns SRTCP index value of specified SSRC.
@@ -196,6 +189,6 @@ impl Context {
196189

197190
/// set_index sets SRTCP index value of specified SSRC.
198191
fn set_index(&mut self, ssrc: u32, index: usize) {
199-
self.get_srtcp_ssrc_state(ssrc).srtcp_index = index;
192+
self.get_srtcp_ssrc_state(ssrc).srtcp_index = index % (MAX_SRTCP_INDEX + 1);
200193
}
201194
}

srtp/src/context/srtcp.rs

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ impl Context {
3131
/// EncryptRTCP marshals and encrypts an RTCP packet, writing to the dst buffer provided.
3232
/// If the dst buffer does not have the capacity to hold `len(plaintext) + 14` bytes, a new one will be allocated and returned.
3333
pub fn encrypt_rtcp(&mut self, decrypted: &[u8]) -> Result<Bytes> {
34+
if decrypted.len() < 8 {
35+
return Err(Error::ErrTooShortRtcp);
36+
}
37+
3438
let mut buf = decrypted;
3539
rtcp::header::Header::unmarshal(&mut buf)?;
3640

0 commit comments

Comments
 (0)