Skip to content

Commit 9c9c12a

Browse files
committed
Fix EndpointDescriptor
Fix creation of EndpointDescriptor and add unit tests for both possible lengths.
1 parent a983073 commit 9c9c12a

File tree

1 file changed

+90
-10
lines changed

1 file changed

+90
-10
lines changed

usb_protocol/types/descriptors/standard.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,23 @@ class DeviceCapabilityTypes(IntEnum):
128128
)
129129

130130

131-
EndpointDescriptor = DescriptorFormat(
131+
EndpointDescriptorLength = construct.Rebuild(construct.Int8ul, 7 if (this.bRefresh is None) and (this.bSynchAddress is None) else 9)
132132

133+
EndpointDescriptor = DescriptorFormat(
133134
# [USB2.0: 9.6; USB Audio Device Class Definition 1.0: 4.6.1.1, 4.6.2.1]
134135
# Interfaces of the Audio 1.0 class extend their subordinate endpoint descriptors with
135136
# 2 additional bytes (extending it from 7 to 9 bytes). Thankfully, this is the only extension that
136137
# changes the length of a standard descriptor type, but we do have to handle this case in Construct.
137-
"bLength" / construct.Default(construct.OneOf(construct.Int8ul, [7, 9]), 7),
138+
"bLength" / EndpointDescriptorLength,
138139
"bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.ENDPOINT),
139140
"bEndpointAddress" / DescriptorField("Endpoint Address"),
140141
"bmAttributes" / DescriptorField("Attributes", default=2),
141142
"wMaxPacketSize" / DescriptorField("Maximum Packet Size", default=64),
142143
"bInterval" / DescriptorField("Polling interval", default=255),
143144

144145
# 2 bytes that are only present on endpoint descriptors for Audio 1.0 class interfaces.
145-
("bRefresh" / construct.Optional(construct.Int8ul)) * "Refresh Rate",
146-
("bSynchAddress" / construct.Optional(construct.Int8ul)) * "Synch Endpoint Address",
146+
("bRefresh" / construct.If(this.bLength == 9, construct.Optional(construct.Int8ul))) * "Refresh Rate",
147+
("bSynchAddress" / construct.If(this.bLength == 9, construct.Optional(construct.Int8ul))) * "Synch Endpoint Address",
147148
)
148149

149150

@@ -198,7 +199,6 @@ class DeviceCapabilityTypes(IntEnum):
198199
)
199200

200201

201-
202202
class DescriptorParserCases(unittest.TestCase):
203203

204204
STRING_DESCRIPTOR = bytes([
@@ -225,7 +225,6 @@ class DescriptorParserCases(unittest.TestCase):
225225
ord('s'), 0x00,
226226
])
227227

228-
229228
def test_string_descriptor_parse(self):
230229

231230
# Parse the relevant string...
@@ -236,23 +235,20 @@ def test_string_descriptor_parse(self):
236235
self.assertEqual(parsed.bDescriptorType, 3)
237236
self.assertEqual(parsed.bString, "Great Scott Gadgets")
238237

239-
240238
def test_string_descriptor_build(self):
241239
data = StringDescriptor.build({
242240
'bString': "Great Scott Gadgets"
243241
})
244242

245243
self.assertEqual(data, self.STRING_DESCRIPTOR)
246244

247-
248245
def test_string_language_descriptor_build(self):
249246
data = StringLanguageDescriptor.build({
250247
'wLANGID': (LanguageIDs.ENGLISH_US,)
251248
})
252249

253250
self.assertEqual(data, b"\x04\x03\x09\x04")
254251

255-
256252
def test_device_descriptor(self):
257253

258254
device_descriptor = [
@@ -291,7 +287,6 @@ def test_device_descriptor(self):
291287
self.assertEqual(parsed.iSerialNumber, 3)
292288
self.assertEqual(parsed.bNumConfigurations, 1)
293289

294-
295290
def test_bcd_constructor(self):
296291

297292
emitter = BCDFieldAdapter(construct.Int16ul)
@@ -300,5 +295,90 @@ def test_bcd_constructor(self):
300295
self.assertEqual(result, b"\x40\x01")
301296

302297

298+
def test_parse_endpoint_descriptor(self):
299+
# Parse the relevant descriptor ...
300+
parsed = EndpointDescriptor.parse([
301+
0x07, # Length
302+
0x05, # Type
303+
0x81, # Endpoint address
304+
0x02, # Attributes
305+
0x40, 0x00, # Maximum packet size
306+
0xFF, # Interval
307+
])
308+
309+
# ... and check the descriptor's fields.
310+
self.assertEqual(parsed.bLength, 7)
311+
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
312+
self.assertEqual(parsed.bEndpointAddress, 0x81)
313+
self.assertEqual(parsed.bmAttributes, 2)
314+
self.assertEqual(parsed.wMaxPacketSize, 64)
315+
self.assertEqual(parsed.bInterval, 255)
316+
317+
def test_build_endpoint_descriptor(self):
318+
# Build the relevant descriptor
319+
data = EndpointDescriptor.build({
320+
'bEndpointAddress': 0x81,
321+
'bmAttributes': 2,
322+
'wMaxPacketSize': 64,
323+
'bInterval': 255,
324+
})
325+
326+
# ... and check the binary output
327+
self.assertEqual(data, bytes([
328+
0x09, # Length
329+
0x05, # Type
330+
0x81, # Endpoint address
331+
0x02, # Attributes
332+
0x40, 0x00, # Maximum packet size
333+
0xFF, # Interval
334+
]))
335+
336+
def test_parse_endpoint_descriptor_audio(self):
337+
# Parse the relevant descriptor ...
338+
parsed = EndpointDescriptor.parse([
339+
0x09, # Length
340+
0x05, # Type
341+
0x81, # Endpoint address
342+
0x02, # Attributes
343+
0x40, 0x00, # Maximum packet size
344+
0xFF, # Interval
345+
0x20, # Refresh rate
346+
0x05, # Synch endpoint address
347+
])
348+
349+
# ... and check the descriptor's fields.
350+
self.assertEqual(parsed.bLength, 9)
351+
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
352+
self.assertEqual(parsed.bEndpointAddress, 0x81)
353+
self.assertEqual(parsed.bmAttributes, 2)
354+
self.assertEqual(parsed.wMaxPacketSize, 64)
355+
self.assertEqual(parsed.bInterval, 255)
356+
self.assertEqual(parsed.bRefresh, 32)
357+
self.assertEqual(parsed.bSynchAddress, 0x05)
358+
359+
def test_build_endpoint_descriptor_audio(self):
360+
# Build the relevant descriptor
361+
data = EndpointDescriptor.build({
362+
'bEndpointAddress': 0x81,
363+
'bmAttributes': 2,
364+
'wMaxPacketSize': 64,
365+
'bInterval': 255,
366+
'bRefresh': 32,
367+
'bSynchAddress': 0x05,
368+
})
369+
370+
# ... and check the binary output
371+
self.assertEqual(data, bytes([
372+
0x09, # Length
373+
0x05, # Type
374+
0x81, # Endpoint address
375+
0x02, # Attributes
376+
0x40, 0x00, # Maximum packet size
377+
0xFF, # Interval
378+
0x20, # Refresh rate
379+
0x05, # Synch endpoint address
380+
]))
381+
382+
303383
if __name__ == "__main__":
304384
unittest.main()

0 commit comments

Comments
 (0)