From 92002de3754d029aed448f68d12dc95cb5f39ac3 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Wed, 27 Aug 2025 11:17:20 +0200 Subject: [PATCH 01/16] quickcheck-monoids: compatibility with QuickCheck-2.16 --- cabal.project | 43 +++++++++- nix/ouroboros-network.nix | 2 + .../ouroboros-network-framework.cabal | 16 ++-- .../Test/Ouroboros/Network/Server/Sim.hs | 3 + .../Network/ConnectionManager/Timeouts.hs | 5 ++ .../Network/ConnectionManager/Utils.hs | 5 ++ .../Network/InboundGovernor/Utils.hs | 5 ++ .../ouroboros-network-protocols.cabal | 2 +- .../Network/Protocol/LocalStateQuery/Test.hs | 9 +- ouroboros-network/ouroboros-network.cabal | 11 ++- .../Network/Diffusion/Testnet/Cardano.hs | 4 + .../Test/Ouroboros/Network/PeerSelection.hs | 1 + .../Network/PeerSelection/PeerMetric.hs | 4 + quickcheck-monoids/CHANGELOG.md | 20 +++++ quickcheck-monoids/quickcheck-monoids.cabal | 57 +++++++++++++ .../src/Test/QuickCheck/Monoids.hs | 85 +++++++++++++++++++ 16 files changed, 258 insertions(+), 14 deletions(-) create mode 100644 quickcheck-monoids/CHANGELOG.md create mode 100644 quickcheck-monoids/quickcheck-monoids.cabal create mode 100644 quickcheck-monoids/src/Test/QuickCheck/Monoids.hs diff --git a/cabal.project b/cabal.project index be0faae202..cc9cfcc52c 100644 --- a/cabal.project +++ b/cabal.project @@ -58,7 +58,48 @@ package ouroboros-network package acts flags: -finitary -allow-newer: quickcheck-instances:QuickCheck +source-repository-package + type: git + location: https://github.com/IntersectMBO/ouroboros-consensus + tag: 9433554b866a7135af5d4097082bf0481fa1a05d + --sha256: sha256-BNdV7fSgJoZR9Z5LafJKCbhAUwfAk2NTK68Yv/V0eBA= + subdir: + ouroboros-consensus-cardano + ouroboros-consensus-diffusion + sop-extras + ouroboros-consensus-protocol + ouroboros-consensus + +source-repository-package + type: git + location: https://github.com/IntersectMBO/cardano-ledger + tag: 20485948f78ab139d246695e540f9ec00963a16e + --sha256: sha256-SHnyp+GvNeR82UXoKeDEgsp1AUE2yF5dGL4HIZm0zK8= + subdir: + eras/allegra/impl + eras/alonzo/impl + eras/alonzo/test-suite + eras/babbage/impl + eras/babbage/test-suite + eras/byron/chain/executable-spec + eras/byron/crypto + eras/byron/ledger/executable-spec + eras/byron/ledger/impl + eras/conway/impl + eras/dijkstra + eras/mary/impl + eras/shelley/impl + eras/shelley-ma/test-suite + eras/shelley/test-suite + libs/cardano-data + libs/cardano-ledger-api + libs/cardano-ledger-binary + libs/cardano-ledger-core + libs/cardano-protocol-tpraos + libs/non-integral + libs/set-algebra + libs/small-steps + libs/vector-map -- kes-agent is not yet in CHaP, so we pull it from its GitHub repo source-repository-package diff --git a/nix/ouroboros-network.nix b/nix/ouroboros-network.nix index cdc2204a98..6ba7af746d 100644 --- a/nix/ouroboros-network.nix +++ b/nix/ouroboros-network.nix @@ -43,6 +43,8 @@ let # stdenv.hostPlatform.isWindows will work as expected src = ./..; name = "ouroboros-network"; + index-state = "2025-07-16T09:24:19Z"; + index-sha256 = "sha256-fmnSRF68/UIQYzzdmNs3UT0cbYhn9d5nlhb3BnVXe48="; compiler-nix-name = lib.mkDefault defaultCompiler; cabalProjectLocal = if pkgs.stdenv.hostPlatform.isWindows diff --git a/ouroboros-network-framework/ouroboros-network-framework.cabal b/ouroboros-network-framework/ouroboros-network-framework.cabal index d5eb85bf06..25a077814b 100644 --- a/ouroboros-network-framework/ouroboros-network-framework.cabal +++ b/ouroboros-network-framework/ouroboros-network-framework.cabal @@ -101,7 +101,6 @@ library -Widentities -Wredundant-constraints -Wno-unticked-promoted-constructors - -Wunused-packages library testlib visibility: public @@ -116,7 +115,7 @@ library testlib other-modules: build-depends: - QuickCheck >=2.16, + QuickCheck, base >=4.14 && <4.22, bytestring, cborg, @@ -129,6 +128,7 @@ library testlib ouroboros-network-api, ouroboros-network-framework, ouroboros-network-testing, + quickcheck-monoids, random, serialise, typed-protocols:{typed-protocols, examples}, @@ -144,7 +144,7 @@ library testlib -Widentities -Wredundant-constraints -Wno-unticked-promoted-constructors - -Wunused-packages + -Wno-unused-packages test-suite sim-tests type: exitcode-stdio-1.0 @@ -157,8 +157,11 @@ test-suite sim-tests Test.Ouroboros.Network.Server.Sim Test.Simulation.Network.Snocket + mixins: + QuickCheck hiding (Test.QuickCheck.Monoids) + build-depends: - QuickCheck >=2.16, + QuickCheck, base >=4.14 && <4.22, bytestring, cborg, @@ -175,6 +178,7 @@ test-suite sim-tests pretty-simple, psqueues, quickcheck-instances, + quickcheck-monoids, quiet, random, serialise, @@ -197,7 +201,7 @@ test-suite sim-tests -Widentities -Wredundant-constraints -Wno-unticked-promoted-constructors - -Wunused-packages + -Wno-unused-packages if flag(ipv6) cpp-options: -DOUROBOROS_NETWORK_IPV6 @@ -213,7 +217,7 @@ test-suite io-tests Test.Ouroboros.Network.Socket build-depends: - QuickCheck >=2.16, + QuickCheck, base >=4.14 && <4.22, bytestring, contra-tracer, diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server/Sim.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server/Sim.hs index 939b520ea8..5d22674ced 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server/Sim.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server/Sim.hs @@ -64,6 +64,9 @@ import System.Random (StdGen, mkStdGen, split) import Text.Printf import Test.QuickCheck +#if !MIN_VERSION_QuickCheck(2,16,0) +import "quickcheck-monoids" Test.QuickCheck.Monoids +#endif import Test.Tasty (TestTree, testGroup) import Test.Tasty.QuickCheck diff --git a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Timeouts.hs b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Timeouts.hs index 22d07e6a33..d13a22a8da 100644 --- a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Timeouts.hs +++ b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Timeouts.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PackageImports #-} module Test.Ouroboros.Network.ConnectionManager.Timeouts ( verifyAllTimeouts @@ -39,6 +41,9 @@ import Data.Monoid (Sum (Sum)) import Text.Printf (printf) import Test.QuickCheck +#if !MIN_VERSION_QuickCheck(2,16,0) +import "quickcheck-monoids" Test.QuickCheck.Monoids +#endif import Ouroboros.Network.ConnectionHandler (ConnectionHandlerTrace) import Ouroboros.Network.ConnectionManager.Core qualified as CM diff --git a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Utils.hs b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Utils.hs index 573ce87ba8..3a02bcd677 100644 --- a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Utils.hs +++ b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Utils.hs @@ -1,4 +1,6 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PackageImports #-} {-# LANGUAGE ScopedTypeVariables #-} module Test.Ouroboros.Network.ConnectionManager.Utils where @@ -10,6 +12,9 @@ import Ouroboros.Network.ConnectionManager.Core as CM import Ouroboros.Network.ConnectionManager.Types import Test.QuickCheck +#if !MIN_VERSION_QuickCheck(2,16,0) +import "quickcheck-monoids" Test.QuickCheck.Monoids +#endif verifyAbstractTransition :: AbstractTransition diff --git a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/InboundGovernor/Utils.hs b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/InboundGovernor/Utils.hs index 171674153d..57d2443448 100644 --- a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/InboundGovernor/Utils.hs +++ b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/InboundGovernor/Utils.hs @@ -1,11 +1,16 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PackageImports #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} module Test.Ouroboros.Network.InboundGovernor.Utils where import Test.QuickCheck +#if !MIN_VERSION_QuickCheck(2,16,0) +import "quickcheck-monoids" Test.QuickCheck.Monoids +#endif import Ouroboros.Network.ConnectionManager.Types import Ouroboros.Network.InboundGovernor (RemoteSt (..)) diff --git a/ouroboros-network-protocols/ouroboros-network-protocols.cabal b/ouroboros-network-protocols/ouroboros-network-protocols.cabal index a3e81f74ba..9f7898718e 100644 --- a/ouroboros-network-protocols/ouroboros-network-protocols.cabal +++ b/ouroboros-network-protocols/ouroboros-network-protocols.cabal @@ -219,7 +219,7 @@ test-suite test default-language: Haskell2010 default-extensions: ImportQualifiedPost build-depends: - QuickCheck ^>=2.16, + QuickCheck, base >=4.14 && <4.22, ouroboros-network-api, ouroboros-network-mock, diff --git a/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/LocalStateQuery/Test.hs b/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/LocalStateQuery/Test.hs index 727aadebf8..85bb0d94eb 100644 --- a/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/LocalStateQuery/Test.hs +++ b/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/LocalStateQuery/Test.hs @@ -45,7 +45,8 @@ import Ouroboros.Network.Mock.Chain (Point) import Ouroboros.Network.Mock.ConcreteBlock (Block) import Ouroboros.Network.Protocol.LocalStateQuery.Client -import Ouroboros.Network.Protocol.LocalStateQuery.Codec +import Ouroboros.Network.Protocol.LocalStateQuery.Codec hiding (Some (..)) +import Ouroboros.Network.Protocol.LocalStateQuery.Codec qualified as LocalStateQuery import Ouroboros.Network.Protocol.LocalStateQuery.Direct import Ouroboros.Network.Protocol.LocalStateQuery.Examples import Ouroboros.Network.Protocol.LocalStateQuery.Server @@ -54,7 +55,7 @@ import Ouroboros.Network.Protocol.LocalStateQuery.Type import Test.ChainGenerators () import Test.Ouroboros.Network.Protocol.Utils -import Test.QuickCheck as QC hiding (Result, Some (Some)) +import Test.QuickCheck as QC hiding (Result) import Test.Tasty (TestTree, testGroup) import Test.Tasty.QuickCheck (testProperty) import Text.Show.Functions () @@ -387,10 +388,10 @@ codec = encodeQuery :: Query result -> CBOR.Encoding encodeQuery GetTheLedgerState = Serialise.encode () - decodeQuery :: forall s . CBOR.Decoder s (Some Query) + decodeQuery :: forall s . CBOR.Decoder s (LocalStateQuery.Some Query) decodeQuery = do () <- Serialise.decode - return $ Some GetTheLedgerState + return $ LocalStateQuery.Some GetTheLedgerState encodeResult :: Query result -> result -> CBOR.Encoding encodeResult GetTheLedgerState = Serialise.encode diff --git a/ouroboros-network/ouroboros-network.cabal b/ouroboros-network/ouroboros-network.cabal index 7707272011..60bcadd1a0 100644 --- a/ouroboros-network/ouroboros-network.cabal +++ b/ouroboros-network/ouroboros-network.cabal @@ -305,13 +305,16 @@ library cardano-diffusion -- Simulation Test Library library testlib + mixins: + QuickCheck hiding (Test.QuickCheck.Monoids) + import: ghc-options-tests default-language: Haskell2010 default-extensions: ImportQualifiedPost visibility: public hs-source-dirs: testlib build-depends: - QuickCheck >=2.16, + QuickCheck, aeson, array, base >=4.14 && <4.22, @@ -343,6 +346,7 @@ library testlib pipes, pretty-simple, psqueues, + quickcheck-monoids, random, serialise, tasty, @@ -391,6 +395,9 @@ library testlib Test.Ouroboros.Network.TxSubmission.TxLogic Test.Ouroboros.Network.TxSubmission.Types + ghc-options: + -Wno-unused-packages + -- Simulation tests, and IO tests which don't require native system calls. -- (i.e. they don't require system call API provided by `Win32-network` or -- `network` dependency). test-suite sim-tests @@ -432,7 +439,7 @@ test-suite io-tests default-language: Haskell2010 default-extensions: ImportQualifiedPost build-depends: - QuickCheck >=2.16, + QuickCheck, base >=4.14 && <4.22, bytestring, contra-tracer, diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Testnet/Cardano.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Testnet/Cardano.hs index f4b9e97b0e..f45d23764f 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Testnet/Cardano.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Testnet/Cardano.hs @@ -3,6 +3,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PackageImports #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -109,6 +110,9 @@ import Test.Ouroboros.Network.Utils hiding (SmallDelay, debugTracer) import Test.QuickCheck +#if !MIN_VERSION_QuickCheck(2,16,0) +import "quickcheck-monoids" Test.QuickCheck.Monoids +#endif import Test.Tasty import Test.Tasty.QuickCheck (testProperty) diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection.hs index 031453ff5e..6854099a9c 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection.hs @@ -105,6 +105,7 @@ import Cardano.Network.Types (LedgerStateJudgement (..), NumberOfBigLedgerPeers (..)) import Ouroboros.Network.BlockFetch (FetchMode (..), PraosFetchMode (..)) import Test.QuickCheck +import Test.QuickCheck.Monoids import Test.Tasty import Test.Tasty.QuickCheck import Text.Pretty.Simple diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection/PeerMetric.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection/PeerMetric.hs index 8da417642b..b5741fc93c 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection/PeerMetric.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/PeerSelection/PeerMetric.hs @@ -4,6 +4,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE PackageImports #-} {-# LANGUAGE ScopedTypeVariables #-} #if __GLASGOW_HASKELL__ >= 908 @@ -46,6 +47,9 @@ import NoThunks.Class import Test.Ouroboros.Network.Data.Script import Test.QuickCheck +#if !MIN_VERSION_QuickCheck(2,16,0) +import "quickcheck-monoids" Test.QuickCheck.Monoids +#endif import Test.Tasty (TestTree, testGroup) import Test.Tasty.QuickCheck (testProperty) diff --git a/quickcheck-monoids/CHANGELOG.md b/quickcheck-monoids/CHANGELOG.md new file mode 100644 index 0000000000..faa779ee3a --- /dev/null +++ b/quickcheck-monoids/CHANGELOG.md @@ -0,0 +1,20 @@ +# Revision history for quickcheck-monoids + +## 0.1.0.3 -- 2025-08-27 + +* Somewhat compatible with `QuickCheck-2.16`: `QuickCheck` is also defining + `Test.QuickCheck.Monoids` module. + +## 0.1.0.2 -- 2025-06-28 + +* Package is deprecated, use `QuickCheck >= 2.16` which provides `Every` and + `Some` monoids. + +## 0.1.0.1 -- 2024-08-07 + +* Make it build with ghc-9.10 + * fix base upper bound + +## 0.1.0.0 -- 2024-06-07 + +* First version. Released on an unsuspecting world. diff --git a/quickcheck-monoids/quickcheck-monoids.cabal b/quickcheck-monoids/quickcheck-monoids.cabal new file mode 100644 index 0000000000..d0e4f45495 --- /dev/null +++ b/quickcheck-monoids/quickcheck-monoids.cabal @@ -0,0 +1,57 @@ +cabal-version: 3.0 +name: quickcheck-monoids +version: 0.1.0.3 +synopsis: QuickCheck monoids +description: All and Any monoids for `Testable` instances based on `.&&.` and `.||.`. +license: Apache-2.0 +license-files: + LICENSE + NOTICE + +author: Marcin Szamotulski +maintainer: coot@coot.me +category: Testing +copyright: 2024 Input Output Global Inc (IOG) +build-type: Simple +extra-doc-files: CHANGELOG.md +extra-source-files: README.md + +common warnings + ghc-options: -Wall + +library + import: warnings + exposed-modules: Test.QuickCheck.Monoids + build-depends: + QuickCheck, + base <4.22, + + hs-source-dirs: src + default-language: Haskell2010 + ghc-options: + -Wall + -Wno-unticked-promoted-constructors + -Wcompat + -Wincomplete-uni-patterns + -Wincomplete-record-updates + -Wpartial-fields + -Widentities + -Wredundant-constraints + -Wunused-packages + +test-suite quickcheck-monoids-test + import: warnings + default-language: Haskell2010 + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + build-depends: + QuickCheck, + base, + quickcheck-monoids, + tasty, + tasty-quickcheck, + + ghc-options: + -Wall + -rtsopts diff --git a/quickcheck-monoids/src/Test/QuickCheck/Monoids.hs b/quickcheck-monoids/src/Test/QuickCheck/Monoids.hs new file mode 100644 index 0000000000..fb6abf8b46 --- /dev/null +++ b/quickcheck-monoids/src/Test/QuickCheck/Monoids.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE PackageImports #-} +{-# LANGUAGE PatternSynonyms #-} + +-- | Monoids using `.&&.` and `.||.`. +-- +-- They satisfy monoid laws with respect to the `isSuccess` unless one is using +-- `checkCoverage` (see test for a counterexample). +-- +module Test.QuickCheck.Monoids +#if !MIN_VERSION_QuickCheck(2,16,0) + ( type Every + , All(Every, getEvery, ..) + , type Some + , Any(Some, getSome, ..) +#else + ( All (..) + , Any (..) + , Every (..) + , Some (..) +#endif + ) where + +import Data.List.NonEmpty as NonEmpty +import Data.Semigroup (Semigroup (..)) +import Test.QuickCheck + +-- | Conjunction monoid build with `.&&.`. +-- +-- Use `property @All` as an accessor which doesn't leak +-- existential variables. +-- +data All = forall p. Testable p => All { getAll :: p } + +#if !MIN_VERSION_QuickCheck(2,16,0) +type Every = All + +pattern Every :: () + => Testable p + => p + -> All +pattern Every { getEvery } = All getEvery +#endif + +instance Testable All where + property (All p) = property p + +instance Semigroup All where + All p <> All p' = All (p .&&. p') + sconcat = All . conjoin . NonEmpty.toList + +instance Monoid All where + mempty = All True + mconcat = All . conjoin + + +-- | Disjunction monoid build with `.||.`. +-- +-- Use `property @Any` as an accessor which doesn't leak +-- existential variables. +-- +data Any = forall p. Testable p => Any { getAny :: p } + +#if !MIN_VERSION_QuickCheck(2,16,0) +type Some = Any + +pattern Some :: () + => Testable p + => p + -> Any +pattern Some { getSome } = Any getSome +#endif + +instance Testable Any where + property (Any p) = property p + +instance Semigroup Any where + Any p <> Any p' = Any (p .||. p') + sconcat = Any . disjoin . NonEmpty.toList + +instance Monoid Any where + mempty = Any False + mconcat = Any . disjoin From 6f11935a51fb569c335100027a09994d5e566a00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Sun, 14 Sep 2025 12:56:30 +0200 Subject: [PATCH 02/16] cabal file fix --- ouroboros-network/ouroboros-network.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ouroboros-network/ouroboros-network.cabal b/ouroboros-network/ouroboros-network.cabal index 60bcadd1a0..cd159fb515 100644 --- a/ouroboros-network/ouroboros-network.cabal +++ b/ouroboros-network/ouroboros-network.cabal @@ -305,10 +305,10 @@ library cardano-diffusion -- Simulation Test Library library testlib + import: ghc-options-tests mixins: QuickCheck hiding (Test.QuickCheck.Monoids) - import: ghc-options-tests default-language: Haskell2010 default-extensions: ImportQualifiedPost visibility: public From 0b8170b1b938e5720bcb6da50ff7ab0728cdccb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Thu, 11 Sep 2025 21:35:24 +0200 Subject: [PATCH 03/16] make it build --- cabal.project | 52 ++++++++--------------- ouroboros-network/ouroboros-network.cabal | 2 +- 2 files changed, 18 insertions(+), 36 deletions(-) diff --git a/cabal.project b/cabal.project index cc9cfcc52c..bdaeaa10f2 100644 --- a/cabal.project +++ b/cabal.project @@ -18,7 +18,7 @@ index-state: , hackage.haskell.org 2025-08-05T15:28:56Z -- Bump this if you need newer packages from CHaP - , cardano-haskell-packages 2025-03-18T17:41:11Z + , cardano-haskell-packages 2025-09-10T20:31:08Z packages: ./cardano-ping ./monoidal-synchronisation @@ -61,8 +61,8 @@ package acts source-repository-package type: git location: https://github.com/IntersectMBO/ouroboros-consensus - tag: 9433554b866a7135af5d4097082bf0481fa1a05d - --sha256: sha256-BNdV7fSgJoZR9Z5LafJKCbhAUwfAk2NTK68Yv/V0eBA= + tag: 6b71fb3f32e516613e1a05d402f52d60c2cb188d + --sha256: subdir: ouroboros-consensus-cardano ouroboros-consensus-diffusion @@ -70,41 +70,23 @@ source-repository-package ouroboros-consensus-protocol ouroboros-consensus +-- kes-agent is not yet in CHaP, so we pull it from its GitHub repo source-repository-package type: git - location: https://github.com/IntersectMBO/cardano-ledger - tag: 20485948f78ab139d246695e540f9ec00963a16e - --sha256: sha256-SHnyp+GvNeR82UXoKeDEgsp1AUE2yF5dGL4HIZm0zK8= + location: https://github.com/coot/kes-agent + tag: 2d41b37a9d199b3f987453594182c579977ad69c + --sha256: subdir: - eras/allegra/impl - eras/alonzo/impl - eras/alonzo/test-suite - eras/babbage/impl - eras/babbage/test-suite - eras/byron/chain/executable-spec - eras/byron/crypto - eras/byron/ledger/executable-spec - eras/byron/ledger/impl - eras/conway/impl - eras/dijkstra - eras/mary/impl - eras/shelley/impl - eras/shelley-ma/test-suite - eras/shelley/test-suite - libs/cardano-data - libs/cardano-ledger-api - libs/cardano-ledger-binary - libs/cardano-ledger-core - libs/cardano-protocol-tpraos - libs/non-integral - libs/set-algebra - libs/small-steps - libs/vector-map + kes-agent + kes-agent-crypto --- kes-agent is not yet in CHaP, so we pull it from its GitHub repo source-repository-package type: git - location: https://github.com/input-output-hk/kes-agent - tag: ebf8c0e480adf7b3ccd68bc7dd5b57f781f369ea - --sha256: sha256-QIb6qgcwtO7aB9PUhZTHyKw50GV3ViXOakQvnR3HFIY= - subdir: kes-agent-crypto + location: https://github.com/input-output-hk/typed-protocols + tag: 326733ac873588f366fa988701c87fc58bad87eb + --sha256: + subdir: + typed-protocols + +constraints: + QuickCheck < 2.16 diff --git a/ouroboros-network/ouroboros-network.cabal b/ouroboros-network/ouroboros-network.cabal index cd159fb515..856ec58dcc 100644 --- a/ouroboros-network/ouroboros-network.cabal +++ b/ouroboros-network/ouroboros-network.cabal @@ -314,7 +314,7 @@ library testlib visibility: public hs-source-dirs: testlib build-depends: - QuickCheck, + QuickCheck < 2.16, aeson, array, base >=4.14 && <4.22, From 26e85bbca1a01b46b2d854d5e4030b4aa59291ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Mon, 25 Aug 2025 11:05:24 +0200 Subject: [PATCH 04/16] Revert "Ensure that compared fragments always intersect" This reverts commit a5b5ba28ea58af9dd3dc695317c7562c25b2d4fb. --- .../Network/BlockFetch/ConsensusInterface.hs | 2 - .../Ouroboros/Network/BlockFetch/Decision.hs | 60 +++++-------------- .../Network/BlockFetch/Decision/Genesis.hs | 5 +- 3 files changed, 17 insertions(+), 50 deletions(-) diff --git a/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs b/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs index c0f5447c50..0a181dcc0b 100644 --- a/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs +++ b/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs @@ -201,8 +201,6 @@ data ChainComparison header = -- This is used as part of selecting which chains to prioritise for -- downloading block bodies. -- - -- PRECONDITION: The two fragments must intersect. - -- compareCandidateChains :: HasCallStack => AnchoredFragment header -> AnchoredFragment header diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs index 67e7cbac1c..c242c6313f 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs @@ -33,7 +33,7 @@ import Data.Set qualified as Set import Data.Function (on) import Data.Hashable import Data.List as List (foldl', groupBy, sortBy, transpose) -import Data.Maybe (fromMaybe, mapMaybe) +import Data.Maybe (mapMaybe) import Data.Set (Set) import Control.Exception (assert) @@ -473,19 +473,8 @@ empty fetch range, but this is ok since we never request empty ranges. -- -- A 'ChainSuffix' must be non-empty, as an empty suffix, i.e. the candidate -- chain is equal to the current chain, would not be a plausible candidate. --- --- Additionally, we store the full candidate (with the same anchor as our --- current chain), as this is needed for comparing different candidates via --- 'compareCandidateChains'. -data ChainSuffix header = ChainSuffix { - -- | The suffix of the candidate after the intersection with the current - -- chain. - getChainSuffix :: !(AnchoredFragment header), - -- | The full candidate, characterized by having the same tip as - -- 'getChainSuffix' and the same anchor as our current chain. In particular, - -- 'getChainSuffix' is a suffix of 'getFullCandidate'. - getFullCandidate :: !(AnchoredFragment header) - } +newtype ChainSuffix header = + ChainSuffix { getChainSuffix :: AnchoredFragment header } {- We define the /chain suffix/ as the suffix of the candidate chain up until (but @@ -522,31 +511,25 @@ interested in this candidate at all. -- current chain. -- chainForkSuffix - :: HasHeader header - => AnchoredFragment header - -> AnchoredFragment header + :: (HasHeader header, HasHeader block, + HeaderHash header ~ HeaderHash block) + => AnchoredFragment block -- ^ Current chain. + -> AnchoredFragment header -- ^ Candidate chain -> Maybe (ChainSuffix header) chainForkSuffix current candidate = case AF.intersect current candidate of Nothing -> Nothing - Just (currentChainPrefix, _, _, candidateSuffix) -> + Just (_, _, _, candidateSuffix) -> -- If the suffix is empty, it means the candidate chain was equal to -- the current chain and didn't fork off. Such a candidate chain is -- not a plausible candidate, so it must have been filtered out. assert (not (AF.null candidateSuffix)) $ - Just ChainSuffix { - getChainSuffix = candidateSuffix, - getFullCandidate = fullCandidate - } - where - fullCandidate = - fromMaybe (error "invariant violation of AF.intersect") $ - AF.join currentChainPrefix candidateSuffix - + Just (ChainSuffix candidateSuffix) selectForkSuffixes - :: HasHeader header - => AnchoredFragment header + :: (HasHeader header, HasHeader block, + HeaderHash header ~ HeaderHash block) + => AnchoredFragment block -> [(FetchDecision (AnchoredFragment header), peerinfo)] -> [(FetchDecision (ChainSuffix header), peerinfo)] selectForkSuffixes current chains = @@ -760,11 +743,7 @@ prioritisePeerChains FetchModeDeadline salt compareCandidateChains blockFetchSiz (equatingPair -- compare on probability band first, then preferred chain (==) - -- Precondition of 'compareCandidateChains' (used by - -- 'equateCandidateChains') is fulfilled as all - -- 'getFullCandidate's intersect pairwise (due to having the - -- same anchor as our current chain). - (equateCandidateChains `on` getFullCandidate) + (equateCandidateChains `on` getChainSuffix) `on` (\(band, chain, _fragments) -> (band, chain))))) . sortBy (descendingOrder @@ -773,10 +752,7 @@ prioritisePeerChains FetchModeDeadline salt compareCandidateChains blockFetchSiz (comparingPair -- compare on probability band first, then preferred chain compare - -- Precondition of 'compareCandidateChains' is fulfilled as - -- all 'getFullCandidate's intersect pairwise (due to - -- having the same anchor as our current chain). - (compareCandidateChains `on` getFullCandidate) + (compareCandidateChains `on` getChainSuffix) `on` (\(band, chain, _fragments) -> (band, chain)))))) . map annotateProbabilityBand @@ -800,7 +776,7 @@ prioritisePeerChains FetchModeDeadline salt compareCandidateChains blockFetchSiz | EQ <- compareCandidateChains chain1 chain2 = True | otherwise = False - chainHeadPoint (_,ChainSuffix {getChainSuffix = c},_) = AF.headPoint c + chainHeadPoint (_,ChainSuffix c,_) = AF.headPoint c prioritisePeerChains FetchModeBulkSync salt compareCandidateChains blockFetchSize = map (\(decision, peer) -> @@ -809,11 +785,7 @@ prioritisePeerChains FetchModeBulkSync salt compareCandidateChains blockFetchSiz (comparingRight (comparingPair -- compare on preferred chain first, then duration - -- - -- Precondition of 'compareCandidateChains' is fulfilled as - -- all 'getFullCandidate's intersect pairwise (due to having - -- the same anchor as our current chain). - (compareCandidateChains `on` getFullCandidate) + (compareCandidateChains `on` getChainSuffix) compare `on` (\(duration, chain, _fragments) -> (chain, duration))))) diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs index 4071a5d6fe..e0eb182427 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs @@ -462,16 +462,13 @@ selectTheCandidate case inRace of [] -> pure Nothing _ : _ -> do - -- Precondition of 'compareCandidateChains' is fulfilled as all - -- 'getFullCandidate's intersect pairwise (due to having the same - -- anchor as our current chain). let maxChainOn f c0 c1 = case compareCandidateChains (f c0) (f c1) of LT -> c1 _ -> c0 -- maximumBy yields the last element in case of a tie while we -- prefer the first one chainSfx = fst $ - List.foldl1' (maxChainOn (getFullCandidate . fst)) inRace + List.foldl1' (maxChainOn (getChainSuffix . fst)) inRace pure $ Just (chainSfx, inRace) -- | Given _the_ candidate fragment to sync from, and a list of peers (with From 7615c6f8bb13485e67966ca84a7aa642205b5bdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Mon, 25 Aug 2025 11:07:57 +0200 Subject: [PATCH 05/16] Revert "Support dynamic chain comparisons" This reverts commit 6b688c0926feed0065fcd4d5593f8ea75ce12590. --- ouroboros-network-api/CHANGELOG.md | 2 - .../Network/BlockFetch/ConsensusInterface.hs | 93 +++++-------------- ouroboros-network/CHANGELOG.md | 2 +- ouroboros-network/demo/chain-sync.hs | 9 +- .../src/Ouroboros/Network/BlockFetch.hs | 8 +- .../Ouroboros/Network/BlockFetch/Decision.hs | 35 ++++--- .../Network/BlockFetch/Decision/Genesis.hs | 15 +-- .../src/Ouroboros/Network/BlockFetch/State.hs | 34 +++---- .../Ouroboros/Network/BlockFetch/Examples.hs | 7 +- .../Test/Ouroboros/Network/Diffusion/Node.hs | 9 +- 10 files changed, 73 insertions(+), 141 deletions(-) diff --git a/ouroboros-network-api/CHANGELOG.md b/ouroboros-network-api/CHANGELOG.md index c3c052169a..2f90fd3ede 100644 --- a/ouroboros-network-api/CHANGELOG.md +++ b/ouroboros-network-api/CHANGELOG.md @@ -8,8 +8,6 @@ * Simplify type of `headerForgeUTCTime` in `BlockFetchConsensusInterface`, and remove the supporting type `FromConsensus`. -* Changed `BlockFetchConsensusInterface` to support dynamic (weighted) chain - comparisons. ### Non-breaking changes diff --git a/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs b/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs index 0a181dcc0b..f4e975c257 100644 --- a/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs +++ b/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs @@ -1,22 +1,15 @@ -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} module Ouroboros.Network.BlockFetch.ConsensusInterface ( PraosFetchMode (..) , FetchMode (..) , BlockFetchConsensusInterface (..) , ChainSelStarvation (..) - , ChainComparison (..) , mkReadFetchMode - -- * Utilities - , WithFingerprint (..) - , Fingerprint (..) - , initialWithFingerprint ) where import Control.Monad.Class.MonadSTM @@ -25,7 +18,6 @@ import Control.Monad.Class.MonadTime.SI (Time) import Data.Functor ((<&>)) import Data.Map.Strict (Map) -import Data.Word (Word64) import GHC.Generics (Generic) import GHC.Stack (HasCallStack) import NoThunks.Class (NoThunks) @@ -137,9 +129,24 @@ data BlockFetchConsensusInterface peer header block m = -- have been downloaded anyway. readFetchedMaxSlotNo :: STM m MaxSlotNo, - -- | Compare chain fragments. This might involve further state, such as - -- Peras certificates (which give certain blocks additional weight). - readChainComparison :: STM m (WithFingerprint (ChainComparison header)), + -- | Given the current chain, is the given chain plausible as a + -- candidate chain. Classically for Ouroboros this would simply + -- check if the candidate is strictly longer, but for Ouroboros + -- with operational key certificates there are also cases where + -- we would consider a chain of equal length to the current chain. + -- + plausibleCandidateChain :: HasCallStack + => AnchoredFragment header + -> AnchoredFragment header -> Bool, + + -- | Compare two candidate chains and return a preference ordering. + -- This is used as part of selecting which chains to prioritise for + -- downloading block bodies. + -- + compareCandidateChains :: HasCallStack + => AnchoredFragment header + -> AnchoredFragment header + -> Ordering, -- | Much of the logic for deciding which blocks to download from which -- peer depends on making estimates based on recent performance metrics. @@ -177,57 +184,3 @@ data ChainSelStarvation = ChainSelStarvationOngoing | ChainSelStarvationEndedAt Time deriving (Eq, Show, NoThunks, Generic) - - -data ChainComparison header = - ChainComparison { - -- | Given the current chain, is the given chain plausible as a candidate - -- chain. Classically for Ouroboros this would simply check if the - -- candidate is strictly longer, but it can also involve further - -- criteria: - -- - -- * Tiebreakers (e.g. based on the opcert numbers and VRFs) for chains - -- of equal length. - -- - -- * Weight in the context of Ouroboros Peras, due to a boost from a - -- Peras certificate. - -- - plausibleCandidateChain :: HasCallStack - => AnchoredFragment header - -> AnchoredFragment header - -> Bool, - - -- | Compare two candidate chains and return a preference ordering. - -- This is used as part of selecting which chains to prioritise for - -- downloading block bodies. - -- - compareCandidateChains :: HasCallStack - => AnchoredFragment header - -> AnchoredFragment header - -> Ordering - } - -{------------------------------------------------------------------------------- - Utilities --------------------------------------------------------------------------------} - --- | Simple type that can be used to indicate some value (without/only with an --- expensive 'Eq' instance) changed. -newtype Fingerprint = Fingerprint Word64 - deriving stock (Show, Eq, Generic) - deriving newtype (Enum) - deriving anyclass (NoThunks) - --- | Store a value together with its 'Fingerprint'. -data WithFingerprint a = WithFingerprint - { forgetFingerprint :: !a - , getFingerprint :: !Fingerprint - } - deriving stock (Show, Functor, Generic) - deriving anyclass (NoThunks) - --- | Attach @'Fingerprint' 0@ to the given value. When the underlying @a@ is --- changed, the 'Fingerprint' must be updated to a new unique value (e.g. via --- 'succ'). -initialWithFingerprint :: a -> WithFingerprint a -initialWithFingerprint a = WithFingerprint a (Fingerprint 0) diff --git a/ouroboros-network/CHANGELOG.md b/ouroboros-network/CHANGELOG.md index 0db7d20af7..f019378c14 100644 --- a/ouroboros-network/CHANGELOG.md +++ b/ouroboros-network/CHANGELOG.md @@ -11,7 +11,7 @@ * Adapt to simplified type of `headerForgeUTCTime` in `BlockFetchConsensusInterface`. * Type of `defaultSyncTargets` changed. * Type of `defaultPeerSharing` changed. -* Adapted to changes of `BlockFetchConsensusInterface`. +* (REVERTED temporarily) Adapted to changes of `BlockFetchConsensusInterface`. * `Ouroboros.Network.TxSubmission.Inbound` moved to `Ouroboros.Network.TxSubmission.Inbound.V1` * `Ouroboros.Network.TxSubmission.Inbound.V1.txSubmissionInbound` takes extra argument: `TxSubmissionInitDelay` (previously configurable through `cabal` flags). * Removed the `txsubmission-delay` cabal flag. diff --git a/ouroboros-network/demo/chain-sync.hs b/ouroboros-network/demo/chain-sync.hs index b8dc166134..1c0ffd05dd 100644 --- a/ouroboros-network/demo/chain-sync.hs +++ b/ouroboros-network/demo/chain-sync.hs @@ -75,8 +75,7 @@ import Ouroboros.Network.Protocol.BlockFetch.Type qualified as BlockFetch import Ouroboros.Network.BlockFetch import Ouroboros.Network.BlockFetch.Client import Ouroboros.Network.BlockFetch.ClientRegistry (FetchClientRegistry (..)) -import Ouroboros.Network.BlockFetch.ConsensusInterface (ChainSelStarvation (..), - initialWithFingerprint) +import Ouroboros.Network.BlockFetch.ConsensusInterface (ChainSelStarvation (..)) import Ouroboros.Network.DeltaQ (defaultGSV) import Ouroboros.Network.Server.Simple qualified as Server.Simple @@ -434,10 +433,8 @@ clientBlockFetch sockAddrs maxSlotNo = withIOManager $ \iocp -> do pure $ \p b -> addTestFetchedBlock blockHeap (castPoint p) (blockHeader b), - readChainComparison = pure $ initialWithFingerprint ChainComparison { - plausibleCandidateChain, - compareCandidateChains - }, + plausibleCandidateChain, + compareCandidateChains, blockFetchSize = \_ -> 1000, blockMatchesHeader = \_ _ -> True, diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs index adfd52df18..8cd0d79bf4 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs @@ -99,7 +99,6 @@ module Ouroboros.Network.BlockFetch -- * Re-export types used by 'BlockFetchConsensusInterface' , PraosFetchMode (..) , FetchMode (..) - , ChainComparison (..) , SizeInBytes ) where @@ -122,7 +121,7 @@ import Ouroboros.Network.BlockFetch.ClientRegistry (FetchClientPolicy (..), readFetchClientsStateVars, readFetchClientsStatus, readPeerGSVs, setFetchClientContext) import Ouroboros.Network.BlockFetch.ConsensusInterface - (BlockFetchConsensusInterface (..), ChainComparison (..)) + (BlockFetchConsensusInterface (..)) import Ouroboros.Network.BlockFetch.Decision.Trace (TraceDecisionEvent) import Ouroboros.Network.BlockFetch.State @@ -222,6 +221,8 @@ blockFetchLogic decisionTracer clientStateTracer peerSalt = bfcSalt, bulkSyncGracePeriod = gbfcGracePeriod bfcGenesisBFConfig, + plausibleCandidateChain, + compareCandidateChains, blockFetchSize } @@ -230,8 +231,7 @@ blockFetchLogic decisionTracer clientStateTracer FetchTriggerVariables { readStateCurrentChain = readCurrentChain, readStateCandidateChains = readCandidateChains, - readStatePeerStatus = readFetchClientsStatus registry, - readStateChainComparison = readChainComparison + readStatePeerStatus = readFetchClientsStatus registry } fetchNonTriggerVariables :: FetchNonTriggerVariables addr header block m diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs index c242c6313f..5f086f097f 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision.hs @@ -35,6 +35,7 @@ import Data.Hashable import Data.List as List (foldl', groupBy, sortBy, transpose) import Data.Maybe (mapMaybe) import Data.Set (Set) +import GHC.Stack (HasCallStack) import Control.Exception (assert) import Control.Monad (guard) @@ -47,8 +48,8 @@ import Ouroboros.Network.Point (withOriginToMaybe) import Ouroboros.Network.BlockFetch.ClientState (FetchRequest (..), PeerFetchInFlight (..), PeerFetchStatus (..)) -import Ouroboros.Network.BlockFetch.ConsensusInterface (ChainComparison (..), - FetchMode (..), PraosFetchMode (..)) +import Ouroboros.Network.BlockFetch.ConsensusInterface (FetchMode (..), + PraosFetchMode (..)) import Ouroboros.Network.BlockFetch.DeltaQ (PeerFetchInFlightLimits (..), PeerGSV (..), SizeInBytes, calculatePeerFetchInFlightLimits, comparePeerGSV, comparePeerGSV', estimateExpectedResponseDuration, @@ -56,16 +57,25 @@ import Ouroboros.Network.BlockFetch.DeltaQ (PeerFetchInFlightLimits (..), data FetchDecisionPolicy header = FetchDecisionPolicy { - maxInFlightReqsPerPeer :: Word, -- A protocol constant. + maxInFlightReqsPerPeer :: Word, -- A protocol constant. - maxConcurrencyBulkSync :: Word, - maxConcurrencyDeadline :: Word, + maxConcurrencyBulkSync :: Word, + maxConcurrencyDeadline :: Word, decisionLoopIntervalGenesis :: DiffTime, - decisionLoopIntervalPraos :: DiffTime, - peerSalt :: Int, - bulkSyncGracePeriod :: DiffTime, + decisionLoopIntervalPraos :: DiffTime, + peerSalt :: Int, + bulkSyncGracePeriod :: DiffTime, - blockFetchSize :: header -> SizeInBytes + plausibleCandidateChain :: HasCallStack + => AnchoredFragment header + -> AnchoredFragment header -> Bool, + + compareCandidateChains :: HasCallStack + => AnchoredFragment header + -> AnchoredFragment header + -> Ordering, + + blockFetchSize :: header -> SizeInBytes } @@ -254,7 +264,6 @@ fetchDecisions HasHeader header, HeaderHash header ~ HeaderHash block) => FetchDecisionPolicy header - -> ChainComparison header -> PraosFetchMode -> AnchoredFragment header -> (Point block -> Bool) @@ -262,13 +271,11 @@ fetchDecisions -> [(AnchoredFragment header, PeerInfo header peer extra)] -> [(FetchDecision (FetchRequest header), PeerInfo header peer extra)] fetchDecisions fetchDecisionPolicy@FetchDecisionPolicy { + plausibleCandidateChain, + compareCandidateChains, blockFetchSize, peerSalt } - ChainComparison { - plausibleCandidateChain, - compareCandidateChains - } fetchMode currentChain fetchedBlocks diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs index e0eb182427..db1aa588a2 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Decision/Genesis.hs @@ -146,8 +146,8 @@ import Ouroboros.Network.AnchoredFragment qualified as AF import Ouroboros.Network.Block import Ouroboros.Network.BlockFetch.ClientState (FetchRequest (..), PeerFetchInFlight (..), PeersOrder (..)) -import Ouroboros.Network.BlockFetch.ConsensusInterface (ChainComparison(..), - ChainSelStarvation (..), FetchMode (..)) +import Ouroboros.Network.BlockFetch.ConsensusInterface (ChainSelStarvation (..), + FetchMode (..)) import Ouroboros.Network.BlockFetch.DeltaQ (calculatePeerFetchInFlightLimits) import Cardano.Slotting.Slot (WithOrigin) @@ -167,7 +167,6 @@ fetchDecisionsGenesisM HeaderHash header ~ HeaderHash block, MonadMonotonicTime m) => Tracer m (TraceDecisionEvent peer header) -> FetchDecisionPolicy header - -> ChainComparison header -> AnchoredFragment header -> (Point block -> Bool) -- ^ Whether the block has been fetched (only if recent, i.e. within @k@). @@ -182,7 +181,6 @@ fetchDecisionsGenesisM fetchDecisionsGenesisM tracer fetchDecisionPolicy@FetchDecisionPolicy {bulkSyncGracePeriod} - chainComparison currentChain fetchedBlocks fetchedMaxSlotNo @@ -205,7 +203,6 @@ fetchDecisionsGenesisM let (theDecision, declines) = fetchDecisionsGenesis fetchDecisionPolicy - chainComparison currentChain fetchedBlocks fetchedMaxSlotNo @@ -319,7 +316,6 @@ fetchDecisionsGenesis , HeaderHash header ~ HeaderHash block ) => FetchDecisionPolicy header - -> ChainComparison header -> AnchoredFragment header -- ^ The current chain, anchored at the immutable tip. -> (Point block -> Bool) @@ -338,7 +334,6 @@ fetchDecisionsGenesis -- one @'FetchRequest' header@. fetchDecisionsGenesis fetchDecisionPolicy - chainComparison currentChain fetchedBlocks fetchedMaxSlotNo @@ -351,7 +346,7 @@ fetchDecisionsGenesis ) <- MaybeT $ selectTheCandidate - chainComparison + fetchDecisionPolicy currentChain candidatesAndPeers @@ -428,7 +423,7 @@ dropAlreadyFetched alreadyDownloaded fetchedMaxSlotNo candidate = selectTheCandidate :: forall header peerInfo. HasHeader header - => ChainComparison header + => FetchDecisionPolicy header -> AnchoredFragment header -- ^ The current chain. -> [(AnchoredFragment header, peerInfo)] @@ -441,7 +436,7 @@ selectTheCandidate -- selected candidate that we choose to sync from and a list of peers that -- are still in the race to serve that candidate. selectTheCandidate - ChainComparison {compareCandidateChains, plausibleCandidateChain} + FetchDecisionPolicy {compareCandidateChains, plausibleCandidateChain} currentChain = separateDeclinedAndStillInRace -- Select the suffix up to the intersection with the current chain. This can diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch/State.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch/State.hs index 62d0f4a301..1ee5718b6e 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch/State.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch/State.hs @@ -43,9 +43,8 @@ import Ouroboros.Network.BlockFetch.ClientState (FetchClientStateVars (..), FetchRequest (..), PeerFetchInFlight (..), PeerFetchStatus (..), PeersOrder (..), TraceFetchClientState (..), TraceLabelPeer (..), addNewFetchRequest, readFetchClientState) -import Ouroboros.Network.BlockFetch.ConsensusInterface (ChainComparison (..), - ChainSelStarvation, FetchMode (..), Fingerprint (..), - WithFingerprint (..)) +import Ouroboros.Network.BlockFetch.ConsensusInterface (ChainSelStarvation, + FetchMode (..)) import Ouroboros.Network.BlockFetch.Decision (FetchDecision, FetchDecisionPolicy (..), FetchDecline (..), PeerInfo, PraosFetchMode (..), fetchDecisions) @@ -228,8 +227,7 @@ fetchDecisionsForStateSnapshot fetchStateFetchedBlocks, fetchStateFetchedMaxSlotNo, fetchStateFetchMode, - fetchStateChainSelStarvation, - fetchStateChainComparison + fetchStateChainSelStarvation } peersOrderHandlers = assert ( Map.keysSet fetchStatePeerChains @@ -242,7 +240,6 @@ fetchDecisionsForStateSnapshot PraosFetchMode fetchMode -> pure $ fetchDecisions fetchDecisionPolicy - fetchStateChainComparison fetchMode fetchStateCurrentChain fetchStateFetchedBlocks @@ -252,7 +249,6 @@ fetchDecisionsForStateSnapshot fetchDecisionsGenesisM tracer fetchDecisionPolicy - fetchStateChainComparison fetchStateCurrentChain fetchStateFetchedBlocks fetchStateFetchedMaxSlotNo @@ -306,8 +302,7 @@ fetchLogicIterationAct clientStateTracer FetchDecisionPolicy{blockFetchSize} data FetchTriggerVariables peer header m = FetchTriggerVariables { readStateCurrentChain :: STM m (AnchoredFragment header), readStateCandidateChains :: STM m (Map peer (AnchoredFragment header)), - readStatePeerStatus :: STM m (Map peer (PeerFetchStatus header)), - readStateChainComparison :: STM m (WithFingerprint (ChainComparison header)) + readStatePeerStatus :: STM m (Map peer (PeerFetchStatus header)) } -- | STM actions to read various state variables that the fetch logic uses. @@ -329,7 +324,6 @@ data FetchStateFingerprint peer header block = !(Maybe (Point block)) !(Map peer (Point header)) !(Map peer (PeerFetchStatus header)) - !Fingerprint -- ^ From 'ChainComparison' deriving Eq initialFetchStateFingerprint :: FetchStateFingerprint peer header block @@ -338,19 +332,17 @@ initialFetchStateFingerprint = Nothing Map.empty Map.empty - (Fingerprint 0) updateFetchStateFingerprintPeerStatus :: Ord peer => [(peer, PeerFetchStatus header)] -> FetchStateFingerprint peer header block -> FetchStateFingerprint peer header block updateFetchStateFingerprintPeerStatus statuses' - (FetchStateFingerprint current candidates statuses fpChainComp) = + (FetchStateFingerprint current candidates statuses) = FetchStateFingerprint current candidates (Map.union (Map.fromList statuses') statuses) -- left overrides right - fpChainComp -- | -- @@ -367,8 +359,7 @@ data FetchStateSnapshot peer header block m = FetchStateSnapshot { fetchStateFetchedBlocks :: Point block -> Bool, fetchStateFetchMode :: FetchMode, fetchStateFetchedMaxSlotNo :: MaxSlotNo, - fetchStateChainSelStarvation :: ChainSelStarvation, - fetchStateChainComparison :: ChainComparison header + fetchStateChainSelStarvation :: ChainSelStarvation } readStateVariables :: (MonadSTM m, Eq peer, @@ -387,11 +378,10 @@ readStateVariables FetchTriggerVariables{..} fetchStateFingerprint = do -- Read all the trigger state variables - fetchStateCurrentChain <- readStateCurrentChain - fetchStatePeerChains <- readStateCandidateChains - fetchStatePeerStatus <- readStatePeerStatus - chainComparison <- readStateChainComparison - gracePeriodExpired <- LazySTM.readTVar gracePeriodTVar + fetchStateCurrentChain <- readStateCurrentChain + fetchStatePeerChains <- readStateCandidateChains + fetchStatePeerStatus <- readStatePeerStatus + gracePeriodExpired <- LazySTM.readTVar gracePeriodTVar -- Construct the change detection fingerprint let !fetchStateFingerprint' = @@ -399,7 +389,6 @@ readStateVariables FetchTriggerVariables{..} (Just (castPoint (AF.headPoint fetchStateCurrentChain))) (Map.map AF.headPoint fetchStatePeerChains) fetchStatePeerStatus - (getFingerprint chainComparison) -- Check the fingerprint changed, or block and wait until it does check (gracePeriodExpired || fetchStateFingerprint' /= fetchStateFingerprint) @@ -423,8 +412,7 @@ readStateVariables FetchTriggerVariables{..} fetchStateFetchedBlocks, fetchStateFetchMode, fetchStateFetchedMaxSlotNo, - fetchStateChainSelStarvation, - fetchStateChainComparison = forgetFingerprint chainComparison + fetchStateChainSelStarvation } return (fetchStateSnapshot, gracePeriodExpired, fetchStateFingerprint') diff --git a/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs b/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs index 28ca67ed77..fbde64a974 100644 --- a/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs +++ b/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs @@ -54,7 +54,6 @@ import Ouroboros.Network.Protocol.BlockFetch.Server import Ouroboros.Network.Protocol.BlockFetch.Type import Ouroboros.Network.Util.ShowProxy -import Ouroboros.Network.BlockFetch.ConsensusInterface (initialWithFingerprint) import Ouroboros.Network.BlockFetch.Decision.Trace (TraceDecisionEvent) import Ouroboros.Network.Mock.ConcreteBlock @@ -296,10 +295,8 @@ sampleBlockFetchPolicy1 fetchMode headerFieldsForgeUTCTime blockHeap currentChai getTestFetchedBlocks blockHeap, mkAddFetchedBlock = pure $ addTestFetchedBlock blockHeap, - readChainComparison = pure $ initialWithFingerprint ChainComparison { - plausibleCandidateChain, - compareCandidateChains - }, + plausibleCandidateChain, + compareCandidateChains, blockFetchSize = \_ -> 2000, blockMatchesHeader = \_ _ -> True, diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs index db466746e8..d224525c76 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs @@ -78,8 +78,7 @@ import Ouroboros.Network.Block (MaxSlotNo (..), maxSlotNoFromWithOrigin, pointSlot) import Ouroboros.Network.BlockFetch import Ouroboros.Network.BlockFetch.ConsensusInterface - (ChainSelStarvation (ChainSelStarvationEndedAt), - initialWithFingerprint) + (ChainSelStarvation (ChainSelStarvationEndedAt)) import Ouroboros.Network.ConnectionManager.State (ConnStateIdSupply) import Ouroboros.Network.ConnectionManager.Types (DataFlow (..)) import Ouroboros.Network.Diffusion qualified as Diffusion @@ -414,10 +413,8 @@ run blockGeneratorArgs limits ni na pure $ \_p b -> atomically (addBlock b (nkChainDB nodeKernel)), - readChainComparison = pure $ initialWithFingerprint ChainComparison { - plausibleCandidateChain, - compareCandidateChains - }, + plausibleCandidateChain, + compareCandidateChains, blockFetchSize = \_ -> 1000, blockMatchesHeader = \_ _ -> True, From df7b6154c4b8ddf8720725252b4742d86afc79f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Mon, 25 Aug 2025 11:08:48 +0200 Subject: [PATCH 06/16] Revert "`BlockFetchConsensusInterface`: simplify `headerForgeUTCTime`" This reverts commit 2312d2bbb518795c4e7036e00add8e0590f57e13. --- ouroboros-network-api/CHANGELOG.md | 3 -- .../Network/BlockFetch/ConsensusInterface.hs | 34 ++++++++++++++++++- ouroboros-network/CHANGELOG.md | 2 +- ouroboros-network/demo/chain-sync.hs | 5 +-- .../src/Ouroboros/Network/BlockFetch.hs | 3 +- .../Ouroboros/Network/BlockFetch/Client.hs | 9 ++--- .../Network/BlockFetch/ClientState.hs | 4 ++- .../Ouroboros/Network/BlockFetch/Examples.hs | 10 +++--- .../Test/Ouroboros/Network/BlockFetch.hs | 2 +- .../Test/Ouroboros/Network/Diffusion/Node.hs | 5 +-- 10 files changed, 56 insertions(+), 21 deletions(-) diff --git a/ouroboros-network-api/CHANGELOG.md b/ouroboros-network-api/CHANGELOG.md index 2f90fd3ede..4148b6e566 100644 --- a/ouroboros-network-api/CHANGELOG.md +++ b/ouroboros-network-api/CHANGELOG.md @@ -6,9 +6,6 @@ ### Breaking changes -* Simplify type of `headerForgeUTCTime` in `BlockFetchConsensusInterface`, and - remove the supporting type `FromConsensus`. - ### Non-breaking changes ## 0.16.0.0 -- 2025-07-21 diff --git a/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs b/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs index f4e975c257..99b4e99dda 100644 --- a/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs +++ b/ouroboros-network-api/src/Ouroboros/Network/BlockFetch/ConsensusInterface.hs @@ -8,6 +8,7 @@ module Ouroboros.Network.BlockFetch.ConsensusInterface ( PraosFetchMode (..) , FetchMode (..) , BlockFetchConsensusInterface (..) + , FromConsensus (..) , ChainSelStarvation (..) , mkReadFetchMode ) where @@ -161,7 +162,19 @@ data BlockFetchConsensusInterface peer header block m = blockMatchesHeader :: header -> block -> Bool, -- | Calculate when a header's block was forged. - headerForgeUTCTime :: header -> UTCTime, + -- + -- PRECONDITION: This function will succeed and give a _correct_ result + -- when applied to headers obtained via this interface (ie via + -- Consensus, ie via 'readCurrentChain' or 'readCandidateChains'). + -- + -- WARNING: This function may fail or, worse, __give an incorrect result + -- (!!)__ if applied to headers obtained from sources outside of this + -- interface. The 'FromConsensus' newtype wrapper is intended to make it + -- difficult to make that mistake, so please pay that syntactic price + -- and consider its meaning at each call to this function. Relatedly, + -- preserve that argument wrapper as much as possible when deriving + -- ancillary functions\/interfaces from this function. + headerForgeUTCTime :: FromConsensus header -> STM m UTCTime, -- | Information on the ChainSel starvation status; whether it is ongoing -- or has ended recently. Needed by the bulk sync decision logic. @@ -184,3 +197,22 @@ data ChainSelStarvation = ChainSelStarvationOngoing | ChainSelStarvationEndedAt Time deriving (Eq, Show, NoThunks, Generic) + +{------------------------------------------------------------------------------- + Syntactic indicator of key precondition about Consensus time conversions +-------------------------------------------------------------------------------} + +-- | A new type used to emphasize the precondition of +-- 'Ouroboros.Network.BlockFetch.ConsensusInterface.headerForgeUTCTime' at each +-- call site. +-- +-- At time of writing, the @a@ is either a header or a block. The headers are +-- literally from Consensus (ie provided by ChainSync). Blocks, on the other +-- hand, are indirectly from Consensus: they were fetched only because we +-- favored the corresponding header that Consensus provided. +newtype FromConsensus a = FromConsensus {unFromConsensus :: a} + deriving (Functor) + +instance Applicative FromConsensus where + pure = FromConsensus + FromConsensus f <*> FromConsensus a = FromConsensus (f a) diff --git a/ouroboros-network/CHANGELOG.md b/ouroboros-network/CHANGELOG.md index f019378c14..7bc5ba03fc 100644 --- a/ouroboros-network/CHANGELOG.md +++ b/ouroboros-network/CHANGELOG.md @@ -8,7 +8,7 @@ * `Ouroboros.Network.NodeTo{Client,Node}` modules moved to `ouroboros-network:cardano-diffusion` (as `Cardano.Network.NodeTo{Node,Client}`) -* Adapt to simplified type of `headerForgeUTCTime` in `BlockFetchConsensusInterface`. +* (REVERTED temporarily) Adapt to simplified type of `headerForgeUTCTime` in `BlockFetchConsensusInterface`. * Type of `defaultSyncTargets` changed. * Type of `defaultPeerSharing` changed. * (REVERTED temporarily) Adapted to changes of `BlockFetchConsensusInterface`. diff --git a/ouroboros-network/demo/chain-sync.hs b/ouroboros-network/demo/chain-sync.hs index 1c0ffd05dd..7d5ba62725 100644 --- a/ouroboros-network/demo/chain-sync.hs +++ b/ouroboros-network/demo/chain-sync.hs @@ -448,8 +448,9 @@ clientBlockFetch sockAddrs maxSlotNo = withIOManager $ \iocp -> do plausibleCandidateChain cur candidate = AF.headBlockNo candidate > AF.headBlockNo cur - headerForgeUTCTime = - convertSlotToTimeForTestsAssumingNoHardFork . headerSlot + headerForgeUTCTime (FromConsensus hdr) = + pure $ + convertSlotToTimeForTestsAssumingNoHardFork (headerSlot hdr) compareCandidateChains c1 c2 = AF.headBlockNo c1 `compare` AF.headBlockNo c2 diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs index 8cd0d79bf4..9b5087b72d 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch.hs @@ -99,6 +99,7 @@ module Ouroboros.Network.BlockFetch -- * Re-export types used by 'BlockFetchConsensusInterface' , PraosFetchMode (..) , FetchMode (..) + , FromConsensus (..) , SizeInBytes ) where @@ -121,7 +122,7 @@ import Ouroboros.Network.BlockFetch.ClientRegistry (FetchClientPolicy (..), readFetchClientsStateVars, readFetchClientsStatus, readPeerGSVs, setFetchClientContext) import Ouroboros.Network.BlockFetch.ConsensusInterface - (BlockFetchConsensusInterface (..)) + (BlockFetchConsensusInterface (..), FromConsensus (..)) import Ouroboros.Network.BlockFetch.Decision.Trace (TraceDecisionEvent) import Ouroboros.Network.BlockFetch.State diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Client.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Client.hs index 81f5612875..ec6253cf64 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch/Client.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch/Client.hs @@ -41,9 +41,9 @@ import Ouroboros.Network.AnchoredFragment (AnchoredFragment) import Ouroboros.Network.AnchoredFragment qualified as AF import Ouroboros.Network.BlockFetch.ClientState (FetchClientContext (..), FetchClientPolicy (..), FetchClientStateVars (..), FetchRequest (..), - TraceFetchClientState (..), acknowledgeFetchRequest, - completeBlockDownload, completeFetchBatch, fetchClientCtxStateVars, - rejectedFetchBatch, startedFetchBatch) + FromConsensus (..), TraceFetchClientState (..), + acknowledgeFetchRequest, completeBlockDownload, completeFetchBatch, + fetchClientCtxStateVars, rejectedFetchBatch, startedFetchBatch) import Ouroboros.Network.BlockFetch.DeltaQ (PeerFetchInFlightLimits (..), PeerGSV (..)) import Ouroboros.Network.PeerSelection.PeerMetric.Type (FetchedMetricsTracer) @@ -267,7 +267,8 @@ blockFetchClient _version controlMessageSTM reportFetched -- Add the block to the chain DB, notifying of any new chains. addFetchedBlock (castPoint (blockPoint header)) block - let blockDelay = diffUTCTime now (headerForgeUTCTime header) + forgeTime <- atomically $ headerForgeUTCTime $ FromConsensus header + let blockDelay = diffUTCTime now forgeTime let hf = getHeaderFields header slotNo = headerFieldSlot hf diff --git a/ouroboros-network/src/Ouroboros/Network/BlockFetch/ClientState.hs b/ouroboros-network/src/Ouroboros/Network/BlockFetch/ClientState.hs index 74958ad58b..3a386d2a18 100644 --- a/ouroboros-network/src/Ouroboros/Network/BlockFetch/ClientState.hs +++ b/ouroboros-network/src/Ouroboros/Network/BlockFetch/ClientState.hs @@ -33,6 +33,7 @@ module Ouroboros.Network.BlockFetch.ClientState , TraceLabelPeer (..) , ChainRange (..) -- * Ancillary + , FromConsensus (..) , PeersOrder (..) ) where @@ -56,6 +57,7 @@ import Ouroboros.Network.AnchoredFragment (AnchoredFragment) import Ouroboros.Network.AnchoredFragment qualified as AF import Ouroboros.Network.Block (HasHeader, HeaderHash, MaxSlotNo (..), Point, blockPoint, castPoint) +import Ouroboros.Network.BlockFetch.ConsensusInterface (FromConsensus (..)) import Ouroboros.Network.BlockFetch.DeltaQ (PeerFetchInFlightLimits (..), PeerGSV, SizeInBytes, calculatePeerFetchInFlightLimits) import Ouroboros.Network.ControlMessage (ControlMessageSTM, @@ -82,7 +84,7 @@ data FetchClientPolicy header block m = blockFetchSize :: header -> SizeInBytes, blockMatchesHeader :: header -> block -> Bool, addFetchedBlock :: Point block -> block -> m (), - headerForgeUTCTime :: header -> UTCTime + headerForgeUTCTime :: FromConsensus header -> STM m UTCTime } -- | A set of variables shared between the block fetch logic thread and each diff --git a/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs b/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs index fbde64a974..28eedb2ae1 100644 --- a/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs +++ b/ouroboros-network/testlib/Ouroboros/Network/BlockFetch/Examples.hs @@ -148,8 +148,8 @@ blockFetchExample0 fetchMode decisionTracer clientStateTracer clientMsgTracer }) >> return () - headerForgeUTCTime = - convertSlotToTimeForTestsAssumingNoHardFork . headerSlot + headerForgeUTCTime (FromConsensus x) = + pure $ convertSlotToTimeForTestsAssumingNoHardFork (blockSlot x) driver :: TestFetchedBlockHeap m Block -> m () driver blockHeap = do @@ -262,8 +262,8 @@ blockFetchExample1 fetchMode decisionTracer clientStateTracer clientMsgTracer }) >> return () - headerForgeUTCTime = - convertSlotToTimeForTestsAssumingNoHardFork . headerSlot + headerForgeUTCTime (FromConsensus x) = + pure $ convertSlotToTimeForTestsAssumingNoHardFork (blockSlot x) -- | Terminates after 1 second per block in the candidate chains. downloadTimer :: m () @@ -277,7 +277,7 @@ blockFetchExample1 fetchMode decisionTracer clientStateTracer clientMsgTracer sampleBlockFetchPolicy1 :: (MonadSTM m, HasHeader header, HasHeader block) => FetchMode - -> (header -> UTCTime) + -> (forall x. HasHeader x => FromConsensus x -> STM m UTCTime) -> TestFetchedBlockHeap m block -> AnchoredFragment header -> Map peer (AnchoredFragment header) diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/BlockFetch.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/BlockFetch.hs index 7693ef31b7..60b84955a2 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/BlockFetch.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/BlockFetch.hs @@ -788,7 +788,7 @@ unit_bracketSyncWithFetchClient step = do dummyPolicy :: forall b h m. (MonadSTM m) => STM m (FetchClientPolicy h b m) dummyPolicy = let addFetchedBlock _ _ = return () - forgeTime _ = read "2000-01-01 00:00:00 UTC" + forgeTime _ = return (read "2000-01-01 00:00:00 UTC") bfSize _ = 1024 matchesHeader _ _ = True in pure $ FetchClientPolicy diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs index d224525c76..5bfdee3854 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node.hs @@ -428,8 +428,9 @@ run blockGeneratorArgs limits ni na plausibleCandidateChain cur candidate = AF.headBlockNo candidate > AF.headBlockNo cur - headerForgeUTCTime = - convertSlotToTimeForTestsAssumingNoHardFork . headerSlot + headerForgeUTCTime (FromConsensus hdr) = + pure $ + convertSlotToTimeForTestsAssumingNoHardFork (headerSlot hdr) compareCandidateChains c1 c2 = AF.headBlockNo c1 `compare` AF.headBlockNo c2 From f8a79959f615f4f8c8276bb5163ada841e80e8eb Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Wed, 13 Aug 2025 23:08:36 +0200 Subject: [PATCH 07/16] mempool: compute set of txids in the mempool incrementally --- dmq-node/src/DMQ/Diffusion/NodeKernel.hs | 44 ++++++++++---- .../Network/TxSubmission/Mempool/Simple.hs | 59 ++++++++++++------- .../Network/Diffusion/Node/Kernel.hs | 4 +- .../Network/Diffusion/Node/MiniProtocols.hs | 4 +- .../Ouroboros/Network/TxSubmission/AppV1.hs | 4 +- .../Ouroboros/Network/TxSubmission/Types.hs | 13 ++-- 6 files changed, 83 insertions(+), 45 deletions(-) diff --git a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs index 01b9b3805e..b5e479e334 100644 --- a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs +++ b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs @@ -14,6 +14,9 @@ import Control.Monad.Class.MonadTime.SI import Control.Monad.Class.MonadTimer.SI import Data.Function (on) +import Data.Set (Set) +import Data.Set qualified as Set +import Data.Sequence (Seq) import Data.Sequence qualified as Seq import Data.Time.Clock.POSIX (POSIXTime) import Data.Time.Clock.POSIX qualified as Time @@ -30,10 +33,11 @@ import Ouroboros.Network.PeerSharing (PeerSharingAPI, PeerSharingRegistry, newPeerSharingAPI, newPeerSharingRegistry, ps_POLICY_PEER_SHARE_MAX_PEERS, ps_POLICY_PEER_SHARE_STICKY_TIME) import Ouroboros.Network.TxSubmission.Inbound.V2.Registry -import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..)) +import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..), + MempoolSeq (..)) import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool -import DMQ.Protocol.SigSubmission.Type (Sig (sigExpiresAt), SigId) +import DMQ.Protocol.SigSubmission.Type (Sig (sigExpiresAt, sigId), SigId) data NodeKernel crypto ntnAddr m = @@ -45,7 +49,7 @@ data NodeKernel crypto ntnAddr m = -- the PeerSharing protocol , peerSharingRegistry :: !(PeerSharingRegistry ntnAddr m) , peerSharingAPI :: !(PeerSharingAPI ntnAddr StdGen m) - , mempool :: !(Mempool m (Sig crypto)) + , mempool :: !(Mempool m SigId (Sig crypto)) , sigChannelVar :: !(TxChannelsVar m ntnAddr SigId (Sig crypto)) , sigMempoolSem :: !(TxMempoolSem m) , sigSharedTxStateVar :: !(SharedTxStateVar m ntnAddr SigId (Sig crypto)) @@ -113,22 +117,36 @@ mempoolWorker :: forall crypto m. , MonadSTM m , MonadTime m ) - => Mempool m (Sig crypto) + => Mempool m SigId (Sig crypto) -> m Void mempoolWorker (Mempool v) = loop where loop = do now <- getCurrentPOSIXTime rt <- atomically $ do - (sigs :: Seq.Seq (Sig crypto)) <- readTVar v - let sigs' :: Seq.Seq (Sig crypto) - (resumeTime, sigs') = - foldr (\a (rt, as) -> if sigExpiresAt a <= now - then (rt, as) - else (rt `min` sigExpiresAt a, a Seq.<| as)) - (now, Seq.empty) - sigs - writeTVar v sigs' + MempoolSeq { mempoolSeq, mempoolSet } <- readTVar v + let mempoolSeq' :: Seq (Sig crypto) + mempoolSet', expiredSet' :: Set SigId + + (resumeTime, expiredSet', mempoolSeq') = + foldr (\sig (rt, expiredSet, sigs) -> + if sigExpiresAt sig <= now + then ( rt + , sigId sig `Set.insert` expiredSet + , sigs + ) + else ( rt `min` sigExpiresAt sig + , expiredSet + , sig Seq.<| sigs + ) + ) + (now, Set.empty, Seq.empty) + mempoolSeq + + mempoolSet' = mempoolSet `Set.difference` expiredSet' + + writeTVar v MempoolSeq { mempoolSet = mempoolSet', + mempoolSeq = mempoolSeq' } return resumeTime now' <- getCurrentPOSIXTime diff --git a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs index 94a4bece42..75e49ace3e 100644 --- a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs +++ b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs @@ -8,6 +8,7 @@ -- module Ouroboros.Network.TxSubmission.Mempool.Simple ( Mempool (..) + , MempoolSeq (..) , empty , new , read @@ -30,6 +31,7 @@ import Data.List (find, nubBy) import Data.Maybe (isJust) import Data.Sequence (Seq) import Data.Sequence qualified as Seq +import Data.Set (Set) import Data.Set qualified as Set import Data.Typeable (Typeable) @@ -38,25 +40,38 @@ import Ouroboros.Network.TxSubmission.Inbound.V2.Types import Ouroboros.Network.TxSubmission.Mempool.Reader +data MempoolSeq txid tx = MempoolSeq { + mempoolSet :: !(Set txid), + -- ^ cached set of `txid`s in the mempool + mempoolSeq :: !(Seq tx) + -- ^ sequence of all `tx`s + } + -- | A simple in-memory mempool implementation. -- -newtype Mempool m tx = Mempool (StrictTVar m (Seq tx)) +newtype Mempool m txid tx = Mempool (StrictTVar m (MempoolSeq txid tx)) -empty :: MonadSTM m => m (Mempool m tx) -empty = Mempool <$> newTVarIO Seq.empty +empty :: MonadSTM m => m (Mempool m txid tx) +empty = Mempool <$> newTVarIO (MempoolSeq Set.empty Seq.empty) -new :: MonadSTM m - => [tx] - -> m (Mempool m tx) -new = fmap Mempool - . newTVarIO - . Seq.fromList +new :: ( MonadSTM m + , Ord txid + ) + => (tx -> txid) + -> [tx] + -> m (Mempool m txid tx) +new getTxId txs = + fmap Mempool + . newTVarIO + $ MempoolSeq { mempoolSet = Set.fromList (getTxId <$> txs), + mempoolSeq = Seq.fromList txs + } -read :: MonadSTM m => Mempool m tx -> m [tx] -read (Mempool mempool) = toList <$> readTVarIO mempool +read :: MonadSTM m => Mempool m txid tx -> m [tx] +read (Mempool mempool) = toList . mempoolSeq <$> readTVarIO mempool getReader :: forall tx txid m. @@ -65,7 +80,7 @@ getReader :: forall tx txid m. ) => (tx -> txid) -> (tx -> SizeInBytes) - -> Mempool m tx + -> Mempool m txid tx -> TxSubmissionMempoolReader txid tx Int m getReader getTxId getTxSize (Mempool mempool) = -- Using `0`-based index. `mempoolZeroIdx = -1` so that @@ -75,7 +90,7 @@ getReader getTxId getTxSize (Mempool mempool) = } where mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Int) - mempoolGetSnapshot = getSnapshot <$> readTVar mempool + mempoolGetSnapshot = getSnapshot . mempoolSeq <$> readTVar mempool getSnapshot :: Seq tx -> MempoolSnapshot txid tx Int @@ -124,7 +139,7 @@ getWriter :: forall tx txid ctx failure m. -- ^ validate a tx, any failing `tx` throws an exception. -> (failure -> Bool) -- ^ return `True` when a failure should throw an exception - -> Mempool m tx + -> Mempool m txid tx -> TxSubmissionMempoolWriter txid tx Int m getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) = TxSubmissionMempoolWriter { @@ -133,11 +148,8 @@ getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) mempoolAddTxs = \txs -> do ctx <- getValidationCtx (invalidTxIds, validTxs) <- atomically $ do - mempoolTxs <- readTVar mempool - let -- TODO: set of current ids should be constructed incrementally, - -- e.g. it should be part of mempoolTxs - currentIds = Set.fromList (map getTxId (toList mempoolTxs)) - (invalidTxIds, validTxs) = + MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool + let (invalidTxIds, validTxs) = bimap (filter (failureFilterFn . snd)) (nubBy (on (==) getTxId)) . partitionEithers @@ -145,9 +157,14 @@ getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) Left e -> Left (getTxId tx, e) Right _ -> Right tx ) - . filter (\tx -> getTxId tx `Set.notMember` currentIds) + . filter (\tx -> getTxId tx `Set.notMember` mempoolSet) $ txs - mempoolTxs' = Foldable.foldl' (Seq.|>) mempoolTxs validTxs + mempoolTxs' = MempoolSeq { + mempoolSet = Foldable.foldl' (\s tx -> getTxId tx `Set.insert` s) + mempoolSet + validTxs, + mempoolSeq = Foldable.foldl' (Seq.|>) mempoolSeq validTxs + } writeTVar mempool mempoolTxs' return (invalidTxIds, map getTxId validTxs) when (not (null invalidTxIds)) $ diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs index c841d89882..5cd48c8f82 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs @@ -311,7 +311,7 @@ data NodeKernel header block s txid m = NodeKernel { :: StrictTVar m (PublicPeerSelectionState NtNAddr), nkMempool - :: Mempool m (Tx txid), + :: Mempool m txid (Tx txid), nkTxChannelsVar :: TxChannelsVar m NtNAddr txid (Tx txid), @@ -326,6 +326,7 @@ data NodeKernel header block s txid m = NodeKernel { newNodeKernel :: ( MonadSTM m , Strict.MonadMVar m , RandomGen rng + , Ord txid , Eq txid ) => rng @@ -427,6 +428,7 @@ withNodeKernelThread , HasFullHeader block , RandomGen seed , Eq txid + , Ord txid ) => NtNAddr -- ^ just for naming a thread diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs index c113a81ad8..335b117aa5 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/Diffusion/Node/MiniProtocols.hs @@ -681,7 +681,7 @@ applications debugTracer txSubmissionInboundTracer txSubmissionInboundDebug node txSubmissionInitiator :: TxDecisionPolicy - -> Mempool m (Tx TxId) + -> Mempool m TxId (Tx TxId) -> MiniProtocolCb (ExpandedInitiatorContext NtNAddr m) ByteString m () txSubmissionInitiator txDecisionPolicy mempool = MiniProtocolCb $ @@ -708,7 +708,7 @@ applications debugTracer txSubmissionInboundTracer txSubmissionInboundDebug node (txSubmissionClientPeer client) txSubmissionResponder - :: Mempool m (Tx TxId) + :: Mempool m TxId (Tx TxId) -> TxChannelsVar m NtNAddr Int (Tx Int) -> TxMempoolSem m -> SharedTxStateVar m NtNAddr Int (Tx Int) diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/AppV1.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/AppV1.hs index 0ccfd07ce1..9c44519720 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/AppV1.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/AppV1.hs @@ -118,7 +118,7 @@ txSubmissionSimulation tracer maxUnacked outboundTxs return (inmp, outmp) where - outboundPeer :: Mempool m (Tx txid) -> TxSubmissionClient txid (Tx txid) m () + outboundPeer :: Mempool m txid (Tx txid) -> TxSubmissionClient txid (Tx txid) m () outboundPeer outboundMempool = txSubmissionOutbound nullTracer @@ -127,7 +127,7 @@ txSubmissionSimulation tracer maxUnacked outboundTxs (maxBound :: NodeToNodeVersion) controlMessageSTM - inboundPeer :: Mempool m (Tx txid) -> TxSubmissionServerPipelined txid (Tx txid) m () + inboundPeer :: Mempool m txid (Tx txid) -> TxSubmissionServerPipelined txid (Tx txid) m () inboundPeer inboundMempool = txSubmissionInbound nullTracer diff --git a/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/Types.hs b/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/Types.hs index 332bc12316..6f02a0b3d3 100644 --- a/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/Types.hs +++ b/ouroboros-network/testlib/Test/Ouroboros/Network/TxSubmission/Types.hs @@ -101,13 +101,14 @@ maxTxSize = 65536 type TxId = Int -emptyMempool :: MonadSTM m => m (Mempool m (Tx txid)) +emptyMempool :: MonadSTM m => m (Mempool m txid (Tx txid)) emptyMempool = Mempool.empty -newMempool :: MonadSTM m => [Tx txid] -> m (Mempool m (Tx txid)) -newMempool = Mempool.new +newMempool :: (MonadSTM m, Ord txid) + => [Tx txid] -> m (Mempool m txid (Tx txid)) +newMempool = Mempool.new getTxId -readMempool :: MonadSTM m => Mempool m (Tx txid) -> m [Tx txid] +readMempool :: MonadSTM m => Mempool m txid (Tx txid) -> m [Tx txid] readMempool = Mempool.read getMempoolReader :: forall txid m. @@ -115,7 +116,7 @@ getMempoolReader :: forall txid m. , Eq txid , Show txid ) - => Mempool m (Tx txid) + => Mempool m txid (Tx txid) -> TxSubmissionMempoolReader txid (Tx txid) Int m getMempoolReader = Mempool.getReader getTxId getTxAdvSize @@ -128,7 +129,7 @@ getMempoolWriter :: forall txid m. , Typeable txid , Show txid ) - => Mempool m (Tx txid) + => Mempool m txid (Tx txid) -> TxSubmissionMempoolWriter txid (Tx txid) Int m getMempoolWriter = Mempool.getWriter getTxId (pure ()) From 21e525a9813e74b63e3a22e1961fa053c01080f7 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Fri, 5 Sep 2025 15:58:14 +0200 Subject: [PATCH 08/16] dmq: Sig validation --- dmq-node/src/DMQ/NodeToNode.hs | 13 +- .../src/DMQ/Protocol/SigSubmission/Codec.hs | 12 +- .../src/DMQ/Protocol/SigSubmission/Type.hs | 137 ++++- .../test/Test/DMQ/Protocol/SigSubmission.hs | 517 ++++++++++++------ 4 files changed, 494 insertions(+), 185 deletions(-) diff --git a/dmq-node/src/DMQ/NodeToNode.hs b/dmq-node/src/DMQ/NodeToNode.hs index c2dc8e4a45..9e54de6205 100644 --- a/dmq-node/src/DMQ/NodeToNode.hs +++ b/dmq-node/src/DMQ/NodeToNode.hs @@ -4,6 +4,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} module DMQ.NodeToNode ( RemoteAddress @@ -40,6 +41,7 @@ import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR import Codec.CBOR.Term qualified as CBOR import Data.Aeson qualified as Aeson +import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as BL import Data.Functor.Contravariant ((>$<)) import Data.Hashable (Hashable) @@ -52,7 +54,10 @@ import Network.Mux.Types (Mode (..)) import Network.Mux.Types qualified as Mx import Network.TypedProtocol.Codec (AnnotatedCodec, Codec) +import Cardano.Crypto.DSIGN.Class qualified as DSIGN +import Cardano.Crypto.KES.Class qualified as KES import Cardano.KESAgent.KES.Crypto (Crypto (..)) +import Cardano.KESAgent.KES.OCert (OCertSignable) import DMQ.Configuration (Configuration, Configuration' (..), I (..)) import DMQ.Diffusion.NodeKernel (NodeKernel (..)) @@ -147,6 +152,10 @@ data Apps addr m a b = ntnApps :: forall crypto m addr . ( Crypto crypto + , DSIGN.ContextDSIGN (DSIGN crypto) ~ () + , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) + , KES.ContextKES (KES crypto) ~ () + , KES.Signable (KES crypto) BS.ByteString , Typeable crypto , Alternative (STM m) , MonadAsync m @@ -220,8 +229,8 @@ ntnApps -- connection if we receive one, rather than validate them in the -- mempool. mempoolWriter = Mempool.getWriter sigId - (pure ()) - (\_ _ -> Right () :: Either Void ()) + (pure ()) -- TODO not needed + (\_ -> validateSig) (\_ -> True) mempool diff --git a/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs b/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs index 25bef7387e..938d818579 100644 --- a/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs +++ b/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -33,8 +32,9 @@ import Codec.CBOR.Read qualified as CBOR import Network.TypedProtocol.Codec.CBOR import Cardano.Binary (FromCBOR (..), ToCBOR (..)) -import Cardano.Crypto.DSIGN.Class (decodeSignedDSIGN, encodeSignedDSIGN) -import Cardano.Crypto.KES.Class (decodeVerKeyKES, encodeVerKeyKES) +import Cardano.Crypto.DSIGN.Class (decodeSignedDSIGN, decodeVerKeyDSIGN, + encodeSignedDSIGN) +import Cardano.Crypto.KES.Class (decodeSigKES, decodeVerKeyKES, encodeVerKeyKES) import Cardano.KESAgent.KES.Crypto (Crypto (..)) import Cardano.KESAgent.KES.OCert (OCert (..)) @@ -154,14 +154,14 @@ decodeSig = do when (a /= 7) $ fail (printf "codecSigSubmission: unexpected number of parameters %d" a) sigRawId <- decodeSigId sigRawBody <- SigBody <$> CBOR.decodeBytes - sigRawKESPeriod <- CBOR.decodeWord + sigRawKESPeriod <- KESPeriod <$> CBOR.decodeWord sigRawExpiresAt <- realToFrac <$> CBOR.decodeWord32 -- end of signed data endOffset <- CBOR.peekByteOffset - sigRawKESSignature <- SigKESSignature <$> CBOR.decodeBytes + sigRawKESSignature <- SigKESSignature <$> decodeSigKES sigRawOpCertificate <- decodeSigOpCertificate - sigRawColdKey <- SigColdKey <$> CBOR.decodeBytes + sigRawColdKey <- SigColdKey <$> decodeVerKeyDSIGN return $ \bytes -- ^ full bytes of the message, not just the sig part -> SigRawWithSignedBytes { sigRawSignedBytes = Utils.bytesBetweenOffsets startOffset endOffset bytes, diff --git a/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs b/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs index 96e5647e9e..ea64299069 100644 --- a/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs +++ b/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs @@ -3,8 +3,9 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module DMQ.Protocol.SigSubmission.Type @@ -13,26 +14,32 @@ module DMQ.Protocol.SigSubmission.Type , SigId (..) , SigBody (..) , SigKESSignature (..) - , SigKESPeriod , SigOpCertificate (..) , SigColdKey (..) , SigRaw (..) , SigRawWithSignedBytes (..) , Sig (Sig, SigWithBytes, sigRawWithSignedBytes, sigRawBytes, sigId, sigBody, sigExpiresAt, sigOpCertificate, sigKESPeriod, sigKESSignature, sigColdKey, sigSignedBytes, sigBytes) + , validateSig -- * `TxSubmission` mini-protocol , SigSubmission , module SigSubmission + -- * Re-exports from `kes-agent` + , KESPeriod (..) ) where +import Data.Bifunctor (first) import Data.ByteString (ByteString) import Data.ByteString.Lazy qualified as LBS import Data.Time.Clock.POSIX (POSIXTime) import Data.Typeable +import Data.Word (Word64) -import Cardano.Crypto.DSIGN.Class (DSIGNAlgorithm) -import Cardano.Crypto.KES.Class (VerKeyKES) +import Cardano.Crypto.DSIGN.Class (ContextDSIGN, DSIGNAlgorithm, VerKeyDSIGN) +import Cardano.Crypto.DSIGN.Class qualified as DSIGN +import Cardano.Crypto.KES.Class (KESAlgorithm (..), Signable) import Cardano.KESAgent.KES.Crypto as KES -import Cardano.KESAgent.KES.OCert (OCert) +import Cardano.KESAgent.KES.OCert (KESPeriod (..), OCert (..), OCertSignable, + validateOCert) import Ouroboros.Network.Protocol.TxSubmission2.Type as SigSubmission hiding (TxSubmission2) @@ -52,13 +59,13 @@ newtype SigBody = SigBody { getSigBody :: ByteString } deriving stock (Show, Eq) --- TODO: --- This type should be something like: `SignedKES (KES crypto) SigPayload` -newtype SigKESSignature = SigKESSignature { getSigKESSignature :: ByteString } - deriving stock (Show, Eq) +newtype SigKESSignature crypto = SigKESSignature { getSigKESSignature :: SigKES (KES crypto) } + +deriving instance Show (SigKES (KES crypto)) + => Show (SigKESSignature crypto) +deriving instance Eq (SigKES (KES crypto)) + => Eq (SigKESSignature crypto) --- TODO: --- This type should be more than just a `ByteString`. newtype SigOpCertificate crypto = SigOpCertificate { getSigOpCertificate :: OCert crypto } deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) @@ -67,13 +74,16 @@ deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) => Show (SigOpCertificate crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) - ) => Eq (SigOpCertificate crypto) + ) => Eq (SigOpCertificate crypto) -type SigKESPeriod = Word +newtype SigColdKey crypto = SigColdKey { getSigColdKey :: VerKeyDSIGN (KES.DSIGN crypto) } -newtype SigColdKey = SigColdKey { getSigColdKey :: ByteString } - deriving stock (Show, Eq) +deriving instance Show (VerKeyDSIGN (KES.DSIGN crypto)) + => Show (SigColdKey crypto) + +deriving instance Eq (VerKeyDSIGN (KES.DSIGN crypto)) + => Eq (SigColdKey crypto) -- | Sig type consists of payload and its KES signature. -- @@ -81,23 +91,28 @@ newtype SigColdKey = SigColdKey { getSigColdKey :: ByteString } data SigRaw crypto = SigRaw { sigRawId :: SigId, sigRawBody :: SigBody, - sigRawKESPeriod :: SigKESPeriod, + sigRawKESPeriod :: KESPeriod, -- ^ KES period when this signature was created. -- -- NOTE: `kes-agent` library is using `Word` for KES period, CIP-137 -- requires `Word64`, thus we're only supporting 64-bit architectures. - sigRawExpiresAt :: POSIXTime, - sigRawKESSignature :: SigKESSignature, sigRawOpCertificate :: SigOpCertificate crypto, - sigRawColdKey :: SigColdKey + sigRawColdKey :: SigColdKey crypto, + sigRawExpiresAt :: POSIXTime, + sigRawKESSignature :: SigKESSignature crypto + -- ^ KES signature of all previous fields. + -- + -- NOTE: this field must be lazy, otetherwise tests will fail. } deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Show (VerKeyKES (KES crypto)) + , Show (SigKES (KES crypto)) ) => Show (SigRaw crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) + , Eq (SigKES (KES crypto)) ) => Eq (SigRaw crypto) @@ -110,14 +125,15 @@ data SigRawWithSignedBytes crypto = SigRawWithSignedBytes { deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Show (VerKeyKES (KES crypto)) + , Show (SigKES (KES crypto)) ) => Show (SigRawWithSignedBytes crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) + , Eq (SigKES (KES crypto)) ) => Eq (SigRawWithSignedBytes crypto) - data Sig crypto = SigWithBytes { sigRawBytes :: LBS.ByteString, -- ^ encoded `SigRaw` data type @@ -127,10 +143,12 @@ data Sig crypto = SigWithBytes { deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Show (VerKeyKES (KES crypto)) + , Show (SigKES (KES crypto)) ) => Show (Sig crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) + , Eq (SigKES (KES crypto)) ) => Eq (Sig crypto) @@ -140,10 +158,10 @@ deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) pattern Sig :: SigId -> SigBody - -> SigKESSignature - -> SigKESPeriod + -> SigKESSignature crypto + -> KESPeriod -> SigOpCertificate crypto - -> SigColdKey + -> SigColdKey crypto -> POSIXTime -> LBS.ByteString -> LBS.ByteString @@ -206,4 +224,77 @@ pattern instance Typeable crypto => ShowProxy (Sig crypto) where + +data SigValidationError = + InvalidKESSignature KESPeriod KESPeriod String + | InvalidSignatureOCERT + !Word64 -- OCert counter + !KESPeriod -- OCert KES period + !String -- DSIGN error message + | KESBeforeStartOCERT KESPeriod KESPeriod + | KESAfterEndOCERT KESPeriod KESPeriod + deriving Show + +-- TODO: +-- We don't validate ocert numbers, since we might not have necessary +-- information to do so, but we can validate that they are growing. +validateSig :: forall crypto. + ( Crypto crypto + , ContextDSIGN (KES.DSIGN crypto) ~ () + , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) + , ContextKES (KES crypto) ~ () + , Signable (KES crypto) ByteString + ) + => Sig crypto + -> Either SigValidationError () +validateSig Sig { sigSignedBytes = signedBytes, + sigKESPeriod, + sigOpCertificate = SigOpCertificate ocert@OCert { + ocertKESPeriod, + ocertVkHot, + ocertN + }, + sigColdKey = SigColdKey coldKey, + sigKESSignature = SigKESSignature kesSig + } + = do + sigKESPeriod < endKESPeriod + ?! KESAfterEndOCERT endKESPeriod sigKESPeriod + sigKESPeriod >= startKESPeriod + ?! KESBeforeStartOCERT startKESPeriod sigKESPeriod + + -- validate OCert, which includes verifying its signature + validateOCert coldKey ocertVkHot ocert + ?!: InvalidSignatureOCERT ocertN sigKESPeriod + -- validate KES signature of the payload + verifyKES () ocertVkHot + (unKESPeriod sigKESPeriod - unKESPeriod startKESPeriod) + (LBS.toStrict signedBytes) + kesSig + ?!: InvalidKESSignature ocertKESPeriod sigKESPeriod + where + startKESPeriod, endKESPeriod :: KESPeriod + + startKESPeriod = ocertKESPeriod + -- TODO: is `totalPeriodsKES` the same as `praosMaxKESEvo` + -- or `sgMaxKESEvolution` in the genesis file? + endKESPeriod = KESPeriod $ unKESPeriod startKESPeriod + + totalPeriodsKES (Proxy :: Proxy (KES crypto)) + type SigSubmission crypto = TxSubmission2.TxSubmission2 SigId (Sig crypto) + + +-- +-- Utility functions +-- + +(?!:) :: Either e1 a -> (e1 -> e2) -> Either e2 a +(?!:) = flip first + +infix 1 ?!: + +(?!) :: Bool -> e -> Either e () +(?!) True _ = Right () +(?!) False e = Left e + +infix 1 ?! diff --git a/dmq-node/test/Test/DMQ/Protocol/SigSubmission.hs b/dmq-node/test/Test/DMQ/Protocol/SigSubmission.hs index 2d4c67f1fa..1cafe647b9 100644 --- a/dmq-node/test/Test/DMQ/Protocol/SigSubmission.hs +++ b/dmq-node/test/Test/DMQ/Protocol/SigSubmission.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -9,6 +10,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -20,11 +22,14 @@ module Test.DMQ.Protocol.SigSubmission where import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR import Codec.CBOR.Write qualified as CBOR +import Control.Monad (zipWithM, (>=>)) import Control.Monad.ST (runST) import Data.Bifunctor (second) +import Data.ByteString (ByteString) import Data.ByteString.Lazy qualified as BL import Data.List.NonEmpty qualified as NonEmpty import Data.Time.Clock.POSIX (POSIXTime) +import Data.Typeable import Data.Word (Word32) import GHC.TypeNats (KnownNat) import System.IO.Unsafe (unsafePerformIO) @@ -32,12 +37,15 @@ import System.IO.Unsafe (unsafePerformIO) import Network.TypedProtocol.Codec import Network.TypedProtocol.Codec.Properties hiding (prop_codec) +import Cardano.Crypto.DSIGN.Class (DSIGNAlgorithm, SignKeyDSIGN, + deriveVerKeyDSIGN, encodeVerKeyDSIGN) import Cardano.Crypto.DSIGN.Class qualified as DSIGN -import Cardano.Crypto.KES.Class (KESAlgorithm (..), VerKeyKES) +import Cardano.Crypto.KES.Class (KESAlgorithm (..), VerKeyKES, encodeSigKES) import Cardano.Crypto.KES.Class qualified as KES import Cardano.Crypto.PinnedSizedBytes (PinnedSizedBytes, psbToByteString) import Cardano.Crypto.Seed (mkSeedFromBytes) import Cardano.KESAgent.KES.Crypto (Crypto (..)) +import Cardano.KESAgent.KES.Evolution qualified as KES import Cardano.KESAgent.KES.OCert (OCert (..)) import Cardano.KESAgent.KES.OCert qualified as KES import Cardano.KESAgent.Protocols.StandardCrypto (MockCrypto, StandardCrypto) @@ -60,36 +68,49 @@ tests :: TestTree tests = testGroup "DMQ.Protocol" [ testGroup "SigSubmission" - [ testGroup "mockcrypto" - [ testProperty "OCert" prop_codec_ocert_mockcrypto - , testProperty "Sig" prop_codec_sig_mockcrypto - , testProperty "codec" prop_codec_mockcrypto - , testProperty "codec id" prop_codec_id_mockcrypto - , testProperty "codec 2-splits" $ withMaxSize 20 - $ withMaxSuccess 20 - prop_codec_splits2_mockcrypto - , testProperty "codec 3-splits" $ withMaxSize 10 - $ withMaxSuccess 10 - prop_codec_splits3_mockcrypto - , testProperty "codec cbor" prop_codec_cbor_mockcrypto - , testProperty "codec valid cbor" prop_codec_valid_cbor_mockcrypto + [ testGroup "Codec" + [ testGroup "MockCrypto" + [ testProperty "OCert" prop_codec_ocert_mockcrypto + , testProperty "Sig" prop_codec_sig_mockcrypto + , testProperty "codec" prop_codec_mockcrypto + , testProperty "codec id" prop_codec_id_mockcrypto + , testProperty "codec 2-splits" $ withMaxSize 20 + $ withMaxSuccess 20 + prop_codec_splits2_mockcrypto + , testProperty "codec 3-splits" $ withMaxSize 10 + $ withMaxSuccess 10 + prop_codec_splits3_mockcrypto + , testProperty "codec cbor" prop_codec_cbor_mockcrypto + , testProperty "codec valid cbor" prop_codec_valid_cbor_mockcrypto + , testProperty "OCert" prop_codec_cbor_mockcrypto + ] + , testGroup "StandardCrypto" + [ testProperty "OCert" prop_codec_ocert_standardcrypto + , testProperty "Sig" prop_codec_sig_standardcrypto + , testProperty "codec" prop_codec_standardcrypto + , testProperty "codec id" prop_codec_id_standardcrypto + , testProperty "codec 2-splits" $ withMaxSize 20 + $ withMaxSuccess 20 + prop_codec_splits2_standardcrypto + -- StandardCrypt produces too large messages for this test to run: + {- + , testProperty "codec 3-splits" $ withMaxSize 10 + $ withMaxSuccess 10 + prop_codec_splits3_standardcrypto + -} + , testProperty "codec cbor" prop_codec_cbor_standardcrypto + , testProperty "codec valid cbor" prop_codec_valid_cbor_standardcrypto + ] + ] + ] + , testGroup "Crypto" + [ testGroup "MockCrypto" + [ testProperty "KES sign verify" prop_sign_verify_mockcrypto + , testProperty "validateSig" prop_validateSig_mockcrypto ] - , testGroup "standardcrypto" - [ testProperty "OCert" prop_codec_ocert_standardcrypto - , testProperty "Sig" prop_codec_sig_standardcrypto - , testProperty "codec" prop_codec_standardcrypto - , testProperty "codec id" prop_codec_id_standardcrypto - , testProperty "codec 2-splits" $ withMaxSize 20 - $ withMaxSuccess 20 - prop_codec_splits2_standardcrypto - -- StandardCrypt produces too large messages for this test to run: - {- - , testProperty "codec 3-splits" $ withMaxSize 10 - $ withMaxSuccess 10 - prop_codec_splits3_standardcrypto - -} - , testProperty "codec cbor" prop_codec_cbor_standardcrypto - , testProperty "codec valid cbor" prop_codec_valid_cbor_standardcrypto + , testGroup "StandardCrypto" + [ testProperty "KES sign verify" prop_sign_verify_standardcrypto + , testProperty "validateSig" prop_validateSig_standardcrypto ] ] ] @@ -112,31 +133,65 @@ instance Arbitrary POSIXTime where -- shrink via Word32 (e.g. in seconds) shrink posix = realToFrac <$> shrink (floor @_ @Word32 posix) -instance Arbitrary SigKESSignature where - arbitrary = SigKESSignature <$> arbitrary - shrink = map SigKESSignature . shrink . getSigKESSignature -mkVerKeyKES +-- | Make a KES key pair. +-- +mkKeysKES :: forall kesCrypto. KESAlgorithm kesCrypto => PinnedSizedBytes (SeedSizeKES kesCrypto) - -> IO (VerKeyKES kesCrypto) -mkVerKeyKES seed = do - withMLockedSeedFromPSB seed $ \mseed -> - KES.genKeyKES mseed >>= deriveVerKeyKES + -> IO (SignKeyKES kesCrypto, VerKeyKES kesCrypto) +mkKeysKES seed = + withMLockedSeedFromPSB seed $ \mseed -> do + snKESKey <- KES.genKeyKES mseed + (snKESKey,) <$> deriveVerKeyKES snKESKey +-- | The idea of this data type is to go around limitation of QuickCheck `Gen` +-- type, which does not allow IO actions. So instead we generate some random +-- context (e.g. key seed) and then the data is created when the property +-- runs. +-- +-- Keeping the `key` seprate allows to have access to it when shrinking, see +-- `shrinkWithConstr`, this is important when the signed data is shrinked and +-- we need to update a KES signature as well. +-- +-- However the limitation is shrinking: it requires `unsafePerformIO` anyway, +-- see `shrinkWithConstr`. +-- +-- TODO: to avoid complexity can we use `UnsoundPureKESAlgorithm` instead of +-- `KESAlgorithm`? +-- data WithConstr ctx key a = - WithConstr { constr :: key -> a, + WithConstr { constr :: key -> IO a, mkKey :: ctx -> IO key, ctx :: ctx } deriving instance Functor (WithConstr ctx key) +withConstrBind :: WithConstr ctx key a -> (a -> IO b) -> WithConstr ctx key b +withConstrBind WithConstr { constr, mkKey, ctx } fn = + WithConstr { constr = constr >=> fn, + mkKey, + ctx + } + +runWithConstr :: WithConstr ctx key a -> IO a +runWithConstr WithConstr { constr, mkKey, ctx } = mkKey ctx >>= constr + +constrWithKeys + :: (keys -> IO a) + -> WithConstr ctx keys keys + -> WithConstr ctx keys a +constrWithKeys f WithConstr { constr, mkKey, ctx } = + WithConstr { constr = constr >=> f, + mkKey, + ctx + } constWithConstr :: a -> WithConstr [ctx] [key] a constWithConstr a = - WithConstr { constr = const a, + WithConstr { constr = const (pure a), mkKey = \_ -> pure [], ctx = [] } @@ -147,16 +202,16 @@ listWithConstr :: forall ctx key a b. -> WithConstr [ctx] [key] b listWithConstr constr' as = WithConstr { - constr = \keys -> constr' (zipWith ($) constrs keys), - mkKey = \ctxs -> sequence (zipWith ($) mkKeys ctxs), + constr = \keys -> constr' <$> zipWithM ($) constrs keys, + mkKey = \ctxs -> zipWithM ($) mkKeys ctxs, ctx = ctx <$> as } where - constrs :: [(key -> a)] + constrs :: [key -> IO a] constrs = constr <$> as - mkKeys :: [(ctx -> IO key)] + mkKeys :: [ctx -> IO key] mkKeys = mkKey <$> as @@ -169,7 +224,7 @@ shrinkWithConstrCtx constr@WithConstr { ctx } = sequenceWithConstr - :: (a -> key -> a) + :: (a -> key -> IO a) -> WithConstr ctx key [a] -> IO [WithConstr ctx key a] sequenceWithConstr update constr@WithConstr { mkKey, ctx } = do @@ -181,33 +236,37 @@ sequenceWithConstr update constr@WithConstr { mkKey, ctx } = do -- unsafePerformIO :( shrinkWithConstr :: Arbitrary ctx - => (a -> key -> a) + => (a -> key -> IO a) -> (a -> [a]) - -> WithConstr ctx key a + -> WithConstr ctx key a -> [WithConstr ctx key a] shrinkWithConstr update shrinker constr = unsafePerformIO (sequenceWithConstr update $ shrinker <$> constr) ++ shrinkWithConstrCtx constr +shrinkWithConstr' :: Arbitrary ctx + => (a -> key -> a) + -> (a -> [a]) + -> WithConstr ctx key a + -> [WithConstr ctx key a] +shrinkWithConstr' update = shrinkWithConstr (\a k -> pure (update a k)) -runWithConstr :: WithConstr ctx key a -> IO a -runWithConstr WithConstr { constr, mkKey, ctx } = constr <$> mkKey ctx +type KESCTX size = PinnedSizedBytes size +type WithConstrKES size crypto a = WithConstr (KESCTX size) (SignKeyKES crypto, VerKeyKES crypto) a +type WithConstrKESList size crypto a = WithConstr [KESCTX size] [(SignKeyKES crypto, VerKeyKES crypto)] a -type VerKeyKESCTX size = PinnedSizedBytes size -type WithConstrVerKeyKES size crypto a = WithConstr (VerKeyKESCTX size) (VerKeyKES crypto) a -type WithConstrVerKeyKESList size crypto a = WithConstr [VerKeyKESCTX size] [VerKeyKES crypto] a -mkVerKeyKESConstr +mkKeysKESConstr :: forall kesCrypto. KESAlgorithm kesCrypto - => VerKeyKESCTX (SeedSizeKES kesCrypto) - -> WithConstrVerKeyKES (SeedSizeKES kesCrypto) - kesCrypto - (VerKeyKES kesCrypto) -mkVerKeyKESConstr ctx = - WithConstr { constr = id, - mkKey = mkVerKeyKES, + => KESCTX (SeedSizeKES kesCrypto) + -> WithConstrKES (SeedSizeKES kesCrypto) + kesCrypto + (SignKeyKES kesCrypto, VerKeyKES kesCrypto) +mkKeysKESConstr ctx = + WithConstr { constr = pure, + mkKey = mkKeysKES, ctx } @@ -215,71 +274,115 @@ instance ( size ~ SeedSizeKES kesCrypto , KnownNat size , KESAlgorithm kesCrypto ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (VerKeyKES kesCrypto)) where - arbitrary = mkVerKeyKESConstr <$> arbitrary + => Arbitrary (WithConstrKES size kesCrypto (SignKeyKES kesCrypto, VerKeyKES kesCrypto)) where + arbitrary = mkKeysKESConstr <$> arbitrary shrink = shrinkWithConstrCtx +-- | An auxiliary data type to hold KES keys along with an OCert, payload and +-- its KES signature. +data CryptoCtx crypto = CryptoCtx { + snKESKey :: SignKeyKES (KES crypto), + -- ^ signing KES key + vnKESKey :: VerKeyKES (KES crypto), + -- ^ verification KES key + coldKey :: SignKeyDSIGN (DSIGN crypto), + -- ^ signing cold key + ocert :: OCert crypto + -- ^ ocert + } + + instance ( Crypto crypto , DSIGN.Signable (DSIGN crypto) (KES.OCertSignable crypto) , DSIGN.ContextDSIGN (DSIGN crypto) ~ () + , ContextKES (KES crypto) ~ () , kesCrypto ~ KES crypto + , KESAlgorithm kesCrypto + , Signable kesCrypto ByteString , size ~ SeedSizeKES kesCrypto , KnownNat size ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (OCert crypto)) where + => Arbitrary (WithConstrKES size kesCrypto (CryptoCtx crypto)) where arbitrary = do - verKeyKES <- arbitrary + withKeys <- arbitrary n <- arbitrary seedColdKey :: PinnedSizedBytes (DSIGN.SeedSizeDSIGN (DSIGN crypto)) <- arbitrary - let !skCold = DSIGN.genKeyDSIGN (mkSeedFromBytes . psbToByteString $ seedColdKey) + let !coldKey = DSIGN.genKeyDSIGN (mkSeedFromBytes . psbToByteString $ seedColdKey) period <- KES.KESPeriod <$> arbitrary - return $ fmap (\vkKES -> KES.makeOCert vkKES n period skCold) verKeyKES - shrink = shrinkWithConstrCtx - - -instance ( kesCrypto ~ KES crypto - , size ~ SeedSizeKES kesCrypto - , KnownNat size - , Arbitrary (WithConstrVerKeyKES size kesCrypto (OCert crypto)) - ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (SigOpCertificate crypto)) where - arbitrary = fmap SigOpCertificate <$> arbitrary + return $ constrWithKeys + (\(snKESKey, vnKESKey) -> + return $ CryptoCtx { + snKESKey, + vnKESKey, + coldKey, + ocert = KES.makeOCert vnKESKey n period coldKey + }) + withKeys shrink = shrinkWithConstrCtx instance ( Crypto crypto , kesCrypto ~ KES crypto + , ContextKES kesCrypto ~ () , size ~ SeedSizeKES kesCrypto - , Arbitrary (WithConstrVerKeyKES size kesCrypto (OCert crypto)) + , Signable kesCrypto ByteString + , dsignCrypto ~ DSIGN crypto + , DSIGNAlgorithm dsignCrypto + , Arbitrary (WithConstrKES size kesCrypto (CryptoCtx crypto)) ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (SigRawWithSignedBytes crypto)) where + => Arbitrary (WithConstrKES size kesCrypto (SigRawWithSignedBytes crypto)) where arbitrary = do sigRawId <- arbitrary - sigRawBody <- arbitrary sigRawExpiresAt <- arbitrary - opCert <- arbitrary - sigRawKESPeriod <- arbitrary - sigRawKESSignature <- arbitrary - sigRawColdKey <- arbitrary - return $ fmap (\cert -> let sigRawOpCertificate = SigOpCertificate cert - sigRaw = SigRaw { - sigRawId, - sigRawBody, - sigRawKESPeriod, - sigRawOpCertificate, - sigRawColdKey, - sigRawExpiresAt, - sigRawKESSignature = undefined -- to be filled below - } - signedBytes = CBOR.toStrictByteString (encodeSigRaw' sigRaw) - in - SigRawWithSignedBytes { - sigRawSignedBytes = BL.fromStrict signedBytes, - sigRaw = sigRaw { sigRawKESSignature } - } - ) opCert + let maxKESOffset :: Word + maxKESOffset = totalPeriodsKES (Proxy :: Proxy kesCrypto) + -- offset since `ocertKESPeriod`, so that the signature is still valid + kesOffset <- arbitrary `suchThat` (< maxKESOffset) + payload <- arbitrary + crypto <- arbitrary + return $ withConstrBind crypto \CryptoCtx {ocert, coldKey, snKESKey} -> do + let sigRawOpCertificate :: SigOpCertificate crypto + sigRawOpCertificate = SigOpCertificate ocert + + sigRawBody :: SigBody + sigRawBody = SigBody payload + + sigRawColdKey :: SigColdKey crypto + sigRawColdKey = SigColdKey $ deriveVerKeyDSIGN coldKey + + sigRawKESPeriod :: KESPeriod + sigRawKESPeriod = KESPeriod $ unKESPeriod (ocertKESPeriod ocert) + + kesOffset + + sigRaw = SigRaw { + sigRawId, + sigRawBody, + sigRawKESPeriod, + sigRawOpCertificate, + sigRawColdKey, + sigRawExpiresAt, + sigRawKESSignature = undefined -- to be filled below + } + signedBytes = CBOR.toStrictByteString (encodeSigRaw' sigRaw) + + -- evolve the key to the target period + mbSnKESKey <- KES.updateKESTo () sigRawKESPeriod ocert (KES.SignKeyWithPeriodKES snKESKey 0) + case mbSnKESKey of + Just (KES.SignKeyWithPeriodKES snKESKey' _) -> do + -- signed bytes with the snKESKey' + sigRawKESSignature + <- SigKESSignature + <$> KES.signKES () kesOffset signedBytes snKESKey' + return SigRawWithSignedBytes { + sigRawSignedBytes = BL.fromStrict signedBytes, + sigRaw = sigRaw { sigRawKESSignature } + } + Nothing -> + error $ "arbitrary SigRawWithSignedBytes: could not evolve KES key to the target period by KES offset: " + ++ show kesOffset + shrink = shrinkWithConstrSigRawWithSignedBytes @@ -290,26 +393,38 @@ instance ( Crypto crypto -- shrinkWithConstrSigRawWithSignedBytes :: forall crypto. - Crypto crypto - => WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto) - -> [WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto)] + ( Crypto crypto + , ContextKES (KES crypto) ~ () + , Signable (KES crypto) ByteString + ) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto) + -> [WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto)] shrinkWithConstrSigRawWithSignedBytes = shrinkWithConstr updateFn shrinkSigRawWithSignedBytesFn where updateFn :: SigRawWithSignedBytes crypto - -> VerKeyKES (KES crypto) - -> SigRawWithSignedBytes crypto + -> (SignKeyKES (KES crypto), VerKeyKES (KES crypto)) + -> IO (SigRawWithSignedBytes crypto) updateFn SigRawWithSignedBytes { - sigRaw = sigRaw@SigRaw { sigRawOpCertificate = SigOpCertificate ocert }, + sigRaw = sigRaw@SigRaw { sigRawOpCertificate = SigOpCertificate ocert, + sigRawKESPeriod + }, sigRawSignedBytes } - ocertVkHot - = + (snKeyKES, ocertVkHot) + = do let sigRaw' = sigRaw { sigRawOpCertificate = SigOpCertificate ocert { ocertVkHot } } - in SigRawWithSignedBytes { - sigRaw = sigRaw', + -- update KES key to sigRawKESPeriod + Just (KES.SignKeyWithPeriodKES snKeyKES' _) + <- KES.updateKESTo () sigRawKESPeriod ocert (KES.SignKeyWithPeriodKES snKeyKES 0) + -- sign the message + sign <- KES.signKES () (KES.unKESPeriod sigRawKESPeriod - KES.unKESPeriod (ocertKESPeriod ocert)) + (BL.toStrict sigRawSignedBytes) + snKeyKES' + pure $ SigRawWithSignedBytes { + sigRaw = sigRaw' { sigRawKESSignature = SigKESSignature sign }, sigRawSignedBytes } @@ -327,11 +442,18 @@ shrinkSigRawWithSignedBytesFn SigRawWithSignedBytes { sigRaw } = | sigRaw' <- shrinkSigRawFn sigRaw , let sigRawSignedBytes' = CBOR.toLazyByteString (encodeSigRaw' sigRaw') ] + + +-- | Pure shrinking function for `SigRaw`. It does not update the KES +-- signature. +-- shrinkSigRawFn :: SigRaw crypto -> [SigRaw crypto] shrinkSigRawFn sig@SigRaw { sigRawId, - sigRawBody, - sigRawExpiresAt - } = + sigRawBody, + sigRawKESPeriod, + sigRawExpiresAt, + sigRawOpCertificate = SigOpCertificate ocert + } = [ sig { sigRawId = sigRawId' } | sigRawId' <- shrink sigRawId ] @@ -340,23 +462,15 @@ shrinkSigRawFn sig@SigRaw { sigRawId, | sigRawBody' <- shrink sigRawBody ] ++ + [ sig { sigRawKESPeriod = sigRawKESPeriod' } + | sigRawKESPeriod' <- KESPeriod <$> shrink (unKESPeriod sigRawKESPeriod) + , sigRawKESPeriod' >= ocertKESPeriod ocert + ] + ++ [ sig { sigRawExpiresAt = sigRawExpiresAt' } | sigRawExpiresAt' <- shrink sigRawExpiresAt ] -instance Arbitrary SigColdKey where - arbitrary = SigColdKey <$> arbitrary - shrink = map SigColdKey . shrink . getSigColdKey - - -mkSigRawWithSignedBytes :: SigRaw crypto -> SigRawWithSignedBytes crypto -mkSigRawWithSignedBytes sigRaw = - SigRawWithSignedBytes { - sigRaw, - sigRawSignedBytes - } - where - sigRawSignedBytes = CBOR.toLazyByteString (encodeSigRaw' sigRaw) -- NOTE: this function is not exposed in the main library on purpose. We -- should never construct `Sig` by serialising `SigRaw`. @@ -369,7 +483,7 @@ mkSig sigRawWithSignedBytes@SigRawWithSignedBytes { sigRaw } = sigRawWithSignedBytes } where - sigRawBytes = CBOR.toLazyByteString (encodeSigRaw sigRaw) + sigRawBytes = CBOR.toLazyByteString (encodeSigRaw sigRaw) -- encode only signed part @@ -384,7 +498,7 @@ encodeSigRaw' SigRaw { = CBOR.encodeListLen 7 <> encodeSigId sigRawId <> CBOR.encodeBytes (getSigBody sigRawBody) - <> CBOR.encodeWord sigRawKESPeriod + <> CBOR.encodeWord (unKESPeriod sigRawKESPeriod) <> CBOR.encodeWord32 (floor sigRawExpiresAt) -- encode together with KES signature, OCert and cold key. @@ -393,41 +507,61 @@ encodeSigRaw :: Crypto crypto -> CBOR.Encoding encodeSigRaw sigRaw@SigRaw { sigRawKESSignature, sigRawOpCertificate, sigRawColdKey } = encodeSigRaw' sigRaw - <> CBOR.encodeBytes (getSigKESSignature sigRawKESSignature) + <> encodeSigKES (getSigKESSignature sigRawKESSignature) <> encodeSigOpCertificate sigRawOpCertificate - <> CBOR.encodeBytes (getSigColdKey sigRawColdKey) + <> encodeVerKeyDSIGN (getSigColdKey sigRawColdKey) - -shrinkSigFn :: forall crypto. Crypto crypto +-- note: KES signature is updated by updateSigFn +shrinkSigFn :: forall crypto. + ( Crypto crypto + ) => Sig crypto -> [Sig crypto] shrinkSigFn SigWithBytes {sigRawWithSignedBytes = SigRawWithSignedBytes { sigRaw, sigRawSignedBytes } } = mkSig . (\sigRaw' -> SigRawWithSignedBytes { sigRaw = sigRaw', sigRawSignedBytes }) <$> shrinkSigRawFn sigRaw + +updateSigFn :: forall crypto. + KESAlgorithm (KES crypto) + => ContextKES (KES crypto) ~ () + => Signable (KES crypto) ByteString + => Sig crypto + -> (SignKeyKES (KES crypto), VerKeyKES (KES crypto)) + -> IO (Sig crypto) +updateSigFn + sig@Sig { sigOpCertificate = SigOpCertificate opCert, + sigBody = SigBody body + } + (snKESKey, vnKESKey) + = do + signature <- KES.signKES () (KES.unKESPeriod (ocertKESPeriod opCert)) body snKESKey + return $ sig { sigOpCertificate = SigOpCertificate opCert { ocertVkHot = vnKESKey}, + sigKESSignature = SigKESSignature signature + } + + instance ( Crypto crypto , DSIGN.ContextDSIGN (DSIGN crypto) ~ () , DSIGN.Signable (DSIGN crypto) (KES.OCertSignable crypto) , kesCrypto ~ KES crypto + , ContextKES kesCrypto ~ () + , Signable kesCrypto ByteString , size ~ SeedSizeKES kesCrypto , KnownNat size ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (Sig crypto)) where + => Arbitrary (WithConstrKES size kesCrypto (Sig crypto)) where arbitrary = fmap mkSig <$> arbitrary shrink = shrinkWithConstr updateSigFn shrinkSigFn -updateSigFn :: Sig crypto -> VerKeyKES (KES crypto) -> Sig crypto -updateSigFn - sig@Sig {sigOpCertificate = SigOpCertificate opCert} - ocertVkHot - = - sig { sigOpCertificate = SigOpCertificate opCert { ocertVkHot } } - instance ( kesCrypto ~ KES crypto + , KESAlgorithm kesCrypto + , ContextKES kesCrypto ~ () + , Signable kesCrypto ByteString , size ~ SeedSizeKES kesCrypto , KnownNat size - , Arbitrary (WithConstrVerKeyKES size kesCrypto (Sig crypto)) + , Arbitrary (WithConstrKES size kesCrypto (Sig crypto)) ) - => Arbitrary (WithConstrVerKeyKESList size kesCrypto (AnyMessage (SigSubmission crypto))) where + => Arbitrary (WithConstrKESList size kesCrypto (AnyMessage (SigSubmission crypto))) where arbitrary = oneof [ pure . constWithConstr $ AnyMessage MsgInit , constWithConstr . AnyMessage <$> @@ -451,15 +585,19 @@ instance ( kesCrypto ~ KES crypto , constWithConstr . AnyMessage <$> MsgRequestTxs <$> arbitrary , listWithConstr (AnyMessage . MsgReplyTxs) - <$> (arbitrary :: Gen [WithConstrVerKeyKES size kesCrypto (Sig crypto)]) + <$> (arbitrary :: Gen [WithConstrKES size kesCrypto (Sig crypto)]) , constWithConstr . AnyMessage <$> pure MsgDone ] shrink = shrinkWithConstr updateFn shrinkFn where - updateFn :: AnyMessage (SigSubmission crypto) -> [VerKeyKES kesCrypto] -> AnyMessage (SigSubmission crypto) - updateFn (AnyMessage (MsgReplyTxs txs)) vkKeyKESs = AnyMessage (MsgReplyTxs (zipWith updateSigFn txs vkKeyKESs)) - updateFn msg _ = msg + updateFn :: AnyMessage (SigSubmission crypto) + -> [(SignKeyKES kesCrypto, VerKeyKES kesCrypto)] + -> IO (AnyMessage (SigSubmission crypto)) + updateFn (AnyMessage (MsgReplyTxs txs)) keys = do + sigs <- traverse (uncurry updateSigFn) (zip txs keys) + return $ AnyMessage (MsgReplyTxs sigs) + updateFn msg _ = pure msg shrinkFn :: AnyMessage (SigSubmission crypto) -> [AnyMessage (SigSubmission crypto)] shrinkFn = \case @@ -494,10 +632,10 @@ instance ( kesCrypto ~ KES crypto prop_codec_ocert :: forall crypto. Crypto crypto - => WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (OCert crypto) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (CryptoCtx crypto) -> Property prop_codec_ocert constr = ioProperty $ do - ocert <- runWithConstr constr + CryptoCtx { ocert } <- runWithConstr constr return . counterexample (show ocert) $ let encoded = CBOR.toLazyByteString (encodeSigOpCertificate (SigOpCertificate ocert)) in case CBOR.deserialiseFromBytes decodeSigOpCertificate encoded of @@ -507,12 +645,12 @@ prop_codec_ocert constr = ioProperty $ do .&&. BL.null bytes prop_codec_ocert_mockcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (OCert MockCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (CryptoCtx MockCrypto)) -> Property prop_codec_ocert_mockcrypto = prop_codec_ocert . getBlind prop_codec_ocert_standardcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (OCert StandardCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (CryptoCtx StandardCrypto)) -> Property prop_codec_ocert_standardcrypto = prop_codec_ocert . getBlind @@ -523,7 +661,7 @@ prop_codec_ocert_standardcrypto = prop_codec_ocert . getBlind -- * signed bytes match the encoding of `encodeSigRaw'`. prop_codec_sig :: forall crypto. Crypto crypto - => WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (Sig crypto) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (Sig crypto) -> Property prop_codec_sig constr = ioProperty $ do sig <- runWithConstr constr @@ -553,17 +691,17 @@ prop_codec_sig constr = ioProperty $ do .&&. BL.null leftovers prop_codec_sig_mockcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) -> Property prop_codec_sig_mockcrypto = prop_codec_sig . getBlind prop_codec_sig_standardcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) -> Property prop_codec_sig_standardcrypto = prop_codec_sig . getBlind -type AnySigMessage crypto = WithConstrVerKeyKESList (SeedSizeKES (KES crypto)) (KES crypto) (AnyMessage (SigSubmission crypto)) +type AnySigMessage crypto = WithConstrKESList (SeedSizeKES (KES crypto)) (KES crypto) (AnyMessage (SigSubmission crypto)) prop_codec :: forall crypto. Crypto crypto @@ -669,3 +807,74 @@ prop_codec_valid_cbor_standardcrypto :: Blind (AnySigMessage StandardCrypto) -> Property prop_codec_valid_cbor_standardcrypto = prop_codec_valid_cbor . getBlind + + +-- | Check that the KES signature is valid. +-- +prop_validateSig + :: ( Crypto crypto + , DSIGN.ContextDSIGN (DSIGN crypto) ~ () + , DSIGN.Signable (DSIGN crypto) (KES.OCertSignable crypto) + , KES.ContextKES (KES crypto) ~ () + , KES.Signable (KES crypto) ByteString + ) + => WithConstrKES size kesCrypt (Sig crypto) + -> Property +prop_validateSig constr = ioProperty $ do + sig <- runWithConstr constr + return $ case validateSig sig of + Left err -> counterexample ("KES seed: " ++ show (ctx constr)) + . counterexample ("KES vk key: " ++ show (ocertVkHot . getSigOpCertificate . sigOpCertificate $ sig)) + . counterexample (show sig) + . counterexample (show err) + $ False + Right () -> property True + +prop_validateSig_mockcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) + -> Property +prop_validateSig_mockcrypto = prop_validateSig . getBlind + +-- TODO: FAILS, why? +prop_validateSig_standardcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (Sig StandardCrypto)) + -> Property +prop_validateSig_standardcrypto = prop_validateSig . getBlind + + +-- | Sign & verify a payload with KES keys. +-- +prop_sign_verify + :: ( Crypto crypto + , ContextKES (KES crypto) ~ () + , Signable (KES crypto) ByteString + ) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (CryptoCtx crypto) + -- ^ KES keys + -> ByteString + -- ^ payload + -> Property +prop_sign_verify constr payload = ioProperty $ do + CryptoCtx { snKESKey, vnKESKey } <- runWithConstr constr + signed <- KES.signKES () 0 payload snKESKey + let res = KES.verifyKES () vnKESKey 0 payload signed + return $ counterexample "KES signature does not verify" + $ case res of + Left err -> counterexample (show err) + . counterexample ("vnKESKey: " ++ show vnKESKey) + . counterexample ("signature: " ++ show signed) + $ False + Right () -> property True + + +prop_sign_verify_mockcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (CryptoCtx MockCrypto)) + -> ByteString + -> Property +prop_sign_verify_mockcrypto = prop_sign_verify . getBlind + +prop_sign_verify_standardcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (CryptoCtx StandardCrypto)) + -> ByteString + -> Property +prop_sign_verify_standardcrypto = prop_sign_verify . getBlind From d32270c46e87159743d255515940c74ff0494484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Mon, 25 Aug 2025 12:41:03 +0200 Subject: [PATCH 09/16] diffusion: export withiomanager --- ouroboros-network/src/Ouroboros/Network/Diffusion.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/ouroboros-network/src/Ouroboros/Network/Diffusion.hs b/ouroboros-network/src/Ouroboros/Network/Diffusion.hs index dd46c7761a..98057e4271 100644 --- a/ouroboros-network/src/Ouroboros/Network/Diffusion.hs +++ b/ouroboros-network/src/Ouroboros/Network/Diffusion.hs @@ -17,6 +17,7 @@ module Ouroboros.Network.Diffusion , runM , mkInterfaces , socketAddressType + , withIOManager , module Ouroboros.Network.Diffusion.Types ) where From 277a77ececfae2df099481fbea9c754dceb0b9d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Mon, 25 Aug 2025 12:42:53 +0200 Subject: [PATCH 10/16] dmq: node kernel to hold pool ids --- dmq-node/src/DMQ/Diffusion/NodeKernel.hs | 45 +++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs index b5e479e334..dbd0f75900 100644 --- a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs +++ b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs @@ -1,9 +1,10 @@ -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE DataKinds #-} module DMQ.Diffusion.NodeKernel ( NodeKernel (..) , withNodeKernel + , PoolValidationCtx (..) + , StakePools (..) ) where import Control.Concurrent.Class.MonadMVar @@ -12,18 +13,23 @@ import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime.SI import Control.Monad.Class.MonadTimer.SI +import Control.Tracer import Data.Function (on) -import Data.Set (Set) -import Data.Set qualified as Set +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map import Data.Sequence (Seq) import Data.Sequence qualified as Seq +import Data.Set (Set) +import Data.Set qualified as Set import Data.Time.Clock.POSIX (POSIXTime) import Data.Time.Clock.POSIX qualified as Time import Data.Void (Void) import System.Random (StdGen) import System.Random qualified as Random +import Cardano.Ledger.Shelley.API +import Ouroboros.Consensus.Shelley.Ledger.Query import Ouroboros.Network.BlockFetch (FetchClientRegistry, newFetchClientRegistry) import Ouroboros.Network.ConnectionId (ConnectionId (..)) @@ -53,10 +59,30 @@ data NodeKernel crypto ntnAddr m = , sigChannelVar :: !(TxChannelsVar m ntnAddr SigId (Sig crypto)) , sigMempoolSem :: !(TxMempoolSem m) , sigSharedTxStateVar :: !(SharedTxStateVar m ntnAddr SigId (Sig crypto)) + , stakePools :: !(StakePools m) + , nextEpochVar :: !(StrictTVar m (Maybe UTCTime)) } +-- | Cardano pool id's are hashes of the cold verification key +-- +type PoolId = KeyHash StakePool + +data StakePools m = StakePools { + -- | contains map of cardano pool stake snapshot obtained + -- via local state query client + stakePoolsVar :: StrictTVar m (Map PoolId StakeSnapshot) + -- | acquires validation context for signature validation + , poolValidationCtx :: m PoolValidationCtx + } + +data PoolValidationCtx = + DMQPoolValidationCtx !UTCTime -- ^ time of context acquisition + !(Maybe UTCTime) -- ^ UTC time of next epoch boundary + !(Map PoolId StakeSnapshot) -- ^ for signature validation + newNodeKernel :: ( MonadLabelledSTM m , MonadMVar m + , MonadTime m , Ord ntnAddr ) => StdGen @@ -72,6 +98,15 @@ newNodeKernel rng = do sigMempoolSem <- newTxMempoolSem let (rng', rng'') = Random.split rng sigSharedTxStateVar <- newSharedTxStateVar rng' + nextEpochVar <- newTVarIO Nothing + stakePoolsVar <- newTVarIO Map.empty + let poolValidationCtx = do + (nextEpochBoundary, stakePools) <- + atomically $ (,) <$> readTVar nextEpochVar <*> readTVar stakePoolsVar + now <- getCurrentTime + return $ DMQPoolValidationCtx now nextEpochBoundary stakePools + + stakePools = StakePools { stakePoolsVar, poolValidationCtx } peerSharingAPI <- newPeerSharingAPI @@ -87,6 +122,8 @@ newNodeKernel rng = do , sigChannelVar , sigMempoolSem , sigSharedTxStateVar + , nextEpochVar + , stakePools } From a23fc411f9986618e8b87f8cf5c26599a6732f0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Wed, 3 Sep 2025 14:15:08 +0200 Subject: [PATCH 11/16] dmq: add cardano-node socket path to configuration & cli options --- dmq-node/app/Main.hs | 3 ++- dmq-node/src/DMQ/Configuration.hs | 6 ++++++ dmq-node/src/DMQ/Configuration/CLIOptions.hs | 14 ++++++++++---- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/dmq-node/app/Main.hs b/dmq-node/app/Main.hs index 44c6612779..7f964afb3d 100644 --- a/dmq-node/app/Main.hs +++ b/dmq-node/app/Main.hs @@ -56,7 +56,8 @@ runDMQ commandLineConfig = do dmqcPrettyLog = I prettyLog, dmqcTopologyFile = I topologyFile, dmqcHandshakeTracer = I handshakeTracer, - dmqcLocalHandshakeTracer = I localHandshakeTracer + dmqcLocalHandshakeTracer = I localHandshakeTracer, + dmqcCardanoNodeSocket = I snocketPath } = config' <> commandLineConfig `act` defaultConfiguration diff --git a/dmq-node/src/DMQ/Configuration.hs b/dmq-node/src/DMQ/Configuration.hs index fccab50be1..bda53a2d26 100644 --- a/dmq-node/src/DMQ/Configuration.hs +++ b/dmq-node/src/DMQ/Configuration.hs @@ -100,6 +100,7 @@ data Configuration' f = dmqcChurnInterval :: f DiffTime, dmqcPeerSharing :: f PeerSharing, dmqcNetworkMagic :: f NetworkMagic, + dmqcCardanoNodeSocket :: f FilePath, dmqcPrettyLog :: f Bool, dmqcMuxTracer :: f Bool, @@ -196,6 +197,7 @@ defaultConfiguration = Configuration { dmqcTopologyFile = I "dmq.topology.json", dmqcAcceptedConnectionsLimit = I defaultAcceptedConnectionsLimit, dmqcDiffusionMode = I InitiatorAndResponderDiffusionMode, + dmqcCardanoNodeSocket = I "cardano-node.socket", dmqcTargetOfRootPeers = I targetNumberOfRootPeers, dmqcTargetOfKnownPeers = I targetNumberOfKnownPeers, dmqcTargetOfEstablishedPeers = I targetNumberOfEstablishedPeers, @@ -271,6 +273,7 @@ instance FromJSON PartialConfig where dmqcNetworkMagic <- Last . fmap NetworkMagic <$> v .:? "NetworkMagic" dmqcDiffusionMode <- Last <$> v .:? "DiffusionMode" dmqcPeerSharing <- Last <$> v .:? "PeerSharing" + dmqcCardanoNodeSocket <- Last <$> v .:? "CardanoNodeSocket" dmqcTargetOfRootPeers <- Last <$> v .:? "TargetNumberOfRootPeers" dmqcTargetOfKnownPeers <- Last <$> v .:? "TargetNumberOfKnownPeers" @@ -324,6 +327,7 @@ instance FromJSON PartialConfig where Configuration { dmqcIPv4 = Last dmqcIPv4 , dmqcIPv6 = Last dmqcIPv6 + , dmqcCardanoNodeSocket , dmqcPortNumber , dmqcConfigFile = mempty , dmqcTopologyFile = mempty @@ -384,6 +388,7 @@ instance ToJSON Configuration where dmqcIPv6, dmqcPortNumber, dmqcConfigFile, + dmqcCardanoNodeSocket, dmqcTopologyFile, dmqcAcceptedConnectionsLimit, dmqcDiffusionMode, @@ -438,6 +443,7 @@ instance ToJSON Configuration where , "IPv6" .= (show <$> unI dmqcIPv6) , "PortNumber" .= unI dmqcPortNumber , "ConfigFile" .= unI dmqcConfigFile + , "CardanoNodeSocket" .= unI dmqcCardanoNodeSocket , "TopologyFile" .= unI dmqcTopologyFile , "AcceptedConnectionsLimit" .= unI dmqcAcceptedConnectionsLimit , "DiffusionMode" .= unI dmqcDiffusionMode diff --git a/dmq-node/src/DMQ/Configuration/CLIOptions.hs b/dmq-node/src/DMQ/Configuration/CLIOptions.hs index b9a5989170..bf0cbcdd32 100644 --- a/dmq-node/src/DMQ/Configuration/CLIOptions.hs +++ b/dmq-node/src/DMQ/Configuration/CLIOptions.hs @@ -46,13 +46,19 @@ parseCLIOptions = <> help "Topology file for DMQ Node" ) ) + <*> optional ( + strOption + ( long "cardano-node-socket" + <> metavar "Cardano node socket path" + <> help "Used for local connections to Cardano node" + ) + ) where - mkConfiguration ipv4 ipv6 portNumber configFile topologyFile = + mkConfiguration ipv4 ipv6 portNumber configFile topologyFile cardanoNodeSocket = mempty { dmqcIPv4 = Last (Just <$> ipv4), dmqcIPv6 = Last (Just <$> ipv6), dmqcPortNumber = Last portNumber, dmqcConfigFile = Last configFile, - dmqcTopologyFile = Last topologyFile + dmqcTopologyFile = Last topologyFile, + dmqcCardanoNodeSocket = Last cardanoNodeSocket } - - From 664a68cc5470e2011a90aec424e759210af84f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Tue, 26 Aug 2025 14:55:59 +0200 Subject: [PATCH 12/16] dmq: local state query client for cardano-node interop * cabal.project: --- dmq-node/app/Main.hs | 104 ++++++----- dmq-node/dmq-node.cabal | 13 +- dmq-node/src/DMQ/Diffusion/NodeKernel.hs | 15 +- .../DMQ/NodeToClient/LocalStateQueryClient.hs | 163 ++++++++++++++++++ 4 files changed, 240 insertions(+), 55 deletions(-) create mode 100644 dmq-node/src/DMQ/NodeToClient/LocalStateQueryClient.hs diff --git a/dmq-node/app/Main.hs b/dmq-node/app/Main.hs index 7f964afb3d..a4c209e65c 100644 --- a/dmq-node/app/Main.hs +++ b/dmq-node/app/Main.hs @@ -1,9 +1,11 @@ -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DisambiguateRecordFields #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} module Main where import Control.Monad (void) +import Control.Monad.Class.MonadAsync import Control.Tracer (Tracer (..), nullTracer, traceWith) import Data.Act @@ -20,7 +22,7 @@ import DMQ.Configuration.CLIOptions (parseCLIOptions) import DMQ.Configuration.Topology (readTopologyFileOrError) import DMQ.Diffusion.Applications (diffusionApplications) import DMQ.Diffusion.Arguments -import DMQ.Diffusion.NodeKernel (mempool, withNodeKernel) +import DMQ.Diffusion.NodeKernel import DMQ.NodeToClient qualified as NtC import DMQ.NodeToNode (dmqCodecs, dmqLimitsAndTimeouts, ntnApps) import DMQ.Protocol.LocalMsgSubmission.Codec @@ -28,11 +30,14 @@ import DMQ.Protocol.SigSubmission.Type (Sig (..)) import DMQ.Tracer import DMQ.Diffusion.PeerSelection (policy) +import DMQ.NodeToClient.LocalStateQueryClient import Ouroboros.Network.Diffusion qualified as Diffusion import Ouroboros.Network.PeerSelection.PeerSharing.Codec (decodeRemoteAddress, encodeRemoteAddress) +import Ouroboros.Network.Snocket import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool + main :: IO () main = void . runDMQ =<< execParser opts where @@ -70,48 +75,53 @@ runDMQ commandLineConfig = do stdGen <- newStdGen let (psRng, policyRng) = split stdGen + diffusionTracers = dmqDiffusionTracers dmqConfig tracer + + Diffusion.withIOManager \iocp -> do + let localSnocket' = localSnocket iocp + + withNodeKernel @StandardCrypto psRng $ \nodeKernel -> do + dmqDiffusionConfiguration <- mkDiffusionConfiguration dmqConfig nt + + let stakePoolMonitor = connectToCardanoNode tracer localSnocket' snocketPath nodeKernel + + withAsync stakePoolMonitor \aid -> do + link aid + let dmqNtNApps = + ntnApps tracer + dmqConfig + nodeKernel + (dmqCodecs + -- TODO: `maxBound :: Cardano.Network.NodeToNode.NodeToNodeVersion` + -- is unsafe here! + (encodeRemoteAddress maxBound) + (decodeRemoteAddress maxBound)) + dmqLimitsAndTimeouts + defaultSigDecisionPolicy + dmqNtCApps = + let sigSize _ = 0 -- TODO + maxMsgs = 1000 -- TODO: make this negotiated in the handshake? + mempoolReader = Mempool.getReader sigId sigSize (mempool nodeKernel) + mempoolWriter = Mempool.getWriter sigId (const ()) (\_ _ -> pure True) (mempool nodeKernel) + in NtC.ntcApps mempoolReader mempoolWriter maxMsgs + (NtC.dmqCodecs encodeReject decodeReject) + dmqDiffusionArguments = + diffusionArguments (if handshakeTracer + then WithEventType "Handshake" >$< tracer + else nullTracer) + (if localHandshakeTracer + then WithEventType "Handshake" >$< tracer + else nullTracer) + dmqDiffusionApplications = + diffusionApplications nodeKernel + dmqConfig + dmqDiffusionConfiguration + dmqLimitsAndTimeouts + dmqNtNApps + dmqNtCApps + (policy policyRng) - withNodeKernel @StandardCrypto psRng $ \nodeKernel -> do - dmqDiffusionConfiguration <- mkDiffusionConfiguration dmqConfig nt - - let dmqNtNApps = - ntnApps tracer - dmqConfig - nodeKernel - (dmqCodecs - -- TODO: `maxBound :: Cardano.Network.NodeToNode.NodeToNodeVersion` - -- is unsafe here! - (encodeRemoteAddress maxBound) - (decodeRemoteAddress maxBound)) - dmqLimitsAndTimeouts - defaultSigDecisionPolicy - dmqNtCApps = - let sigSize _ = 0 -- TODO - maxMsgs = 1000 -- TODO: make this dynamic? - mempoolReader = Mempool.getReader sigId sigSize (mempool nodeKernel) - mempoolWriter = Mempool.getWriter sigId (pure ()) - (\_ _ -> Right () :: Either Void ()) - (\_ -> True) - (mempool nodeKernel) - in NtC.ntcApps mempoolReader mempoolWriter maxMsgs - (NtC.dmqCodecs encodeReject decodeReject) - dmqDiffusionArguments = - diffusionArguments (if handshakeTracer - then WithEventType "Handshake" >$< tracer - else nullTracer) - (if localHandshakeTracer - then WithEventType "Handshake" >$< tracer - else nullTracer) - dmqDiffusionApplications = - diffusionApplications nodeKernel - dmqConfig - dmqDiffusionConfiguration - dmqLimitsAndTimeouts - dmqNtNApps - dmqNtCApps - (policy policyRng) - - Diffusion.run dmqDiffusionArguments - (dmqDiffusionTracers dmqConfig tracer) - dmqDiffusionConfiguration - dmqDiffusionApplications + Diffusion.run dmqDiffusionArguments + diffusionTracers + dmqDiffusionConfiguration + dmqDiffusionApplications diff --git a/dmq-node/dmq-node.cabal b/dmq-node/dmq-node.cabal index a40f90253d..86199f79c1 100644 --- a/dmq-node/dmq-node.cabal +++ b/dmq-node/dmq-node.cabal @@ -56,6 +56,7 @@ library DMQ.NodeToClient DMQ.NodeToClient.LocalMsgNotification DMQ.NodeToClient.LocalMsgSubmission + DMQ.NodeToClient.LocalStateQueryClient DMQ.NodeToClient.Version DMQ.NodeToNode DMQ.NodeToNode.Version @@ -81,6 +82,10 @@ library bytestring >=0.10 && <0.13, cardano-binary, cardano-crypto-class, + cardano-crypto-wrapper, + cardano-ledger-byron, + cardano-ledger-shelley, + cardano-slotting, cborg >=0.2.1 && <0.3, containers >=0.5 && <0.8, contra-tracer >=0.1 && <0.3, @@ -96,7 +101,10 @@ library network ^>=3.2.7, network-mux ^>=0.9.1, optparse-applicative ^>=0.18, - ouroboros-network:{ouroboros-network, orphan-instances} ^>=0.23, + ouroboros-consensus, + ouroboros-consensus-cardano, + ouroboros-consensus-diffusion, + ouroboros-network:{cardano-diffusion, ouroboros-network, orphan-instances} ^>=0.23, ouroboros-network-api ^>=0.17, ouroboros-network-framework ^>=0.20, ouroboros-network-protocols ^>=0.16, @@ -104,6 +112,7 @@ library singletons, text >=1.2.4 && <2.2, time ^>=1.12, + transformers, typed-protocols:{typed-protocols, cborg} ^>=1.1, hs-source-dirs: src @@ -125,10 +134,12 @@ executable dmq-node base, contra-tracer >=0.1 && <0.3, dmq-node, + io-classes, kes-agent-crypto, optparse-applicative, ouroboros-network, ouroboros-network-api, + ouroboros-network-framework, random, hs-source-dirs: app diff --git a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs index dbd0f75900..b964bca09b 100644 --- a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs +++ b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs @@ -13,7 +13,6 @@ import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime.SI import Control.Monad.Class.MonadTimer.SI -import Control.Tracer import Data.Function (on) import Data.Map.Strict (Map) @@ -101,10 +100,10 @@ newNodeKernel rng = do nextEpochVar <- newTVarIO Nothing stakePoolsVar <- newTVarIO Map.empty let poolValidationCtx = do - (nextEpochBoundary, stakePools) <- + (nextEpochBoundary, stakePools') <- atomically $ (,) <$> readTVar nextEpochVar <*> readTVar stakePoolsVar now <- getCurrentTime - return $ DMQPoolValidationCtx now nextEpochBoundary stakePools + return $ DMQPoolValidationCtx now nextEpochBoundary stakePools' stakePools = StakePools { stakePoolsVar, poolValidationCtx } @@ -138,15 +137,17 @@ withNodeKernel :: forall crypto ntnAddr m a. , Ord ntnAddr ) => StdGen + -> (NodeKernel crypto ntnAddr m -> m (Either SomeException Void)) -> (NodeKernel crypto ntnAddr m -> m a) -- ^ as soon as the callback exits the `mempoolWorker` will be -- killed -> m a -withNodeKernel rng k = do +withNodeKernel rng mkStakePoolMonitor k = do nodeKernel@NodeKernel { mempool } <- newNodeKernel rng - withAsync (mempoolWorker mempool) - $ \thread -> link thread - >> k nodeKernel + withAsync (mempoolWorker mempool) \workerAid -> do + link workerAid + withAsync (mkStakePoolMonitor nodeKernel) \spmAid -> + link spmAid >> k nodeKernel mempoolWorker :: forall crypto m. diff --git a/dmq-node/src/DMQ/NodeToClient/LocalStateQueryClient.hs b/dmq-node/src/DMQ/NodeToClient/LocalStateQueryClient.hs new file mode 100644 index 0000000000..4381c16152 --- /dev/null +++ b/dmq-node/src/DMQ/NodeToClient/LocalStateQueryClient.hs @@ -0,0 +1,163 @@ +{-# LANGUAGE DisambiguateRecordFields #-} +{-# LANGUAGE TypeOperators #-} + +module DMQ.NodeToClient.LocalStateQueryClient + ( cardanoClient + , connectToCardanoNode + ) where + +import Control.Concurrent.Class.MonadSTM.Strict +import Control.Monad.Class.MonadThrow +import Control.Monad.Class.MonadTime.SI +import Control.Monad.Class.MonadTimer.SI +import Control.Monad.Trans.Except +import Control.Tracer (Tracer (..), nullTracer) +import Data.Functor.Contravariant ((>$<)) +import Data.Map.Strict qualified as Map +import Data.Proxy +import Data.Void + +import Cardano.Chain.Genesis +import Cardano.Chain.Slotting +import Cardano.Crypto.ProtocolMagic +import Cardano.Network.NodeToClient +import Cardano.Slotting.EpochInfo.API +import Cardano.Slotting.Time +import DMQ.Diffusion.NodeKernel +import DMQ.Tracer +import Ouroboros.Consensus.Cardano.Block +import Ouroboros.Consensus.Cardano.Node +import Ouroboros.Consensus.HardFork.Combinator.Ledger.Query +import Ouroboros.Consensus.HardFork.History.EpochInfo +import Ouroboros.Consensus.Ledger.Query +import Ouroboros.Consensus.Network.NodeToClient +import Ouroboros.Consensus.Node.NetworkProtocolVersion +import Ouroboros.Consensus.Node.ProtocolInfo +import Ouroboros.Consensus.Shelley.Ledger.Query +import Ouroboros.Consensus.Shelley.Ledger.SupportsProtocol () +import Ouroboros.Network.Block +import Ouroboros.Network.Magic +import Ouroboros.Network.Mux qualified as Mx +import Ouroboros.Network.Protocol.LocalStateQuery.Client +import Ouroboros.Network.Protocol.LocalStateQuery.Type +import Ouroboros.Network.Socket + +-- TODO generalize to handle ledger eras other than Conway +-- | connects the dmq node to cardano node via local state query +-- and updates the node kernel with stake pool data necessary to perform message +-- validation +cardanoClient + :: forall block query point crypto m. (MonadDelay m, MonadSTM m, MonadThrow m, MonadTime m) + => (block ~ CardanoBlock crypto, query ~ Query block, point ~ Point block) + => Tracer m String -- TODO: replace string with a proper type + -> StakePools m + -> StrictTVar m (Maybe UTCTime) -- ^ from node kernel + -> LocalStateQueryClient (CardanoBlock crypto) (Point block) (Query block) m Void +cardanoClient _tracer StakePools { stakePoolsVar } nextEpochVar = LocalStateQueryClient (idle Nothing) + where + idle mSystemStart = pure $ SendMsgAcquire ImmutableTip acquire + where + acquire :: ClientStAcquiring block point query m Void + acquire = ClientStAcquiring { + recvMsgAcquired = + let epochQry systemStart = pure $ + SendMsgQuery (BlockQuery . QueryIfCurrentConway $ GetEpochNo) + $ wrappingMismatch (handleEpoch systemStart) + in case mSystemStart of + Just systemStart -> epochQry systemStart + Nothing -> pure $ + SendMsgQuery GetSystemStart $ ClientStQuerying epochQry + + , recvMsgFailure = \failure -> + throwIO . userError $ "recvMsgFailure: " <> show failure + } + + wrappingMismatch k = ClientStQuerying $ + either (const . throwIO . userError $ "mismatch era info") k + + getInterpreter systemStart epoch = ClientStQuerying \interpreter -> do + let ei = interpreterToEpochInfo interpreter + res = + runExcept do + lastSlot <- snd <$> epochInfoRange ei epoch + lastSlotTime <- epochInfoSlotToRelativeTime ei lastSlot + lastSlotLength <- epochInfoSlotLength ei lastSlot + pure $ addRelativeTime (getSlotLength lastSlotLength) lastSlotTime + + case res of + Left _err -> pure $ SendMsgRelease do + threadDelay 86400 -- TODO fuzz this? + idle $ Just systemStart + Right relativeTime -> do + now <- getCurrentTime + let nextEpoch = fromRelativeTime systemStart relativeTime + toNextEpoch = diffUTCTime nextEpoch now + if toNextEpoch < 5 then + pure $ SendMsgRelease do + threadDelay $ realToFrac toNextEpoch + idle $ Just systemStart + else pure $ + SendMsgQuery (BlockQuery . QueryIfCurrentConway $ GetStakeSnapshots Nothing) + $ wrappingMismatch (handleStakeSnapshots systemStart nextEpoch toNextEpoch) + + handleEpoch systemStart epoch = pure + . SendMsgQuery (BlockQuery . QueryHardFork $ GetInterpreter) + $ getInterpreter systemStart epoch + + handleStakeSnapshots systemStart nextEpoch toNextEpoch StakeSnapshots { ssStakeSnapshots } = do + atomically do + writeTVar stakePoolsVar ssStakeSnapshots + writeTVar nextEpochVar $ Just nextEpoch + pure $ SendMsgRelease do + threadDelay $ min (realToFrac toNextEpoch) 86400 -- TODO fuzz this? + idle $ Just systemStart + +connectToCardanoNode :: Tracer IO (WithEventType String) + -> LocalSnocket + -> FilePath + -> NodeKernel crypto ntnAddr IO + -> IO (Either SomeException Void) +connectToCardanoNode tracer localSnocket' snocketPath nodeKernel = + connectTo + localSnocket' + debuggingNetworkConnectTracers --nullNetworkConnectTracers + (combineVersions + [ simpleSingletonVersions + version + NodeToClientVersionData { + networkMagic = + NetworkMagic -- 2 {- preview net -} + . unProtocolMagicId + $ mainnetProtocolMagicId + , query = False + } + \_version -> + Mx.OuroborosApplication + [ Mx.MiniProtocol + { miniProtocolNum = Mx.MiniProtocolNum 7 + , miniProtocolStart = Mx.StartEagerly + , miniProtocolLimits = + Mx.MiniProtocolLimits + { maximumIngressQueue = 0xffffffff + } + , miniProtocolRun = + Mx.InitiatorProtocolOnly + . Mx.mkMiniProtocolCbFromPeerSt + . const + $ ( nullTracer + , cStateQueryCodec + , StateIdle + , localStateQueryClientPeer + $ cardanoClient (WithEventType "LocalStateQuery" >$< tracer) + (stakePools nodeKernel) + (nextEpochVar nodeKernel) + ) + } + ] + | version <- [minBound..maxBound] + , let supportedVersionMap = supportedNodeToClientVersions (Proxy :: Proxy (CardanoBlock StandardCrypto)) + blk = supportedVersionMap Map.! version + Codecs {cStateQueryCodec} = + clientCodecs (pClientInfoCodecConfig . protocolClientInfoCardano $ EpochSlots 21600) blk version + ]) + snocketPath From 7767396ced684c3cd6ac0ac5ee54f6603f6f5c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Fri, 12 Sep 2025 15:12:38 +0200 Subject: [PATCH 13/16] sig validation --- dmq-node/dmq-node.cabal | 3 + .../src/DMQ/Protocol/SigSubmission/Type.hs | 84 +-------- .../DMQ/Protocol/SigSubmission/Validate.hs | 176 ++++++++++++++++++ 3 files changed, 183 insertions(+), 80 deletions(-) create mode 100644 dmq-node/src/DMQ/Protocol/SigSubmission/Validate.hs diff --git a/dmq-node/dmq-node.cabal b/dmq-node/dmq-node.cabal index 86199f79c1..fad40d53b6 100644 --- a/dmq-node/dmq-node.cabal +++ b/dmq-node/dmq-node.cabal @@ -71,6 +71,7 @@ library DMQ.Protocol.LocalMsgSubmission.Type DMQ.Protocol.SigSubmission.Codec DMQ.Protocol.SigSubmission.Type + DMQ.Protocol.SigSubmission.Validate DMQ.Tracer build-depends: @@ -84,6 +85,7 @@ library cardano-crypto-class, cardano-crypto-wrapper, cardano-ledger-byron, + cardano-ledger-core, cardano-ledger-shelley, cardano-slotting, cborg >=0.2.1 && <0.3, @@ -113,6 +115,7 @@ library text >=1.2.4 && <2.2, time ^>=1.12, transformers, + transformers-except, typed-protocols:{typed-protocols, cborg} ^>=1.1, hs-source-dirs: src diff --git a/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs b/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs index ea64299069..ca31facb73 100644 --- a/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs +++ b/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs @@ -5,6 +5,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -19,7 +20,6 @@ module DMQ.Protocol.SigSubmission.Type , SigRaw (..) , SigRawWithSignedBytes (..) , Sig (Sig, SigWithBytes, sigRawWithSignedBytes, sigRawBytes, sigId, sigBody, sigExpiresAt, sigOpCertificate, sigKESPeriod, sigKESSignature, sigColdKey, sigSignedBytes, sigBytes) - , validateSig -- * `TxSubmission` mini-protocol , SigSubmission , module SigSubmission @@ -27,19 +27,15 @@ module DMQ.Protocol.SigSubmission.Type , KESPeriod (..) ) where -import Data.Bifunctor (first) import Data.ByteString (ByteString) import Data.ByteString.Lazy qualified as LBS import Data.Time.Clock.POSIX (POSIXTime) import Data.Typeable -import Data.Word (Word64) -import Cardano.Crypto.DSIGN.Class (ContextDSIGN, DSIGNAlgorithm, VerKeyDSIGN) -import Cardano.Crypto.DSIGN.Class qualified as DSIGN -import Cardano.Crypto.KES.Class (KESAlgorithm (..), Signable) +import Cardano.Crypto.DSIGN.Class (DSIGNAlgorithm, VerKeyDSIGN) +import Cardano.Crypto.KES.Class (KESAlgorithm (..)) import Cardano.KESAgent.KES.Crypto as KES -import Cardano.KESAgent.KES.OCert (KESPeriod (..), OCert (..), OCertSignable, - validateOCert) +import Cardano.KESAgent.KES.OCert (KESPeriod (..), OCert (..)) import Ouroboros.Network.Protocol.TxSubmission2.Type as SigSubmission hiding (TxSubmission2) @@ -225,76 +221,4 @@ pattern instance Typeable crypto => ShowProxy (Sig crypto) where -data SigValidationError = - InvalidKESSignature KESPeriod KESPeriod String - | InvalidSignatureOCERT - !Word64 -- OCert counter - !KESPeriod -- OCert KES period - !String -- DSIGN error message - | KESBeforeStartOCERT KESPeriod KESPeriod - | KESAfterEndOCERT KESPeriod KESPeriod - deriving Show - --- TODO: --- We don't validate ocert numbers, since we might not have necessary --- information to do so, but we can validate that they are growing. -validateSig :: forall crypto. - ( Crypto crypto - , ContextDSIGN (KES.DSIGN crypto) ~ () - , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) - , ContextKES (KES crypto) ~ () - , Signable (KES crypto) ByteString - ) - => Sig crypto - -> Either SigValidationError () -validateSig Sig { sigSignedBytes = signedBytes, - sigKESPeriod, - sigOpCertificate = SigOpCertificate ocert@OCert { - ocertKESPeriod, - ocertVkHot, - ocertN - }, - sigColdKey = SigColdKey coldKey, - sigKESSignature = SigKESSignature kesSig - } - = do - sigKESPeriod < endKESPeriod - ?! KESAfterEndOCERT endKESPeriod sigKESPeriod - sigKESPeriod >= startKESPeriod - ?! KESBeforeStartOCERT startKESPeriod sigKESPeriod - - -- validate OCert, which includes verifying its signature - validateOCert coldKey ocertVkHot ocert - ?!: InvalidSignatureOCERT ocertN sigKESPeriod - -- validate KES signature of the payload - verifyKES () ocertVkHot - (unKESPeriod sigKESPeriod - unKESPeriod startKESPeriod) - (LBS.toStrict signedBytes) - kesSig - ?!: InvalidKESSignature ocertKESPeriod sigKESPeriod - where - startKESPeriod, endKESPeriod :: KESPeriod - - startKESPeriod = ocertKESPeriod - -- TODO: is `totalPeriodsKES` the same as `praosMaxKESEvo` - -- or `sgMaxKESEvolution` in the genesis file? - endKESPeriod = KESPeriod $ unKESPeriod startKESPeriod - + totalPeriodsKES (Proxy :: Proxy (KES crypto)) - type SigSubmission crypto = TxSubmission2.TxSubmission2 SigId (Sig crypto) - - --- --- Utility functions --- - -(?!:) :: Either e1 a -> (e1 -> e2) -> Either e2 a -(?!:) = flip first - -infix 1 ?!: - -(?!) :: Bool -> e -> Either e () -(?!) True _ = Right () -(?!) False e = Left e - -infix 1 ?! diff --git a/dmq-node/src/DMQ/Protocol/SigSubmission/Validate.hs b/dmq-node/src/DMQ/Protocol/SigSubmission/Validate.hs new file mode 100644 index 0000000000..f76952a950 --- /dev/null +++ b/dmq-node/src/DMQ/Protocol/SigSubmission/Validate.hs @@ -0,0 +1,176 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +-- | Encapsulates signature validation utilities leveraged by the mempool writer +-- +module DMQ.Protocol.SigSubmission.Validate where + +import Control.Exception +import Control.Monad.Class.MonadTime.SI +import Control.Monad.Trans.Except +import Control.Monad.Trans.Except.Extra +import Data.ByteString (ByteString) +import Data.ByteString.Lazy qualified as LBS +import Data.Map.Strict qualified as Map +import Data.Maybe (isNothing) +import Data.Text (Text) +import Data.Text qualified as Text +import Data.Typeable +import Data.Word +import Text.Printf + +import Cardano.Crypto.DSIGN.Class (ContextDSIGN) +import Cardano.Crypto.DSIGN.Class qualified as DSIGN +import Cardano.Crypto.KES.Class (KESAlgorithm (..)) +import Cardano.KESAgent.KES.Crypto as KES +import Cardano.KESAgent.KES.OCert (OCert (..), OCertSignable, validateOCert) +import Cardano.Ledger.BaseTypes.NonZero +import Cardano.Ledger.Hashes + +import DMQ.Diffusion.NodeKernel (PoolValidationCtx (..)) +import DMQ.Protocol.SigSubmission.Type +import Ouroboros.Consensus.Shelley.Ledger.Query +import Ouroboros.Network.TxSubmission.Mempool.Simple +import Ouroboros.Network.Util.ShowProxy + + +-- | The type of non-fatal failures reported by the mempool writer +-- for invalid messages +-- +data instance MempoolAddFail (Sig crypto) = + SigInvalid Text + | SigDuplicate + | SigExpired + | SigResultOther Text + deriving (Eq, Show) + +instance (Typeable crypto) => ShowProxy (MempoolAddFail (Sig crypto)) + +-- | The type of exception raised by the mempool writer for invalid messages +-- as determined by the validation procedure and severity policy +-- +newtype instance InvalidTxsError SigValidationError = InvalidTxsError SigValidationError + +deriving instance Show (InvalidTxsError SigValidationError) +instance Exception (InvalidTxsError SigValidationError) + +-- | The policy which is realized by the mempool writer when encountering +-- an invalid message. +-- +data ValidationSeverity = + FailDefault | FailSoft + +data SigValidationError = + InvalidKESSignature KESPeriod KESPeriod String + | InvalidSignatureOCERT + !Word64 -- OCert counter + !KESPeriod -- OCert KES period + !String -- DSIGN error message + | KESBeforeStartOCERT KESPeriod KESPeriod + | KESAfterEndOCERT KESPeriod KESPeriod + | UnrecognizedPool + | NotInitialized + | ClockSkew + deriving Show + +-- TODO fine tune policy +sigValidationPolicy + :: SigValidationError + -> Either (MempoolAddFail (Sig crypto)) (MempoolAddFail (Sig crypto)) +sigValidationPolicy sve = case sve of + InvalidKESSignature {} -> Left . SigInvalid . Text.pack . show $ sve + InvalidSignatureOCERT {} -> Left . SigInvalid . Text.pack . show $ sve + KESAfterEndOCERT {} -> Left SigExpired + KESBeforeStartOCERT start sig -> + Left . SigResultOther . Text.pack + $ printf "KESBeforeStartOCERT %s %s" (show start) (show sig) + UnrecognizedPool -> Left . SigInvalid $ Text.pack "unrecognized pool id" + ClockSkew -> Left . SigInvalid $ Text.pack "clock skew out of range" + NotInitialized -> Right . SigResultOther $ Text.pack "not initialized yet" + +-- TODO: +-- We don't validate ocert numbers, since we might not have necessary +-- information to do so, but we can validate that they are growing. +validateSig :: forall crypto. + ( Crypto crypto + , ContextDSIGN (KES.DSIGN crypto) ~ () + , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) + , ContextKES (KES crypto) ~ () + , Signable (KES crypto) ByteString + ) + => ValidationSeverity + -> (DSIGN.VerKeyDSIGN (DSIGN crypto) -> KeyHash StakePool) + -> [Sig crypto] + -> PoolValidationCtx + -- ^ cardano pool id verification + -> Except (InvalidTxsError SigValidationError) [Either (MempoolAddFail (Sig crypto)) ()] +validateSig severity verKeyHashingFn sigs ctx = firstExceptT InvalidTxsError $ traverse process sigs + where + DMQPoolValidationCtx now mNextEpoch pools = ctx + + process Sig { sigSignedBytes = signedBytes, + sigKESPeriod, + sigOpCertificate = SigOpCertificate ocert@OCert { + ocertKESPeriod, + ocertVkHot, + ocertN + }, + sigColdKey = SigColdKey coldKey, + sigKESSignature = SigKESSignature kesSig + } = do + e1 <- sigKESPeriod < endKESPeriod + ?! KESAfterEndOCERT endKESPeriod sigKESPeriod + e2 <- sigKESPeriod >= startKESPeriod + ?! KESBeforeStartOCERT startKESPeriod sigKESPeriod + e3 <- case Map.lookup (verKeyHashingFn coldKey) pools of + Nothing | isNothing mNextEpoch -> classifyError NotInitialized + | otherwise -> classifyError UnrecognizedPool + Just ss | not (isZero (ssSetPool ss)) -> right $ Right () + | not (isZero (ssMarkPool ss)) + , Just nextEpoch <- mNextEpoch + -- TODO make this a constant + , diffUTCTime nextEpoch now <= 5 -> right $ Right () + | otherwise -> classifyError ClockSkew + -- validate OCert, which includes verifying its signature + e4 <- validateOCert coldKey ocertVkHot ocert + ?!: InvalidSignatureOCERT ocertN sigKESPeriod + -- validate KES signature of the payload + e5 <- verifyKES () ocertVkHot + (unKESPeriod sigKESPeriod - unKESPeriod startKESPeriod) + (LBS.toStrict signedBytes) + kesSig + ?!: InvalidKESSignature ocertKESPeriod sigKESPeriod + -- for eg. remember to run all results with possibly non-fatal errors + right $ e1 >> e2 >> e3 >> e4 >> e5 + where + startKESPeriod, endKESPeriod :: KESPeriod + + startKESPeriod = ocertKESPeriod + -- TODO: is `totalPeriodsKES` the same as `praosMaxKESEvo` + -- or `sgMaxKESEvolution` in the genesis file? + endKESPeriod = KESPeriod $ unKESPeriod startKESPeriod + + totalPeriodsKES (Proxy :: Proxy (KES crypto)) + + classifyError sigValidationError = case severity of + FailSoft -> + let mempoolAddFail = either id id (sigValidationPolicy sigValidationError) + in right . Left $ mempoolAddFail + FailDefault -> + either (const $ throwE sigValidationError) (right . Left) + (sigValidationPolicy sigValidationError) + + (?!:) :: Either e1 () + -> (e1 -> SigValidationError) + -> Except SigValidationError (Either (MempoolAddFail (Sig crypto)) ()) + (?!:) = (handleE classifyError .) . flip firstExceptT . hoistEither . fmap Right + + (?!) :: Bool + -> SigValidationError + -> Except SigValidationError (Either (MempoolAddFail (Sig crypto)) ()) + (?!) flag sve = if flag then right $ Right () else classifyError sve + + infix 1 ?! + infix 1 ?!: From e2b74a13d56fc54c4747fd66d65490c9732363b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Thu, 11 Sep 2025 16:14:44 +0200 Subject: [PATCH 14/16] mempool: adapt for generalized validation --- .../DMQ/Protocol/LocalMsgSubmission/Client.hs | 3 +- .../DMQ/Protocol/LocalMsgSubmission/Server.hs | 3 +- .../DMQ/Protocol/LocalMsgSubmission/Type.hs | 16 +- .../Network/TxSubmission/Mempool/Simple.hs | 170 +++++++++++------- 4 files changed, 110 insertions(+), 82 deletions(-) diff --git a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Client.hs b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Client.hs index 01429e66b8..dc23363fae 100644 --- a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Client.hs +++ b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Client.hs @@ -17,10 +17,11 @@ module DMQ.Protocol.LocalMsgSubmission.Client import DMQ.Protocol.LocalMsgSubmission.Type import Network.TypedProtocol.Peer.Client import Ouroboros.Network.Protocol.LocalTxSubmission.Client +import Ouroboros.Network.TxSubmission.Mempool.Simple -- | Type aliases for the high level client API -- -type LocalMsgSubmissionClient sig = LocalTxSubmissionClient sig SigMempoolFail +type LocalMsgSubmissionClient sig = LocalTxSubmissionClient sig (MempoolAddFail sig) type LocalMsgClientStIdle = LocalTxClientStIdle diff --git a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Server.hs b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Server.hs index 9a44d2b006..7936fd7894 100644 --- a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Server.hs +++ b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Server.hs @@ -18,10 +18,11 @@ module DMQ.Protocol.LocalMsgSubmission.Server import DMQ.Protocol.LocalMsgSubmission.Type import Network.TypedProtocol.Peer.Server import Ouroboros.Network.Protocol.LocalTxSubmission.Server as LocalTxSubmission +import Ouroboros.Network.TxSubmission.Mempool.Simple -- | Type aliases for the high level client API -- -type LocalMsgSubmissionServer sig = LocalTxSubmissionServer sig SigMempoolFail +type LocalMsgSubmissionServer sig = LocalTxSubmissionServer sig (MempoolAddFail sig) -- | A non-pipelined 'Peer' representing the 'LocalMsgSubmissionServer'. diff --git a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Type.hs b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Type.hs index 1271328c36..fa3f841447 100644 --- a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Type.hs +++ b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Type.hs @@ -12,22 +12,10 @@ module DMQ.Protocol.LocalMsgSubmission.Type , module Ouroboros ) where -import Data.Text (Text) import Network.TypedProtocol.Core as Core import Ouroboros.Network.Protocol.LocalTxSubmission.Type as Ouroboros -import Ouroboros.Network.Util.ShowProxy +import Ouroboros.Network.TxSubmission.Mempool.Simple -- | The LocalMsgSubmission protocol is an alias for the LocalTxSubmission -- -type LocalMsgSubmission sig = Ouroboros.LocalTxSubmission sig SigMempoolFail - --- | The type of failures when adding to the mempool --- -data SigMempoolFail = - SigInvalid Text - | SigDuplicate - | SigExpired - | SigResultOther Text - deriving (Eq, Show) - -instance ShowProxy SigMempoolFail where +type LocalMsgSubmission sig = Ouroboros.LocalTxSubmission sig (MempoolAddFail sig) diff --git a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs index 75e49ace3e..f627ca7c98 100644 --- a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs +++ b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs @@ -1,40 +1,49 @@ -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DisambiguateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeFamilies #-} -- | The module should be imported qualified. -- module Ouroboros.Network.TxSubmission.Mempool.Simple - ( Mempool (..) + ( InvalidTxsError + , MempoolAddFail + , Mempool (..) , MempoolSeq (..) + , MempoolWriter (..) , empty , new , read , getReader , getWriter + , writerAdapter ) where import Prelude hiding (read, seq) import Control.Concurrent.Class.MonadSTM.Strict -import Control.Monad (when) +import Control.DeepSeq +import Control.Exception (assert) import Control.Monad.Class.MonadThrow - +import Control.Monad.Trans.Except import Data.Bifunctor (bimap) -import Data.Either (partitionEithers) +import Data.Either import Data.Foldable (toList) import Data.Foldable qualified as Foldable -import Data.Function (on) -import Data.List (find, nubBy) +import Data.List (find) import Data.Maybe (isJust) import Data.Sequence (Seq) import Data.Sequence qualified as Seq import Data.Set (Set) import Data.Set qualified as Set -import Data.Typeable (Typeable) +import Ouroboros.Network.Protocol.LocalTxSubmission.Type (SubmitResult (..)) import Ouroboros.Network.SizeInBytes import Ouroboros.Network.TxSubmission.Inbound.V2.Types import Ouroboros.Network.TxSubmission.Mempool.Reader @@ -105,69 +114,98 @@ getReader getTxId getTxSize (Mempool mempool) = f :: Int -> tx -> (txid, Int, SizeInBytes) f idx tx = (getTxId tx, idx, getTxSize tx) +-- | type of mempool validation errors which are thrown as exceptions +-- +data family InvalidTxsError failure -data InvalidTxsError where - InvalidTxsError :: forall txid failure. - ( Typeable txid - , Typeable failure - , Show txid - , Show failure - ) - => [(txid, failure)] - -> InvalidTxsError - -deriving instance Show InvalidTxsError -instance Exception InvalidTxsError - +-- | type of mempool validation errors which are non-fatal +-- +data family MempoolAddFail tx --- | A simple mempool writer. +-- | A mempool writer which generalizes the tx submission mempool writer +-- TODO: We could replace TxSubmissionMempoolWriter with this at some point +-- +data MempoolWriter txid tx failure idx m = + MempoolWriter { + + -- | Compute the transaction id from a transaction. + -- + -- This is used in the protocol handler to verify a full transaction + -- matches a previously given transaction id. + -- + txId :: tx -> txid, + + -- | Supply a batch of transactions to the mempool. They are either + -- accepted or rejected individually, but in the order supplied. + -- + -- The 'txid's of all transactions that were added successfully are + -- returned. + mempoolAddTxs :: [tx] -> m [(txid, SubmitResult (MempoolAddFail tx))] + } + + +-- | A mempool writer with validation harness +-- PRECONDITION: no duplicates given to mempoolAddTxs -- getWriter :: forall tx txid ctx failure m. ( MonadSTM m + , Exception (InvalidTxsError failure) , MonadThrow m + -- TODO: + -- , NFData txid + -- , NFData tx + -- , NFData (MempoolAddFail tx) , Ord txid - , Typeable txid - , Typeable failure - , Show txid - , Show failure ) => (tx -> txid) -- ^ get txid of a tx -> m ctx - -- ^ monadic validation ctx - -> (ctx -> tx -> Either failure ()) - -- ^ validate a tx, any failing `tx` throws an exception. - -> (failure -> Bool) - -- ^ return `True` when a failure should throw an exception + -- ^ acquire validation context + -> ([tx] -> ctx -> Except (InvalidTxsError failure) [(Either (MempoolAddFail tx) ())]) + -- ^ validation function which should evaluate its result to normal form + -- esp. if it is 'expensive' + -> MempoolAddFail tx + -- ^ replace duplicates -> Mempool m txid tx - -> TxSubmissionMempoolWriter txid tx Int m -getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) = - TxSubmissionMempoolWriter { - txId = getTxId, - - mempoolAddTxs = \txs -> do - ctx <- getValidationCtx - (invalidTxIds, validTxs) <- atomically $ do - MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool - let (invalidTxIds, validTxs) = - bimap (filter (failureFilterFn . snd)) - (nubBy (on (==) getTxId)) - . partitionEithers - . map (\tx -> case validateTx ctx tx of - Left e -> Left (getTxId tx, e) - Right _ -> Right tx - ) - . filter (\tx -> getTxId tx `Set.notMember` mempoolSet) - $ txs - mempoolTxs' = MempoolSeq { - mempoolSet = Foldable.foldl' (\s tx -> getTxId tx `Set.insert` s) - mempoolSet - validTxs, - mempoolSeq = Foldable.foldl' (Seq.|>) mempoolSeq validTxs - } - writeTVar mempool mempoolTxs' - return (invalidTxIds, map getTxId validTxs) - when (not (null invalidTxIds)) $ - throwIO (InvalidTxsError invalidTxIds) - return validTxs - } + -> MempoolWriter txid tx failure Int m +getWriter getTxId acquireCtx validateTxs duplicateFail (Mempool mempool) = + MempoolWriter { + txId = getTxId, + + mempoolAddTxs = \txs -> assert (not . null $ txs) $ do + ctx <- acquireCtx + !vTxs <- case runExcept (validateTxs txs ctx) of + Left e -> throwIO e + Right r -> pure {-. force-} $ zipWith3 ((,,) . getTxId) txs txs r + + atomically $ do + MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool + let result = + [if duplicate then + Left . (txid,) $ SubmitFail duplicateFail + else + bimap ((txid,) . SubmitFail) (const (txid, tx)) eErrTx + | (txid, tx, eErrTx) <- vTxs + , let duplicate = txid `Set.member` mempoolSet + ] + (validIds, validTxs) = unzip . rights $ result + mempoolTxs' = MempoolSeq { + mempoolSet = Set.union mempoolSet (Set.fromList validIds), + mempoolSeq = Foldable.foldl' (Seq.|>) mempoolSeq validTxs + } + writeTVar mempool mempoolTxs' + return $ either id ((,SubmitSuccess) . fst) <$> result + } + + +-- | Takes the general mempool writer defined here +-- and adapts it to the API of the tx submission mempool writer +-- to avoid more breaking changes for now. +-- +writerAdapter :: (Functor m) + => MempoolWriter txid tx failure idx m + -> TxSubmissionMempoolWriter txid tx idx m +writerAdapter MempoolWriter { txId, mempoolAddTxs } = + TxSubmissionMempoolWriter { txId, mempoolAddTxs = adapter } + where + adapter = fmap (fmap fst) . mempoolAddTxs From b7023f42655b7e7584ebccfc190bf826353009be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Sun, 14 Sep 2025 09:49:51 +0200 Subject: [PATCH 15/16] localmsgsubmission: codec and server changes --- .../DMQ/NodeToClient/LocalMsgSubmission.hs | 44 ++++++++++--------- .../DMQ/Protocol/LocalMsgSubmission/Codec.hs | 9 ++-- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs b/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs index 5004a313f1..b5902d0208 100644 --- a/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs +++ b/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs @@ -1,42 +1,46 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE StandaloneDeriving #-} + module DMQ.NodeToClient.LocalMsgSubmission where import Control.Concurrent.Class.MonadSTM import Control.Tracer -import Data.Maybe import DMQ.Protocol.LocalMsgSubmission.Server import DMQ.Protocol.LocalMsgSubmission.Type -import Ouroboros.Network.TxSubmission.Inbound.V2 +import Ouroboros.Network.TxSubmission.Mempool.Simple -- | Local transaction submission server, for adding txs to the 'Mempool' -- localMsgSubmissionServer :: MonadSTM m - => Tracer m (TraceLocalMsgSubmission msg msgid SigMempoolFail) - -> TxSubmissionMempoolWriter msgid msg idx m - -> m (LocalMsgSubmissionServer msg m ()) -localMsgSubmissionServer tracer TxSubmissionMempoolWriter { mempoolAddTxs } = + => Tracer m (TraceLocalMsgSubmission sig sigid) + -> MempoolWriter sigid sig failure idx m + -- ^ duplicate error tag in case the mempool returns the empty list on failure + -> m (LocalMsgSubmissionServer sig m ()) +localMsgSubmissionServer tracer MempoolWriter { mempoolAddTxs } = pure server where - failure = - -- TODO remove dummy hardcode when mempool returns reason - (SubmitFail SigExpired, server) <$ traceWith tracer (TraceSubmitFailure SigExpired) - success msgid = - (SubmitSuccess, server) <$ traceWith tracer (TraceSubmitAccept msgid) + process (sigid, e@(SubmitFail reason)) = + (e, server) <$ traceWith tracer (TraceSubmitFailure sigid reason) + process (sigid, success) = + (success, server) <$ traceWith tracer (TraceSubmitAccept sigid) server = LocalTxSubmissionServer { - recvMsgSubmitTx = \msg -> do - traceWith tracer $ TraceReceivedMsg msg - -- TODO mempool should return 'SubmitResult' - maybe failure success . listToMaybe =<< mempoolAddTxs [msg] + recvMsgSubmitTx = \sig -> do + traceWith tracer $ TraceReceivedMsg sig + process . head =<< mempoolAddTxs [sig] , recvMsgDone = () } -data TraceLocalMsgSubmission msg msgid reject = - TraceReceivedMsg msg +data TraceLocalMsgSubmission sig sigid = + TraceReceivedMsg sig -- ^ A transaction was received. - | TraceSubmitFailure reject - | TraceSubmitAccept msgid - deriving Show + | TraceSubmitFailure sigid (MempoolAddFail sig) + | TraceSubmitAccept sigid + +deriving instance + (Show sig, Show sigid, Show (MempoolAddFail sig)) + => Show (TraceLocalMsgSubmission sig sigid) diff --git a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Codec.hs b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Codec.hs index 8a010bff40..5172495191 100644 --- a/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Codec.hs +++ b/dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Codec.hs @@ -17,6 +17,7 @@ import Cardano.KESAgent.KES.Crypto (Crypto (..)) import DMQ.Protocol.LocalMsgSubmission.Type import DMQ.Protocol.SigSubmission.Codec qualified as SigSubmission import DMQ.Protocol.SigSubmission.Type (Sig (..)) +import DMQ.Protocol.SigSubmission.Validate import Network.TypedProtocol.Codec.CBOR import Ouroboros.Network.Protocol.LocalTxSubmission.Codec qualified as LTX @@ -26,13 +27,13 @@ codecLocalMsgSubmission ( MonadST m , Crypto crypto ) - => (SigMempoolFail -> CBOR.Encoding) - -> (forall s. CBOR.Decoder s SigMempoolFail) + => (MempoolAddFail (Sig crypto) -> CBOR.Encoding) + -> (forall s. CBOR.Decoder s (MempoolAddFail (Sig crypto))) -> AnnotatedCodec (LocalMsgSubmission (Sig crypto)) CBOR.DeserialiseFailure m ByteString codecLocalMsgSubmission = LTX.anncodecLocalTxSubmission' SigWithBytes SigSubmission.encodeSig SigSubmission.decodeSig -encodeReject :: SigMempoolFail -> CBOR.Encoding +encodeReject :: MempoolAddFail (Sig crypto) -> CBOR.Encoding encodeReject = \case SigInvalid reason -> CBOR.encodeListLen 2 <> CBOR.encodeWord 0 <> CBOR.encodeString reason SigDuplicate -> CBOR.encodeListLen 1 <> CBOR.encodeWord 1 @@ -40,7 +41,7 @@ encodeReject = \case SigResultOther reason -> CBOR.encodeListLen 2 <> CBOR.encodeWord 3 <> CBOR.encodeString reason -decodeReject :: CBOR.Decoder s SigMempoolFail +decodeReject :: CBOR.Decoder s (MempoolAddFail (Sig crypto)) decodeReject = do len <- CBOR.decodeListLen tag <- CBOR.decodeWord From d22bd657de37bfdab90c41196aded8d049d9d2da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C3=B3jtowicz?= Date: Thu, 11 Sep 2025 16:14:51 +0200 Subject: [PATCH 16/16] app: integration --- dmq-node/app/Main.hs | 105 ++++++++++++++++++------------- dmq-node/dmq-node.cabal | 2 +- dmq-node/src/DMQ/NodeToClient.hs | 12 ++-- dmq-node/src/DMQ/NodeToNode.hs | 34 ++++------ 4 files changed, 80 insertions(+), 73 deletions(-) diff --git a/dmq-node/app/Main.hs b/dmq-node/app/Main.hs index a4c209e65c..fd4574fde4 100644 --- a/dmq-node/app/Main.hs +++ b/dmq-node/app/Main.hs @@ -5,7 +5,6 @@ module Main where import Control.Monad (void) -import Control.Monad.Class.MonadAsync import Control.Tracer (Tracer (..), nullTracer, traceWith) import Data.Act @@ -16,6 +15,8 @@ import Options.Applicative import System.Random (newStdGen, split) import Cardano.KESAgent.Protocols.StandardCrypto (StandardCrypto) +import Cardano.Ledger.Keys (VKey (..)) +import Cardano.Ledger.Hashes (hashKey) import DMQ.Configuration import DMQ.Configuration.CLIOptions (parseCLIOptions) @@ -31,9 +32,11 @@ import DMQ.Tracer import DMQ.Diffusion.PeerSelection (policy) import DMQ.NodeToClient.LocalStateQueryClient +import DMQ.Protocol.SigSubmission.Validate import Ouroboros.Network.Diffusion qualified as Diffusion import Ouroboros.Network.PeerSelection.PeerSharing.Codec (decodeRemoteAddress, encodeRemoteAddress) +import Ouroboros.Network.SizeInBytes import Ouroboros.Network.Snocket import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool @@ -78,50 +81,62 @@ runDMQ commandLineConfig = do diffusionTracers = dmqDiffusionTracers dmqConfig tracer Diffusion.withIOManager \iocp -> do - let localSnocket' = localSnocket iocp + let localSnocket' = localSnocket iocp + mkStakePoolMonitor = connectToCardanoNode tracer localSnocket' snocketPath - withNodeKernel @StandardCrypto psRng $ \nodeKernel -> do + withNodeKernel @StandardCrypto psRng mkStakePoolMonitor \nodeKernel -> do dmqDiffusionConfiguration <- mkDiffusionConfiguration dmqConfig nt - let stakePoolMonitor = connectToCardanoNode tracer localSnocket' snocketPath nodeKernel - - withAsync stakePoolMonitor \aid -> do - link aid - let dmqNtNApps = - ntnApps tracer - dmqConfig - nodeKernel - (dmqCodecs - -- TODO: `maxBound :: Cardano.Network.NodeToNode.NodeToNodeVersion` - -- is unsafe here! - (encodeRemoteAddress maxBound) - (decodeRemoteAddress maxBound)) - dmqLimitsAndTimeouts - defaultSigDecisionPolicy - dmqNtCApps = - let sigSize _ = 0 -- TODO - maxMsgs = 1000 -- TODO: make this negotiated in the handshake? - mempoolReader = Mempool.getReader sigId sigSize (mempool nodeKernel) - mempoolWriter = Mempool.getWriter sigId (const ()) (\_ _ -> pure True) (mempool nodeKernel) - in NtC.ntcApps mempoolReader mempoolWriter maxMsgs - (NtC.dmqCodecs encodeReject decodeReject) - dmqDiffusionArguments = - diffusionArguments (if handshakeTracer - then WithEventType "Handshake" >$< tracer - else nullTracer) - (if localHandshakeTracer - then WithEventType "Handshake" >$< tracer - else nullTracer) - dmqDiffusionApplications = - diffusionApplications nodeKernel - dmqConfig - dmqDiffusionConfiguration - dmqLimitsAndTimeouts - dmqNtNApps - dmqNtCApps - (policy policyRng) - - Diffusion.run dmqDiffusionArguments - diffusionTracers - dmqDiffusionConfiguration - dmqDiffusionApplications + let sigSize :: Sig StandardCrypto -> SizeInBytes + sigSize _ = 0 -- TODO + mempoolReader = Mempool.getReader sigId sigSize (mempool nodeKernel) + dmqNtNApps = + let ntnMempoolWriter = Mempool.writerAdapter $ + Mempool.getWriter sigId + (poolValidationCtx $ stakePools nodeKernel) + (validateSig FailDefault (hashKey . VKey)) + SigDuplicate + (mempool nodeKernel) + in ntnApps tracer + dmqConfig + mempoolReader + ntnMempoolWriter + sigSize + nodeKernel + (dmqCodecs + -- TODO: `maxBound :: Cardano.Network.NodeToNode.NodeToNodeVersion` + -- is unsafe here! + (encodeRemoteAddress maxBound) + (decodeRemoteAddress maxBound)) + dmqLimitsAndTimeouts + defaultSigDecisionPolicy + dmqNtCApps = + let maxMsgs = 1000 -- TODO: make this negotiated in the handshake? + ntcMempoolWriter = + Mempool.getWriter sigId + (poolValidationCtx $ stakePools nodeKernel) + (validateSig FailSoft (hashKey . VKey)) + SigDuplicate + (mempool nodeKernel) + in NtC.ntcApps mempoolReader ntcMempoolWriter maxMsgs + (NtC.dmqCodecs encodeReject decodeReject) + dmqDiffusionArguments = + diffusionArguments (if handshakeTracer + then WithEventType "Handshake" >$< tracer + else nullTracer) + (if localHandshakeTracer + then WithEventType "Handshake" >$< tracer + else nullTracer) + dmqDiffusionApplications = + diffusionApplications nodeKernel + dmqConfig + dmqDiffusionConfiguration + dmqLimitsAndTimeouts + dmqNtNApps + dmqNtCApps + (policy policyRng) + + Diffusion.run dmqDiffusionArguments + diffusionTracers + dmqDiffusionConfiguration + dmqDiffusionApplications diff --git a/dmq-node/dmq-node.cabal b/dmq-node/dmq-node.cabal index fad40d53b6..d74a9e996c 100644 --- a/dmq-node/dmq-node.cabal +++ b/dmq-node/dmq-node.cabal @@ -135,9 +135,9 @@ executable dmq-node acts, aeson, base, + cardano-ledger-core, contra-tracer >=0.1 && <0.3, dmq-node, - io-classes, kes-agent-crypto, optparse-applicative, ouroboros-network, diff --git a/dmq-node/src/DMQ/NodeToClient.hs b/dmq-node/src/DMQ/NodeToClient.hs index edb252c0f8..fa79861370 100644 --- a/dmq-node/src/DMQ/NodeToClient.hs +++ b/dmq-node/src/DMQ/NodeToClient.hs @@ -43,6 +43,7 @@ import DMQ.Protocol.LocalMsgSubmission.Codec import DMQ.Protocol.LocalMsgSubmission.Server import DMQ.Protocol.LocalMsgSubmission.Type import DMQ.Protocol.SigSubmission.Type (Sig) +import DMQ.Protocol.SigSubmission.Validate import Ouroboros.Network.Context import Ouroboros.Network.Driver.Simple @@ -52,9 +53,8 @@ import Ouroboros.Network.Mux import Ouroboros.Network.Protocol.Handshake (Handshake, HandshakeArguments (..)) import Ouroboros.Network.Protocol.Handshake.Codec (cborTermVersionDataCodec, codecHandshake, noTimeLimitsHandshake) -import Ouroboros.Network.TxSubmission.Inbound.V2.Types - (TxSubmissionMempoolWriter) import Ouroboros.Network.TxSubmission.Mempool.Reader +import Ouroboros.Network.TxSubmission.Mempool.Simple import Ouroboros.Network.Util.ShowProxy @@ -95,8 +95,8 @@ data Codecs m sig = dmqCodecs :: ( MonadST m , Crypto crypto ) - => (SigMempoolFail -> CBOR.Encoding) - -> (forall s. CBOR.Decoder s SigMempoolFail) + => (MempoolAddFail (Sig crypto) -> CBOR.Encoding) + -> (forall s. CBOR.Decoder s (MempoolAddFail (Sig crypto))) -> Codecs m (Sig crypto) dmqCodecs encodeReject' decodeReject' = Codecs { @@ -127,9 +127,9 @@ data Apps ntcAddr m a = -- | Construct applications for the node-to-client protocols -- ntcApps - :: (MonadThrow m, MonadThread m, MonadSTM m, ShowProxy SigMempoolFail, ShowProxy sig) + :: (MonadThrow m, MonadThread m, MonadSTM m, ShowProxy (MempoolAddFail sig), ShowProxy sig) => TxSubmissionMempoolReader msgid sig idx m - -> TxSubmissionMempoolWriter msgid sig idx m + -> MempoolWriter msgid sig failure idx m -> Word16 -> Codecs m sig -> Apps ntcAddr m () diff --git a/dmq-node/src/DMQ/NodeToNode.hs b/dmq-node/src/DMQ/NodeToNode.hs index 9e54de6205..17da4f673d 100644 --- a/dmq-node/src/DMQ/NodeToNode.hs +++ b/dmq-node/src/DMQ/NodeToNode.hs @@ -90,7 +90,7 @@ import Ouroboros.Network.PeerSharing (bracketPeerSharingClient, peerSharingClient, peerSharingServer) import Ouroboros.Network.Snocket (RemoteAddress) import Ouroboros.Network.TxSubmission.Inbound.V2 as SigSubmission -import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool +import Ouroboros.Network.TxSubmission.Mempool.Reader import Ouroboros.Network.TxSubmission.Outbound import Ouroboros.Network.OrphanInstances () @@ -150,12 +150,12 @@ data Apps addr m a b = } ntnApps - :: forall crypto m addr . + :: forall crypto m addr idx. ( Crypto crypto - , DSIGN.ContextDSIGN (DSIGN crypto) ~ () - , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) - , KES.ContextKES (KES crypto) ~ () - , KES.Signable (KES crypto) BS.ByteString + -- , DSIGN.ContextDSIGN (DSIGN crypto) ~ () + -- , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) + -- , KES.ContextKES (KES crypto) ~ () + -- , KES.Signable (KES crypto) BS.ByteString , Typeable crypto , Alternative (STM m) , MonadAsync m @@ -166,12 +166,16 @@ ntnApps , MonadThrow (STM m) , MonadTimer m , Ord addr + , Ord idx , Show addr , Hashable addr , Aeson.ToJSON addr ) => (forall ev. Aeson.ToJSON ev => Tracer m (WithEventType ev)) -> Configuration + -> TxSubmissionMempoolReader SigId (Sig crypto) idx m + -> TxSubmissionMempoolWriter SigId (Sig crypto) idx m + -> (Sig crypto -> SizeInBytes) -> NodeKernel crypto addr m -> Codecs crypto addr m -> LimitsAndTimeouts crypto addr @@ -187,11 +191,13 @@ ntnApps , dmqcPeerSharingClientTracer = I peerSharingClientTracer , dmqcPeerSharingServerTracer = I peerSharingServerTracer } + mempoolReader + mempoolWriter + sigSize NodeKernel { fetchClientRegistry , peerSharingRegistry , peerSharingAPI - , mempool , sigChannelVar , sigMempoolSem , sigSharedTxStateVar @@ -220,20 +226,6 @@ ntnApps , aPeerSharingServer } where - sigSize :: Sig crypto -> SizeInBytes - sigSize _ = 0 -- TODO - - mempoolReader = Mempool.getReader sigId sigSize mempool - -- TODO: invalid signatures are just omitted from the mempool. For DMQ - -- we need to validate signatures when we received them, and shutdown - -- connection if we receive one, rather than validate them in the - -- mempool. - mempoolWriter = Mempool.getWriter sigId - (pure ()) -- TODO not needed - (\_ -> validateSig) - (\_ -> True) - mempool - aSigSubmissionClient :: NodeToNodeVersion -> ExpandedInitiatorContext addr m