@@ -128,22 +128,23 @@ class DeviceCapabilityTypes(IntEnum):
128128)
129129
130130
131- EndpointDescriptor = DescriptorFormat (
131+ EndpointDescriptorLength = construct . Rebuild ( construct . Int8ul , 7 if ( this . bRefresh is None ) and ( this . bSynchAddress is None ) else 9 )
132132
133+ EndpointDescriptor = DescriptorFormat (
133134 # [USB2.0: 9.6; USB Audio Device Class Definition 1.0: 4.6.1.1, 4.6.2.1]
134135 # Interfaces of the Audio 1.0 class extend their subordinate endpoint descriptors with
135136 # 2 additional bytes (extending it from 7 to 9 bytes). Thankfully, this is the only extension that
136137 # changes the length of a standard descriptor type, but we do have to handle this case in Construct.
137- "bLength" / construct . Default ( construct . OneOf ( construct . Int8ul , [ 7 , 9 ]), 7 ) ,
138+ "bLength" / EndpointDescriptorLength ,
138139 "bDescriptorType" / DescriptorNumber (StandardDescriptorNumbers .ENDPOINT ),
139140 "bEndpointAddress" / DescriptorField ("Endpoint Address" ),
140141 "bmAttributes" / DescriptorField ("Attributes" , default = 2 ),
141142 "wMaxPacketSize" / DescriptorField ("Maximum Packet Size" , default = 64 ),
142143 "bInterval" / DescriptorField ("Polling interval" , default = 255 ),
143144
144145 # 2 bytes that are only present on endpoint descriptors for Audio 1.0 class interfaces.
145- ("bRefresh" / construct .Optional ( construct .Int8ul )) * "Refresh Rate" ,
146- ("bSynchAddress" / construct .Optional ( construct .Int8ul )) * "Synch Endpoint Address" ,
146+ ("bRefresh" / construct .If ( this . bLength == 9 , construct .Int8ul )) * "Refresh Rate" ,
147+ ("bSynchAddress" / construct .If ( this . bLength == 9 , construct .Int8ul )) * "Synch Endpoint Address" ,
147148)
148149
149150
@@ -198,7 +199,6 @@ class DeviceCapabilityTypes(IntEnum):
198199)
199200
200201
201-
202202class DescriptorParserCases (unittest .TestCase ):
203203
204204 STRING_DESCRIPTOR = bytes ([
@@ -225,7 +225,6 @@ class DescriptorParserCases(unittest.TestCase):
225225 ord ('s' ), 0x00 ,
226226 ])
227227
228-
229228 def test_string_descriptor_parse (self ):
230229
231230 # Parse the relevant string...
@@ -236,23 +235,20 @@ def test_string_descriptor_parse(self):
236235 self .assertEqual (parsed .bDescriptorType , 3 )
237236 self .assertEqual (parsed .bString , "Great Scott Gadgets" )
238237
239-
240238 def test_string_descriptor_build (self ):
241239 data = StringDescriptor .build ({
242240 'bString' : "Great Scott Gadgets"
243241 })
244242
245243 self .assertEqual (data , self .STRING_DESCRIPTOR )
246244
247-
248245 def test_string_language_descriptor_build (self ):
249246 data = StringLanguageDescriptor .build ({
250247 'wLANGID' : (LanguageIDs .ENGLISH_US ,)
251248 })
252249
253250 self .assertEqual (data , b"\x04 \x03 \x09 \x04 " )
254251
255-
256252 def test_device_descriptor (self ):
257253
258254 device_descriptor = [
@@ -291,7 +287,6 @@ def test_device_descriptor(self):
291287 self .assertEqual (parsed .iSerialNumber , 3 )
292288 self .assertEqual (parsed .bNumConfigurations , 1 )
293289
294-
295290 def test_bcd_constructor (self ):
296291
297292 emitter = BCDFieldAdapter (construct .Int16ul )
@@ -300,5 +295,71 @@ def test_bcd_constructor(self):
300295 self .assertEqual (result , b"\x40 \x01 " )
301296
302297
298+ def test_parse_endpoint_descriptor (self ):
299+ # Parse the relevant descriptor ...
300+ parsed = EndpointDescriptor .parse ([
301+ 0x07 , # Length
302+ 0x05 , # Type
303+ 0x81 , # Endpoint address
304+ 0x02 , # Attributes
305+ 0x40 , 0x00 , # Maximum packet size
306+ 0xFF , # Interval
307+ ])
308+
309+ # ... and check the descriptor's fields.
310+ self .assertEqual (parsed .bLength , 7 )
311+ self .assertEqual (parsed .bDescriptorType , StandardDescriptorNumbers .ENDPOINT )
312+ self .assertEqual (parsed .bEndpointAddress , 0x81 )
313+ self .assertEqual (parsed .bmAttributes , 2 )
314+ self .assertEqual (parsed .wMaxPacketSize , 64 )
315+ self .assertEqual (parsed .bInterval , 255 )
316+
317+ def test_parse_endpoint_descriptor_audio (self ):
318+ # Parse the relevant descriptor ...
319+ parsed = EndpointDescriptor .parse ([
320+ 0x09 , # Length
321+ 0x05 , # Type
322+ 0x81 , # Endpoint address
323+ 0x02 , # Attributes
324+ 0x40 , 0x00 , # Maximum packet size
325+ 0xFF , # Interval
326+ 0x20 , # Refresh rate
327+ 0x05 , # Synch endpoint address
328+ ])
329+
330+ # ... and check the descriptor's fields.
331+ self .assertEqual (parsed .bLength , 9 )
332+ self .assertEqual (parsed .bDescriptorType , StandardDescriptorNumbers .ENDPOINT )
333+ self .assertEqual (parsed .bEndpointAddress , 0x81 )
334+ self .assertEqual (parsed .bmAttributes , 2 )
335+ self .assertEqual (parsed .wMaxPacketSize , 64 )
336+ self .assertEqual (parsed .bInterval , 255 )
337+ self .assertEqual (parsed .bRefresh , 32 )
338+ self .assertEqual (parsed .bSynchAddress , 0x05 )
339+
340+ def test_build_endpoint_descriptor_audio (self ):
341+ # Build the relevant descriptor
342+ data = EndpointDescriptor .build ({
343+ 'bEndpointAddress' : 0x81 ,
344+ 'bmAttributes' : 2 ,
345+ 'wMaxPacketSize' : 64 ,
346+ 'bInterval' : 255 ,
347+ 'bRefresh' : 32 ,
348+ 'bSynchAddress' : 0x05 ,
349+ })
350+
351+ # ... and check the binary output
352+ self .assertEqual (data , bytes ([
353+ 0x09 , # Length
354+ 0x05 , # Type
355+ 0x81 , # Endpoint address
356+ 0x02 , # Attributes
357+ 0x40 , 0x00 , # Maximum packet size
358+ 0xFF , # Interval
359+ 0x20 , # Refresh rate
360+ 0x05 , # Synch endpoint address
361+ ]))
362+
363+
303364if __name__ == "__main__" :
304365 unittest .main ()
0 commit comments