@@ -340,6 +340,36 @@ impl InboundGroupSession {
340340 Self :: try_from ( exported_session)
341341 }
342342
343+ /// Create a new [`InboundGroupSession`] which is a copy of this one, except
344+ /// that its Megolm ratchet is replaced with a copy of that from another
345+ /// [`InboundGroupSession`].
346+ ///
347+ /// This can be useful, for example, when we receive a new copy of the room
348+ /// key, but at an earlier ratchet index.
349+ ///
350+ /// # Panics
351+ ///
352+ /// If the two sessions are for different room IDs, or have different
353+ /// session IDs, this function will panic. It is up to the caller to ensure
354+ /// that it only attempts to merge related sessions.
355+ pub ( crate ) fn with_ratchet ( mut self , other : & InboundGroupSession ) -> Self {
356+ if self . session_id != other. session_id {
357+ panic ! (
358+ "Attempt to merge Megolm sessions with different session IDs: {} vs {}" ,
359+ self . session_id, other. session_id
360+ ) ;
361+ }
362+ if self . room_id != other. room_id {
363+ panic ! (
364+ "Attempt to merge Megolm sessions with different room IDs: {} vs {}" ,
365+ self . room_id, other. room_id,
366+ ) ;
367+ }
368+ self . inner = other. inner . clone ( ) ;
369+ self . first_known_index = other. first_known_index ;
370+ self
371+ }
372+
343373 /// Convert the [`InboundGroupSession`] into a
344374 /// [`PickledInboundGroupSession`] which can be serialized.
345375 pub async fn pickle ( & self ) -> PickledInboundGroupSession {
@@ -488,7 +518,34 @@ impl InboundGroupSession {
488518
489519 /// Check if the [`InboundGroupSession`] is better than the given other
490520 /// [`InboundGroupSession`]
521+ #[ deprecated(
522+ note = "Sessions cannot be compared on a linear scale. Consider calling `compare_ratchet`, as well as comparing the `sender_data`."
523+ ) ]
491524 pub async fn compare ( & self , other : & InboundGroupSession ) -> SessionOrdering {
525+ match self . compare_ratchet ( other) . await {
526+ SessionOrdering :: Equal => {
527+ match self . sender_data . compare_trust_level ( & other. sender_data ) {
528+ Ordering :: Less => SessionOrdering :: Worse ,
529+ Ordering :: Equal => SessionOrdering :: Equal ,
530+ Ordering :: Greater => SessionOrdering :: Better ,
531+ }
532+ }
533+ result => result,
534+ }
535+ }
536+
537+ /// Check if the [`InboundGroupSession`]'s ratchet index is better than that
538+ /// of the given other [`InboundGroupSession`].
539+ ///
540+ /// If the two sessions are not connected (i.e., they are from different
541+ /// senders, or if advancing the ratchets to the same index does not
542+ /// give the same ratchet value), returns [`SessionOrdering::Unconnected`].
543+ ///
544+ /// Otherwise, returns [`SessionOrdering::Equal`],
545+ /// [`SessionOrdering::Better`], or [`SessionOrdering::Worse`] respectively
546+ /// depending on whether this session's first known index is equal to,
547+ /// lower than, or higher than, that of `other`.
548+ pub async fn compare_ratchet ( & self , other : & InboundGroupSession ) -> SessionOrdering {
492549 // If this is the same object the ordering is the same, we can't compare because
493550 // we would deadlock while trying to acquire the same lock twice.
494551 if Arc :: ptr_eq ( & self . inner , & other. inner ) {
@@ -501,17 +558,7 @@ impl InboundGroupSession {
501558 SessionOrdering :: Unconnected
502559 } else {
503560 let mut other_inner = other. inner . lock ( ) . await ;
504-
505- match self . inner . lock ( ) . await . compare ( & mut other_inner) {
506- SessionOrdering :: Equal => {
507- match self . sender_data . compare_trust_level ( & other. sender_data ) {
508- Ordering :: Less => SessionOrdering :: Worse ,
509- Ordering :: Equal => SessionOrdering :: Equal ,
510- Ordering :: Greater => SessionOrdering :: Better ,
511- }
512- }
513- result => result,
514- }
561+ self . inner . lock ( ) . await . compare ( & mut other_inner)
515562 }
516563 }
517564
@@ -1057,6 +1104,7 @@ mod tests {
10571104 }
10581105
10591106 #[ async_test]
1107+ #[ allow( deprecated) ]
10601108 async fn test_session_comparison ( ) {
10611109 let alice = Account :: with_device_id ( alice_id ( ) , alice_device_id ( ) ) ;
10621110 let room_id = room_id ! ( "!test:localhost" ) ;
@@ -1067,18 +1115,24 @@ mod tests {
10671115 let mut copy = InboundGroupSession :: from_pickle ( inbound. pickle ( ) . await ) . unwrap ( ) ;
10681116
10691117 assert_eq ! ( inbound. compare( & worse) . await , SessionOrdering :: Better ) ;
1118+ assert_eq ! ( inbound. compare_ratchet( & worse) . await , SessionOrdering :: Better ) ;
10701119 assert_eq ! ( worse. compare( & inbound) . await , SessionOrdering :: Worse ) ;
1120+ assert_eq ! ( worse. compare_ratchet( & inbound) . await , SessionOrdering :: Worse ) ;
10711121 assert_eq ! ( inbound. compare( & inbound) . await , SessionOrdering :: Equal ) ;
1122+ assert_eq ! ( inbound. compare_ratchet( & inbound) . await , SessionOrdering :: Equal ) ;
10721123 assert_eq ! ( inbound. compare( & copy) . await , SessionOrdering :: Equal ) ;
1124+ assert_eq ! ( inbound. compare_ratchet( & copy) . await , SessionOrdering :: Equal ) ;
10731125
10741126 copy. creator_info . curve25519_key =
10751127 Curve25519PublicKey :: from_base64 ( "XbmrPa1kMwmdtNYng1B2gsfoo8UtF+NklzsTZiaVKyY" )
10761128 . unwrap ( ) ;
10771129
10781130 assert_eq ! ( inbound. compare( & copy) . await , SessionOrdering :: Unconnected ) ;
1131+ assert_eq ! ( inbound. compare_ratchet( & copy) . await , SessionOrdering :: Unconnected ) ;
10791132 }
10801133
10811134 #[ async_test]
1135+ #[ allow( deprecated) ]
10821136 async fn test_session_comparison_sender_data ( ) {
10831137 let alice = Account :: with_device_id ( alice_id ( ) , alice_device_id ( ) ) ;
10841138 let room_id = room_id ! ( "!test:localhost" ) ;
0 commit comments