@@ -128,22 +128,23 @@ class DeviceCapabilityTypes(IntEnum):
128128)
129129
130130
131- EndpointDescriptor = DescriptorFormat (
131+ EndpointDescriptorLength = construct . Rebuild ( construct . Int8ul , 7 if ( this . bRefresh is None ) and ( this . bSynchAddress is None ) else 9 )
132132
133+ EndpointDescriptor = DescriptorFormat (
133134 # [USB2.0: 9.6; USB Audio Device Class Definition 1.0: 4.6.1.1, 4.6.2.1]
134135 # Interfaces of the Audio 1.0 class extend their subordinate endpoint descriptors with
135136 # 2 additional bytes (extending it from 7 to 9 bytes). Thankfully, this is the only extension that
136137 # changes the length of a standard descriptor type, but we do have to handle this case in Construct.
137- "bLength" / construct . Default ( construct . OneOf ( construct . Int8ul , [ 7 , 9 ]), 7 ) ,
138+ "bLength" / EndpointDescriptorLength ,
138139 "bDescriptorType" / DescriptorNumber (StandardDescriptorNumbers .ENDPOINT ),
139140 "bEndpointAddress" / DescriptorField ("Endpoint Address" ),
140141 "bmAttributes" / DescriptorField ("Attributes" , default = 2 ),
141142 "wMaxPacketSize" / DescriptorField ("Maximum Packet Size" , default = 64 ),
142143 "bInterval" / DescriptorField ("Polling interval" , default = 255 ),
143144
144145 # 2 bytes that are only present on endpoint descriptors for Audio 1.0 class interfaces.
145- ("bRefresh" / construct .Optional (construct .Int8ul )) * "Refresh Rate" ,
146- ("bSynchAddress" / construct .Optional (construct .Int8ul )) * "Synch Endpoint Address" ,
146+ ("bRefresh" / construct .If ( this . bLength == 9 , construct . Optional (construct .Int8ul ) )) * "Refresh Rate" ,
147+ ("bSynchAddress" / construct .If ( this . bLength == 9 , construct . Optional (construct .Int8ul ) )) * "Synch Endpoint Address" ,
147148)
148149
149150
@@ -198,7 +199,6 @@ class DeviceCapabilityTypes(IntEnum):
198199)
199200
200201
201-
202202class DescriptorParserCases (unittest .TestCase ):
203203
204204 STRING_DESCRIPTOR = bytes ([
@@ -225,7 +225,6 @@ class DescriptorParserCases(unittest.TestCase):
225225 ord ('s' ), 0x00 ,
226226 ])
227227
228-
229228 def test_string_descriptor_parse (self ):
230229
231230 # Parse the relevant string...
@@ -236,23 +235,20 @@ def test_string_descriptor_parse(self):
236235 self .assertEqual (parsed .bDescriptorType , 3 )
237236 self .assertEqual (parsed .bString , "Great Scott Gadgets" )
238237
239-
240238 def test_string_descriptor_build (self ):
241239 data = StringDescriptor .build ({
242240 'bString' : "Great Scott Gadgets"
243241 })
244242
245243 self .assertEqual (data , self .STRING_DESCRIPTOR )
246244
247-
248245 def test_string_language_descriptor_build (self ):
249246 data = StringLanguageDescriptor .build ({
250247 'wLANGID' : (LanguageIDs .ENGLISH_US ,)
251248 })
252249
253250 self .assertEqual (data , b"\x04 \x03 \x09 \x04 " )
254251
255-
256252 def test_device_descriptor (self ):
257253
258254 device_descriptor = [
@@ -291,7 +287,6 @@ def test_device_descriptor(self):
291287 self .assertEqual (parsed .iSerialNumber , 3 )
292288 self .assertEqual (parsed .bNumConfigurations , 1 )
293289
294-
295290 def test_bcd_constructor (self ):
296291
297292 emitter = BCDFieldAdapter (construct .Int16ul )
@@ -300,5 +295,90 @@ def test_bcd_constructor(self):
300295 self .assertEqual (result , b"\x40 \x01 " )
301296
302297
298+ def test_parse_endpoint_descriptor (self ):
299+ # Parse the relevant descriptor ...
300+ parsed = EndpointDescriptor .parse ([
301+ 0x07 , # Length
302+ 0x05 , # Type
303+ 0x81 , # Endpoint address
304+ 0x02 , # Attributes
305+ 0x40 , 0x00 , # Maximum packet size
306+ 0xFF , # Interval
307+ ])
308+
309+ # ... and check the descriptor's fields.
310+ self .assertEqual (parsed .bLength , 7 )
311+ self .assertEqual (parsed .bDescriptorType , StandardDescriptorNumbers .ENDPOINT )
312+ self .assertEqual (parsed .bEndpointAddress , 0x81 )
313+ self .assertEqual (parsed .bmAttributes , 2 )
314+ self .assertEqual (parsed .wMaxPacketSize , 64 )
315+ self .assertEqual (parsed .bInterval , 255 )
316+
317+ def test_build_endpoint_descriptor (self ):
318+ # Build the relevant descriptor
319+ data = EndpointDescriptor .build ({
320+ 'bEndpointAddress' : 0x81 ,
321+ 'bmAttributes' : 2 ,
322+ 'wMaxPacketSize' : 64 ,
323+ 'bInterval' : 255 ,
324+ })
325+
326+ # ... and check the binary output
327+ self .assertEqual (data , bytes ([
328+ 0x09 , # Length
329+ 0x05 , # Type
330+ 0x81 , # Endpoint address
331+ 0x02 , # Attributes
332+ 0x40 , 0x00 , # Maximum packet size
333+ 0xFF , # Interval
334+ ]))
335+
336+ def test_parse_endpoint_descriptor_audio (self ):
337+ # Parse the relevant descriptor ...
338+ parsed = EndpointDescriptor .parse ([
339+ 0x09 , # Length
340+ 0x05 , # Type
341+ 0x81 , # Endpoint address
342+ 0x02 , # Attributes
343+ 0x40 , 0x00 , # Maximum packet size
344+ 0xFF , # Interval
345+ 0x20 , # Refresh rate
346+ 0x05 , # Synch endpoint address
347+ ])
348+
349+ # ... and check the descriptor's fields.
350+ self .assertEqual (parsed .bLength , 9 )
351+ self .assertEqual (parsed .bDescriptorType , StandardDescriptorNumbers .ENDPOINT )
352+ self .assertEqual (parsed .bEndpointAddress , 0x81 )
353+ self .assertEqual (parsed .bmAttributes , 2 )
354+ self .assertEqual (parsed .wMaxPacketSize , 64 )
355+ self .assertEqual (parsed .bInterval , 255 )
356+ self .assertEqual (parsed .bRefresh , 32 )
357+ self .assertEqual (parsed .bSynchAddress , 0x05 )
358+
359+ def test_build_endpoint_descriptor_audio (self ):
360+ # Build the relevant descriptor
361+ data = EndpointDescriptor .build ({
362+ 'bEndpointAddress' : 0x81 ,
363+ 'bmAttributes' : 2 ,
364+ 'wMaxPacketSize' : 64 ,
365+ 'bInterval' : 255 ,
366+ 'bRefresh' : 32 ,
367+ 'bSynchAddress' : 0x05 ,
368+ })
369+
370+ # ... and check the binary output
371+ self .assertEqual (data , bytes ([
372+ 0x09 , # Length
373+ 0x05 , # Type
374+ 0x81 , # Endpoint address
375+ 0x02 , # Attributes
376+ 0x40 , 0x00 , # Maximum packet size
377+ 0xFF , # Interval
378+ 0x20 , # Refresh rate
379+ 0x05 , # Synch endpoint address
380+ ]))
381+
382+
303383if __name__ == "__main__" :
304384 unittest .main ()
0 commit comments