Skip to content

Commit 02d8d26

Browse files
committed
reassemble control messages fragments
1 parent 1d88206 commit 02d8d26

10 files changed

+257
-140
lines changed

app/Main.hs

+93-17
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,106 @@
1+
{-# LANGUAGE LambdaCase #-}
2+
13
module Main (main) where
24

3-
import Control.Monad.Trans.Class (lift)
4-
import Data.Binary (get)
5-
import Data.Binary.Get (Decoder (..), runGetIncremental)
6-
import qualified Data.ByteString as B (null)
5+
import Common (FragmentationIndicator (..), consumeAll)
6+
import Control.Monad (when)
7+
import Data.Binary (Binary, Get, Word16, Word32, Word64, get)
8+
import Data.Binary.Get (runGetOrFail)
79
import qualified Data.ByteString.Lazy as L (readFile)
8-
import qualified Data.ByteString.Lazy.Internal as L (ByteString (Chunk, Empty), chunk)
10+
import Data.ByteString.Lazy.Internal (ByteString (Empty))
911
import Data.Int (Int64)
10-
import Lib (TLVPacket)
11-
import ListT (ListT, cons)
12-
import qualified ListT (head)
12+
import qualified Data.Map as Map
13+
import Lib
14+
import Message (ControlMessage, ControlMessages (..))
15+
import qualified Streamly.Data.Fold as Fold
16+
import Streamly.Data.Stream (Stream, mapMaybe, postscan, unfoldrM)
17+
import qualified Streamly.Data.Stream as S (foldr, foldrM, mapM, take, toList)
18+
import Streamly.Internal.Data.Pipe.Type (Pipe (Pipe), PipeState (..), Step (..))
19+
import Streamly.Internal.Data.Stream.StreamD.Transform (transform)
1320
import Text.Show.Pretty (pPrint)
1421

15-
parseTLVPackets :: MonadFail m => Decoder TLVPacket -> (Int64, L.ByteString) -> ListT m TLVPacket
16-
parseTLVPackets (Fail _ loc e) (br, _) = lift $ fail $ "error parse at " ++ show (br + loc) ++ ": " ++ e
17-
parseTLVPackets (Done r _ p) (_, L.Empty) | B.null r = return p
18-
parseTLVPackets (Done r o p) (br, input) = cons p $ parseTLVPackets (runGetIncremental get) (o + br, L.chunk r input)
19-
parseTLVPackets (Partial k) (br, L.Empty) = parseTLVPackets (k Nothing) (br, L.Empty)
20-
parseTLVPackets (Partial k) (br, L.Chunk bs input) = parseTLVPackets (k (Just bs)) (br, input)
22+
data ParseState = ParseState ByteString Int64
23+
24+
type CMState = Map.Map Word16 (Word32, ByteString)
25+
26+
type CMPipeState = PipeState CMState (CMState, [ByteString])
27+
28+
reassembleControlMessages :: Pipe IO (Word16, Word32, ControlMessages) ControlMessage
29+
reassembleControlMessages = Pipe consume produce Map.empty
30+
where
31+
produce :: (CMState, [ByteString]) -> IO (Step CMPipeState ControlMessage)
32+
produce (m, []) = return $ Continue $ Consume m
33+
produce (m, x : xs) = do
34+
p <- consumeAll get x
35+
return $ Yield p $ Produce (m, xs)
36+
37+
consume :: CMState -> (Word16, Word32, ControlMessages) -> IO (Step CMPipeState ControlMessage)
38+
consume m (pid, pseq, ControlMessages ind _ _ msgs) =
39+
case ind of
40+
FragmentationIndicatorUndivided -> case msgs of
41+
Left bs -> return $ Continue $ Produce (m, [bs])
42+
Right bs -> return $ Continue $ Produce (m, bs)
43+
FragmentationIndicatorDividedHead -> do
44+
when (Map.member pid m) $ fail $ "Packet ID " ++ show pid ++ " already started fragmented message"
45+
t <- extractOne msgs
46+
return $ Continue $ Consume $ Map.insert pid (pseq, t) m
47+
FragmentationIndicatorDividedBody -> do
48+
case Map.lookup pid m of
49+
Nothing -> fail $ "Packet ID " ++ show pid ++ " not started fragmented message"
50+
Just (s, bs) -> do
51+
when (s >= pseq) $ fail "Sequence number reversed"
52+
t <- extractOne msgs
53+
return $ Continue $ Consume $ Map.insert pid (pseq, bs <> t) m
54+
FragmentationIndicatorDividedEnd -> do
55+
case Map.lookup pid m of
56+
Nothing -> fail $ "Packet ID " ++ show pid ++ " not started fragmented message"
57+
Just (s, bs) -> do
58+
when (s >= pseq) $ fail "Sequence number reversed"
59+
t <- extractOne msgs
60+
return $ Continue $ Produce (Map.delete pid m, [bs <> t])
61+
where
62+
extractOne :: MonadFail m => Either ByteString [ByteString] -> m ByteString
63+
extractOne (Left bs) = return bs
64+
extractOne (Right _) = fail "more than one message"
65+
66+
parseTLVPackets :: ParseState -> IO (Maybe (TLVPacket, ParseState))
67+
parseTLVPackets (ParseState Empty _) = return Nothing
68+
parseTLVPackets (ParseState bs o) = do
69+
case runGetOrFail get bs of
70+
Left (_, off, err) -> fail $ "error parse at " ++ show (o + off) ++ ": " ++ err
71+
Right (r, off, p) -> return $ Just (p, ParseState r (o + off))
2172

2273
main :: IO ()
2374
main = do
2475
file <- L.readFile "F:\\29999.mmts"
25-
let packets = parseTLVPackets (runGetIncremental get) (0, file)
26-
len <- ListT.head packets
27-
pPrint len
76+
let packets = unfoldrM parseTLVPackets (ParseState file 0)
77+
headers =
78+
mapMaybe
79+
(\case TLVPacketHeaderCompressedIP c -> Just $ contextIdentificationHeader c; _ -> Nothing)
80+
packets
81+
payloads =
82+
mapMaybe
83+
(\case ContextIdentificationNoCompressedHeader h -> Just h; _ -> Nothing)
84+
headers
85+
fragments =
86+
mapMaybe
87+
( \case
88+
MMTPPacket
89+
{ packetId = pid,
90+
packetSequenceNumber = pseq,
91+
mmtpPayload = MMTPPayloadControlMessages m
92+
} -> Just (pid, pseq, m)
93+
_ -> Nothing
94+
)
95+
payloads
96+
messages = transform reassembleControlMessages fragments
97+
-- (r, o) <- S.foldrM reassembleControlMessages (pure (mempty, [])) fragments
98+
99+
_ <- S.toList $ S.mapM pPrint messages
100+
return ()
101+
102+
-- len <- S.foldr (\_ c -> c + 1) (0 :: Word64) messages
103+
-- print len
28104

29105
-- list <- toList $ parseTLVPackets (runGetIncremental get) file
30106
-- print list

package.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ ghc-options:
3737

3838
library:
3939
source-dirs: src
40+
default-extensions:
41+
- BinaryLiterals
42+
- DuplicateRecordFields
43+
- NumericUnderscores
4044
dependencies:
4145
- extra
4246

@@ -50,9 +54,10 @@ executables:
5054
- -with-rtsopts=-N
5155
dependencies:
5256
- tlvmmt
53-
- list-t
57+
- streamly-core
5458
- transformers
5559
- pretty-show
60+
- containers
5661

5762
tests:
5863
tlvmmt-test:

src/Common.hs

+6-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Common where
22

3+
import Control.Monad.Extra (ifM)
34
import Data.Binary (Binary (..), Get, Word8)
45
import Data.Binary.Get
56
( Decoder (..),
@@ -9,6 +10,7 @@ import Data.Binary.Get
910
runGetIncremental,
1011
)
1112
import qualified Data.ByteString as B
13+
import Data.ByteString.Internal (w2c)
1214
import Data.ByteString.Lazy (ByteString)
1315

1416
data FragmentationIndicator
@@ -21,33 +23,22 @@ data FragmentationIndicator
2123
data ISO639LanguageCode = ISO639LanguageCode Word8 Word8 Word8
2224

2325
instance Show ISO639LanguageCode where
24-
show (ISO639LanguageCode a b c) =
25-
[ toEnum $ fromIntegral a,
26-
toEnum $ fromIntegral b,
27-
toEnum $ fromIntegral c
28-
]
26+
show (ISO639LanguageCode a b c) = [w2c a, w2c b, w2c c]
2927

3028
instance Binary ISO639LanguageCode where
3129
get = ISO639LanguageCode <$> get <*> get <*> get
3230
put (ISO639LanguageCode a b c) = put a >> put b >> put c
3331

34-
consumeAll :: Get a -> ByteString -> Get a
32+
consumeAll :: MonadFail m => Get a -> ByteString -> m a
3533
consumeAll g bs = do
3634
case pushEndOfInput $ runGetIncremental g `pushChunks` bs of
3735
Fail _ loc err -> fail $ "error at " ++ show loc ++ ": " ++ err
3836
Partial _ -> fail "not enough bytes"
3937
-- Done _ _ a -> return a
40-
Done r _ a -> if B.null r then return a else fail "unconsumed input"
38+
Done r _ a -> if B.null r then pure a else fail "unconsumed input"
4139

4240
repeatRead :: Get a -> Get [a]
43-
repeatRead g = do
44-
end <- isEmpty
45-
if end
46-
then return []
47-
else do
48-
a <- g
49-
as <- repeatRead g
50-
return $ a : as
41+
repeatRead g = ifM isEmpty (pure []) $ (:) <$> g <*> repeatRead g
5142

5243
readN :: Binary a => Word8 -> Get [a]
5344
readN 0 = return []

src/Descriptor.hs

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
{-# LANGUAGE BinaryLiterals #-}
2-
{-# LANGUAGE DuplicateRecordFields #-}
31
{-# OPTIONS_GHC -Wno-partial-fields #-}
42

53
module Descriptor where
@@ -34,7 +32,7 @@ data Descriptor
3432
instance Binary Descriptor where
3533
get = do
3634
tag <- lookAhead getWord16be
37-
case traceShow ("tag", tag) tag of
35+
case tag of
3836
0x0001 -> MPUTimestamp <$> get
3937
0x8010 -> VideoComponent <$> get
4038
0x8011 -> MHStreamID <$> get
@@ -93,12 +91,6 @@ instance Binary MHStreamIDDescriptor where
9391
MHStreamIDDescriptor <$> getWord16be
9492
put _ = undefined
9593

96-
verifyDescriptorHeader :: Word16 -> Get a -> Get a
97-
verifyDescriptorHeader tag parse = do
98-
t <- getWord16be
99-
when (t /= tag) $ fail "Invalid descriptor tag"
100-
getWord8 >>= getLazyByteString . fromIntegral >>= consumeAll parse
101-
10294
data MHAudioComponentDescriptor = MHAudioComponentDescriptor
10395
{ streamContent :: Word8,
10496
componentType :: Word8,
@@ -288,3 +280,9 @@ instance Binary MHDataComponentDescriptor where
288280
_ -> fail "Unknown MH-data component descriptor tag"
289281

290282
put = undefined
283+
284+
verifyDescriptorHeader :: Word16 -> Get a -> Get a
285+
verifyDescriptorHeader tag parse = do
286+
t <- getWord16be
287+
when (t /= tag) $ fail "Unexpected descriptor tag"
288+
getWord8 >>= getLazyByteString . fromIntegral >>= consumeAll parse

src/Lib.hs

+41-61
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
{-# LANGUAGE BinaryLiterals #-}
2-
{-# LANGUAGE DuplicateRecordFields #-}
3-
{-# LANGUAGE NumericUnderscores #-}
4-
5-
module Lib (TLVPacket) where
1+
module Lib where
62

73
import Common (FragmentationIndicator, consumeAll, repeatRead)
4+
import Control.Monad.Extra (whenMaybe)
85
import Data.Binary (Binary (..), Word16, Word32, Word8)
96
import Data.Binary.Get
107
( Get,
@@ -26,10 +23,10 @@ data MPUFragment
2623
deriving (Show)
2724

2825
data MFU
29-
= NonTimed MFUNonTimed
30-
| NonTimedAggregated [MFUNonTimed]
31-
| Timed MFUTimed
32-
| TimedAggregated [MFUTimed]
26+
= MFUNonTimedType MFUNonTimed
27+
| MFUNonTimedAggregatedType [MFUNonTimed]
28+
| MFUTimedType MFUTimed
29+
| MFUTimedAggregatedType [MFUTimed]
3330
deriving (Show)
3431

3532
data MFUTimed = MFUTimed
@@ -74,29 +71,21 @@ instance Binary MPU where
7471
get = do
7572
len <- getWord16be
7673
byte <- getWord8
77-
dnc <- getWord8
78-
mpuSeqNum <- getWord32be
79-
payload <- getLazyByteString $ fromIntegral $ len - 6
80-
81-
fragment <- case shiftR byte 4 of
82-
0x00 -> return MPUFragmentMPUMetadata
83-
0x01 -> return MPUFragmentMovieFragmentMetadata
84-
0x02 -> MPUFragmentMFU <$> consumeAll (parseMFU (testBit byte 3) (testBit byte 0)) payload
85-
_ -> fail "Invalid MPU fragmentation indicator"
8674

87-
return $
88-
MPU
89-
{ fragmentationIndicator = toEnum $ fromIntegral $ shiftR byte 1 .&. 0b11,
90-
divisionNumberCounter = dnc,
91-
mpuSequenceNumber = mpuSeqNum,
92-
mpuFragment = fragment
93-
}
75+
MPU (toEnum $ fromIntegral $ shiftR byte 1 .&. 0b11)
76+
<$> getWord8 <*> getWord32be <*> do
77+
payload <- getLazyByteString $ fromIntegral $ len - 6
78+
case shiftR byte 4 of
79+
0x00 -> return MPUFragmentMPUMetadata
80+
0x01 -> return MPUFragmentMovieFragmentMetadata
81+
0x02 -> MPUFragmentMFU <$> consumeAll (parseMFU (testBit byte 3) (testBit byte 0)) payload
82+
_ -> fail "Invalid MPU fragmentation indicator"
9483
where
9584
parseMFU :: Bool -> Bool -> Get MFU -- timed flag, aggregated flag
96-
parseMFU False False = NonTimed <$> get
97-
parseMFU False True = NonTimedAggregated <$> parseAggregatedNonTimedMFU
98-
parseMFU True False = Timed <$> get
99-
parseMFU True True = TimedAggregated <$> parseAggregatedTimedMFU
85+
parseMFU False False = MFUNonTimedType <$> get
86+
parseMFU False True = MFUNonTimedAggregatedType <$> parseAggregatedNonTimedMFU
87+
parseMFU True False = MFUTimedType <$> get
88+
parseMFU True True = MFUTimedAggregatedType <$> parseAggregatedTimedMFU
10089

10190
parseAggregatedNonTimedMFU :: Get [MFUNonTimed]
10291
parseAggregatedNonTimedMFU =
@@ -180,33 +169,24 @@ instance Binary MMTPPacket where
180169
get = do
181170
firstByte <- getWord8
182171
secondByte <- getWord8
183-
pid <- getWord16be
184-
dts <- getWord32be
185-
psn <- getWord32be
186-
187-
pktct <- if testBit firstByte 5 then Just <$> getWord32be else return Nothing
188-
extHdr <- if testBit firstByte 1 then Just <$> get else return Nothing
189-
190-
payload <- case secondByte .&. 0b11_1111 of
191-
0x00 -> MMTPPayloadMPU <$> get
192-
0x01 -> MMTPPayloadGenericObject <$> getRemainingLazyByteString
193-
0x02 -> MMTPPayloadControlMessages <$> get
194-
0x03 -> MMTPPayloadRepairSymbol <$> getRemainingLazyByteString
195-
typ | typ >= 0x04 && typ <= 0x1F -> MMTPPayloadReserved typ <$> getRemainingLazyByteString
196-
typ | typ >= 0x20 && typ <= 0x3F -> MMTPPayloadPrivate typ <$> getRemainingLazyByteString
197-
_ -> undefined -- unreachable
198-
return $
199-
MMTPPacket
200-
{ version = shiftR firstByte 6,
201-
fecType = shiftR firstByte 3 .&. 0b11,
202-
rapFlag = testBit firstByte 0,
203-
packetId = pid,
204-
deliveryTimestamp = dts,
205-
packetSequenceNumber = psn,
206-
packetCounter = pktct,
207-
extensionHeader = extHdr,
208-
mmtpPayload = payload
209-
}
172+
MMTPPacket
173+
(shiftR firstByte 6)
174+
(shiftR firstByte 3 .&. 0b11)
175+
(testBit firstByte 0)
176+
<$> getWord16be
177+
<*> getWord32be
178+
<*> getWord32be
179+
<*> whenMaybe (testBit firstByte 5) getWord32be
180+
<*> whenMaybe (testBit firstByte 1) get
181+
<*> ( case secondByte .&. 0b11_1111 of
182+
0x00 -> MMTPPayloadMPU <$> get
183+
0x01 -> MMTPPayloadGenericObject <$> getRemainingLazyByteString
184+
0x02 -> MMTPPayloadControlMessages <$> get
185+
0x03 -> MMTPPayloadRepairSymbol <$> getRemainingLazyByteString
186+
typ | typ >= 0x04 && typ <= 0x1F -> MMTPPayloadReserved typ <$> getRemainingLazyByteString
187+
typ | typ >= 0x20 && typ <= 0x3F -> MMTPPayloadPrivate typ <$> getRemainingLazyByteString
188+
_ -> undefined -- unreachable
189+
)
210190

211191
put = undefined
212192

@@ -227,13 +207,13 @@ data CompressedIPPacket = CompressedIPPacket
227207
instance Binary CompressedIPPacket where
228208
get = do
229209
firstTwo <- getWord16be
230-
header <- getWord8 >>= extractHeader
231-
return $ CompressedIPPacket (shiftR firstTwo 4) (fromIntegral (firstTwo .&. 0b1111)) header
210+
CompressedIPPacket (shiftR firstTwo 4) (fromIntegral (firstTwo .&. 0b1111))
211+
<$> (getWord8 >>= extractHeader)
232212
where
233213
extractHeader :: Word8 -> Get ContextIdentificationHeader
234-
extractHeader 0x20 = return ContextIdentificationHeaderPartialIPv4UDP
235-
extractHeader 0x21 = return ContextIdentificationHeaderIPv4Identifier
236-
extractHeader 0x60 = return ContextIdentificationHeaderPartialIPv6UDP
214+
extractHeader 0x20 = ContextIdentificationHeaderPartialIPv4UDP <$ getRemainingLazyByteString
215+
extractHeader 0x21 = ContextIdentificationHeaderIPv4Identifier <$ getRemainingLazyByteString
216+
extractHeader 0x60 = ContextIdentificationHeaderPartialIPv6UDP <$ getRemainingLazyByteString
237217
extractHeader 0x61 = ContextIdentificationNoCompressedHeader <$> get
238218
extractHeader _ = fail "unknown header"
239219

0 commit comments

Comments
 (0)