Skip to content

Commit f16173f

Browse files
committed
Fix EndpointDescriptor
Fix creation of EndpointDescriptor and add unit tests for both possible lengths.
1 parent a983073 commit f16173f

File tree

1 file changed

+71
-10
lines changed

1 file changed

+71
-10
lines changed

usb_protocol/types/descriptors/standard.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,23 @@ class DeviceCapabilityTypes(IntEnum):
128128
)
129129

130130

131-
EndpointDescriptor = DescriptorFormat(
131+
EndpointDescriptorLength = construct.Rebuild(construct.Int8ul, 7 if (this.bRefresh is None) and (this.bSynchAddress is None) else 9)
132132

133+
EndpointDescriptor = DescriptorFormat(
133134
# [USB2.0: 9.6; USB Audio Device Class Definition 1.0: 4.6.1.1, 4.6.2.1]
134135
# Interfaces of the Audio 1.0 class extend their subordinate endpoint descriptors with
135136
# 2 additional bytes (extending it from 7 to 9 bytes). Thankfully, this is the only extension that
136137
# changes the length of a standard descriptor type, but we do have to handle this case in Construct.
137-
"bLength" / construct.Default(construct.OneOf(construct.Int8ul, [7, 9]), 7),
138+
"bLength" / EndpointDescriptorLength,
138139
"bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.ENDPOINT),
139140
"bEndpointAddress" / DescriptorField("Endpoint Address"),
140141
"bmAttributes" / DescriptorField("Attributes", default=2),
141142
"wMaxPacketSize" / DescriptorField("Maximum Packet Size", default=64),
142143
"bInterval" / DescriptorField("Polling interval", default=255),
143144

144145
# 2 bytes that are only present on endpoint descriptors for Audio 1.0 class interfaces.
145-
("bRefresh" / construct.Optional(construct.Int8ul)) * "Refresh Rate",
146-
("bSynchAddress" / construct.Optional(construct.Int8ul)) * "Synch Endpoint Address",
146+
("bRefresh" / construct.If(this.bLength == 9, construct.Int8ul)) * "Refresh Rate",
147+
("bSynchAddress" / construct.If(this.bLength == 9, construct.Int8ul)) * "Synch Endpoint Address",
147148
)
148149

149150

@@ -198,7 +199,6 @@ class DeviceCapabilityTypes(IntEnum):
198199
)
199200

200201

201-
202202
class DescriptorParserCases(unittest.TestCase):
203203

204204
STRING_DESCRIPTOR = bytes([
@@ -225,7 +225,6 @@ class DescriptorParserCases(unittest.TestCase):
225225
ord('s'), 0x00,
226226
])
227227

228-
229228
def test_string_descriptor_parse(self):
230229

231230
# Parse the relevant string...
@@ -236,23 +235,20 @@ def test_string_descriptor_parse(self):
236235
self.assertEqual(parsed.bDescriptorType, 3)
237236
self.assertEqual(parsed.bString, "Great Scott Gadgets")
238237

239-
240238
def test_string_descriptor_build(self):
241239
data = StringDescriptor.build({
242240
'bString': "Great Scott Gadgets"
243241
})
244242

245243
self.assertEqual(data, self.STRING_DESCRIPTOR)
246244

247-
248245
def test_string_language_descriptor_build(self):
249246
data = StringLanguageDescriptor.build({
250247
'wLANGID': (LanguageIDs.ENGLISH_US,)
251248
})
252249

253250
self.assertEqual(data, b"\x04\x03\x09\x04")
254251

255-
256252
def test_device_descriptor(self):
257253

258254
device_descriptor = [
@@ -291,7 +287,6 @@ def test_device_descriptor(self):
291287
self.assertEqual(parsed.iSerialNumber, 3)
292288
self.assertEqual(parsed.bNumConfigurations, 1)
293289

294-
295290
def test_bcd_constructor(self):
296291

297292
emitter = BCDFieldAdapter(construct.Int16ul)
@@ -300,5 +295,71 @@ def test_bcd_constructor(self):
300295
self.assertEqual(result, b"\x40\x01")
301296

302297

298+
def test_parse_endpoint_descriptor(self):
299+
# Parse the relevant descriptor ...
300+
parsed = EndpointDescriptor.parse([
301+
0x07, # Length
302+
0x05, # Type
303+
0x81, # Endpoint address
304+
0x02, # Attributes
305+
0x40, 0x00, # Maximum packet size
306+
0xFF, # Interval
307+
])
308+
309+
# ... and check the descriptor's fields.
310+
self.assertEqual(parsed.bLength, 7)
311+
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
312+
self.assertEqual(parsed.bEndpointAddress, 0x81)
313+
self.assertEqual(parsed.bmAttributes, 2)
314+
self.assertEqual(parsed.wMaxPacketSize, 64)
315+
self.assertEqual(parsed.bInterval, 255)
316+
317+
def test_parse_endpoint_descriptor_audio(self):
318+
# Parse the relevant descriptor ...
319+
parsed = EndpointDescriptor.parse([
320+
0x09, # Length
321+
0x05, # Type
322+
0x81, # Endpoint address
323+
0x02, # Attributes
324+
0x40, 0x00, # Maximum packet size
325+
0xFF, # Interval
326+
0x20, # Refresh rate
327+
0x05, # Synch endpoint address
328+
])
329+
330+
# ... and check the descriptor's fields.
331+
self.assertEqual(parsed.bLength, 9)
332+
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
333+
self.assertEqual(parsed.bEndpointAddress, 0x81)
334+
self.assertEqual(parsed.bmAttributes, 2)
335+
self.assertEqual(parsed.wMaxPacketSize, 64)
336+
self.assertEqual(parsed.bInterval, 255)
337+
self.assertEqual(parsed.bRefresh, 32)
338+
self.assertEqual(parsed.bSynchAddress, 0x05)
339+
340+
def test_build_endpoint_descriptor_audio(self):
341+
# Build the relevant descriptor
342+
data = EndpointDescriptor.build({
343+
'bEndpointAddress': 0x81,
344+
'bmAttributes': 2,
345+
'wMaxPacketSize': 64,
346+
'bInterval': 255,
347+
'bRefresh': 32,
348+
'bSynchAddress': 0x05,
349+
})
350+
351+
# ... and check the binary output
352+
self.assertEqual(data, bytes([
353+
0x09, # Length
354+
0x05, # Type
355+
0x81, # Endpoint address
356+
0x02, # Attributes
357+
0x40, 0x00, # Maximum packet size
358+
0xFF, # Interval
359+
0x20, # Refresh rate
360+
0x05, # Synch endpoint address
361+
]))
362+
363+
303364
if __name__ == "__main__":
304365
unittest.main()

0 commit comments

Comments
 (0)