diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8c8f87ddf1b..38b12606aec 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -147,6 +147,8 @@ jobs: - run: cargo-zigbuild clippy --lib --bins --all-features --target x86_64-unknown-linux-gnu -- -Wclippy::indexing_slicing -D warnings # Check that compiles for the supported linux targets (aarch64) - run: cargo-zigbuild clippy --lib --bins --all-features --target aarch64-unknown-linux-gnu -- -Wclippy::indexing_slicing -D warnings + # Check whether `mirrord-operator` crate compiles the way it's used in the operator + - run: cargo-zigbuild check -p mirrord-operator --features crd --target x86_64-unknown-linux-gnu # if the branch is named is `x.x.x`, x ∈ [0, 9], then it's a release branch # the output of this test is a boolean indicating if it's a release branch @@ -322,6 +324,10 @@ jobs: run: cargo test --target x86_64-unknown-linux-gnu -p mirrord-kube --all-features - name: mirrord intproxy UT run: cargo test --target x86_64-unknown-linux-gnu -p mirrord-intproxy + - name: mirrord auth UT + run: cargo test --target x86_64-unknown-linux-gnu -p mirrord-auth + - name: mirrord operator UT + run: cargo test --target x86_64-unknown-linux-gnu -p mirrord-operator --features "crd, client" macos_tests: runs-on: macos-13 diff --git a/CHANGELOG.md b/CHANGELOG.md index fa143d55f48..04391931837 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,112 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang +## [3.111.0](https://github.com/metalbear-co/mirrord/tree/3.111.0) - 2024-07-17 + + +### Added + +- Extended `feature.network.dns` config with an optional local/remote filter, + following `feature.network.outgoing` pattern. + [#2581](https://github.com/metalbear-co/mirrord/issues/2581) + + +### Fixed + +- Update loopback detection to include pod ip's + [#2572](https://github.com/metalbear-co/mirrord/issues/2572) +- Fixed a bug where enabling remote DNS prevented making a local connection + with telnet. [#2579](https://github.com/metalbear-co/mirrord/issues/2579) +- Remove automatic ignore of incoming/outgoing traffic for ports 50000-60000 + [#2597](https://github.com/metalbear-co/mirrord/issues/2597) + + +### Internal + +- Add test to ensure empty streamed request doesn't hang if empty + [#2593](https://github.com/metalbear-co/mirrord/issues/2593) + +## [3.110.0](https://github.com/metalbear-co/mirrord/tree/3.110.0) - 2024-07-12 + + +### Added + +- Added experimental.trust_any_certificate to enable making app trust any + certificate on macOS + [#2576](https://github.com/metalbear-co/mirrord/issues/2576) + + +### Fixed + +- Fix empty request streaming hanging forever + [#2590](https://github.com/metalbear-co/mirrord/issues/2590) + +## [3.109.0](https://github.com/metalbear-co/mirrord/tree/3.109.0) - 2024-07-10 + + +### Changed + +- mirrord commands now provide a nicer error message when the operator required + but not installed. + [#1730](https://github.com/metalbear-co/mirrord/issues/1730) +- Add Unknown target variant for forwards compatibility. + [#2515](https://github.com/metalbear-co/mirrord/issues/2515) + + +### Fixed + +- Improved agent performance when mirroring is under high load. + [#2529](https://github.com/metalbear-co/mirrord/issues/2529) +- Don't include non-running pods in node capacity check + [#2582](https://github.com/metalbear-co/mirrord/issues/2582) +- Add exclusion for DOTNET_EnableDiagnostics to make DotNet debugging work by + default + + +### Internal + +- CLI now sends additional headers with each request to the mirrord operator. + [#2466](https://github.com/metalbear-co/mirrord/issues/2466) +- Add mirrord-operator-apiserver-authentication `Role` and `RoleBinding` to + fetch `extension-apiserver-authentication` configmap from "kube-system". +- Fixed compilation errors in `mirrord-operator` crate with only `crd` feature + enabled. +- Fixed compilation of `mirrord-operator` crate with no features. +- Updated `x509-certificate` dependency. + + +## [3.108.0](https://github.com/metalbear-co/mirrord/tree/3.108.0) - 2024-07-02 + + +### Added + +- Added support for streaming HTTP responses. + [#2557](https://github.com/metalbear-co/mirrord/issues/2557) + + +### Changed + +- Changed http path filter to include query params in match + [#2551](https://github.com/metalbear-co/mirrord/issues/2551) +- Configuration documentation contents order. +- Errors that occur when using discovery API to detect mirrord operator are no + longer fatal. When such error is encountered, mirrord command falls back to + using the OSS version. + + +### Fixed + +- When using mesh use `lo` interface for mirroring traffic. + [#2452](https://github.com/metalbear-co/mirrord/issues/2452) + + +### Internal + +- Correct version of HTTP response is sent based on agent protocol version. + [#2562](https://github.com/metalbear-co/mirrord/issues/2562) +- `mirrord-intproxy` crate unit tests are now part of the CI. + + ## [3.107.0](https://github.com/metalbear-co/mirrord/tree/3.107.0) - 2024-06-25 diff --git a/Cargo.lock b/Cargo.lock index a371188366a..72dddfdb8f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,7 +362,7 @@ dependencies = [ "once_cell", "p12", "p256", - "pem 3.0.4", + "pem", "pkcs1", "pkcs8", "plist", @@ -371,7 +371,7 @@ dependencies = [ "rayon", "regex", "reqwest 0.11.27", - "ring 0.17.8", + "ring", "rsa", "scroll", "security-framework", @@ -395,7 +395,7 @@ dependencies = [ "widestring", "windows-sys 0.52.0", "x509", - "x509-certificate 0.23.1", + "x509-certificate", "xml-rs", "yasna", "zeroize", @@ -443,7 +443,7 @@ dependencies = [ "signature", "thiserror", "url", - "x509-certificate 0.23.1", + "x509-certificate", "xml-rs", "xz2", ] @@ -1493,11 +1493,11 @@ dependencies = [ "bytes", "chrono", "hex", - "pem 3.0.4", + "pem", "reqwest 0.11.27", - "ring 0.17.8", + "ring", "signature", - "x509-certificate 0.23.1", + "x509-certificate", ] [[package]] @@ -1976,6 +1976,15 @@ dependencies = [ "log", ] +[[package]] +name = "envy" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f47e0157f2cb54f5ae1bd371b30a2ae4311e1c028f575cd4e81de7353215965" +dependencies = [ + "serde", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -2104,7 +2113,7 @@ dependencies = [ [[package]] name = "fileops" -version = "3.107.0" +version = "3.111.0" dependencies = [ "libc", ] @@ -3218,7 +3227,7 @@ checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" [[package]] name = "issue1317" -version = "3.107.0" +version = "3.111.0" dependencies = [ "actix-web", "env_logger 0.11.3", @@ -3229,7 +3238,7 @@ dependencies = [ [[package]] name = "issue1776" -version = "3.107.0" +version = "3.111.0" dependencies = [ "errno 0.3.9", "libc", @@ -3238,7 +3247,7 @@ dependencies = [ [[package]] name = "issue1776portnot53" -version = "3.107.0" +version = "3.111.0" dependencies = [ "libc", "socket2", @@ -3246,14 +3255,14 @@ dependencies = [ [[package]] name = "issue1899" -version = "3.107.0" +version = "3.111.0" dependencies = [ "libc", ] [[package]] name = "issue2001" -version = "3.107.0" +version = "3.111.0" dependencies = [ "libc", ] @@ -3409,8 +3418,7 @@ dependencies = [ [[package]] name = "kube" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12dc4487eda98835dcaa7ac92a14165446db29dbd67a743c79fe9f41bf38ee72" +source = "git+https://github.com/kube-rs/kube?rev=f9902f1439b3c0baafc2ece1680644c2bfade742#f9902f1439b3c0baafc2ece1680644c2bfade742" dependencies = [ "k8s-openapi", "kube-client", @@ -3422,8 +3430,7 @@ dependencies = [ [[package]] name = "kube-client" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "408f35eab36927d3b883e4ad54c3080ea8c49f899ac84a7856e7182e4ee3b392" +source = "git+https://github.com/kube-rs/kube?rev=f9902f1439b3c0baafc2ece1680644c2bfade742#f9902f1439b3c0baafc2ece1680644c2bfade742" dependencies = [ "base64 0.22.1", "bytes", @@ -3444,7 +3451,7 @@ dependencies = [ "jsonpath-rust", "k8s-openapi", "kube-core", - "pem 3.0.4", + "pem", "rand", "rustls 0.23.10", "rustls-pemfile 2.1.2", @@ -3464,8 +3471,7 @@ dependencies = [ [[package]] name = "kube-core" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f776624097c1e09e72eb1e9e0c2bb5d17d97c27a6a87734390a9fba246a8f67f" +source = "git+https://github.com/kube-rs/kube?rev=f9902f1439b3c0baafc2ece1680644c2bfade742#f9902f1439b3c0baafc2ece1680644c2bfade742" dependencies = [ "chrono", "form_urlencoded", @@ -3481,8 +3487,7 @@ dependencies = [ [[package]] name = "kube-derive" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae07adfd7d21b7fa582789206391243f98e155b46c806eb494839569853bcfd" +source = "git+https://github.com/kube-rs/kube?rev=f9902f1439b3c0baafc2ece1680644c2bfade742#f9902f1439b3c0baafc2ece1680644c2bfade742" dependencies = [ "darling", "proc-macro2", @@ -3494,8 +3499,7 @@ dependencies = [ [[package]] name = "kube-runtime" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12e5933f2d429f3a05d4cb67f935b25c94a133b0baeb558ab3917c270a11f6ef" +source = "git+https://github.com/kube-rs/kube?rev=f9902f1439b3c0baafc2ece1680644c2bfade742#f9902f1439b3c0baafc2ece1680644c2bfade742" dependencies = [ "ahash", "async-broadcast", @@ -3592,7 +3596,7 @@ checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "listen_ports" -version = "3.107.0" +version = "3.111.0" [[package]] name = "litemap" @@ -3814,7 +3818,7 @@ checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" [[package]] name = "mirrord" -version = "3.107.0" +version = "3.111.0" dependencies = [ "actix-codec", "anyhow", @@ -3860,7 +3864,7 @@ dependencies = [ [[package]] name = "mirrord-agent" -version = "3.107.0" +version = "3.111.0" dependencies = [ "actix-codec", "async-trait", @@ -3871,6 +3875,7 @@ dependencies = [ "dashmap", "drain", "enum_dispatch", + "envy", "faccess", "fancy-regex", "futures", @@ -3894,6 +3899,7 @@ dependencies = [ "rawsocket", "rcgen", "regex", + "rstest", "rustls 0.23.10", "semver 1.0.23", "serde", @@ -3916,7 +3922,7 @@ dependencies = [ [[package]] name = "mirrord-analytics" -version = "3.107.0" +version = "3.111.0" dependencies = [ "assert-json-diff", "base64 0.22.1", @@ -3930,28 +3936,28 @@ dependencies = [ [[package]] name = "mirrord-auth" -version = "3.107.0" +version = "3.111.0" dependencies = [ + "bcder", "chrono", "fs4", "home", "k8s-openapi", "kube", - "pem 3.0.4", + "pem", "reqwest 0.12.5", - "ring 0.16.20", "serde", "serde_yaml", "thiserror", "tokio", "tracing", "whoami", - "x509-certificate 0.21.0", + "x509-certificate", ] [[package]] name = "mirrord-config" -version = "3.107.0" +version = "3.111.0" dependencies = [ "bimap", "bitflags 2.5.0", @@ -3973,7 +3979,7 @@ dependencies = [ [[package]] name = "mirrord-config-derive" -version = "3.107.0" +version = "3.111.0" dependencies = [ "proc-macro2", "proc-macro2-diagnostics", @@ -3983,7 +3989,7 @@ dependencies = [ [[package]] name = "mirrord-console" -version = "3.107.0" +version = "3.111.0" dependencies = [ "bincode", "drain", @@ -3999,7 +4005,7 @@ dependencies = [ [[package]] name = "mirrord-intproxy" -version = "3.107.0" +version = "3.111.0" dependencies = [ "bytes", "futures", @@ -4013,6 +4019,7 @@ dependencies = [ "mirrord-operator", "mirrord-protocol", "rand", + "semver 1.0.23", "serde", "thiserror", "tokio", @@ -4022,7 +4029,7 @@ dependencies = [ [[package]] name = "mirrord-intproxy-protocol" -version = "3.107.0" +version = "3.111.0" dependencies = [ "bincode", "mirrord-protocol", @@ -4032,7 +4039,7 @@ dependencies = [ [[package]] name = "mirrord-kube" -version = "3.107.0" +version = "3.111.0" dependencies = [ "actix-codec", "base64 0.22.1", @@ -4059,7 +4066,7 @@ dependencies = [ [[package]] name = "mirrord-layer" -version = "3.107.0" +version = "3.111.0" dependencies = [ "actix-codec", "anyhow", @@ -4115,7 +4122,7 @@ dependencies = [ [[package]] name = "mirrord-layer-macro" -version = "3.107.0" +version = "3.111.0" dependencies = [ "proc-macro2", "quote", @@ -4124,7 +4131,7 @@ dependencies = [ [[package]] name = "mirrord-macros" -version = "3.107.0" +version = "3.111.0" dependencies = [ "proc-macro2", "proc-macro2-diagnostics", @@ -4134,7 +4141,7 @@ dependencies = [ [[package]] name = "mirrord-operator" -version = "3.107.0" +version = "3.111.0" dependencies = [ "actix-codec", "async-trait", @@ -4171,7 +4178,7 @@ dependencies = [ [[package]] name = "mirrord-progress" -version = "3.107.0" +version = "3.111.0" dependencies = [ "enum_dispatch", "indicatif", @@ -4181,7 +4188,7 @@ dependencies = [ [[package]] name = "mirrord-protocol" -version = "1.7.0" +version = "1.8.1" dependencies = [ "actix-codec", "bincode", @@ -4205,7 +4212,7 @@ dependencies = [ [[package]] name = "mirrord-sip" -version = "3.107.0" +version = "3.111.0" dependencies = [ "apple-codesign", "memchr", @@ -4542,7 +4549,7 @@ dependencies = [ [[package]] name = "outgoing" -version = "3.107.0" +version = "3.111.0" [[package]] name = "overload" @@ -4651,16 +4658,6 @@ dependencies = [ "syn 2.0.66", ] -[[package]] -name = "pem" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b13fe415cdf3c8e44518e18a7c95a13431d9bdf6d15367d82b23c377fdd441a" -dependencies = [ - "base64 0.21.7", - "serde", -] - [[package]] name = "pem" version = "3.0.4" @@ -5180,7 +5177,7 @@ checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" dependencies = [ "bytes", "rand", - "ring 0.17.8", + "ring", "rustc-hash", "rustls 0.23.10", "slab", @@ -5333,8 +5330,8 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779" dependencies = [ - "pem 3.0.4", - "ring 0.17.8", + "pem", + "ring", "rustls-pki-types", "time", "yasna", @@ -5530,21 +5527,6 @@ dependencies = [ "quick-error", ] -[[package]] -name = "ring" -version = "0.16.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" -dependencies = [ - "cc", - "libc", - "once_cell", - "spin 0.5.2", - "untrusted 0.7.1", - "web-sys", - "winapi", -] - [[package]] name = "ring" version = "0.17.8" @@ -5556,7 +5538,7 @@ dependencies = [ "getrandom", "libc", "spin 0.9.8", - "untrusted 0.9.0", + "untrusted", "windows-sys 0.52.0", ] @@ -5612,21 +5594,21 @@ dependencies = [ [[package]] name = "rust-bypassed-unix-socket" -version = "3.107.0" +version = "3.111.0" dependencies = [ "tokio", ] [[package]] name = "rust-e2e-fileops" -version = "3.107.0" +version = "3.111.0" dependencies = [ "libc", ] [[package]] name = "rust-unix-socket-client" -version = "3.107.0" +version = "3.111.0" dependencies = [ "tokio", ] @@ -5689,7 +5671,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", - "ring 0.17.8", + "ring", "rustls-webpki 0.101.7", "sct", ] @@ -5701,7 +5683,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" dependencies = [ "log", - "ring 0.17.8", + "ring", "rustls-pki-types", "rustls-webpki 0.102.4", "subtle", @@ -5717,7 +5699,7 @@ dependencies = [ "aws-lc-rs", "log", "once_cell", - "ring 0.17.8", + "ring", "rustls-pki-types", "rustls-webpki 0.102.4", "subtle", @@ -5780,8 +5762,8 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", + "ring", + "untrusted", ] [[package]] @@ -5791,9 +5773,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" dependencies = [ "aws-lc-rs", - "ring 0.17.8", + "ring", "rustls-pki-types", - "untrusted 0.9.0", + "untrusted", ] [[package]] @@ -5913,8 +5895,8 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", + "ring", + "untrusted", ] [[package]] @@ -6783,6 +6765,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] @@ -7251,12 +7234,6 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" -[[package]] -name = "untrusted" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" - [[package]] name = "untrusted" version = "0.9.0" @@ -7773,24 +7750,6 @@ dependencies = [ "cookie-factory", ] -[[package]] -name = "x509-certificate" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5d27c90840e84503cf44364de338794d5d5680bdd1da6272d13f80b0769ee0" -dependencies = [ - "bcder", - "bytes", - "chrono", - "der", - "hex", - "pem 2.0.1", - "ring 0.16.20", - "signature", - "spki", - "thiserror", -] - [[package]] name = "x509-certificate" version = "0.23.1" @@ -7802,8 +7761,8 @@ dependencies = [ "chrono", "der", "hex", - "pem 3.0.4", - "ring 0.17.8", + "pem", + "ring", "signature", "spki", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index ad889ca8029..84c217bdc73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ resolver = "2" # latest commits on rustls suppress certificate verification [workspace.package] -version = "3.107.0" +version = "3.111.0" edition = "2021" license = "MIT" readme = "README.md" @@ -56,7 +56,7 @@ actix-codec = "0.5" bincode = { version = "2.0.0-rc.2", features = ["serde"] } bytes = "1" tokio = { version = "1" } -tokio-stream = "0.1" +tokio-stream = { version = "0.1", features = ["sync"] } serde = { version = "1", features = ["derive"] } serde_json = "1" anyhow = "1" @@ -74,7 +74,7 @@ reqwest = { version = "0.12", default-features = false, features = [ "socks", "http2", ] } -kube = { version = "0.92", default-features = false, features = [ +kube = { git = "https://github.com/kube-rs/kube", rev = "f9902f1439b3c0baafc2ece1680644c2bfade742", default-features = false, features = [ "runtime", "derive", "client", @@ -83,7 +83,7 @@ kube = { version = "0.92", default-features = false, features = [ "oidc", "socks5", "http-proxy", -]} +] } hickory-resolver = { version = "0.24", features = [ "serde-config", "tokio-runtime", diff --git a/changelog.d/+fallible-operator-detection.changed.md b/changelog.d/+fallible-operator-detection.changed.md deleted file mode 100644 index 7c49409bfc5..00000000000 --- a/changelog.d/+fallible-operator-detection.changed.md +++ /dev/null @@ -1 +0,0 @@ -Errors that occur when using discovery API to detect mirrord operator are no longer fatal. When such error is encountered, mirrord command falls back to using the OSS version. \ No newline at end of file diff --git a/changelog.d/+medschool-order.changed.md b/changelog.d/+medschool-order.changed.md deleted file mode 100644 index c9c24b55104..00000000000 --- a/changelog.d/+medschool-order.changed.md +++ /dev/null @@ -1 +0,0 @@ -Configuration documentation contents order. diff --git a/changelog.d/+run-intproxy-tests-in-ci.internal.md b/changelog.d/+run-intproxy-tests-in-ci.internal.md deleted file mode 100644 index 760ea58f769..00000000000 --- a/changelog.d/+run-intproxy-tests-in-ci.internal.md +++ /dev/null @@ -1 +0,0 @@ -`mirrord-intproxy` crate unit tests are now part of the CI. \ No newline at end of file diff --git a/changelog.d/2452.fixed.md b/changelog.d/2452.fixed.md deleted file mode 100644 index de8dd48623d..00000000000 --- a/changelog.d/2452.fixed.md +++ /dev/null @@ -1 +0,0 @@ -When using mesh use `lo` interface for mirroring traffic. diff --git a/changelog.d/2551.changed.md b/changelog.d/2551.changed.md deleted file mode 100644 index 08e21ec7550..00000000000 --- a/changelog.d/2551.changed.md +++ /dev/null @@ -1 +0,0 @@ -Changed http path filter to include query params in match \ No newline at end of file diff --git a/mirrord-schema.json b/mirrord-schema.json index aa2bf6752ea..8cc3f592cc5 100644 --- a/mirrord-schema.json +++ b/mirrord-schema.json @@ -1,7 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "LayerFileConfig", - "description": "mirrord allows for a high degree of customization when it comes to which features you want to enable, and how they should function.\n\nAll of the configuration fields have a default value, so a minimal configuration would be no configuration at all.\n\nThe configuration supports templating using the [Tera](https://keats.github.io/tera/docs/) template engine. Currently we don't provide additional values to the context, if you have anything you want us to provide please let us know.\n\nTo use a configuration file in the CLI, use the `-f ` flag. Or if using VSCode Extension or JetBrains plugin, simply create a `.mirrord/mirrord.json` file or use the UI.\n\nTo help you get started, here are examples of a basic configuration file, and a complete configuration file containing all fields.\n\n### Basic `config.json` {#root-basic}\n\n```json { \"target\": \"pod/bear-pod\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Basic `config.json` with templating {#root-basic-templating}\n\n```json { \"target\": \"{{ get_env(name=\"TARGET\", default=\"pod/fallback\") }}\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Complete `config.json` {#root-complete}\n\nDon't use this example as a starting point, it's just here to show you all the available options. ```json { \"accept_invalid_certificates\": false, \"skip_processes\": \"ide-debugger\", \"target\": { \"path\": \"pod/bear-pod\", \"namespace\": \"default\" }, \"connect_tcp\": null, \"agent\": { \"log_level\": \"info\", \"json_log\": false, \"labels\": { \"user\": \"meow\" }, \"annotations\": { \"cats.io/inject\": \"enabled\" }, \"namespace\": \"default\", \"image\": \"ghcr.io/metalbear-co/mirrord:latest\", \"image_pull_policy\": \"IfNotPresent\", \"image_pull_secrets\": [ { \"secret-key\": \"secret\" } ], \"ttl\": 30, \"ephemeral\": false, \"communication_timeout\": 30, \"startup_timeout\": 360, \"network_interface\": \"eth0\", \"flush_connections\": true }, \"feature\": { \"env\": { \"include\": \"DATABASE_USER;PUBLIC_ENV\", \"exclude\": \"DATABASE_PASSWORD;SECRET_ENV\", \"override\": { \"DATABASE_CONNECTION\": \"db://localhost:7777/my-db\", \"LOCAL_BEAR\": \"panda\" } }, \"fs\": { \"mode\": \"write\", \"read_write\": \".+\\\\.json\" , \"read_only\": [ \".+\\\\.yaml\", \".+important-file\\\\.txt\" ], \"local\": [ \".+\\\\.js\", \".+\\\\.mjs\" ] }, \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": false }, \"copy_target\": { \"scale_down\": false } }, \"operator\": true, \"kubeconfig\": \"~/.kube/config\", \"sip_binaries\": \"bash\", \"telemetry\": true, \"kube_context\": \"my-cluster\" } ```\n\n# Options {#root-options}", + "description": "mirrord allows for a high degree of customization when it comes to which features you want to enable, and how they should function.\n\nAll of the configuration fields have a default value, so a minimal configuration would be no configuration at all.\n\nThe configuration supports templating using the [Tera](https://keats.github.io/tera/docs/) template engine. Currently we don't provide additional values to the context, if you have anything you want us to provide please let us know.\n\nTo use a configuration file in the CLI, use the `-f ` flag. Or if using VSCode Extension or JetBrains plugin, simply create a `.mirrord/mirrord.json` file or use the UI.\n\nTo help you get started, here are examples of a basic configuration file, and a complete configuration file containing all fields.\n\n### Basic `config.json` {#root-basic}\n\n```json { \"target\": \"pod/bear-pod\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Basic `config.json` with templating {#root-basic-templating}\n\n```json { \"target\": \"{{ get_env(name=\"TARGET\", default=\"pod/fallback\") }}\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Complete `config.json` {#root-complete}\n\nDon't use this example as a starting point, it's just here to show you all the available options. ```json { \"accept_invalid_certificates\": false, \"skip_processes\": \"ide-debugger\", \"target\": { \"path\": \"pod/bear-pod\", \"namespace\": \"default\" }, \"connect_tcp\": null, \"agent\": { \"log_level\": \"info\", \"json_log\": false, \"labels\": { \"user\": \"meow\" }, \"annotations\": { \"cats.io/inject\": \"enabled\" }, \"namespace\": \"default\", \"image\": \"ghcr.io/metalbear-co/mirrord:latest\", \"image_pull_policy\": \"IfNotPresent\", \"image_pull_secrets\": [ { \"secret-key\": \"secret\" } ], \"ttl\": 30, \"ephemeral\": false, \"communication_timeout\": 30, \"startup_timeout\": 360, \"network_interface\": \"eth0\", \"flush_connections\": true }, \"feature\": { \"env\": { \"include\": \"DATABASE_USER;PUBLIC_ENV\", \"exclude\": \"DATABASE_PASSWORD;SECRET_ENV\", \"override\": { \"DATABASE_CONNECTION\": \"db://localhost:7777/my-db\", \"LOCAL_BEAR\": \"panda\" } }, \"fs\": { \"mode\": \"write\", \"read_write\": \".+\\\\.json\" , \"read_only\": [ \".+\\\\.yaml\", \".+important-file\\\\.txt\" ], \"local\": [ \".+\\\\.js\", \".+\\\\.mjs\" ] }, \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": { \"enabled\": true, \"filter\": { \"local\": [\"1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\"] } } }, \"copy_target\": { \"scale_down\": false } }, \"operator\": true, \"kubeconfig\": \"~/.kube/config\", \"sip_binaries\": \"bash\", \"telemetry\": true, \"kube_context\": \"my-cluster\" } ```\n\n# Options {#root-options}", "type": "object", "properties": { "accept_invalid_certificates": { @@ -561,6 +561,62 @@ }, "additionalProperties": false }, + "DnsFileConfig": { + "description": "Resolve DNS via the remote pod.\n\nDefaults to `true`.\n\nMind that: - DNS resolving can be done in multiple ways. Some frameworks use `getaddrinfo`/`gethostbyname` functions, while others communicate directly with the DNS server at port `53` and perform a sort of manual resolution. Just enabling the `dns` feature in mirrord might not be enough. If you see an address resolution error, try enabling the [`fs`](#feature-fs) feature, and setting `read_only: [\"/etc/resolv.conf\"]`. - DNS filter currently works only with frameworks that use `getaddrinfo`/`gethostbyname` functions.", + "type": "object", + "properties": { + "enabled": { + "type": [ + "boolean", + "null" + ] + }, + "filter": { + "title": "feature.network.dns.filter {#feature-network-dns-filter}", + "description": "Unstable: the precise syntax of this config is subject to change.", + "anyOf": [ + { + "$ref": "#/definitions/DnsFilterConfig" + }, + { + "type": "null" + } + ] + } + }, + "additionalProperties": false + }, + "DnsFilterConfig": { + "description": "List of addresses/ports/subnets that should be resolved through either the remote pod or local app, depending how you set this up with either `remote` or `local`.\n\nYou may use this option to specify when DNS resolution is done from the remote pod (which is the default behavior when you enable remote DNS), or from the local app (default when you have remote DNS disabled).\n\nTakes a list of values, such as:\n\n- Only queries for hostname `my-service-in-cluster` will go through the remote pod.\n\n```json { \"remote\": [\"my-service-in-cluster\"] } ```\n\n- Only queries for addresses in subnet `1.1.1.0/24` with service port `1337`` will go through the remote pod.\n\n```json { \"remote\": [\"1.1.1.0/24:1337\"] } ```\n\n- Only queries for hostname `google.com` with service port `1337` or `7331` will go through the remote pod.\n\n```json { \"remote\": [\"google.com:1337\", \"google.com:7331\"] } ```\n\n- Only queries for `localhost` with service port `1337` will go through the local app.\n\n```json { \"local\": [\"localhost:1337\"] } ```\n\n- Only queries with service port `1337` or `7331` will go through the local app.\n\n```json { \"local\": [\":1337\", \":7331\"] } ```\n\nValid values follow this pattern: `[name|address|subnet/mask][:port]`.", + "oneOf": [ + { + "description": "DNS queries matching what is specified here will go through the remote pod, everything else will go through local.", + "type": "object", + "required": [ + "remote" + ], + "properties": { + "remote": { + "$ref": "#/definitions/VecOrSingle_for_String" + } + }, + "additionalProperties": false + }, + { + "description": "DNS queries matching what is specified here will go through the local app, everything else will go through the remote pod.", + "type": "object", + "required": [ + "local" + ], + "properties": { + "local": { + "$ref": "#/definitions/VecOrSingle_for_String" + } + }, + "additionalProperties": false + } + ] + }, "EnvFileConfig": { "description": "Allows the user to set or override the local process' environment variables with the ones from the remote pod.\n\nWhich environment variables to load from the remote pod are controlled by setting either [`include`](#feature-env-include) or [`exclude`](#feature-env-exclude).\n\nSee the environment variables [reference](https://mirrord.dev/docs/reference/env/) for more details.\n\n```json { \"feature\": { \"env\": { \"include\": \"DATABASE_USER;PUBLIC_ENV;MY_APP_*\", \"exclude\": \"DATABASE_PASSWORD;SECRET_ENV\", \"override\": { \"DATABASE_CONNECTION\": \"db://localhost:7777/my-db\", \"LOCAL_BEAR\": \"panda\" } } } } ```", "type": "object", @@ -628,7 +684,7 @@ "type": "object", "properties": { "readlink": { - "title": "_experimental_ readlink {#fexperimental-readlink}", + "title": "_experimental_ readlink {#experimental-readlink}", "description": "Enables the `readlink` hook.", "type": [ "boolean", @@ -636,12 +692,20 @@ ] }, "tcp_ping4_mock": { - "title": "_experimental_ tcp_ping4_mock {#fexperimental-tcp_ping4_mock}", + "title": "_experimental_ tcp_ping4_mock {#experimental-tcp_ping4_mock}", "description": "", "type": [ "boolean", "null" ] + }, + "trust_any_certificate": { + "title": "_experimental_ trust_any_certificate {#experimental-trust_any_certificate}", + "description": "Enables trusting any certificate on macOS, useful for ", + "type": [ + "boolean", + "null" + ] } }, "additionalProperties": false @@ -1070,15 +1134,18 @@ ] }, "NetworkFileConfig": { - "description": "Controls mirrord network operations.\n\nSee the network traffic [reference](https://mirrord.dev/docs/reference/traffic/) for more details.\n\n```json { \"feature\": { \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": false } } } ```", + "description": "Controls mirrord network operations.\n\nSee the network traffic [reference](https://mirrord.dev/docs/reference/traffic/) for more details.\n\n```json { \"feature\": { \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": { \"enabled\": true, \"filter\": { \"local\": [\"1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\"] } } } } } ```", "type": "object", "properties": { "dns": { "title": "feature.network.dns {#feature-network-dns}", - "description": "Resolve DNS via the remote pod.\n\nDefaults to `true`.\n\n- Caveats: DNS resolving can be done in multiple ways, some frameworks will use `getaddrinfo`, while others will create a connection on port `53` and perform a sort of manual resolution. Just enabling the `dns` feature in mirrord might not be enough. If you see an address resolution error, try enabling the [`fs`](#feature-fs) feature, and setting `read_only: [\"/etc/resolv.conf\"]`.", - "type": [ - "boolean", - "null" + "anyOf": [ + { + "$ref": "#/definitions/ToggleableConfig_for_DnsFileConfig" + }, + { + "type": "null" + } ] }, "incoming": { @@ -1356,6 +1423,16 @@ } ] }, + "ToggleableConfig_for_DnsFileConfig": { + "anyOf": [ + { + "type": "boolean" + }, + { + "$ref": "#/definitions/DnsFileConfig" + } + ] + }, "ToggleableConfig_for_EnvFileConfig": { "anyOf": [ { diff --git a/mirrord/agent/Cargo.toml b/mirrord/agent/Cargo.toml index 62b9f1c2a7e..65d9891edea 100644 --- a/mirrord/agent/Cargo.toml +++ b/mirrord/agent/Cargo.toml @@ -20,13 +20,20 @@ workspace = true [dependencies] containerd-client = "0.5" -tokio = { workspace = true, features = ["rt", "net", "macros", "fs", "process", "signal"] } +tokio = { workspace = true, features = [ + "rt", + "net", + "macros", + "fs", + "process", + "signal", +] } serde.workspace = true serde_json.workspace = true pnet = "0.35" -nix = { workspace = true, features = ["mount", "sched", "user"] } +nix = { workspace = true, features = ["mount", "sched", "user"] } clap = { workspace = true, features = ["env"] } -mirrord-protocol = { path = "../protocol"} +mirrord-protocol = { path = "../protocol" } actix-codec.workspace = true futures.workspace = true tracing.workspace = true @@ -64,13 +71,15 @@ drain.workspace = true tokio-rustls = "0.26" x509-parser = "0.16" rustls.workspace = true +envy = "0.4" [target.'cfg(target_os = "linux")'.dependencies] -iptables = {git = "https://github.com/metalbear-co/rust-iptables.git", rev = "e66c7332e361df3c61a194f08eefe3f40763d624"} -rawsocket = {git = "https://github.com/metalbear-co/rawsocket.git"} +iptables = { git = "https://github.com/metalbear-co/rust-iptables.git", rev = "e66c7332e361df3c61a194f08eefe3f40763d624" } +rawsocket = { git = "https://github.com/metalbear-co/rawsocket.git" } [dev-dependencies] +rstest = "0.21" mockall = "0.12" # 0.11.3 is broken test_bin = "0.4" rcgen = "0.13" diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 0abb9ba2d11..16867856dc0 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -13,6 +13,7 @@ use client_connection::AgentTlsConnector; use dns::{DnsCommand, DnsWorker}; use futures::TryFutureExt; use mirrord_protocol::{ClientMessage, DaemonMessage, GetEnvVarsRequest, LogMessage}; +use sniffer::tcp_capture::RawSocketTcpCapture; use tokio::{ net::{TcpListener, TcpStream}, process::Command, @@ -35,7 +36,7 @@ use crate::{ file::FileManager, outgoing::{TcpOutgoingApi, UdpOutgoingApi}, runtime::get_container, - sniffer::{SnifferCommand, TcpConnectionSniffer, TcpSnifferApi}, + sniffer::{api::TcpSnifferApi, messages::SnifferCommand, TcpConnectionSniffer}, steal::{ ip_tables::{ new_iptables, IPTablesWrapper, SafeIpTables, IPTABLE_MESH, IPTABLE_MESH_ENV, @@ -48,8 +49,7 @@ use crate::{ *, }; -/// Size of [`mpsc`] channels connecting [`TcpStealerApi`] and [`TcpSnifferApi`] with their -/// background tasks. +/// Size of [`mpsc`] channels connecting [`TcpStealerApi`] with the background task. const CHANNEL_SIZE: usize = 1024; /// Keeps track of next client id. @@ -201,6 +201,8 @@ struct ClientConnectionHandler { udp_outgoing_api: UdpOutgoingApi, dns_api: DnsApi, state: State, + /// Whether the client has sent us [`ClientMessage::ReadyForLogs`]. + ready_for_logs: bool, } impl ClientConnectionHandler { @@ -233,6 +235,7 @@ impl ClientConnectionHandler { udp_outgoing_api, dns_api, state, + ready_for_logs: false, }; Ok(client_handler) @@ -244,7 +247,7 @@ impl ClientConnectionHandler { connection: &mut ClientConnection, ) -> Option { if let BackgroundTask::Running(sniffer_status, sniffer_sender) = task { - match TcpSnifferApi::new(id, sniffer_sender, sniffer_status, CHANNEL_SIZE).await { + match TcpSnifferApi::new(id, sniffer_sender, sniffer_status).await { Ok(api) => Some(api), Err(e) => { let message = format!( @@ -338,7 +341,13 @@ impl ClientConnectionHandler { unreachable!() } }, if self.tcp_sniffer_api.is_some() => match message { - Ok(message) => self.respond(DaemonMessage::Tcp(message)).await?, + Ok((message, Some(log))) if self.ready_for_logs => { + self.respond(DaemonMessage::LogMessage(log)).await?; + self.respond(DaemonMessage::Tcp(message)).await?; + } + Ok((message, _)) => { + self.respond(DaemonMessage::Tcp(message)).await?; + }, Err(e) => break e, }, message = async { @@ -461,7 +470,9 @@ impl ClientConnectionHandler { )) .await?; } - ClientMessage::ReadyForLogs => {} + ClientMessage::ReadyForLogs => { + self.ready_for_logs = true; + } } Ok(true) @@ -498,7 +509,7 @@ async fn start_agent(args: Args) -> Result<()> { let mesh = args.mode.mesh(); let watched_task = WatchedTask::new( - TcpConnectionSniffer::TASK_NAME, + TcpConnectionSniffer::::TASK_NAME, TcpConnectionSniffer::new(sniffer_command_rx, args.network_interface, mesh).and_then( |sniffer| async move { let res = sniffer.start(cancellation_token).await; @@ -512,7 +523,7 @@ async fn start_agent(args: Args) -> Result<()> { let status = watched_task.status(); let task = run_thread_in_namespace( watched_task.start(), - TcpConnectionSniffer::TASK_NAME.to_string(), + TcpConnectionSniffer::::TASK_NAME.to_string(), state.container_pid(), "net", ); diff --git a/mirrord/agent/src/env.rs b/mirrord/agent/src/env.rs index e6990dcfa61..3fafbc10714 100644 --- a/mirrord/agent/src/env.rs +++ b/mirrord/agent/src/env.rs @@ -60,6 +60,7 @@ impl EnvFilter { WildMatch::new("RUBYOPT"), WildMatch::new("RUST_LOG"), WildMatch::new("_JAVA_OPTIONS"), + WildMatch::new("DOTNET_EnableDiagnostics"), ]; for selector in &filter_env_vars { diff --git a/mirrord/agent/src/error.rs b/mirrord/agent/src/error.rs index 7cd5121497c..6c3b89cab5d 100644 --- a/mirrord/agent/src/error.rs +++ b/mirrord/agent/src/error.rs @@ -12,7 +12,7 @@ use mirrord_protocol::{ use thiserror::Error; use crate::{ - client_connection::TlsSetupError, namespace::NamespaceError, sniffer::SnifferCommand, + client_connection::TlsSetupError, namespace::NamespaceError, sniffer::messages::SnifferCommand, steal::StealerCommand, }; @@ -135,6 +135,9 @@ pub(crate) enum AgentError { /// Child agent process spawned in `main` failed. #[error("Agent child process failed: {0}")] AgentFailed(ExitStatus), + + #[error("Exhausted possible identifiers for incoming connections.")] + ExhaustedConnectionId, } pub(crate) type Result = std::result::Result; diff --git a/mirrord/agent/src/main.rs b/mirrord/agent/src/main.rs index 8eca727707b..bdfea92bac3 100644 --- a/mirrord/agent/src/main.rs +++ b/mirrord/agent/src/main.rs @@ -1,6 +1,7 @@ #![feature(hash_extract_if)] #![feature(let_chains)] #![feature(type_alias_impl_trait)] +#![feature(entry_insert)] #![cfg_attr(target_os = "linux", feature(tcp_quickack))] #![feature(lazy_cell)] #![warn(clippy::indexing_slicing)] diff --git a/mirrord/agent/src/sniffer.rs b/mirrord/agent/src/sniffer.rs index d226be3a2f7..b03df01a23a 100644 --- a/mirrord/agent/src/sniffer.rs +++ b/mirrord/agent/src/sniffer.rs @@ -1,37 +1,64 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{hash_map::Entry, HashMap}, + fmt, + future::Future, hash::{Hash, Hasher}, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::Ipv4Addr, + pin::Pin, + task::{Context, Poll}, }; -use mirrord_protocol::{ - tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, - ConnectionId, MeshVendor, Port, -}; -use nix::sys::socket::SockaddrStorage; -use pnet::packet::{ - ethernet::{EtherTypes, EthernetPacket}, - ip::IpNextHeaderProtocols, - ipv4::Ipv4Packet, - tcp::{TcpFlags, TcpPacket}, - Packet, -}; -use rawsocket::RawCapture; +use futures::{stream::FuturesUnordered, StreamExt}; +use mirrord_protocol::{MeshVendor, Port}; +use pnet::packet::tcp::TcpFlags; +use tcp_capture::TcpCapture; use tokio::{ - net::UdpSocket, select, - sync::mpsc::{self, Receiver, Sender}, + sync::{ + broadcast, + mpsc::{error::TrySendError, Receiver, Sender}, + }, }; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, trace, warn}; +use tracing::Level; +use self::{ + messages::{SniffedConnection, SnifferCommand, SnifferCommandInner}, + tcp_capture::RawSocketTcpCapture, +}; use crate::{ error::AgentError, http::HttpVersion, - util::{ClientId, IndexAllocator, Subscriptions}, - watched_task::TaskStatus, + util::{ClientId, Subscriptions}, }; +pub(crate) mod api; +pub(crate) mod messages; +pub(crate) mod tcp_capture; + +/// [`Future`] that resolves to [`ClientId`] when the [`TcpConnectionSniffer`] client drops their +/// [`TcpSnifferApi`](api::TcpSnifferApi). +struct ClientClosed { + /// [`Sender`] used by [`TcpConnectionSniffer`] to send data to the client. + /// Here used only to poll [`Sender::closed`]. + client_tx: Sender, + /// Id of the client. + client_id: ClientId, +} + +impl Future for ClientClosed { + type Output = ClientId; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let client_id = self.client_id; + + let future = std::pin::pin!(self.get_mut().client_tx.closed()); + std::task::ready!(future.poll(cx)); + + Poll::Ready(client_id) + } +} + #[derive(Debug, Eq, Copy, Clone)] pub(crate) struct TcpSessionIdentifier { /// The remote address that is sending a packet to the impersonated pod. @@ -40,7 +67,7 @@ pub(crate) struct TcpSessionIdentifier { /// /// If you were to `curl {impersonated_pod_ip}:{port}`, this would be the address of whoever /// is making the request. - source_addr: Ipv4Addr, + pub(crate) source_addr: Ipv4Addr, /// Local address of the impersonated pod. /// @@ -53,9 +80,9 @@ pub(crate) struct TcpSessionIdentifier { /// NAME READY STATUS IP /// happy-pod 1/1 Running 1.2.3.4 /// ``` - dest_addr: Ipv4Addr, - source_port: u16, - dest_port: u16, + pub(crate) dest_addr: Ipv4Addr, + pub(crate) source_port: u16, + pub(crate) dest_port: u16, } impl PartialEq for TcpSessionIdentifier { @@ -91,13 +118,7 @@ impl Hash for TcpSessionIdentifier { } } -#[derive(Debug)] -struct TCPSession { - id: ConnectionId, - clients: HashSet, -} - -type TCPSessionMap = HashMap; +type TCPSessionMap = HashMap>>; const fn is_new_connection(flags: u8) -> bool { 0 != (flags & TcpFlags::SYN) && 0 == (flags & (TcpFlags::ACK | TcpFlags::RST | TcpFlags::FIN)) @@ -107,408 +128,207 @@ fn is_closed_connection(flags: u8) -> bool { 0 != (flags & (TcpFlags::FIN | TcpFlags::RST)) } -/// Connects to a remote address (`8.8.8.8:53`) so we can find which network interface to use. -/// -/// Used when no `user_interface` is specified in [`prepare_sniffer`] to prevent mirrord from -/// defaulting to the wrong network interface (`eth0`), as sometimes the user's machine doesn't have -/// it available (i.e. their default network is `enp2s0`). -#[tracing::instrument(level = "trace")] -async fn resolve_interface() -> Result, AgentError> { - // Connect to a remote address so we can later get the default network interface. - let temporary_socket = UdpSocket::bind("0.0.0.0:0").await?; - temporary_socket.connect("8.8.8.8:53").await?; - - // Create comparison address here with `port: 0`, to match the network interface's address of - // `sin_port: 0`. - let local_address = SocketAddr::new(temporary_socket.local_addr()?.ip(), 0); - let raw_local_address = SockaddrStorage::from(local_address); - - // Try to find an interface that matches the local ip we have. - let usable_interface_name = nix::ifaddrs::getifaddrs()? - .find_map(|iface| (raw_local_address == iface.address?).then_some(iface.interface_name)); - - Ok(usable_interface_name) -} - -// TODO(alex): Errors here are not reported back anywhere, we end up with a generic fail of: -// "ERROR ThreadId(03) mirrord_agent: ClientConnectionHandler::start -> Client 0 disconnected with -// error: SnifferCommand sender failed with `channel closed`" -// -// And to make matters worse, the error reported back to the user is the very generic: -// "mirrord-layer received an unexpected response from the agent pod!" -#[tracing::instrument(level = "trace")] -async fn prepare_sniffer( - network_interface: Option, - mesh: Option, -) -> Result { - // Priority is whatever the user set as an option to mirrord, then we check if we're in a mesh - // to use `lo` interface, otherwise we try to get the appropriate interface. - let interface = match network_interface.or_else(|| mesh.map(|_| "lo".to_string())) { - Some(interface) => interface, - None => resolve_interface() - .await? - .unwrap_or_else(|| "eth0".to_string()), - }; - - trace!("Using {interface:#?} interface."); - let capture = RawCapture::from_interface_name(&interface)?; - // We start with a BPF that drops everything so we won't receive *EVERYTHING* - // as we don't know what the layer will ask us to listen for, so this is essentially setting - // it to none - // we ofc could've done this when a layer connected, but I (A.H) thought it'd make more sense - // to have this shared among layers (potentially, in the future) - fme. - capture.set_filter(rawsocket::filter::build_drop_always())?; - capture - .ignore_outgoing() - .map_err(AgentError::PacketIgnoreOutgoing)?; - Ok(capture) -} - #[derive(Debug)] -struct TcpPacketData { +pub(crate) struct TcpPacketData { bytes: Vec, flags: u8, } -#[tracing::instrument(skip(eth_packet), level = "trace", fields(bytes = %eth_packet.len()))] -fn get_tcp_packet(eth_packet: Vec) -> Option<(TcpSessionIdentifier, TcpPacketData)> { - let eth_packet = EthernetPacket::new(ð_packet[..])?; - let ip_packet = match eth_packet.get_ethertype() { - EtherTypes::Ipv4 => Ipv4Packet::new(eth_packet.payload())?, - _ => return None, - }; - - let tcp_packet = match ip_packet.get_next_level_protocol() { - IpNextHeaderProtocols::Tcp => TcpPacket::new(ip_packet.payload())?, - _ => return None, - }; - - let dest_port = tcp_packet.get_destination(); - let source_port = tcp_packet.get_source(); - - let identifier = TcpSessionIdentifier { - source_addr: ip_packet.get_source(), - dest_addr: ip_packet.get_destination(), - source_port, - dest_port, - }; +/// Main struct implementing incoming traffic mirroring feature. +/// Utilizes [`TcpCapture`] for sniffing on incoming TCP packets. Transforms them into +/// incoming TCP data streams and sends copy of the traffic to all subscribed clients. +/// +/// Can be easily used via [`api::TcpSnifferApi`]. +/// +/// # Notes on behavior under high load +/// +/// Because this struct does not talk directly with the remote peers, we can't apply any back +/// pressure on the incoming connections. There is no reliable mechanism to ensure that all +/// subscribed clients receive all of the traffic. If we wait too long when distributing data +/// between the clients, raw socket's recv buffer will overflow and we'll lose packets. +/// +/// Having this in mind, this struct distributes incoming data using [`broadcast`] channels. If the +/// clients are not fast enough to pick up TCP packets, they will lose them +/// ([`broadcast::error::RecvError::Lagged`]). +/// +/// At the same time, notifying clients about new connections (and distributing +/// [`broadcast::Receiver`]s) is done with [`tokio::sync::mpsc`] channels (one per client). +/// To prevent global packet loss, this struct does not use the blocking [`Sender::send`] method. It +/// uses the non-blocking [`Sender::try_send`] method, so if the client is not fast enough to pick +/// up the [`broadcast::Receiver`], they will miss the whole connection. +pub(crate) struct TcpConnectionSniffer { + command_rx: Receiver, + tcp_capture: T, - trace!("identifier {identifier:?}"); - Some(( - identifier, - TcpPacketData { - flags: tcp_packet.get_flags(), - bytes: tcp_packet.payload().to_vec(), - }, - )) -} + port_subscriptions: Subscriptions, + sessions: TCPSessionMap, -#[derive(Debug)] -enum SnifferCommands { - NewAgent(Sender), - Subscribe(Port), - UnsubscribePort(Port), - UnsubscribeConnection(ConnectionId), - AgentClosed, + client_txs: HashMap>, + clients_closed: FuturesUnordered, } -impl From for SnifferCommands { - fn from(value: LayerTcp) -> Self { - match value { - LayerTcp::PortSubscribe(port) => Self::Subscribe(port), - LayerTcp::PortUnsubscribe(port) => Self::UnsubscribePort(port), - LayerTcp::ConnectionUnsubscribe(id) => Self::UnsubscribeConnection(id), - } +impl fmt::Debug for TcpConnectionSniffer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpConnectionSniffer") + .field("clients", &self.client_txs.keys()) + .field("port_subscriptions", &self.port_subscriptions) + .field("open_tcp_sessions", &self.sessions.keys()) + .finish() } } -#[derive(Debug)] -pub(crate) struct SnifferCommand { - client_id: ClientId, - command: SnifferCommands, -} - -/// Interface used by clients to interact with the [`TcpConnectionSniffer`]. -/// Multiple instances of this struct operate on a single sniffer instance. -pub(crate) struct TcpSnifferApi { - /// Id of the client using this struct. - client_id: ClientId, - /// Channel used to send commands to the [`TcpConnectionSniffer`]. - sender: Sender, - /// Channel used to receive messages from the [`TcpConnectionSniffer`]. - receiver: Receiver, - /// View on the sniffer task's status. - task_status: TaskStatus, -} - -impl TcpSnifferApi { - /// Create a new instance of this struct and connect it to a [`TcpConnectionSniffer`] instance. - /// * `client_id` - id of the client using this struct - /// * `sniffer_sender` - channel used to send commands to the [`TcpConnectionSniffer`] - /// * `task_status` - handle to the [`TcpConnectionSniffer`] exit status - /// * `channel_size` - capacity of the channel connecting [`TcpConnectionSniffer`] back to this - /// struct +impl TcpConnectionSniffer { + /// Creates and prepares a new [`TcpConnectionSniffer`] that uses BPF filters to capture network + /// packets. + /// + /// The capture uses a network interface specified by the user, if there is none, then it tries + /// to find a proper one by starting a connection. If this fails, we use "eth0" as a last + /// resort. + #[tracing::instrument(level = Level::TRACE, skip(command_rx), err)] pub async fn new( - client_id: ClientId, - sniffer_sender: Sender, - task_status: TaskStatus, - channel_size: usize, - ) -> Result { - let (sender, receiver) = mpsc::channel(channel_size); - - sniffer_sender - .send(SnifferCommand { - client_id, - command: SnifferCommands::NewAgent(sender), - }) - .await?; + command_rx: Receiver, + network_interface: Option, + mesh: Option, + ) -> Result { + let tcp_capture = RawSocketTcpCapture::new(network_interface, mesh).await?; Ok(Self { - client_id, - sender: sniffer_sender, - receiver, - task_status, - }) - } - - /// Send the given command to the connected [`TcpConnectionSniffer`]. - async fn send_command(&mut self, command: SnifferCommands) -> Result<(), AgentError> { - let command = SnifferCommand { - client_id: self.client_id, - command, - }; - - if self.sender.send(command).await.is_ok() { - Ok(()) - } else { - Err(self.task_status.unwrap_err().await) - } - } - - /// Return the next message from the connected [`TcpConnectionSniffer`]. - pub async fn recv(&mut self) -> Result { - match self.receiver.recv().await { - Some(msg) => Ok(msg), - None => Err(self.task_status.unwrap_err().await), - } - } + command_rx, + tcp_capture, - /// Tansform the given message into a [`SnifferCommands`] and pass it to the connected - /// [`TcpConnectionSniffer`]. - pub async fn handle_client_message(&mut self, message: LayerTcp) -> Result<(), AgentError> { - self.send_command(message.into()).await - } -} + port_subscriptions: Default::default(), + sessions: TCPSessionMap::new(), -impl Drop for TcpSnifferApi { - fn drop(&mut self) { - self.sender - .try_send(SnifferCommand { - client_id: self.client_id, - command: SnifferCommands::AgentClosed, - }) - .unwrap(); + client_txs: HashMap::new(), + clients_closed: Default::default(), + }) } } -pub(crate) struct TcpConnectionSniffer { - port_subscriptions: Subscriptions, - receiver: Receiver, - client_senders: HashMap>, - raw_capture: RawCapture, - sessions: TCPSessionMap, - //todo: impl drop for index allocator and connection id.. - connection_id_to_tcp_identifier: HashMap, - index_allocator: IndexAllocator, -} - -impl TcpConnectionSniffer { +impl TcpConnectionSniffer +where + R: TcpCapture, +{ pub const TASK_NAME: &'static str = "Sniffer"; + /// Capacity of [`broadcast`] channels used to distribute incoming TCP packets between clients. + const CONNECTION_DATA_CHANNEL_CAPACITY: usize = 512; + /// Runs the sniffer loop, capturing packets. - #[tracing::instrument(level = "trace", skip(self))] + #[tracing::instrument(level = Level::DEBUG, skip(cancel_token), err)] pub async fn start(mut self, cancel_token: CancellationToken) -> Result<(), AgentError> { loop { select! { - command = self.receiver.recv() => { - if let Some(command) = command { - self.handle_command(command).await?; - } else { break; } + command = self.command_rx.recv() => { + let Some(command) = command else { + tracing::debug!("command channel closed, exiting"); + break; + }; + + self.handle_command(command)?; }, - packet = self.raw_capture.next() => { - self.handle_packet(packet?).await?; + + Some(client_id) = self.clients_closed.next() => { + self.handle_client_closed(client_id)?; + } + + result = self.tcp_capture.next() => { + let (identifier, packet_data) = result?; + self.handle_packet(identifier, packet_data)?; } + _ = cancel_token.cancelled() => { + tracing::debug!("token cancelled, exiting"); break; } } } - debug!("TCPConnectionSniffer exiting"); - Ok(()) - } - /// Creates and prepares a new [`TcpConnectionSniffer`] that uses BPF filters to capture network - /// packets. - /// - /// The capture uses a network interface specified by the user, if there is none, then it tries - /// to find a proper one by starting a connection. If this fails, we use "eth0" as a last - /// resort. - #[tracing::instrument(level = "trace")] - pub async fn new( - receiver: Receiver, - network_interface: Option, - mesh: Option, - ) -> Result { - let raw_capture = prepare_sniffer(network_interface, mesh).await?; - - Ok(Self { - receiver, - raw_capture, - port_subscriptions: Default::default(), - client_senders: HashMap::new(), - sessions: TCPSessionMap::new(), - //todo: impl drop for index allocator and connection id.. - connection_id_to_tcp_identifier: HashMap::new(), - index_allocator: Default::default(), - }) + Ok(()) } /// New layer is connecting to this agent sniffer. - #[tracing::instrument(level = "trace", ret, skip(self, sender))] - fn handle_new_client(&mut self, client_id: ClientId, sender: Sender) { - self.client_senders.insert(client_id, sender); - } - - /// layer with `client_id` wants to sniff on `port`. - #[tracing::instrument(level = "trace", ret, skip(self))] - async fn handle_subscribe( - &mut self, - client_id: ClientId, - port: Port, - ) -> Result<(), AgentError> { - self.port_subscriptions.subscribe(client_id, port); - self.update_sniffer()?; - self.send_message_to_client(&client_id, DaemonTcp::SubscribeResult(Ok(port))) - .await + #[tracing::instrument(level = Level::TRACE, skip(sender))] + fn handle_new_client(&mut self, client_id: ClientId, sender: Sender) { + self.client_txs.insert(client_id, sender.clone()); + self.clients_closed.push(ClientClosed { + client_tx: sender.clone(), + client_id, + }); } /// Removes the client with `client_id`, and also unsubscribes its port. - #[tracing::instrument(level = "trace", skip(self))] + /// Adjusts BPF filter if needed. + #[tracing::instrument(level = Level::TRACE, err)] fn handle_client_closed(&mut self, client_id: ClientId) -> Result<(), AgentError> { - self.client_senders.remove(&client_id); - self.port_subscriptions.remove_client(client_id); - self.update_sniffer() + self.client_txs.remove(&client_id); + + if self.port_subscriptions.remove_client(client_id) { + self.update_packet_filter()?; + } + + Ok(()) } - /// Updates the sniffer's internal state. - /// - /// Called when the sniffer receives a new command. - #[tracing::instrument(level = "trace", ret, skip(self))] - fn update_sniffer(&mut self) -> Result<(), AgentError> { + /// Updates BPF filter used by [`Self::tcp_capture`] to match state of + /// [`Self::port_subscriptions`]. + #[tracing::instrument(level = Level::TRACE, err)] + fn update_packet_filter(&mut self) -> Result<(), AgentError> { let ports = self.port_subscriptions.get_subscribed_topics(); - if ports.is_empty() { - trace!("Empty ports, setting dummy bpf"); - self.raw_capture - .set_filter(rawsocket::filter::build_drop_always())? + let filter = if ports.is_empty() { + tracing::trace!("No ports subscribed, setting dummy bpf"); + rawsocket::filter::build_drop_always() } else { - self.raw_capture - .set_filter(rawsocket::filter::build_tcp_port_filter(&ports))? + rawsocket::filter::build_tcp_port_filter(&ports) }; - Ok(()) - } - fn qualified_port(&self, port: u16) -> bool { - self.port_subscriptions - .get_subscribed_topics() - .contains(&port) + self.tcp_capture.set_filter(filter)?; + + Ok(()) } - #[tracing::instrument(level = "trace", ret, skip(self))] - async fn handle_command(&mut self, command: SnifferCommand) -> Result<(), AgentError> { + #[tracing::instrument(level = Level::TRACE, err)] + fn handle_command(&mut self, command: SnifferCommand) -> Result<(), AgentError> { match command { SnifferCommand { client_id, - command: SnifferCommands::NewAgent(sender), + command: SnifferCommandInner::NewClient(sender), } => { self.handle_new_client(client_id, sender); } + SnifferCommand { client_id, - command: SnifferCommands::Subscribe(port), - } => { - self.handle_subscribe(client_id, port).await?; - } - SnifferCommand { - client_id, - command: SnifferCommands::AgentClosed, - } => { - self.handle_client_closed(client_id)?; - } - SnifferCommand { - client_id, - command: SnifferCommands::UnsubscribeConnection(connection_id), + command: SnifferCommandInner::Subscribe(port, tx), } => { - self.connection_id_to_tcp_identifier - .get(&connection_id) - .and_then(|identifier| { - self.sessions - .get_mut(identifier) - .map(|session| session.clients.remove(&client_id)) - }); + if self.port_subscriptions.subscribe(client_id, port) { + self.update_packet_filter()?; + } + + let _ = tx.send(port); } + SnifferCommand { client_id, - command: SnifferCommands::UnsubscribePort(port), + command: SnifferCommandInner::UnsubscribePort(port), } => { - self.port_subscriptions.unsubscribe(client_id, port); - self.update_sniffer()?; + if self.port_subscriptions.unsubscribe(client_id, port) { + self.update_packet_filter()?; + } } } - Ok(()) - } - - async fn send_message_to_clients( - &mut self, - clients: impl Iterator, - message: DaemonTcp, - ) -> Result<(), AgentError> { - trace!("TcpConnectionSniffer::send_message_to_clients"); - - for client_id in clients { - self.send_message_to_client(client_id, message.clone()) - .await?; - } - Ok(()) - } - /// Sends a [`DaemonTcp`] message back to the client with `client_id`. - #[tracing::instrument(level = "trace", ret, skip(self, message))] - async fn send_message_to_client( - &mut self, - client_id: &ClientId, - message: DaemonTcp, - ) -> Result<(), AgentError> { - if let Some(sender) = self.client_senders.get(client_id) { - sender.send(message).await.map_err(|err| { - warn!( - "Failed to send message to client {} with {:#?}!", - client_id, err - ); - let _ = self.handle_client_closed(*client_id); - err - })?; - } Ok(()) } /// First it checks the `tcp_flags` with [`is_new_connection`], if that's not the case, meaning /// we have traffic from some existing connection from before mirrord started, then it tries to - /// see if `bytes` contains an HTTP request (HTTP/1) of some sort. When an HTTP request is + /// see if `bytes` contains an HTTP request of some sort. When an HTTP request is /// detected, then the agent should start mirroring as if it was a new connection. /// /// tl;dr: checks packet flags, or if it's an HTTP packet, then begins a new sniffing session. - #[tracing::instrument(level = "trace", ret, skip(bytes))] + #[tracing::instrument(level = Level::TRACE, ret, skip(bytes), fields(bytes = bytes.len()), ret)] fn treat_as_new_session(tcp_flags: u8, bytes: &[u8]) -> bool { is_new_connection(tcp_flags) || matches!( @@ -517,100 +337,540 @@ impl TcpConnectionSniffer { ) } - #[tracing::instrument(level = "trace", ret, skip(self, eth_packet), fields(bytes = %eth_packet.len()))] - async fn handle_packet(&mut self, eth_packet: Vec) -> Result<(), AgentError> { - let (identifier, tcp_packet) = match get_tcp_packet(eth_packet) { - Some(res) => res, - None => return Ok(()), - }; - - let dest_port = identifier.dest_port; - let source_port = identifier.source_port; - let tcp_flags = tcp_packet.flags; - trace!( - "dest_port {:#?} | source_port {:#?} | tcp_flags {:#?}", - dest_port, - source_port, - tcp_flags - ); - - let is_client_packet = self.qualified_port(dest_port); - - let session = match self.sessions.remove(&identifier) { - Some(session) => session, - None => { + /// Handles TCP packet sniffed by [`Self::tcp_capture`]. + #[tracing::instrument( + level = Level::TRACE, + ret, + skip(self), + fields( + destination_port = identifier.dest_port, + source_port = identifier.source_port, + tcp_flags = tcp_packet.flags, + bytes = tcp_packet.bytes.len(), + ) + )] + fn handle_packet( + &mut self, + identifier: TcpSessionIdentifier, + tcp_packet: TcpPacketData, + ) -> Result<(), AgentError> { + let data_tx = match self.sessions.entry(identifier) { + Entry::Occupied(e) => e, + Entry::Vacant(e) => { // Performs a check on the `tcp_flags` and on the packet contents to see if this // should be treated as a new connection. - if !Self::treat_as_new_session(tcp_flags, &tcp_packet.bytes) { + if !Self::treat_as_new_session(tcp_packet.flags, &tcp_packet.bytes) { // Either it's an existing session, or some sort of existing traffic we don't // care to start mirroring. return Ok(()); } - if !is_client_packet { + let Some(client_ids) = self + .port_subscriptions + .get_topic_subscribers(identifier.dest_port) + .filter(|ids| !ids.is_empty()) + else { return Ok(()); - } + }; + + tracing::trace!( + ?client_ids, + "TCP packet should be treated as new session and start connections for clients" + ); - let id = match self.index_allocator.next_index() { - Some(id) => id, - None => { - error!("connection index exhausted, dropping new connection"); - return Ok(()); + let (data_tx, _) = broadcast::channel(Self::CONNECTION_DATA_CHANNEL_CAPACITY); + + for client_id in client_ids { + let Some(client_tx) = self.client_txs.get(client_id) else { + tracing::error!( + client_id, + destination_port = identifier.dest_port, + source_port = identifier.source_port, + tcp_flags = tcp_packet.flags, + bytes = tcp_packet.bytes.len(), + "Failed to find client while handling new sniffed TCP connection, this is a bug", + ); + + continue; + }; + + let connection = SniffedConnection { + session_id: identifier, + data: data_tx.subscribe(), + }; + + match client_tx.try_send(connection) { + Ok(()) => {} + + Err(TrySendError::Closed(..)) => { + // Client closed. + // State will be cleaned up when `self.clients_closed` picks it up. + } + + Err(TrySendError::Full(..)) => { + tracing::warn!( + client_id, + destination_port = identifier.dest_port, + source_port = identifier.source_port, + tcp_flags = tcp_packet.flags, + bytes = tcp_packet.bytes.len(), + "Client queue of new sniffed TCP connections is full, dropping", + ); + + continue; + } } - }; + } - let client_ids = self.port_subscriptions.get_topic_subscribers(dest_port); - trace!("client_ids {:#?}", client_ids); + e.insert_entry(data_tx) + } + }; - let message = DaemonTcp::NewConnection(NewTcpConnection { - destination_port: dest_port, - source_port, - connection_id: id, - remote_address: IpAddr::V4(identifier.source_addr), - local_address: IpAddr::V4(identifier.dest_addr), - }); - trace!("message {:#?}", message); + tracing::trace!("Resolved data broadcast channel"); - self.send_message_to_clients(client_ids.iter(), message) - .await?; + if !tcp_packet.bytes.is_empty() && data_tx.get().send(tcp_packet.bytes).is_err() { + tracing::trace!("All data receivers are dead, dropping data broadcast sender"); + data_tx.remove(); + return Ok(()); + } - self.connection_id_to_tcp_identifier.insert(id, identifier); + if is_closed_connection(tcp_packet.flags) { + tracing::trace!("TCP packet closes connection, dropping data broadcast channel"); + data_tx.remove(); + } - TCPSession { - id, - clients: client_ids.into_iter().collect(), - } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, + }; + + use api::TcpSnifferApi; + use mirrord_protocol::{ + tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, + ConnectionId, LogLevel, + }; + use tcp_capture::test::TcpPacketsChannel; + use tokio::sync::mpsc; + + use super::*; + use crate::watched_task::{TaskStatus, WatchedTask}; + + struct TestSnifferSetup { + command_tx: Sender, + task_status: TaskStatus, + packet_tx: Sender<(TcpSessionIdentifier, TcpPacketData)>, + times_filter_changed: Arc, + next_client_id: ClientId, + } + + impl TestSnifferSetup { + async fn get_api(&mut self) -> TcpSnifferApi { + let client_id = self.next_client_id; + self.next_client_id += 1; + TcpSnifferApi::new(client_id, self.command_tx.clone(), self.task_status.clone()) + .await + .unwrap() + } + + fn times_filter_changed(&self) -> usize { + self.times_filter_changed.load(Ordering::Relaxed) + } + + fn new() -> Self { + let (packet_tx, packet_rx) = mpsc::channel(128); + let (command_tx, command_rx) = mpsc::channel(16); + let times_filter_changed = Arc::new(AtomicUsize::default()); + + let sniffer = TcpConnectionSniffer { + command_rx, + tcp_capture: TcpPacketsChannel { + times_filter_changed: times_filter_changed.clone(), + receiver: packet_rx, + }, + port_subscriptions: Default::default(), + sessions: Default::default(), + client_txs: Default::default(), + clients_closed: Default::default(), + }; + let watched_task = WatchedTask::new( + TcpConnectionSniffer::::TASK_NAME, + sniffer.start(CancellationToken::new()), + ); + let task_status = watched_task.status(); + tokio::spawn(watched_task.start()); + + Self { + command_tx, + task_status, + packet_tx, + times_filter_changed, + next_client_id: 0, } + } + } + + /// Simulates two sniffed connections, only one matching client's subscription. + #[tokio::test] + async fn one_client() { + let mut setup = TestSnifferSetup::new(); + let mut api = setup.get_api().await; + + api.handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + + assert_eq!( + api.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + + for dest_port in [80, 81] { + setup + .packet_tx + .send(( + TcpSessionIdentifier { + source_addr: "1.1.1.1".parse().unwrap(), + dest_addr: "127.0.0.1".parse().unwrap(), + source_port: 3133, + dest_port, + }, + TcpPacketData { + bytes: b"hello_1".into(), + flags: TcpFlags::SYN, + }, + )) + .await + .unwrap(); + + setup + .packet_tx + .send(( + TcpSessionIdentifier { + source_addr: "1.1.1.1".parse().unwrap(), + dest_addr: "127.0.0.1".parse().unwrap(), + source_port: 3133, + dest_port: 80, + }, + TcpPacketData { + bytes: b"hello_2".into(), + flags: TcpFlags::FIN, + }, + )) + .await + .unwrap(); + } + + let (message, log) = api.recv().await.unwrap(); + assert_eq!( + message, + DaemonTcp::NewConnection(NewTcpConnection { + connection_id: 0, + remote_address: "1.1.1.1".parse().unwrap(), + destination_port: 80, + source_port: 3133, + local_address: "127.0.0.1".parse().unwrap(), + }), + ); + assert_eq!(log, None); + + let (message, log) = api.recv().await.unwrap(); + assert_eq!( + message, + DaemonTcp::Data(TcpData { + connection_id: 0, + bytes: b"hello_1".into(), + }), + ); + assert_eq!(log, None); + + let (message, log) = api.recv().await.unwrap(); + assert_eq!( + message, + DaemonTcp::Data(TcpData { + connection_id: 0, + bytes: b"hello_2".into(), + }), + ); + assert_eq!(log, None); + + let (message, log) = api.recv().await.unwrap(); + assert_eq!(message, DaemonTcp::Close(TcpClose { connection_id: 0 }),); + assert_eq!(log, None); + } + + /// Tests that [`TcpCapture`] filter is replaced only when needed. + /// + /// # Note + /// + /// Due to fact that [`LayerTcp::PortUnsubscribe`] request does not generate any response, this + /// test does some sleeping to give the sniffer time to process. + #[tokio::test] + async fn filter_replace() { + let mut setup = TestSnifferSetup::new(); + + let mut api_1 = setup.get_api().await; + let mut api_2 = setup.get_api().await; + + api_1 + .handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + assert_eq!( + api_1.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + assert_eq!(setup.times_filter_changed(), 1); + + api_2 + .handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + assert_eq!( + api_2.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + assert_eq!(setup.times_filter_changed(), 1); // api_1 already subscribed `80` + + api_2 + .handle_client_message(LayerTcp::PortSubscribe(81)) + .await + .unwrap(); + assert_eq!( + api_2.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(81)), None), + ); + assert_eq!(setup.times_filter_changed(), 2); + + api_1 + .handle_client_message(LayerTcp::PortSubscribe(81)) + .await + .unwrap(); + assert_eq!( + api_1.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(81)), None), + ); + assert_eq!(setup.times_filter_changed(), 2); // api_2 already subscribed `81` + + api_1 + .handle_client_message(LayerTcp::PortUnsubscribe(80)) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(setup.times_filter_changed(), 2); // api_2 still subscribes `80` + + api_2 + .handle_client_message(LayerTcp::PortUnsubscribe(81)) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(setup.times_filter_changed(), 2); // api_1 still subscribes `81` + + api_1 + .handle_client_message(LayerTcp::PortUnsubscribe(81)) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(setup.times_filter_changed(), 3); + + api_2 + .handle_client_message(LayerTcp::PortUnsubscribe(80)) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(setup.times_filter_changed(), 4); + } + + /// Simulates scenario where client does not read connection data fast enough. + /// Packet buffer should overflow in the [`broadcast`] channel and the client should see the + /// connection being closed. + #[tokio::test] + async fn client_lagging_on_data() { + let mut setup = TestSnifferSetup::new(); + let mut api = setup.get_api().await; + + api.handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + + assert_eq!( + api.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + + let session_id = TcpSessionIdentifier { + source_addr: "1.1.1.1".parse().unwrap(), + dest_addr: "127.0.0.1".parse().unwrap(), + source_port: 3133, + dest_port: 80, }; - trace!("session {:#?}", session); - if is_client_packet && !tcp_packet.bytes.is_empty() { - let message = DaemonTcp::Data(TcpData { - bytes: tcp_packet.bytes, - connection_id: session.id, - }); - self.send_message_to_clients(session.clients.iter(), message) - .await?; + setup + .packet_tx + .send(( + session_id, + TcpPacketData { + bytes: b"hello".into(), + flags: TcpFlags::SYN, + }, + )) + .await + .unwrap(); + + let (message, log) = api.recv().await.unwrap(); + assert_eq!( + message, + DaemonTcp::NewConnection(NewTcpConnection { + connection_id: 0, + remote_address: session_id.source_addr.into(), + destination_port: session_id.dest_port, + source_port: session_id.source_port, + local_address: session_id.dest_addr.into(), + }), + ); + assert_eq!(log, None); + + let (message, log) = api.recv().await.unwrap(); + assert_eq!( + message, + DaemonTcp::Data(TcpData { + connection_id: 0, + bytes: b"hello".to_vec(), + }), + ); + assert_eq!(log, None); + + for _ in 0..TcpConnectionSniffer::::CONNECTION_DATA_CHANNEL_CAPACITY + 2 + { + setup + .packet_tx + .send(( + session_id, + TcpPacketData { + bytes: vec![0], + flags: 0, + }, + )) + .await + .unwrap(); } - if is_closed_connection(tcp_flags) { - self.index_allocator.free_index(session.id); - self.connection_id_to_tcp_identifier.remove(&session.id); - let message = DaemonTcp::Close(TcpClose { - connection_id: session.id, - }); + // Wait until sniffer consumes all messages. + setup + .packet_tx + .reserve_many(setup.packet_tx.max_capacity()) + .await + .unwrap(); - debug!( - "TcpConnectionSniffer::handle_packet -> message {:#?}", - message - ); + let (message, log) = api.recv().await.unwrap(); + assert_eq!(message, DaemonTcp::Close(TcpClose { connection_id: 0 }),); + let log = log.unwrap(); + assert_eq!(log.level, LogLevel::Error); + } - self.send_message_to_clients(session.clients.iter(), message) - .await?; - } else { - self.sessions.insert(identifier, session); + /// Simulates scenario where client does not read notifications about new connections fast + /// enough. Client should miss new connections. + #[tokio::test] + async fn client_lagging_on_new_connections() { + let mut setup = TestSnifferSetup::new(); + let mut api = setup.get_api().await; + + api.handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + + assert_eq!( + api.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + + let source_addr = "1.1.1.1".parse().unwrap(); + let dest_addr = "127.0.0.1".parse().unwrap(); + + // First send `TcpSnifferApi::CONNECTION_CHANNEL_SIZE` + 2 first connections. + let session_ids = + (0..=TcpSnifferApi::CONNECTION_CHANNEL_SIZE).map(|idx| TcpSessionIdentifier { + source_addr, + dest_addr, + source_port: 3000 + idx as u16, + dest_port: 80, + }); + for session in session_ids { + setup + .packet_tx + .send(( + session, + TcpPacketData { + bytes: Default::default(), + flags: TcpFlags::SYN, + }, + )) + .await + .unwrap(); } - Ok(()) + // Wait until sniffer processes all packets. + let permit = setup + .packet_tx + .reserve_many(setup.packet_tx.max_capacity()) + .await + .unwrap(); + std::mem::drop(permit); + + // Verify that we picked up `TcpSnifferApi::CONNECTION_CHANNEL_SIZE` first connections. + for i in 0..TcpSnifferApi::CONNECTION_CHANNEL_SIZE { + let (msg, log) = api.recv().await.unwrap(); + assert_eq!(log, None); + assert_eq!( + msg, + DaemonTcp::NewConnection(NewTcpConnection { + connection_id: i as ConnectionId, + remote_address: source_addr.into(), + destination_port: 80, + source_port: 3000 + i as u16, + local_address: dest_addr.into(), + }) + ) + } + + // Send one more connection. + setup + .packet_tx + .send(( + TcpSessionIdentifier { + source_addr, + dest_addr, + source_port: 3222, + dest_port: 80, + }, + TcpPacketData { + bytes: Default::default(), + flags: TcpFlags::SYN, + }, + )) + .await + .unwrap(); + + // Verify that we missed the last connections from the first batch. + let (msg, log) = api.recv().await.unwrap(); + assert_eq!(log, None); + assert_eq!( + msg, + DaemonTcp::NewConnection(NewTcpConnection { + connection_id: TcpSnifferApi::CONNECTION_CHANNEL_SIZE as ConnectionId, + remote_address: source_addr.into(), + destination_port: 80, + source_port: 3222, + local_address: dest_addr.into(), + }), + ); } } diff --git a/mirrord/agent/src/sniffer/api.rs b/mirrord/agent/src/sniffer/api.rs new file mode 100644 index 00000000000..35ee0f6d21d --- /dev/null +++ b/mirrord/agent/src/sniffer/api.rs @@ -0,0 +1,185 @@ +use futures::{stream::FuturesUnordered, StreamExt}; +use mirrord_protocol::{ + tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, + ConnectionId, LogMessage, Port, +}; +use tokio::sync::{ + mpsc::{self, Receiver, Sender}, + oneshot, +}; +use tokio_stream::{ + wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}, + StreamMap, StreamNotifyClose, +}; + +use super::messages::{SniffedConnection, SnifferCommand, SnifferCommandInner}; +use crate::{error::AgentError, util::ClientId, watched_task::TaskStatus}; + +/// Interface used by clients to interact with the +/// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). Multiple instances of this struct operate +/// on a single sniffer instance. +pub(crate) struct TcpSnifferApi { + /// Id of the client using this struct. + client_id: ClientId, + /// Channel used to send commands to the [`TcpConnectionSniffer`](super::TcpConnectionSniffer). + sender: Sender, + /// Channel used to receive messages from the + /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). + receiver: Receiver, + /// View on the sniffer task's status. + task_status: TaskStatus, + /// Currently sniffed connections. + connections: StreamMap>>>, + /// Id for the next sniffed connection. + next_connection_id: Option, + /// [`LayerTcp::PortSubscribe`] requests in progress. + subscriptions_in_progress: FuturesUnordered>, +} + +impl TcpSnifferApi { + /// Capacity for channel that will be used by + /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer) to notify this struct about new + /// connections. + pub const CONNECTION_CHANNEL_SIZE: usize = 128; + + /// Create a new instance of this struct and connect it to a + /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer) instance. + /// * `client_id` - id of the client using this struct + /// * `sniffer_sender` - channel used to send commands to the + /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer) + /// * `task_status` - handle to the [`TcpConnectionSniffer`](super::TcpConnectionSniffer) exit + /// status + pub async fn new( + client_id: ClientId, + sniffer_sender: Sender, + mut task_status: TaskStatus, + ) -> Result { + let (sender, receiver) = mpsc::channel(Self::CONNECTION_CHANNEL_SIZE); + + let command = SnifferCommand { + client_id, + command: SnifferCommandInner::NewClient(sender), + }; + if sniffer_sender.send(command).await.is_err() { + return Err(task_status.unwrap_err().await); + } + + Ok(Self { + client_id, + sender: sniffer_sender, + receiver, + task_status, + connections: Default::default(), + next_connection_id: Some(0), + subscriptions_in_progress: Default::default(), + }) + } + + /// Send the given command to the connected + /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). + async fn send_command(&mut self, command: SnifferCommandInner) -> Result<(), AgentError> { + let command = SnifferCommand { + client_id: self.client_id, + command, + }; + + if self.sender.send(command).await.is_ok() { + Ok(()) + } else { + Err(self.task_status.unwrap_err().await) + } + } + + /// Return the next message from the connected + /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). + pub async fn recv(&mut self) -> Result<(DaemonTcp, Option), AgentError> { + tokio::select! { + conn = self.receiver.recv() => match conn { + Some(conn) => { + let id = self.next_connection_id.ok_or(AgentError::ExhaustedConnectionId)?; + self.next_connection_id = id.checked_add(1); + + self.connections.insert(id, StreamNotifyClose::new(BroadcastStream::new(conn.data))); + + Ok(( + DaemonTcp::NewConnection(NewTcpConnection { + connection_id: id, + remote_address: conn.session_id.source_addr.into(), + local_address: conn.session_id.dest_addr.into(), + source_port: conn.session_id.source_port, + destination_port: conn.session_id.dest_port, + }), + None, + )) + }, + + None => { + Err(self.task_status.unwrap_err().await) + }, + }, + + Some((connection_id, bytes)) = self.connections.next() => match bytes { + Some(Ok(bytes)) => { + Ok(( + DaemonTcp::Data(TcpData { + connection_id, + bytes, + }), + None, + )) + } + + Some(Err(BroadcastStreamRecvError::Lagged(missed_packets))) => { + let log = LogMessage::error(format!( + "failed to process on time {missed_packets} packet(s) from mirrored connection {connection_id}, closing connection" + )); + + Ok(( + DaemonTcp::Close(TcpClose { connection_id }), + Some(log), + )) + } + + None => { + Ok(( + DaemonTcp::Close(TcpClose { connection_id }), + None + )) + } + }, + + Some(result) = self.subscriptions_in_progress.next() => match result { + Ok(port) => Ok((DaemonTcp::SubscribeResult(Ok(port)), None)), + Err(..) => { + Err(self.task_status.unwrap_err().await) + } + } + } + } + + /// Tansform the given message into a [`SnifferCommand`] and pass it to the connected + /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). + pub async fn handle_client_message(&mut self, message: LayerTcp) -> Result<(), AgentError> { + match message { + LayerTcp::PortSubscribe(port) => { + let (tx, rx) = oneshot::channel(); + self.send_command(SnifferCommandInner::Subscribe(port, tx)) + .await?; + self.subscriptions_in_progress.push(rx); + + Ok(()) + } + + LayerTcp::PortUnsubscribe(port) => { + self.send_command(SnifferCommandInner::UnsubscribePort(port)) + .await + } + + LayerTcp::ConnectionUnsubscribe(connection_id) => { + self.connections.remove(&connection_id); + + Ok(()) + } + } + } +} diff --git a/mirrord/agent/src/sniffer/messages.rs b/mirrord/agent/src/sniffer/messages.rs new file mode 100644 index 00000000000..83556c1d971 --- /dev/null +++ b/mirrord/agent/src/sniffer/messages.rs @@ -0,0 +1,45 @@ +use mirrord_protocol::Port; +use tokio::sync::{broadcast, mpsc::Sender, oneshot}; + +use super::TcpSessionIdentifier; +use crate::util::ClientId; + +/// Commmand for [`TcpConnectionSniffer`](super::TcpConnectionSniffer). +#[derive(Debug)] +pub(crate) enum SnifferCommandInner { + /// New client wants to use the sniffer. + NewClient( + /// For notyfing the client about new incoming connections. + Sender, + ), + /// Client wants to start receiving connections incoming to a specific port. + Subscribe( + /// Number of port to subscribe. + Port, + /// Channel to notify with the same port number when the operation is done. + oneshot::Sender, + ), + /// Client no longer wants to receive connections incoming to a specific port. + UnsubscribePort( + /// Number of port to unsubscribe. + Port, + ), +} + +/// Client's command for [`TcpConnectionSniffer`](super::TcpConnectionSniffer). +#[derive(Debug)] +pub(crate) struct SnifferCommand { + /// Id of the client. + pub client_id: ClientId, + /// Actual command. + pub command: SnifferCommandInner, +} + +/// New TCP connection picked up by [`TcpConnectionSniffer`](super::TcpConnectionSniffer). +pub(crate) struct SniffedConnection { + /// Parameters of this connection's TCP session. + /// Can be used to create [`NewTcpConnection`](mirrord_protocol::tcp::NewTcpConnection). + pub session_id: TcpSessionIdentifier, + /// For receiving data from this connection. + pub data: broadcast::Receiver>, +} diff --git a/mirrord/agent/src/sniffer/tcp_capture.rs b/mirrord/agent/src/sniffer/tcp_capture.rs new file mode 100644 index 00000000000..c15e9b940df --- /dev/null +++ b/mirrord/agent/src/sniffer/tcp_capture.rs @@ -0,0 +1,173 @@ +use std::{io, net::SocketAddr}; + +use mirrord_protocol::MeshVendor; +use nix::sys::socket::SockaddrStorage; +use pnet::packet::{ + ethernet::{EtherTypes, EthernetPacket}, + ip::IpNextHeaderProtocols, + ipv4::Ipv4Packet, + tcp::TcpPacket, + Packet, +}; +use rawsocket::{filter::SocketFilterProgram, RawCapture}; +use tokio::net::UdpSocket; +use tracing::Level; + +use super::{TcpPacketData, TcpSessionIdentifier}; +use crate::error::AgentError; + +/// Trait for structs that are able to sniff incoming Ethernet packets and filter TCP packets. +pub trait TcpCapture { + /// Sets a filter for incoming Ethernet packets. + fn set_filter(&mut self, filter: SocketFilterProgram) -> io::Result<()>; + + /// Returns the next sniffed TCP packet. + async fn next(&mut self) -> io::Result<(TcpSessionIdentifier, TcpPacketData)>; +} + +/// Implementor of [`TcpCapture`] that uses a raw OS socket and a BPF filter. +pub struct RawSocketTcpCapture { + /// Raw OS socket. + inner: RawCapture, +} + +impl RawSocketTcpCapture { + /// Creates a new instance. `network_interface` and `mesh` will be used to determine correct + /// network interface for the raw OS socket. + /// + /// Returned instance initially uses a BPF filter that drops every packet. + #[tracing::instrument(level = Level::DEBUG, err)] + pub async fn new( + network_interface: Option, + mesh: Option, + ) -> Result { + // Priority is whatever the user set as an option to mirrord, then we check if we're in a + // mesh to use `lo` interface, otherwise we try to get the appropriate interface. + let interface = match network_interface.or_else(|| mesh.map(|_| "lo".to_string())) { + Some(interface) => interface, + None => Self::resolve_interface() + .await? + .unwrap_or_else(|| "eth0".to_string()), + }; + + tracing::debug!( + resolved_interface = interface, + "Resolved raw capture interface" + ); + + let capture = RawCapture::from_interface_name(&interface)?; + capture.set_filter(rawsocket::filter::build_drop_always())?; + capture + .ignore_outgoing() + .map_err(AgentError::PacketIgnoreOutgoing)?; + Ok(Self { inner: capture }) + } + + /// Connects to a remote address (`8.8.8.8:53`) so we can find which network interface to use. + /// + /// Used when no `user_interface` is specified in [`Self::new`] to prevent mirrord from + /// defaulting to the wrong network interface (`eth0`), as sometimes the user's machine doesn't + /// have it available (i.e. their default network is `enp2s0`). + #[tracing::instrument(level = Level::DEBUG, err)] + async fn resolve_interface() -> io::Result> { + // Connect to a remote address so we can later get the default network interface. + let temporary_socket = UdpSocket::bind("0.0.0.0:0").await?; + temporary_socket.connect("8.8.8.8:53").await?; + + // Create comparison address here with `port: 0`, to match the network interface's address + // of `sin_port: 0`. + let local_address = SocketAddr::new(temporary_socket.local_addr()?.ip(), 0); + let raw_local_address = SockaddrStorage::from(local_address); + + // Try to find an interface that matches the local ip we have. + let usable_interface_name: Option = nix::ifaddrs::getifaddrs()?.find_map(|iface| { + (raw_local_address == iface.address?).then_some(iface.interface_name) + }); + + Ok(usable_interface_name) + } + + /// Extracts TCP packet from the raw Ethernet packet given as bytes. + /// If the given Ethernet packet is not TCP, returns [`None`]. + #[tracing::instrument(skip(eth_packet), level = Level::TRACE, fields(bytes = %eth_packet.len()))] + fn get_tcp_packet(eth_packet: Vec) -> Option<(TcpSessionIdentifier, TcpPacketData)> { + let eth_packet = EthernetPacket::new(ð_packet[..])?; + let ip_packet = match eth_packet.get_ethertype() { + EtherTypes::Ipv4 => Ipv4Packet::new(eth_packet.payload())?, + _ => return None, + }; + + let tcp_packet = match ip_packet.get_next_level_protocol() { + IpNextHeaderProtocols::Tcp => TcpPacket::new(ip_packet.payload())?, + _ => return None, + }; + + let dest_port = tcp_packet.get_destination(); + let source_port = tcp_packet.get_source(); + + let identifier = TcpSessionIdentifier { + source_addr: ip_packet.get_source(), + dest_addr: ip_packet.get_destination(), + source_port, + dest_port, + }; + + tracing::trace!(session_identifier = ?identifier, "Got TCP packet"); + + Some(( + identifier, + TcpPacketData { + flags: tcp_packet.get_flags(), + bytes: tcp_packet.payload().to_vec(), + }, + )) + } +} + +impl TcpCapture for RawSocketTcpCapture { + fn set_filter(&mut self, filter: SocketFilterProgram) -> io::Result<()> { + self.inner.set_filter(filter) + } + + async fn next(&mut self) -> io::Result<(TcpSessionIdentifier, TcpPacketData)> { + loop { + let raw = self.inner.next().await?; + + if let Some(tcp) = Self::get_tcp_packet(raw) { + break Ok(tcp); + } + } + } +} + +#[cfg(test)] +pub mod test { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + use tokio::sync::mpsc::Receiver; + + use super::*; + + /// Implementor of [`TcpCapture`] that returns packets received from an + /// [`mpsc`](tokio::sync::mpsc) channel. + pub struct TcpPacketsChannel { + pub times_filter_changed: Arc, + pub receiver: Receiver<(TcpSessionIdentifier, TcpPacketData)>, + } + + impl TcpCapture for TcpPacketsChannel { + /// Filter is ignored, we don't want to execute BPF programs in tests. + fn set_filter(&mut self, _filter: SocketFilterProgram) -> io::Result<()> { + self.times_filter_changed.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + + async fn next(&mut self) -> io::Result<(TcpSessionIdentifier, TcpPacketData)> { + Ok(self.receiver.recv().await.expect("channel closed")) + } + } +} diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index 170e1442a37..44e9d10b564 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -1,5 +1,16 @@ -use mirrord_protocol::tcp::{DaemonTcp, HttpResponseFallback, LayerTcpSteal, TcpData}; +use std::collections::HashMap; + +use bytes::Bytes; +use hyper::body::Frame; +use mirrord_protocol::{ + tcp::{ + ChunkedResponse, DaemonTcp, HttpResponse, HttpResponseFallback, InternalHttpResponse, + LayerTcpSteal, ReceiverStreamBody, TcpData, + }, + RequestId, +}; use tokio::sync::mpsc::{self, OwnedPermit, Receiver, Sender}; +use tokio_stream::wrappers::ReceiverStream; use super::*; use crate::{ @@ -31,6 +42,8 @@ pub(crate) struct TcpStealerApi { /// View on the stealer task's status. task_status: TaskStatus, + + response_body_txs: HashMap<(ConnectionId, RequestId), Sender>>>, } impl TcpStealerApi { @@ -65,6 +78,7 @@ impl TcpStealerApi { close_permit: Some(close_permit), daemon_rx, task_status, + response_body_txs: HashMap::new(), }) } @@ -89,7 +103,13 @@ impl TcpStealerApi { #[tracing::instrument(level = "trace", skip(self))] pub(crate) async fn recv(&mut self) -> Result { match self.daemon_rx.recv().await { - Some(msg) => Ok(msg), + Some(msg) => { + if let DaemonTcp::Close(close) = &msg { + self.response_body_txs + .retain(|(key_id, _), _| *key_id != close.connection_id); + } + Ok(msg) + } None => Err(self.task_status.unwrap_err().await), } } @@ -153,6 +173,8 @@ impl TcpStealerApi { match message { LayerTcpSteal::PortSubscribe(port_steal) => self.port_subscribe(port_steal).await, LayerTcpSteal::ConnectionUnsubscribe(connection_id) => { + self.response_body_txs + .retain(|(key_id, _), _| *key_id != connection_id); self.connection_unsubscribe(connection_id).await } LayerTcpSteal::PortUnsubscribe(port) => self.port_unsubscribe(port).await, @@ -165,6 +187,63 @@ impl TcpStealerApi { self.http_response(HttpResponseFallback::Framed(response)) .await } + LayerTcpSteal::HttpResponseChunked(inner) => match inner { + ChunkedResponse::Start(response) => { + let (tx, rx) = mpsc::channel(12); + let body = ReceiverStreamBody::new(ReceiverStream::from(rx)); + let http_response: HttpResponse = HttpResponse { + port: response.port, + connection_id: response.connection_id, + request_id: response.request_id, + internal_response: InternalHttpResponse { + status: response.internal_response.status, + version: response.internal_response.version, + headers: response.internal_response.headers, + body, + }, + }; + + let key = (response.connection_id, response.request_id); + self.response_body_txs.insert(key, tx.clone()); + + self.http_response(HttpResponseFallback::Streamed(http_response)) + .await?; + + for frame in response.internal_response.body { + if let Err(err) = tx.send(Ok(frame.into())).await { + self.response_body_txs.remove(&key); + tracing::trace!(?err, "error while sending streaming response frame"); + } + } + Ok(()) + } + ChunkedResponse::Body(body) => { + let key = &(body.connection_id, body.request_id); + let mut send_err = false; + if let Some(tx) = self.response_body_txs.get(key) { + for frame in body.frames { + if let Err(err) = tx.send(Ok(frame.into())).await { + send_err = true; + tracing::trace!( + ?err, + "error while sending streaming response body" + ); + break; + } + } + } + if send_err || body.is_last { + self.response_body_txs.remove(key); + }; + Ok(()) + } + ChunkedResponse::Error(err) => { + self.response_body_txs + .remove(&(err.connection_id, err.request_id)); + tracing::trace!(?err, "ChunkedResponse error received"); + Ok(()) + } + }, } } } diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 4eb754be4c0..d175a58936e 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -11,21 +11,24 @@ use hyper::{ http::{header::UPGRADE, request::Parts}, }; use mirrord_protocol::{ + body_chunks::{BodyExt as _, Frames}, tcp::{ - ChunkedRequest, ChunkedRequestBody, ChunkedRequestError, DaemonTcp, HttpRequest, + ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest, HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, - StealType, TcpClose, TcpData, HTTP_CHUNKED_VERSION, HTTP_FILTERED_UPGRADE_VERSION, + StealType, TcpClose, TcpData, HTTP_CHUNKED_REQUEST_VERSION, HTTP_FILTERED_UPGRADE_VERSION, HTTP_FRAMED_VERSION, }, ConnectionId, Port, RemoteError::{BadHttpFilterExRegex, BadHttpFilterRegex}, RequestId, }; +use serde::Deserialize; use tokio::{ net::TcpStream, sync::mpsc::{Receiver, Sender}, }; use tokio_util::sync::CancellationToken; +use tracing::warn; use crate::{ error::{AgentError, Result}, @@ -33,7 +36,7 @@ use crate::{ connections::{ ConnectionMessageIn, ConnectionMessageOut, StolenConnection, StolenConnections, }, - http::{Frames, HttpFilter, IncomingExt}, + http::HttpFilter, orig_dst, subscriptions::{IpTablesRedirector, PortSubscriptions}, Command, StealerCommand, @@ -149,10 +152,11 @@ impl Client { } let framed = HTTP_FRAMED_VERSION.matches(&self.protocol_version); - let chunked = HTTP_CHUNKED_VERSION.matches(&self.protocol_version); + let chunked = HTTP_CHUNKED_REQUEST_VERSION.matches(&self.protocol_version); let tx = self.tx.clone(); tokio::spawn(async move { + tracing::trace!(?request.connection_id, ?request.request_id, ?chunked, ?framed, "starting request"); // Chunked data is preferred over framed data if chunked { // Send headers @@ -170,7 +174,9 @@ impl Client { ) = request.request.into_parts(); match body.next_frames(true).await { Err(..) => return, - Ok(Frames { frames, is_last }) => { + // We don't check is_last here since loop will finish when body.next_frames() + // returns None + Ok(Frames { frames, .. }) => { let frames = frames .into_iter() .map(InternalHttpBodyFrame::try_from) @@ -189,7 +195,9 @@ impl Client { request_id, port: request.port, })); - if tx.send(message).await.is_err() || is_last { + + if let Err(e) = tx.send(message).await { + warn!(?e, ?connection_id, ?request_id, ?request.port, "failed to send chunked request start"); return; } } @@ -204,21 +212,27 @@ impl Client { .filter_map(Result::ok) .collect(); let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Body( - ChunkedRequestBody { + ChunkedHttpBody { frames, is_last, connection_id, request_id, }, )); - if tx.send(message).await.is_err() || is_last { + + if let Err(e) = tx.send(message).await { + warn!(?e, ?connection_id, ?request_id, ?request.port, "failed to send chunked request body"); + return; + } + + if is_last { return; } } Err(_) => { let _ = tx .send(DaemonTcp::HttpRequestChunked(ChunkedRequest::Error( - ChunkedRequestError { + ChunkedHttpError { connection_id, request_id, }, @@ -247,6 +261,12 @@ impl Client { } } +#[derive(Deserialize, Debug, Default)] +struct TcpStealerConfig { + stealer_flush_connections: bool, + pod_ips: Option, +} + /// Created once per agent during initialization. /// /// Meant to be run (see [`TcpConnectionStealer::start`]) in a separate thread while the agent @@ -274,12 +294,13 @@ impl TcpConnectionStealer { /// You need to call [`TcpConnectionStealer::start`] to do so. #[tracing::instrument(level = "trace")] pub(crate) async fn new(command_rx: Receiver) -> Result { + let config = envy::prefixed("MIRRORD_AGENT_") + .from_env::() + .unwrap_or_default(); + let port_subscriptions = { - let flush_connections = std::env::var("MIRRORD_AGENT_STEALER_FLUSH_CONNECTIONS") - .ok() - .and_then(|var| var.parse::().ok()) - .unwrap_or_default(); - let redirector = IpTablesRedirector::new(flush_connections).await?; + let redirector = + IpTablesRedirector::new(config.stealer_flush_connections, config.pod_ips).await?; PortSubscriptions::new(redirector, 4) }; @@ -674,6 +695,7 @@ mod test { }; use hyper_util::rt::TokioIo; use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame}; + use rstest::rstest; use tokio::{ net::{TcpListener, TcpStream}, sync::{ @@ -781,7 +803,7 @@ mod test { }; assert_eq!( x.internal_request.body, - vec![InternalHttpBodyFrame::Data(b"string".to_vec().into())] + vec![InternalHttpBodyFrame::Data(b"string".to_vec())] ); let x = client_rx.recv().now_or_never(); assert!(x.is_none()); @@ -797,13 +819,67 @@ mod test { }; assert_eq!( x.frames, - vec![InternalHttpBodyFrame::Data( - b"another_string".to_vec().into() - )] + vec![InternalHttpBodyFrame::Data(b"another_string".to_vec())] ); let x = client_rx.recv().now_or_never(); assert!(x.is_none()); let _ = response_tx.send(Response::new(Empty::default())); } + + #[rstest] + #[tokio::test] + #[timeout(std::time::Duration::from_secs(5))] + async fn test_empty_streaming_request() { + let (addr, mut request_rx) = prepare_dummy_service().await; + let conn = TcpStream::connect(addr).await.unwrap(); + let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(conn)) + .await + .unwrap(); + tokio::spawn(conn); + + tokio::spawn( + sender.send_request( + Request::builder() + .method(Method::POST) + .uri("/") + .version(Version::HTTP_11) + .body(http_body_util::Empty::::new()) + .unwrap(), + ), + ); + + let (client_tx, mut client_rx) = mpsc::channel::(4); + let client = Client { + tx: client_tx, + protocol_version: "1.7.0".parse().unwrap(), + subscribed_connections: Default::default(), + }; + + let (request, response_tx) = request_rx.recv().await.unwrap(); + client.send_request_async(MatchedHttpRequest { + connection_id: 0, + port: 80, + request_id: 0, + request, + }); + + // Verify that ChunkedRequest::Start request is as expected + let msg = client_rx.recv().await.unwrap(); + let DaemonTcp::HttpRequestChunked(ChunkedRequest::Start(_)) = msg else { + panic!("unexpected type received: {msg:?}") + }; + + // Verify that empty ChunkedRequest::Body request is as expected + let msg = client_rx.recv().await.unwrap(); + let DaemonTcp::HttpRequestChunked(ChunkedRequest::Body(x)) = msg else { + panic!("unexpected type received: {msg:?}") + }; + assert_eq!(x.frames, vec![]); + assert!(x.is_last); + let x = client_rx.recv().now_or_never(); + assert!(x.is_none()); + + let _ = response_tx.send(Response::new(Empty::default())); + } } diff --git a/mirrord/agent/src/steal/http.rs b/mirrord/agent/src/steal/http.rs index 53da34db228..159d9c9aac8 100644 --- a/mirrord/agent/src/steal/http.rs +++ b/mirrord/agent/src/steal/http.rs @@ -2,16 +2,12 @@ use crate::http::HttpVersion; -mod body_chunks; mod filter; mod reversible_stream; pub use filter::HttpFilter; -pub(crate) use self::{ - body_chunks::{Frames, IncomingExt}, - reversible_stream::ReversibleStream, -}; +pub(crate) use self::reversible_stream::ReversibleStream; /// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches. pub(crate) type DefaultReversibleStream = ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }>; diff --git a/mirrord/agent/src/steal/ip_tables.rs b/mirrord/agent/src/steal/ip_tables.rs index 9078169895d..b12440cc792 100644 --- a/mirrord/agent/src/steal/ip_tables.rs +++ b/mirrord/agent/src/steal/ip_tables.rs @@ -236,13 +236,17 @@ impl SafeIpTables where IPT: IPTables + Send + Sync, { - pub(super) async fn create(ipt: IPT, flush_connections: bool) -> Result { + pub(super) async fn create( + ipt: IPT, + flush_connections: bool, + pod_ips: Option<&str>, + ) -> Result { let ipt = Arc::new(ipt); let mut redirect = if let Some(vendor) = MeshVendor::detect(ipt.as_ref())? { - Redirects::Mesh(MeshRedirect::create(ipt.clone(), vendor)?) + Redirects::Mesh(MeshRedirect::create(ipt.clone(), vendor, pod_ips)?) } else { - match StandardRedirect::create(ipt.clone()) { + match StandardRedirect::create(ipt.clone(), pod_ips) { Err(err) => { warn!("Unable to create StandardRedirect chain: {err}"); @@ -416,7 +420,7 @@ mod tests { .times(1) .returning(|_| Ok(())); - let ipt = SafeIpTables::create(mock, false) + let ipt = SafeIpTables::create(mock, false, None) .await .expect("Create Failed"); @@ -549,7 +553,7 @@ mod tests { .times(1) .returning(|_| Ok(())); - let ipt = SafeIpTables::create(mock, false) + let ipt = SafeIpTables::create(mock, false, None) .await .expect("Create Failed"); diff --git a/mirrord/agent/src/steal/ip_tables/mesh.rs b/mirrord/agent/src/steal/ip_tables/mesh.rs index 7ce85e0e0e9..f365aeb8b8e 100644 --- a/mirrord/agent/src/steal/ip_tables/mesh.rs +++ b/mirrord/agent/src/steal/ip_tables/mesh.rs @@ -27,14 +27,14 @@ impl MeshRedirect where IPT: IPTables, { - pub fn create(ipt: Arc, vendor: MeshVendor) -> Result { + pub fn create(ipt: Arc, vendor: MeshVendor, pod_ips: Option<&str>) -> Result { let prerouteing = PreroutingRedirect::create(ipt.clone())?; for port in Self::get_skip_ports(&ipt, &vendor)? { prerouteing.add_rule(&format!("-m multiport -p tcp ! --dports {port} -j RETURN"))?; } - let output = OutputRedirect::create(ipt, IPTABLE_MESH.to_string())?; + let output = OutputRedirect::create(ipt, IPTABLE_MESH.to_string(), pod_ips)?; Ok(MeshRedirect { prerouteing, @@ -220,7 +220,7 @@ mod tests { mock.expect_insert_rule() .with( eq(IPTABLE_MESH.as_str()), - eq(format!("-m owner --gid-owner {gid} -p tcp -j RETURN")), + eq(format!("-m owner --gid-owner {gid} -p tcp -j RETURN")), eq(1), ) .times(1) @@ -245,8 +245,8 @@ mod tests { .times(1) .returning(|_| Ok(())); - let prerouting = - MeshRedirect::create(Arc::new(mock), MeshVendor::Linkerd).expect("Unable to create"); + let prerouting = MeshRedirect::create(Arc::new(mock), MeshVendor::Linkerd, None) + .expect("Unable to create"); assert!(prerouting.add_redirect(69, 420).await.is_ok()); } diff --git a/mirrord/agent/src/steal/ip_tables/output.rs b/mirrord/agent/src/steal/ip_tables/output.rs index 19ef3758753..96f21eccb32 100644 --- a/mirrord/agent/src/steal/ip_tables/output.rs +++ b/mirrord/agent/src/steal/ip_tables/output.rs @@ -20,12 +20,18 @@ where { const ENTRYPOINT: &'static str = "OUTPUT"; - pub fn create(ipt: Arc, chain_name: String) -> Result { + pub fn create(ipt: Arc, chain_name: String, pod_ips: Option<&str>) -> Result { let managed = IPTableChain::create(ipt, chain_name)?; + let exclude_source_ips = pod_ips + .map(|pod_ips| format!("! -s {pod_ips}")) + .unwrap_or_default(); + let gid = getgid(); managed - .add_rule(&format!("-m owner --gid-owner {gid} -p tcp -j RETURN")) + .add_rule(&format!( + "-m owner --gid-owner {gid} -p tcp {exclude_source_ips} -j RETURN" + )) .inspect_err(|_| { warn!("Unable to create iptable rule with \"--gid-owner {gid}\" filter") })?; @@ -34,7 +40,7 @@ where } pub fn load(ipt: Arc, chain_name: String) -> Result { - let managed = IPTableChain::create(ipt, chain_name)?; + let managed = IPTableChain::load(ipt, chain_name)?; Ok(OutputRedirect { managed }) } diff --git a/mirrord/agent/src/steal/ip_tables/standard.rs b/mirrord/agent/src/steal/ip_tables/standard.rs index 15155c2f3bb..79ac0e897ed 100644 --- a/mirrord/agent/src/steal/ip_tables/standard.rs +++ b/mirrord/agent/src/steal/ip_tables/standard.rs @@ -20,9 +20,9 @@ impl StandardRedirect where IPT: IPTables, { - pub fn create(ipt: Arc) -> Result { + pub fn create(ipt: Arc, pod_ips: Option<&str>) -> Result { let prerouteing = PreroutingRedirect::create(ipt.clone())?; - let output = OutputRedirect::create(ipt, IPTABLE_STANDARD.to_string())?; + let output = OutputRedirect::create(ipt, IPTABLE_STANDARD.to_string(), pod_ips)?; Ok(StandardRedirect { prerouteing, diff --git a/mirrord/agent/src/steal/subscriptions.rs b/mirrord/agent/src/steal/subscriptions.rs index 6ca96a81585..1fc0c04c756 100644 --- a/mirrord/agent/src/steal/subscriptions.rs +++ b/mirrord/agent/src/steal/subscriptions.rs @@ -58,6 +58,8 @@ pub(crate) struct IpTablesRedirector { redirect_to: Port, /// Listener to which redirect all connections. listener: TcpListener, + + pod_ips: Option, } impl IpTablesRedirector { @@ -73,7 +75,10 @@ impl IpTablesRedirector { /// /// * `flush_connections` - whether exisitng connections should be flushed when adding new /// redirects - pub(crate) async fn new(flush_connections: bool) -> Result { + pub(crate) async fn new( + flush_connections: bool, + pod_ips: Option, + ) -> Result { let listener = TcpListener::bind((Ipv4Addr::UNSPECIFIED, 0)).await?; let redirect_to = listener.local_addr()?.port(); @@ -82,6 +87,7 @@ impl IpTablesRedirector { flush_connections, redirect_to, listener, + pod_ips, }) } } @@ -95,7 +101,12 @@ impl PortRedirector for IpTablesRedirector { Some(iptables) => iptables, None => { let iptables = new_iptables(); - let safe = SafeIpTables::create(iptables.into(), self.flush_connections).await?; + let safe = SafeIpTables::create( + iptables.into(), + self.flush_connections, + self.pod_ips.as_deref(), + ) + .await?; self.iptables.insert(safe) } }; diff --git a/mirrord/agent/src/util.rs b/mirrord/agent/src/util.rs index eb87921b42a..ebcaab0e058 100644 --- a/mirrord/agent/src/util.rs +++ b/mirrord/agent/src/util.rs @@ -1,6 +1,6 @@ use std::{ clone::Clone, - collections::{HashMap, HashSet, VecDeque}, + collections::{hash_map::Entry, HashMap, HashSet, VecDeque}, future::Future, hash::Hash, thread::JoinHandle, @@ -19,7 +19,7 @@ use crate::{ /// When a topic has no subscribers, it is removed. #[derive(Debug, Default)] pub struct Subscriptions { - _inner: HashMap>, + inner: HashMap>, } /// Id of an agent's client. Each new client connection is assigned with a unique id. @@ -31,57 +31,66 @@ where C: Eq + Hash + Clone + Copy, { /// Add a new subscription to a topic for a given client. - pub fn subscribe(&mut self, client: C, topic: T) { - self._inner.entry(topic).or_default().insert(client); + /// Returns whether this resulted in adding a new topic to this mapping. + pub fn subscribe(&mut self, client: C, topic: T) -> bool { + match self.inner.entry(topic) { + Entry::Occupied(mut e) => { + e.get_mut().insert(client); + false + } + Entry::Vacant(e) => { + e.insert([client].into()); + true + } + } } /// Remove a subscription of given client from the topic. - /// topic is removed if no subscribers left. - pub fn unsubscribe(&mut self, client: C, topic: T) { - if let Some(set) = self._inner.get_mut(&topic) { - set.remove(&client); - if set.is_empty() { - self._inner.remove(&topic); + /// Topic is removed if no subscribers left. + /// Return whether the topic was removed. + pub fn unsubscribe(&mut self, client: C, topic: T) -> bool { + match self.inner.entry(topic) { + Entry::Occupied(mut e) => { + e.get_mut().remove(&client); + if e.get().is_empty() { + e.remove(); + true + } else { + false + } } + Entry::Vacant(..) => false, } } /// Get a vector of clients subscribed to a specific topic - pub fn get_topic_subscribers(&self, topic: T) -> Vec { - match self._inner.get(&topic) { - Some(clients_set) => clients_set.iter().cloned().collect(), - None => Vec::new(), - } + pub fn get_topic_subscribers(&self, topic: T) -> Option<&HashSet> { + self.inner.get(&topic) } /// Get subscribed topics pub fn get_subscribed_topics(&self) -> Vec { - self._inner.keys().cloned().collect() + self.inner.keys().cloned().collect() } - /// Get topics subscribed by a client - pub fn get_client_topics(&self, client: C) -> Vec { - let mut result = Vec::new(); - for (topic, client_set) in self._inner.iter() { - if client_set.contains(&client) { - result.push(*topic) - } - } - result - } + /// Remove all subscriptions of a client. + /// Topics are removed if no subscribers left. + /// Returns whether any topic was removed. + pub fn remove_client(&mut self, client: C) -> bool { + let prev_length = self.inner.len(); - /// Remove all subscriptions of a client - pub fn remove_client(&mut self, client: C) { - let topics = self.get_client_topics(client); - for topic in topics { - self.unsubscribe(client, topic) - } + self.inner.retain(|_, client_set| { + client_set.remove(&client); + !client_set.is_empty() + }); + + self.inner.len() != prev_length } /// Removes a topic and all of it's clients #[allow(dead_code)] // we might want it later on pub fn remove_topic(&mut self, topic: T) { - self._inner.remove(&topic); + self.inner.remove(&topic); } } @@ -205,10 +214,29 @@ pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> Result<(), A #[cfg(test)] mod subscription_tests { + use std::hash::Hash; + use mirrord_protocol::Port; use super::Subscriptions; + impl Subscriptions + where + C: Hash + Eq, + T: Copy, + { + /// Get topics subscribed by a client + fn get_client_topics(&self, client: C) -> Vec { + let mut result = Vec::new(); + for (topic, client_set) in self.inner.iter() { + if client_set.contains(&client) { + result.push(*topic) + } + } + result + } + } + #[test] fn sanity() { let mut subscriptions: Subscriptions = Default::default(); @@ -219,12 +247,25 @@ mod subscription_tests { subscriptions.subscribe(2, 1); subscriptions.subscribe(2, 1); subscriptions.subscribe(3, 1); - let mut subscribers = subscriptions.get_topic_subscribers(1); + let mut subscribers = subscriptions + .get_topic_subscribers(1) + .into_iter() + .flatten() + .copied() + .collect::>(); subscribers.sort(); assert_eq!(subscribers, vec![2, 3]); - assert_eq!(subscriptions.get_topic_subscribers(4), vec![1]); - assert_eq!(subscriptions.get_topic_subscribers(10), Vec::::new()); + assert_eq!( + subscriptions + .get_topic_subscribers(4) + .into_iter() + .flatten() + .copied() + .collect::>(), + vec![1] + ); + assert!(subscriptions.get_topic_subscribers(10).is_none()); let mut topics = subscriptions.get_subscribed_topics(); topics.sort(); assert_eq!(topics, vec![1, 2, 3, 4]); @@ -232,9 +273,17 @@ mod subscription_tests { topics.sort(); assert_eq!(topics, vec![1, 2, 3]); subscriptions.unsubscribe(3, 1); - assert_eq!(subscriptions.get_topic_subscribers(1), vec![2]); + assert_eq!( + subscriptions + .get_topic_subscribers(1) + .into_iter() + .flatten() + .copied() + .collect::>(), + vec![2] + ); subscriptions.unsubscribe(1, 4); - assert_eq!(subscriptions.get_topic_subscribers(4), Vec::::new()); + assert!(subscriptions.get_topic_subscribers(4).is_none()); subscriptions.remove_client(3); assert_eq!(subscriptions.get_client_topics(3), Vec::::new()); subscriptions.remove_topic(1); diff --git a/mirrord/auth/Cargo.toml b/mirrord/auth/Cargo.toml index f75753f8bd5..ede8d9eeac8 100644 --- a/mirrord/auth/Cargo.toml +++ b/mirrord/auth/Cargo.toml @@ -42,12 +42,10 @@ serde = { version = "1", features = ["derive"] } serde_yaml = { version = "0.9", optional = true } tokio = { workspace = true, features = ["fs"], optional = true } thiserror = "1" -# don't upgrade it due to https://github.com/metalbear-co/operator/issues/556 -# unless you know what you're doing!!! -x509-certificate = "0.21" +x509-certificate = "0.23.1" # not direct dependency, but if we don't put it here it'll use openssl :( reqwest = { workspace = true, features=["json", "rustls-tls-native-roots"], default-features = false, optional = true } tracing.workspace = true -# don't upgrade it due to https://github.com/metalbear-co/operator/issues/556 -# unless you know what you're doing!!! -ring = "0.16" \ No newline at end of file + +[dev-dependencies] +bcder = "0.7" diff --git a/mirrord/auth/src/certificate.rs b/mirrord/auth/src/certificate.rs index e4fbd2a7560..b51b53b9178 100644 --- a/mirrord/auth/src/certificate.rs +++ b/mirrord/auth/src/certificate.rs @@ -80,3 +80,44 @@ impl Deref for Certificate { &self.0 } } + +#[cfg(test)] +mod test { + use chrono::{TimeZone, Utc}; + use x509_certificate::asn1time::Time; + + use super::Certificate; + + /// Verifies that [`Certificate`] properly deserializes from value produced by old code. + #[test] + fn deserialize_from_old_format() { + const SERIALIZED: &'static str = r#""-----BEGIN CERTIFICATE-----\r\nMIICGTCCAcmgAwIBAgIBATAHBgMrZXAFADBwMUIwQAYDVQQDDDlUaGUgTWljaGHF\r\ngiBTbW9sYXJlayBPcmdhbml6YXRpb25gcyBUZWFtcyBMaWNlbnNlIChUcmlhbCkx\r\nKjAoBgNVBAoMIVRoZSBNaWNoYcWCIFNtb2xhcmVrIE9yZ2FuaXphdGlvbjAeFw0y\r\nNDAyMDgxNTUwNDFaFw0yNDEyMjQwMDAwMDBaMBsxGTAXBgNVBAMMEHJheno0Nzgw\r\nLW1hY2hpbmUwLDAHBgMrZW4FAAMhAAfxTouyk5L5lB3eFwC5Rg9iI4KmQaFpnGVM\r\n2sYpv9HOo4HYMIHVMIHSBhcvbWV0YWxiZWFyL2xpY2Vuc2UvaW5mbwEB/wSBs3si\r\ndHlwZSI6InRlYW1zIiwibWF4X3NlYXRzIjpudWxsLCJzdWJzY3JpcHRpb25faWQi\r\nOiJmMWIxZDI2ZS02NGQzLTQ4YjYtYjVkMi05MzAxMzAwNWE3MmUiLCJvcmdhbml6\r\nYXRpb25faWQiOiIzNTdhZmE4MS0yN2QxLTQ3YjEtYTFiYS1hYzM1ZjlhM2MyNjMi\r\nLCJ0cmlhbCI6dHJ1ZSwidmVyc2lvbiI6IjMuNzMuMCJ9MAcGAytlcAUAA0EAJbbo\r\nu42KnHJBbPMYspMdv9ZdTQMixJgQUheNEs/o4+XfwgYOaRjCVQTzYs1m9f720WQ9\r\n4J04GdQvcu7B/oTgDQ==\r\n-----END CERTIFICATE-----\r\n""#; + let cert: Certificate = serde_yaml::from_str(SERIALIZED).unwrap(); + + assert_eq!( + cert.as_ref().signature.octet_bytes().as_ref(), + b"%\xb6\xe8\xbb\x8d\x8a\x9crAl\xf3\x18\xb2\x93\x1d\xbf\xd6]M\x03\"\xc4\x98\x10R\x17\x8d\x12\xcf\xe8\xe3\xe5\xdf\xc2\x06\x0ei\x18\xc2U\x04\xf3b\xcdf\xf5\xfe\xf6\xd1d=\xe0\x9d8\x19\xd4/r\xee\xc1\xfe\x84\xe0\r", + ); + + assert_eq!( + cert.as_ref().tbs_certificate.subject_public_key_info.subject_public_key.octet_bytes().as_ref(), + b"\x07\xf1N\x8b\xb2\x93\x92\xf9\x94\x1d\xde\x17\0\xb9F\x0fb#\x82\xa6A\xa1i\x9ceL\xda\xc6)\xbf\xd1\xce", + ); + assert_eq!( + cert.as_ref() + .tbs_certificate + .subject + .user_friendly_str() + .unwrap(), + "CN=razz4780-machine", + ); + assert_eq!( + cert.as_ref().tbs_certificate.validity.not_before, + Time::from(Utc.with_ymd_and_hms(2024, 2, 8, 15, 50, 41).unwrap()) + ); + assert_eq!( + cert.as_ref().tbs_certificate.validity.not_after, + Time::from(Utc.with_ymd_and_hms(2024, 12, 24, 00, 00, 00).unwrap()) + ); + } +} diff --git a/mirrord/auth/src/credential_store.rs b/mirrord/auth/src/credential_store.rs index a6589680ce6..8b9624d6aec 100644 --- a/mirrord/auth/src/credential_store.rs +++ b/mirrord/auth/src/credential_store.rs @@ -12,13 +12,10 @@ use tokio::{ fs, io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt, SeekFrom}, }; -use tracing::info; use whoami::fallible; use crate::{ - certificate::Certificate, - credentials::Credentials, - error::{AuthenticationError, CertificateStoreError, Result}, + certificate::Certificate, credentials::Credentials, error::CredentialStoreError, key_pair::KeyPair, }; @@ -69,28 +66,27 @@ impl UserIdentity { impl CredentialStore { /// Load contents of store from file - async fn load(source: &mut R) -> Result { + async fn load(source: &mut R) -> Result { let mut buffer = Vec::new(); - source .read_to_end(&mut buffer) .await - .map_err(CertificateStoreError::from)?; - - serde_yaml::from_slice(&buffer) - .map_err(CertificateStoreError::from) - .map_err(AuthenticationError::from) + .map_err(CredentialStoreError::FileAccess)?; + serde_yaml::from_slice(&buffer).map_err(From::from) } /// Save contents of store to file - async fn save(&self, writer: &mut W) -> Result<()> { - let buffer = serde_yaml::to_string(&self).map_err(CertificateStoreError::from)?; - + async fn save( + &self, + writer: &mut W, + ) -> Result<(), CredentialStoreError> { + let buffer = serde_yaml::to_string(&self)?; writer .write_all(buffer.as_bytes()) .await - .map_err(CertificateStoreError::from) - .map_err(AuthenticationError::from) + .map_err(CredentialStoreError::FileAccess)?; + + Ok(()) } /// Get hostname to be used as common name in a certification request. @@ -124,7 +120,7 @@ impl CredentialStore { client: &Client, operator_fingerprint: String, operator_subscription_id: Option, - ) -> Result<&mut Credentials> + ) -> Result<&mut Credentials, CredentialStoreError> where R: Resource + Clone + Debug, R: for<'de> Deserialize<'de>, @@ -173,11 +169,11 @@ pub struct CredentialStoreSync { } impl CredentialStoreSync { - pub async fn open() -> Result { + pub async fn open() -> Result { if !CREDENTIALS_DIR.exists() { fs::create_dir_all(&*CREDENTIALS_DIR) .await - .map_err(CertificateStoreError::from)?; + .map_err(CredentialStoreError::ParentDir)?; } let store_file = fs::OpenOptions::new() @@ -187,7 +183,7 @@ impl CredentialStoreSync { .truncate(false) .open(&*CREDENTIALS_PATH) .await - .map_err(CertificateStoreError::from)?; + .map_err(CredentialStoreError::FileAccess)?; Ok(Self { store_file }) } @@ -200,7 +196,7 @@ impl CredentialStoreSync { operator_fingerprint: String, operator_subscription_id: Option, callback: C, - ) -> Result + ) -> Result where R: Resource + Clone + Debug, R: for<'de> Deserialize<'de>, @@ -209,7 +205,7 @@ impl CredentialStoreSync { { let mut store = CredentialStore::load(&mut self.store_file) .await - .inspect_err(|err| info!("CredentialStore Load Error {err:?}")) + .inspect_err(|error| tracing::warn!(%error, "CredentialStore load failed")) .unwrap_or_default(); let value = callback( @@ -222,7 +218,7 @@ impl CredentialStoreSync { self.store_file .seek(SeekFrom::Start(0)) .await - .map_err(CertificateStoreError::from)?; + .map_err(CredentialStoreError::FileAccess)?; store.save(&mut self.store_file).await?; @@ -235,7 +231,7 @@ impl CredentialStoreSync { client: &Client, operator_fingerprint: String, operator_subscription_id: Option, - ) -> Result + ) -> Result where R: Resource + Clone + Debug, R: for<'de> Deserialize<'de>, @@ -243,7 +239,7 @@ impl CredentialStoreSync { { self.store_file .lock_exclusive() - .map_err(CertificateStoreError::Lockfile)?; + .map_err(CredentialStoreError::Lockfile)?; let result = self .access_credential::( @@ -256,7 +252,7 @@ impl CredentialStoreSync { self.store_file .unlock() - .map_err(CertificateStoreError::Lockfile)?; + .map_err(CredentialStoreError::Lockfile)?; result } diff --git a/mirrord/auth/src/credentials.rs b/mirrord/auth/src/credentials.rs index 0e49ac0f520..33b2ce887ad 100644 --- a/mirrord/auth/src/credentials.rs +++ b/mirrord/auth/src/credentials.rs @@ -4,14 +4,11 @@ use chrono::{DateTime, NaiveDate, NaiveTime, Utc}; use serde::{Deserialize, Serialize}; pub use x509_certificate; use x509_certificate::{ - asn1time::Time, rfc2986, rfc5280, InMemorySigningKeyPair, KeyAlgorithm, X509CertificateBuilder, + asn1time::Time, rfc2986, rfc5280, InMemorySigningKeyPair, X509CertificateBuilder, + X509CertificateError, }; -use crate::{ - certificate::Certificate, - error::{AuthenticationError, Result}, - key_pair::KeyPair, -}; +use crate::{certificate::Certificate, key_pair::KeyPair}; /// Client credentials container for authentication with the operator. /// Contains a local [`KeyPair`] and an optional [`Certificate`]. @@ -44,16 +41,14 @@ impl Credentials { fn certificate_request( common_name: &str, key_pair: &InMemorySigningKeyPair, - ) -> Result { - let mut builder = X509CertificateBuilder::new(KeyAlgorithm::Ed25519); + ) -> Result { + let mut builder = X509CertificateBuilder::default(); let _ = builder .subject() .append_common_name_utf8_string(common_name); - builder - .create_certificate_signing_request(key_pair) - .map_err(AuthenticationError::from) + builder.create_certificate_signing_request(key_pair) } } @@ -166,9 +161,9 @@ impl DateValidityExt for rfc5280::Validity { #[cfg(feature = "client")] pub mod client { use kube::{api::PostParams, Api, Client, Resource}; - use ring::rand::SystemRandom; use super::*; + use crate::error::CredentialStoreError; impl Credentials { /// Create a [`rfc2986::CertificationRequest`] and send it to the operator. @@ -177,7 +172,7 @@ pub mod client { client: Client, common_name: &str, key_pair: Option, - ) -> Result + ) -> Result where R: Resource + Clone + Debug, R: for<'de> Deserialize<'de>, @@ -185,18 +180,12 @@ pub mod client { { let key_pair = match key_pair { Some(key_pair) => key_pair, - None => { - let rng = SystemRandom::new(); - let document = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng) - .map_err(|_| AuthenticationError::KeyGenerationError)?; - let pem_key = pem::Pem::new("PRIVATE KEY", document.as_ref()); - pem::encode(&pem_key).into() - } + None => KeyPair::new_random()?, }; let certificate_request = Self::certificate_request(common_name, &key_pair)? .encode_pem() - .map_err(x509_certificate::X509CertificateError::from)?; + .map_err(X509CertificateError::from)?; let api: Api = Api::all(client); @@ -217,7 +206,11 @@ pub mod client { /// Create [`rfc2986::CertificationRequest`] and send it to the operator. /// Returned certificate replaces the [`Certificate`] stored in this struct. - pub async fn refresh(&mut self, client: Client, common_name: &str) -> Result<()> + pub async fn refresh( + &mut self, + client: Client, + common_name: &str, + ) -> Result<(), CredentialStoreError> where R: Resource + Clone + Debug, R: for<'de> Deserialize<'de>, @@ -225,7 +218,7 @@ pub mod client { { let certificate_request = Self::certificate_request(common_name, &self.key_pair)? .encode_pem() - .map_err(x509_certificate::X509CertificateError::from)?; + .map_err(X509CertificateError::from)?; let api: Api = Api::all(client); @@ -244,3 +237,55 @@ pub mod client { } } } + +#[cfg(test)] +mod test { + use bcder::{ + decode::{BytesSource, Constructed}, + Mode, + }; + use x509_certificate::rfc2986::CertificationRequest; + + /// Verifies that [`CertificationRequest`] properly decodes from value produced by old code. + #[test] + fn decode_old_certificate_request() { + const REQUEST: &'static str = "PEM: -----BEGIN CERTIFICATE REQUEST----- +MIGXMEkCAQAwFDESMBAGA1UEAwwJc29tZV9uYW1lMCwwBwYDK2VuBQADIQDhLn5T +fFTb4xOq+a1HyC3T7ScFiQGBy+oUcwFiCVCUI6AAMAcGAytlcAUAA0EAPBRvsUHo ++J/INwq6tn5kgcE9vMo48kRkyhWSp3XmfuUvxW/b7LufrlTcjw+4RG8pdugMXhcz +5+u20nm4VY+sCg== +-----END CERTIFICATE REQUEST----- +"; + const PUBLIC_KEY: &'static [u8] = b"\xe1.~S|T\xdb\xe3\x13\xaa\xf9\xadG\xc8-\xd3\xed'\x05\x89\x01\x81\xcb\xea\x14s\x01b\tP\x94#"; + + let certification_request_pem = pem::parse(REQUEST).unwrap(); + let certification_request_source = + BytesSource::new(certification_request_pem.into_contents().into()); + let certification_request = Constructed::decode( + certification_request_source, + Mode::Der, + CertificationRequest::take_from, + ) + .unwrap(); + + assert_eq!( + certification_request + .certificate_request_info + .subject + .iter_common_name() + .next() + .unwrap() + .to_string() + .unwrap(), + "some_name" + ); + assert_eq!( + certification_request + .certificate_request_info + .subject_public_key_info + .subject_public_key + .octet_bytes(), + PUBLIC_KEY + ); + } +} diff --git a/mirrord/auth/src/error.rs b/mirrord/auth/src/error.rs index 6a3107292bc..3488568907d 100644 --- a/mirrord/auth/src/error.rs +++ b/mirrord/auth/src/error.rs @@ -1,47 +1,26 @@ use thiserror::Error; use x509_certificate::X509CertificateError; -/// Wrapper error for errors in mirrord-auth library -#[derive(Debug, Error)] -pub enum AuthenticationError { - /// Error from parsing `pem` wrapped certificate/key-pair - #[error(transparent)] - Pem(std::io::Error), - - /// Error from from generating sha256 fingerprint for certificate/key-pair - #[error(transparent)] - Fingerprint(std::io::Error), - - /// Error from `x509_certificate` library - #[error(transparent)] - X509Certificate(#[from] X509CertificateError), - - #[cfg(feature = "client")] - #[error(transparent)] - CertificateStore(#[from] CertificateStoreError), - - #[cfg(feature = "client")] - #[error(transparent)] - Kube(#[from] kube::Error), - - /// Failed to generate key pair - #[error("Failed to generate key pair")] - KeyGenerationError, -} - -/// Error from CredentialStore operations +/// Errors from [`CredentialStore`](crate::credential_store::CredentialStore) and +/// [`CredentialStoreSync`](crate::credential_store::CredentialStoreSync) operations #[cfg(feature = "client")] #[derive(Debug, Error)] -pub enum CertificateStoreError { - #[error("Unable to save/load CertificateStore: {0}")] - Io(#[from] std::io::Error), +pub enum CredentialStoreError { + #[error("failed to parent directory for credential store file: {0}")] + ParentDir(std::io::Error), - #[error("Unable to create CertificateStore lockfile: {0}")] + #[error("IO on credential store file failed: {0}")] + FileAccess(std::io::Error), + + #[error("failed to lock/unlock credential store file: {0}")] Lockfile(std::io::Error), - #[error("Unable serialize/deserialize CertificateStore: {0}")] + #[error("failed to serialize/deserialize credentials: {0}")] Yaml(#[from] serde_yaml::Error), -} -/// `Result` with `AuthenticationError` as default error -pub type Result = std::result::Result; + #[error("x509 certificate error: {0}")] + X509Certificate(#[from] X509CertificateError), + + #[error("certification request failed: {0}")] + Kube(#[from] kube::Error), +} diff --git a/mirrord/auth/src/key_pair.rs b/mirrord/auth/src/key_pair.rs index dc7190587bd..06d2bc44a6b 100644 --- a/mirrord/auth/src/key_pair.rs +++ b/mirrord/auth/src/key_pair.rs @@ -1,23 +1,124 @@ -use std::{ops::Deref, sync::OnceLock}; +use std::{borrow::Cow, ops::Deref, sync::Arc}; -use serde::{Deserialize, Serialize}; -use x509_certificate::InMemorySigningKeyPair; +use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; +use x509_certificate::{InMemorySigningKeyPair, KeyAlgorithm, X509CertificateError}; -/// Wrapps `InMemorySigningKeyPair` & the underlying pkcs8 formatted key -#[derive(Debug, Serialize, Deserialize)] -#[serde(transparent)] -pub struct KeyPair(String, #[serde(skip)] OnceLock); - -impl Clone for KeyPair { - fn clone(&self) -> Self { - Self(self.0.clone(), Default::default()) - } +/// Wrapper over [`InMemorySigningKeyPair`]. +/// +/// Can be (de)serialized from/to either valid or buggy format. The format can also be switched in +/// memory with [`Self::bug_der`] and [`Self::fix_der`]. See . +#[derive(Debug, Clone)] +pub struct KeyPair { + /// PEM-encoded document containing the key pair. + pem: String, + /// Whether the DER-encoded key pair container in [`Self::pem`] is in buggy format. + der_bugged: bool, + /// Deserialized and initialized key pair for signing. + /// The key pair is wrapped in [`Arc`] only because [`InMemorySigningKeyPair`] is not + /// cloneable. + key_pair: Arc, } impl KeyPair { - /// Access the PEM encoded SigningKeyPair + /// Buggy prefix of DER encoding used in old version `ring` crate. + const BUGGED_PREFIX: [u8; 16] = [ + 0x30, 0x53, 0x02, 0x01, 0x01, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x04, 0x22, 0x04, + 0x20, + ]; + + /// Buggy middle part of DER encoding used in old version of `ring` crate. + const BUGGED_MIDDLE: [u8; 5] = [0xa1, 0x23, 0x03, 0x21, 0x00]; + + /// Valid prefix of DER encoding, used for patching buggy DERs. + const FIXED_PREFIX: [u8; 16] = [ + 0x30, 0x51, 0x02, 0x01, 0x01, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x04, 0x22, 0x04, + 0x20, + ]; + + /// Valid middle part of DER encoding, used for patching buggy DERs. + const FIXED_MIDDLE: [u8; 3] = [0x81, 0x21, 0x00]; + + /// If the given DER was produced by old version of `ring` crate, returns a patched valid + /// version that contains the same key pair. If not, returns [`None`]. + fn patch_buggy_der(der: &[u8]) -> Option> { + if der.len() != 85 + || der.get(..16).expect("length was checked") != Self::BUGGED_PREFIX + || der.get(16 + 32..16 + 32 + 5).expect("length was checked") != Self::BUGGED_MIDDLE + { + return None; + } + + let seed = der + .get(Self::BUGGED_PREFIX.len()..Self::BUGGED_PREFIX.len() + 32) + .expect("length was checked"); + let public_key = der + .get(Self::BUGGED_PREFIX.len() + 32 + Self::BUGGED_MIDDLE.len()..) + .expect("length was checked"); + + [&Self::FIXED_PREFIX, seed, &Self::FIXED_MIDDLE, public_key] + .concat() + .into() + } + + /// Generates a new random [`KeyPair`]. + /// The new [`KeyPair`] initially has valid format. + pub fn new_random() -> Result { + let key_pair = InMemorySigningKeyPair::generate_random(KeyAlgorithm::Ed25519)?; + let der = key_pair.to_pkcs8_one_asymmetric_key_der(); + let pem_key = pem::Pem::new("PRIVATE KEY", der.to_vec()); + let pem_document = pem::encode(&pem_key); + + Ok(Self { + pem: pem_document, + key_pair: key_pair.into(), + der_bugged: false, + }) + } + + /// Exposes this key pair as a PEM-encoded document. pub fn document(&self) -> &str { - &self.0 + &self.pem + } + + /// Changes format to the buggy one. + /// Old `ring` and [`x509_certificate`] versions will accept it. + /// New `ring` and [`x509_certificate`] versions will reject it. + pub fn bug_der(&mut self) { + if self.der_bugged { + return; + } + + let pem_key = pem::parse(&self.pem).expect("PEM was verified when creating this KeyPair"); + let der = pem_key.contents(); + let seed = der + .get(Self::FIXED_PREFIX.len()..Self::FIXED_PREFIX.len() + 32) + .expect("PEM was verified when creating this KeyPair"); + let public_key = der + .get(Self::FIXED_PREFIX.len() + 32 + Self::FIXED_MIDDLE.len()..) + .expect("PEM was verified when creating this KeyPair"); + + let bugged_der = [&Self::BUGGED_PREFIX, seed, &Self::BUGGED_MIDDLE, public_key].concat(); + let pem_key = pem::Pem::new("PRIVATE KEY", bugged_der); + let pem_document = pem::encode(&pem_key); + + self.pem = pem_document; + self.der_bugged = true; + } + + /// Changes format to the valid one. + /// Old `ring` and [`x509_certificate`] versions will reject it. + /// New `ring` and [`x509_certificate`] versions will accept it. + pub fn fix_der(&mut self) { + if !self.der_bugged { + return; + } + + let der = self.key_pair.to_pkcs8_one_asymmetric_key_der(); + let pem_key = pem::Pem::new("PRIVATE KEY", der.to_vec()); + let pem_document = pem::encode(&pem_key); + + self.pem = pem_document; + self.der_bugged = false; } } @@ -25,20 +126,110 @@ impl Deref for KeyPair { type Target = InMemorySigningKeyPair; fn deref(&self) -> &Self::Target { - self.1.get_or_init(|| { - InMemorySigningKeyPair::from_pkcs8_pem(&self.0).expect("Invalid pkcs8 key stored") + &self.key_pair + } +} + +impl TryFrom for KeyPair { + type Error = X509CertificateError; + + fn try_from(value: String) -> Result { + let pem_key = pem::parse(&value).map_err(X509CertificateError::PemDecode)?; + let (contents, der_bugged) = match Self::patch_buggy_der(pem_key.contents()) { + Some(contents) => (Cow::Owned(contents), true), + None => (Cow::Borrowed(pem_key.contents()), false), + }; + + let key_pair = InMemorySigningKeyPair::from_pkcs8_der(contents)?; + + Ok(Self { + pem: value, + key_pair: key_pair.into(), + der_bugged, }) } } -impl From<&str> for KeyPair { - fn from(key: &str) -> Self { - KeyPair(key.to_owned(), Default::default()) +impl Serialize for KeyPair { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.pem.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for KeyPair { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let pem = String::deserialize(deserializer)?; + Self::try_from(pem).map_err(D::Error::custom) } } -impl From for KeyPair { - fn from(key: String) -> Self { - KeyPair(key, Default::default()) +#[cfg(test)] +mod test { + use x509_certificate::Signer; + + use super::KeyPair; + + /// Verifies that [`KeyPair`] properly deserializes from old buggy format. + #[test] + fn deserialize_old_format() { + // Produced with previous version of this crate. + const SERIALIZED: &'static str = "-----BEGIN PRIVATE KEY-----\r\nMFMCAQEwBQYDK2VwBCIEIAnnKqvgSX5b4p2WZhe/hQOpt/D7z4P1H9UHJ2iiIat1\r\noSMDIQCQaTis0CQ62Y8+pePb3+x7umYRY0368BNyD5UrLZCMqA==\r\n-----END PRIVATE KEY-----\r\n"; + const EXPECTED_SIGNATURE: &'static [u8] = &[ + 138, 4, 156, 91, 93, 73, 133, 216, 66, 25, 175, 249, 20, 105, 24, 28, 39, 28, 188, 63, + 249, 207, 106, 200, 98, 81, 184, 66, 241, 182, 24, 77, 2, 112, 208, 30, 189, 192, 138, + 69, 77, 143, 244, 61, 250, 18, 241, 254, 230, 160, 250, 208, 66, 217, 124, 86, 186, + 188, 139, 24, 152, 16, 185, 4, + ]; + const MESSAGE_TO_SIGN: &'static [u8] = b"hello"; + + // Verify that we're able to deserialize. + let key_pair: KeyPair = serde_yaml::from_str(SERIALIZED).unwrap(); + assert!(key_pair.der_bugged); + + // Verify that signature is the same - we deserialized the exact same key pair. + let signature = key_pair.sign(MESSAGE_TO_SIGN); + assert_eq!(signature.as_ref(), EXPECTED_SIGNATURE); + } + + /// Verifies that switching [`KeyPair`] format works fine - key pair is not changed and we're + /// de(serializing) correct variant. + #[test] + fn format_conversion() { + const MESSAGE_TO_SIGN: &'static [u8] = b"hello there"; + + let key_pair = KeyPair::new_random().unwrap(); + let expected_signature = key_pair.sign(MESSAGE_TO_SIGN); + + // Serialize and deserialize without changing format, check key pair identity. + assert!(!key_pair.der_bugged); + let serialized = serde_yaml::to_string(&key_pair).unwrap(); + let mut deserialized: KeyPair = serde_yaml::from_str(&serialized).unwrap(); + assert!(!deserialized.der_bugged); + let signature = deserialized.sign(MESSAGE_TO_SIGN); + assert_eq!(signature.as_ref(), expected_signature.as_ref()); + + // Switch to bugged format, serialize and deserialize. Check key pair identity. + deserialized.bug_der(); + assert!(deserialized.der_bugged); + let serialized = serde_yaml::to_string(&deserialized).unwrap(); + let mut deserialized: KeyPair = serde_yaml::from_str(&serialized).unwrap(); + assert!(deserialized.der_bugged); + let signature = deserialized.sign(MESSAGE_TO_SIGN); + assert_eq!(signature.as_ref(), expected_signature.as_ref()); + + // Switch back to fixed format, serialize and deserialize. Check key pair identity. + deserialized.fix_der(); + assert!(!deserialized.der_bugged); + let serialized = serde_yaml::to_string(&deserialized).unwrap(); + let deserialized: KeyPair = serde_yaml::from_str(&serialized).unwrap(); + assert!(!deserialized.der_bugged); + let signature = deserialized.sign(MESSAGE_TO_SIGN); + assert_eq!(signature.as_ref(), expected_signature.as_ref()); } } diff --git a/mirrord/cli/src/connection.rs b/mirrord/cli/src/connection.rs index 9ce4c6bf96c..587f1554fda 100644 --- a/mirrord/cli/src/connection.rs +++ b/mirrord/cli/src/connection.rs @@ -1,20 +1,13 @@ use std::{collections::HashSet, time::Duration}; -use kube::{api::GroupVersionKind, discovery, Resource}; use mirrord_analytics::Reporter; use mirrord_config::LayerConfig; use mirrord_intproxy::agent_conn::AgentConnectInfo; use mirrord_kube::{ - api::{ - kubernetes::{create_kube_api, KubernetesAPI}, - wrap_raw_connection, - }, + api::{kubernetes::KubernetesAPI, wrap_raw_connection}, error::KubeApiError, }; -use mirrord_operator::{ - client::{OperatorApi, OperatorApiError, OperatorOperation}, - crd::MirrordOperatorCrd, -}; +use mirrord_operator::client::{OperatorApi, OperatorSessionConnection}; use mirrord_progress::{ messages::MULTIPOD_WARNING, IdeAction, IdeMessage, NotificationLevel, Progress, }; @@ -28,36 +21,77 @@ pub(crate) struct AgentConnection { pub receiver: mpsc::Receiver, } -#[tracing::instrument(level = "trace", skip(config), ret, err)] -async fn check_if_operator_resource_exists(config: &LayerConfig) -> Result { - let client = create_kube_api( - config.accept_invalid_certificates, - config.kubeconfig.clone(), - config.kube_context.clone(), - ) - .await?; +/// 1. If mirrord-operator is explicitly enabled in the given [`LayerConfig`], makes a connection +/// with the target using the mirrord-operator. +/// 2. If mirrord-operator is explicitly disabled in the given [`LayerConfig`], returns [`None`]. +/// 3. Otherwise, attempts to use the mirrord-operator and returns [`None`] in case mirrord-operator +/// is not found or its license is invalid. +async fn try_connect_using_operator( + config: &LayerConfig, + progress: &P, + analytics: &mut R, +) -> Result> +where + P: Progress, + R: Reporter, +{ + let mut operator_subtask = progress.subtask("checking operator"); + if config.operator == Some(false) { + operator_subtask.success(Some("operator disabled")); + return Ok(None); + } - let gvk = GroupVersionKind { - group: MirrordOperatorCrd::group(&()).into_owned(), - version: MirrordOperatorCrd::version(&()).into_owned(), - kind: MirrordOperatorCrd::kind(&()).into_owned(), + let api = match OperatorApi::try_new(config, analytics).await? { + Some(api) => api, + None if config.operator == Some(true) => return Err(CliError::OperatorNotInstalled), + None => { + operator_subtask.success(Some("operator not found")); + return Ok(None); + } }; - match discovery::oneshot::pinned_kind(&client, &gvk).await { - Ok(..) => Ok(true), - Err(kube::Error::Api(response)) if response.code == 404 => Ok(false), - Err(error) => Err(error.into()), + let mut version_cmp_subtask = operator_subtask.subtask("checking version compatibility"); + let compatible = api.check_operator_version(&version_cmp_subtask); + if compatible { + version_cmp_subtask.success(Some("operator version compatible")); + } else { + version_cmp_subtask.failure(Some("operator version may not be compatible")); + } + + let mut license_subtask = operator_subtask.subtask("checking license"); + match api.check_license_validity(&license_subtask) { + Ok(()) => license_subtask.success(Some("operator license valid")), + Err(error) => { + license_subtask.failure(Some("operator license expired")); + + if config.operator == Some(true) { + return Err(error.into()); + } else { + operator_subtask.failure(Some("proceeding without operator")); + return Ok(None); + } + } } + + let mut user_cert_subtask = operator_subtask.subtask("preparing user credentials"); + let api = api.prepare_client_cert(analytics).await.into_certified()?; + user_cert_subtask.success(Some("user credentials prepared")); + + let mut session_subtask = operator_subtask.subtask("starting session"); + let connection = api.connect_in_new_session(config, &session_subtask).await?; + session_subtask.success(Some("session started")); + + operator_subtask.success(Some("using operator")); + + Ok(Some(connection)) } -/// Creates an agent if needed then connects to it. -/// -/// First it checks if we have an `operator` in the [`config`](LayerConfig), which we do if the -/// user has installed the mirrord-operator in their cluster, even without a valid license. And -/// then we create a session with the operator with [`OperatorApi::create_session`]. -/// -/// If there is no operator, or the license is not good enough for starting an operator session, -/// then we create the mirrord-agent and run mirrord by itself, without the operator. +/// 1. If mirrord-operator is explicitly enabled in the given [`LayerConfig`], makes a connection +/// with the target using the mirrord-operator. +/// 2. If mirrord-operator is explicitly disabled in the given [`LayerConfig`], creates a +/// mirrord-agent and runs session without the mirrord-operator. +/// 3. Otherwise, attempts to use the mirrord-operator and falls back to OSS flow in case +/// mirrord-operator is not found or its license is invalid. /// /// Here is where we start interactions with the kubernetes API. #[tracing::instrument(level = "trace", skip_all)] @@ -69,56 +103,14 @@ pub(crate) async fn create_and_connect( where P: Progress + Send + Sync, { - if config.operator != Some(false) { - let mut subtask = progress.subtask("checking operator"); - - match OperatorApi::create_session(config, &subtask, analytics).await { - Ok(session) => { - subtask.success(Some("connected to the operator")); - - return Ok(( - AgentConnectInfo::Operator(session.info), - AgentConnection { - sender: session.tx, - receiver: session.rx, - }, - )); - } - - Err(OperatorApiError::NoLicense) if config.operator.is_none() => { - tracing::trace!("operator license expired"); - subtask.success(Some("operator license expired")); - } - - Err( - e @ OperatorApiError::KubeError { - operation: OperatorOperation::FindingOperator, - .. - }, - ) if config.operator.is_none() => { - // We need to check if the operator is really installed or not. - match check_if_operator_resource_exists(config).await { - // Operator is installed yet we failed to use it, abort - Ok(true) => { - return Err(e.into()); - } - // Operator is not installed, fallback to OSS - Ok(false) => { - subtask.success(Some("operator not found")); - } - // We don't know if operator is installed or not, - // prompt a warning and fallback to OSS - Err(error) => { - let message = "failed to check if operator is installed"; - tracing::debug!(%error, message); - subtask.warning(message); - subtask.success(Some("operator not found")); - } - } - } - - Err(e) => return Err(e.into()), - } + if let Some(connection) = try_connect_using_operator(config, progress, analytics).await? { + return Ok(( + AgentConnectInfo::Operator(connection.session), + AgentConnection { + sender: connection.tx, + receiver: connection.rx, + }, + )); } if config.feature.copy_target.enabled { diff --git a/mirrord/cli/src/error.rs b/mirrord/cli/src/error.rs index 5d03d90a0ee..a4ede3764ea 100644 --- a/mirrord/cli/src/error.rs +++ b/mirrord/cli/src/error.rs @@ -6,7 +6,7 @@ use mirrord_config::config::ConfigError; use mirrord_console::error::ConsoleError; use mirrord_intproxy::error::IntProxyError; use mirrord_kube::error::KubeApiError; -use mirrord_operator::client::{HttpError, OperatorApiError, OperatorOperation}; +use mirrord_operator::client::error::{HttpError, OperatorApiError, OperatorOperation}; use reqwest::StatusCode; use thiserror::Error; @@ -232,6 +232,21 @@ pub(crate) enum CliError { "This usually means that connectivity was lost while pinging.{GENERAL_HELP}" ))] PingPongFailed(String), + + #[error("Failed to prepare mirrord operator client certificate: {0}")] + #[diagnostic(help("{GENERAL_BUG}"))] + OperatorClientCertError(String), + + #[error("mirrord operator was not found in the cluster.")] + #[diagnostic(help( + "Command requires the mirrord operator or operator usage was explicitly enabled in the configuration file. + Read more here: https://mirrord.dev/docs/overview/quick-start/#operator.{GENERAL_HELP}" + ))] + OperatorNotInstalled, + + #[error("mirrord returned a target resource of unknown type: {0}")] + #[diagnostic(help("{GENERAL_BUG}"))] + OperatorReturnedUnknownTargetType(String), } impl From for CliError { @@ -244,7 +259,7 @@ impl From for CliError { feature, operator_version, }, - OperatorApiError::CreateApiError(e) => Self::CreateKubeApiFailed(e), + OperatorApiError::CreateKubeClient(e) => Self::CreateKubeApiFailed(e), OperatorApiError::ConnectRequestBuildError(e) => Self::ConnectRequestBuildError(e), OperatorApiError::KubeError { error: kube::Error::Api(ErrorResponse { message, code, .. }), @@ -269,6 +284,10 @@ impl From for CliError { Self::OperatorApiFailed(operation, error) } OperatorApiError::NoLicense => Self::OperatorLicenseExpired, + OperatorApiError::ClientCertError(error) => Self::OperatorClientCertError(error), + OperatorApiError::FetchedUnknownTargetType(error) => { + Self::OperatorReturnedUnknownTargetType(error.0) + } } } } diff --git a/mirrord/cli/src/main.rs b/mirrord/cli/src/main.rs index 99ba6a9140f..144ee9b16be 100644 --- a/mirrord/cli/src/main.rs +++ b/mirrord/cli/src/main.rs @@ -17,18 +17,26 @@ use k8s_openapi::{ api::{apps::v1::Deployment, core::v1::Pod}, Metadata, NamespaceResourceScope, }; -use kube::api::ListParams; +use kube::{api::ListParams, Client}; use miette::JSONReportHandler; -use mirrord_analytics::{AnalyticsError, AnalyticsReporter, CollectAnalytics, Reporter}; +use mirrord_analytics::{ + AnalyticsError, AnalyticsReporter, CollectAnalytics, NullReporter, Reporter, +}; use mirrord_config::{ config::{ConfigContext, MirrordConfig}, - feature::{fs::FsModeConfig, network::incoming::IncomingMode}, + feature::{ + fs::FsModeConfig, + network::{ + dns::{DnsConfig, DnsFilterConfig}, + incoming::IncomingMode, + }, + }, target::TargetDisplay, LayerConfig, LayerFileConfig, }; use mirrord_kube::api::{ container::SKIP_NAMES, - kubernetes::{create_kube_api, get_k8s_resource_api, rollout::Rollout}, + kubernetes::{create_kube_config, get_k8s_resource_api, rollout::Rollout}, }; use mirrord_operator::client::OperatorApi; use mirrord_progress::{Progress, ProgressTracker}; @@ -267,9 +275,28 @@ fn print_config

( }; messages.push(format!("outgoing: forwarding is {}", outgoing_info)); - let dns_info = match config.feature.network.dns { - true => "remotely", - false => "locally", + let dns_info = match &config.feature.network.dns { + DnsConfig { enabled: false, .. } => "locally", + DnsConfig { + enabled: true, + filter: None, + } => "remotely", + DnsConfig { + enabled: true, + filter: Some(DnsFilterConfig::Remote(filters)), + } if filters.is_empty() => "locally", + DnsConfig { + enabled: true, + filter: Some(DnsFilterConfig::Local(filters)), + } if filters.is_empty() => "remotely", + DnsConfig { + enabled: true, + filter: Some(DnsFilterConfig::Remote(..)), + } => "locally with exceptions", + DnsConfig { + enabled: true, + filter: Some(DnsFilterConfig::Local(..)), + } => "remotely with exceptions", }; messages.push(format!("dns: DNS will be resolved {}", dns_info)); @@ -500,12 +527,13 @@ where } async fn list_pods(layer_config: &LayerConfig, args: &ListTargetArgs) -> Result> { - let client = create_kube_api( + let client = create_kube_config( layer_config.accept_invalid_certificates, layer_config.kubeconfig.clone(), layer_config.kube_context.clone(), ) .await + .and_then(|config| Client::try_from(config).map_err(From::from)) .map_err(CliError::CreateKubeApiFailed)?; let namespace = args @@ -563,39 +591,36 @@ async fn print_targets(args: &ListTargetArgs) -> Result<()> { } // Try operator first if relevant - let mut targets = match &layer_config.operator { - Some(true) | None => { - let operator_targets = OperatorApi::list_targets(&layer_config).await; - match operator_targets { - Ok(targets) => { - // adjust format to match non-operator output - targets - .iter() - .filter_map(|target_crd| { - let target = target_crd.spec.target.as_ref()?; - if let Some(container) = target.container_name() { - if SKIP_NAMES.contains(container.as_str()) { - return None; - } - } - Some(format!("{target}")) - }) - .collect::>() - } + let operator_api = if layer_config.operator == Some(false) { + None + } else { + OperatorApi::try_new(&layer_config, &mut NullReporter::default()).await? + }; - Err(error) => { - if layer_config.operator.is_some() { - error!( - ?error, - "Operator was explicitly enabled and we failed to list targets" - ); - return Err(error.into()); + let mut targets = match operator_api { + Some(api) => { + let api = api.prepare_client_cert(&mut NullReporter::default()).await; + api.inspect_cert_error( + |error| tracing::error!(%error, "failed to prepare client certificate"), + ); + api.list_targets(layer_config.target.namespace.as_deref()) + .await? + .iter() + .filter_map(|target_crd| { + let target = target_crd.spec.target.as_known().ok()?; + if let Some(container) = target.container() { + if SKIP_NAMES.contains(container.as_str()) { + return None; + } } - list_pods(&layer_config, args).await? - } - } + Some(format!("{target}")) + }) + .collect() } - Some(false) => list_pods(&layer_config, args).await?, + + None if layer_config.operator == Some(true) => return Err(CliError::OperatorNotInstalled), + + None => list_pods(&layer_config, args).await?, }; targets.sort(); @@ -705,9 +730,9 @@ async fn prompt_outdated_version(progress: &ProgressTracker) { let command = if is_homebrew { "brew upgrade metalbear-co/mirrord/mirrord" } else { "curl -fsSL https://raw.githubusercontent.com/metalbear-co/mirrord/main/scripts/install.sh | bash" }; progress.print(&format!("New mirrord version available: {}. To update, run: `{:?}`.", latest_version, command)); progress.print("To disable version checks, set env variable MIRRORD_CHECK_VERSION to 'false'."); - progress.success(Some(&format!("Update to {latest_version} available"))); + progress.success(Some(&format!("update to {latest_version} available"))); } else { - progress.success(Some(&format!("Running on latest ({CURRENT_VERSION})!"))); + progress.success(Some(&format!("running on latest ({CURRENT_VERSION})!"))); } } } diff --git a/mirrord/cli/src/operator.rs b/mirrord/cli/src/operator.rs index 856233f46f2..56f250c46ab 100644 --- a/mirrord/cli/src/operator.rs +++ b/mirrord/cli/src/operator.rs @@ -5,15 +5,16 @@ use std::{ }; use futures::TryFutureExt; -use kube::Api; +use kube::{Api, Client}; +use mirrord_analytics::NullReporter; use mirrord_config::{ config::{ConfigContext, MirrordConfig}, LayerConfig, LayerFileConfig, }; -use mirrord_kube::api::kubernetes::create_kube_api; +use mirrord_kube::api::kubernetes::create_kube_config; use mirrord_operator::{ - client::{OperatorApiError, OperatorOperation}, - crd::{MirrordOperatorCrd, MirrordOperatorSpec, OPERATOR_STATUS_NAME}, + client::OperatorApi, + crd::{MirrordOperatorCrd, MirrordOperatorSpec}, setup::{LicenseType, Operator, OperatorNamespace, OperatorSetup, SetupOptions}, types::LicenseInfoOwned, }; @@ -130,42 +131,43 @@ async fn get_status_api(config: Option<&Path>) -> Result remove_proxy_env(); } - let kube_api = create_kube_api( + let client = create_kube_config( layer_config.accept_invalid_certificates, layer_config.kubeconfig, layer_config.kube_context, ) .await + .and_then(|config| Client::try_from(config).map_err(From::from)) .map_err(CliError::CreateKubeApiFailed)?; - Ok(Api::all(kube_api)) + Ok(Api::all(client)) } #[tracing::instrument(level = "trace", ret)] async fn operator_status(config: Option<&Path>) -> Result<()> { let mut progress = ProgressTracker::from_env("Operator Status"); - let status_api = get_status_api(config).await?; + let layer_config = if let Some(config) = config { + let mut cfg_context = ConfigContext::default(); + LayerFileConfig::from_path(config)?.generate_config(&mut cfg_context)? + } else { + LayerConfig::from_env()? + }; - let mut status_progress = progress.subtask("fetching status"); + if !layer_config.use_proxy { + remove_proxy_env(); + } - let mirrord_status = match status_api - .get(OPERATOR_STATUS_NAME) + let mut status_progress = progress.subtask("fetching status"); + let api = OperatorApi::try_new(&layer_config, &mut NullReporter::default()) .await - .map_err(|error| OperatorApiError::KubeError { - error, - operation: OperatorOperation::GettingStatus, - }) - .map_err(CliError::from) - { - Ok(status) => status, - Err(err) => { - status_progress.failure(Some("unable to get status")); - - return Err(err); - } + .inspect_err(|_| { + status_progress.failure(Some("failed to get status")); + })?; + let Some(api) = api else { + status_progress.failure(Some("operator not found")); + return Err(CliError::OperatorNotInstalled); }; - status_progress.success(Some("fetched status")); progress.success(None); @@ -181,7 +183,7 @@ async fn operator_status(config: Option<&Path>) -> Result<()> { .. }, .. - } = mirrord_status.spec; + } = &api.operator().spec; let expire_at = expire_at.format("%e-%b-%Y"); @@ -196,11 +198,11 @@ Operator License "# ); - let Some(status) = mirrord_status.status else { + let Some(status) = &api.operator().status else { return Ok(()); }; - if let Some(copy_targets) = status.copy_targets { + if let Some(copy_targets) = status.copy_targets.as_ref() { if copy_targets.is_empty() { println!("No active copy targets."); } else { @@ -217,7 +219,11 @@ Operator License for (pod_name, copy_target_resource) in copy_targets { copy_targets_table.add_row(row![ copy_target_resource.spec.target.to_string(), - copy_target_resource.metadata.namespace.unwrap_or_default(), + copy_target_resource + .metadata + .namespace + .as_deref() + .unwrap_or_default(), pod_name, if copy_target_resource.spec.scale_down { "*" @@ -232,7 +238,7 @@ Operator License println!(); } - if let Some(statistics) = status.statistics { + if let Some(statistics) = status.statistics.as_ref() { println!("Operator Daily Users: {}", statistics.dau); println!("Operator Monthly Users: {}", statistics.mau); } diff --git a/mirrord/cli/src/operator/session.rs b/mirrord/cli/src/operator/session.rs index 75fcdaf3f8e..2c090adf622 100644 --- a/mirrord/cli/src/operator/session.rs +++ b/mirrord/cli/src/operator/session.rs @@ -1,12 +1,16 @@ use kube::{core::ErrorResponse, Api}; +use mirrord_analytics::NullReporter; +use mirrord_config::LayerConfig; use mirrord_operator::{ - client::{session_api, OperatorApiError, OperatorOperation}, - crd::{MirrordOperatorCrd, SessionCrd, OPERATOR_STATUS_NAME}, + client::{ + error::{OperatorApiError, OperatorOperation}, + MaybeClientCert, OperatorApi, + }, + crd::SessionCrd, }; use mirrord_progress::{Progress, ProgressTracker}; -use super::get_status_api; -use crate::{Result, SessionCommand}; +use crate::{CliError, Result, SessionCommand}; /// Handles the [`SessionCommand`]s that deal with session management in the operator. pub(super) struct SessionCommandHandler { @@ -17,11 +21,8 @@ pub(super) struct SessionCommandHandler { /// operation is going. sub_progress: ProgressTracker, - /// Kube API to talk with session routes in the operator. - operator_api: Api, - - /// Kube API to talk with session routes in the operator. - session_api: Api, + /// Api to talk with session routes in the operator. + operator_api: OperatorApi, /// The command the user is trying to execute from the cli. command: SessionCommand, @@ -33,13 +34,23 @@ impl SessionCommandHandler { pub(super) async fn new(command: SessionCommand) -> Result { let mut progress = ProgressTracker::from_env("Operator session action"); - let operator_api = get_status_api(None).await.inspect_err(|fail| { - progress.failure(Some(&format!("Failed to create operator API with {fail}!"))) + let config = LayerConfig::from_env().inspect_err(|error| { + progress.failure(Some(&format!("failed to read config from env: {error}"))); })?; - let session_api = session_api(None).await.inspect_err(|fail| { - progress.failure(Some(&format!("Failed to create session API with {fail}!"))) - })?; + let mut subtask = progress.subtask("checking operator"); + let operator_api = match OperatorApi::try_new(&config, &mut NullReporter::default()).await? + { + Some(api) => api.prepare_client_cert(&mut NullReporter::default()).await, + None => { + subtask.failure(Some("operator not found")); + return Err(CliError::OperatorNotInstalled); + } + }; + + operator_api.inspect_cert_error(|error| { + progress.warning(&format!("Failed to prepare user certificate: {error}")); + }); let sub_progress = progress.subtask("preparing..."); @@ -47,7 +58,6 @@ impl SessionCommandHandler { progress, sub_progress, operator_api, - session_api, command, }) } @@ -60,21 +70,13 @@ impl SessionCommandHandler { mut progress, mut sub_progress, operator_api, - session_api, command, } = self; - let operator_version = operator_api - .get(OPERATOR_STATUS_NAME) - .await - .map_err(|error| OperatorApiError::KubeError { - error, - operation: OperatorOperation::GettingStatus, - }) - .map(|crd| crd.spec.operator_version)?; - sub_progress.print(&format!("executing `{command}`")); + let session_api: Api = Api::all(operator_api.client().clone()); + // We're interested in the `Status`es, so we map the results into those. match command { SessionCommand::Kill { id } => session_api @@ -97,7 +99,7 @@ impl SessionCommandHandler { { OperatorApiError::UnsupportedFeature { feature: "session management".to_string(), - operator_version, + operator_version: operator_api.operator().spec.operator_version.clone(), } } // Something actually went wrong. diff --git a/mirrord/cli/src/verify_config.rs b/mirrord/cli/src/verify_config.rs index eca074bcd13..4b04c8d7c4d 100644 --- a/mirrord/cli/src/verify_config.rs +++ b/mirrord/cli/src/verify_config.rs @@ -27,6 +27,7 @@ use crate::{config::VerifyConfigArgs, error, LayerFileConfig}; enum VerifiedTarget { #[serde(rename = "targetless")] Targetless, + #[serde(untagged)] Pod(PodTarget), #[serde(untagged)] diff --git a/mirrord/config/configuration.md b/mirrord/config/configuration.md index 2d818882baa..a54eeb2d93a 100644 --- a/mirrord/config/configuration.md +++ b/mirrord/config/configuration.md @@ -104,7 +104,12 @@ configuration file containing all fields. "ignore_localhost": false, "unix_streams": "bear.+" }, - "dns": false + "dns": { + "enabled": true, + "filter": { + "local": ["1.1.1.0/24:1337", "1.1.5.0/24", "google.com"] + } + } }, "copy_target": { "scale_down": false @@ -382,14 +387,18 @@ IP:PORT to connect to instead of using k8s api, for testing purposes. mirrord Experimental features. This shouldn't be used unless someone from MetalBear/mirrord tells you to. -## _experimental_ readlink {#fexperimental-readlink} +## _experimental_ readlink {#experimental-readlink} Enables the `readlink` hook. -## _experimental_ tcp_ping4_mock {#fexperimental-tcp_ping4_mock} +## _experimental_ tcp_ping4_mock {#experimental-tcp_ping4_mock} +# _experimental_ trust_any_certificate {#experimental-trust_any_certificate} + +Enables trusting any certificate on macOS, useful for + # feature {#root-feature} Controls mirrord features. @@ -686,7 +695,12 @@ for more details. "ignore_localhost": false, "unix_streams": "bear.+" }, - "dns": false + "dns": { + "enabled": true, + "filter": { + "local": ["1.1.1.0/24:1337", "1.1.5.0/24", "google.com"] + } + } } } } @@ -698,11 +712,71 @@ Resolve DNS via the remote pod. Defaults to `true`. -- Caveats: DNS resolving can be done in multiple ways, some frameworks will use -`getaddrinfo`, while others will create a connection on port `53` and perform a sort -of manual resolution. Just enabling the `dns` feature in mirrord might not be enough. -If you see an address resolution error, try enabling the [`fs`](#feature-fs) feature, -and setting `read_only: ["/etc/resolv.conf"]`. +Mind that: +- DNS resolving can be done in multiple ways. Some frameworks use +`getaddrinfo`/`gethostbyname` functions, while others communicate directly with the DNS server +at port `53` and perform a sort of manual resolution. Just enabling the `dns` feature in mirrord +might not be enough. If you see an address resolution error, try enabling the +[`fs`](#feature-fs) feature, and setting `read_only: ["/etc/resolv.conf"]`. +- DNS filter currently works only with frameworks that use `getaddrinfo`/`gethostbyname` + functions. + +#### feature.network.dns.filter {#feature-network-dns-filter} + +Unstable: the precise syntax of this config is subject to change. + +List of addresses/ports/subnets that should be resolved through either the remote pod or local +app, depending how you set this up with either `remote` or `local`. + +You may use this option to specify when DNS resolution is done from the remote pod (which +is the default behavior when you enable remote DNS), or from the local app (default when +you have remote DNS disabled). + +Takes a list of values, such as: + +- Only queries for hostname `my-service-in-cluster` will go through the remote pod. + +```json +{ + "remote": ["my-service-in-cluster"] +} +``` + +- Only queries for addresses in subnet `1.1.1.0/24` with service port `1337`` will go through + the remote pod. + +```json +{ + "remote": ["1.1.1.0/24:1337"] +} +``` + +- Only queries for hostname `google.com` with service port `1337` or `7331` +will go through the remote pod. + +```json +{ + "remote": ["google.com:1337", "google.com:7331"] +} +``` + +- Only queries for `localhost` with service port `1337` will go through the local app. + +```json +{ + "local": ["localhost:1337"] +} +``` + +- Only queries with service port `1337` or `7331` will go through the local app. + +```json +{ + "local": [":1337", ":7331"] +} +``` + +Valid values follow this pattern: `[name|address|subnet/mask][:port]`. ### feature.network.incoming {#feature-network-incoming} diff --git a/mirrord/config/src/config.rs b/mirrord/config/src/config.rs index 88c7f2bfe3b..b038efaf084 100644 --- a/mirrord/config/src/config.rs +++ b/mirrord/config/src/config.rs @@ -17,8 +17,15 @@ pub enum ConfigError { #[error("value for {1:?} not provided in {0:?} (env override {2:?})")] ValueNotProvided(&'static str, &'static str, Option<&'static str>), - #[error("value {0:?} for {1:?} is invalid.")] - InvalidValue(String, &'static str), + #[error("invalid {} value `{}`: {}", .name, .provided, .error)] + InvalidValue { + // Name of parsed env var or field path in the config. + name: &'static str, + // Value provided by the user. + provided: String, + // Error that occurred when processing the value. + error: Box, + }, #[error("mirrord-config: IO operation failed with `{0}`")] Io(#[from] std::io::Error), diff --git a/mirrord/config/src/config/from_env.rs b/mirrord/config/src/config/from_env.rs index 607203c6633..9770456721a 100644 --- a/mirrord/config/src/config/from_env.rs +++ b/mirrord/config/src/config/from_env.rs @@ -1,3 +1,4 @@ +use core::fmt; use std::{marker::PhantomData, str::FromStr}; use super::ConfigContext; @@ -15,13 +16,18 @@ impl FromEnv { impl MirrordConfigSource for FromEnv where T: FromStr, + T::Err: 'static + Send + Sync + fmt::Display + std::error::Error, { type Value = T; fn source_value(self, _context: &mut ConfigContext) -> Option> { std::env::var(self.0).ok().map(|var| { - var.parse() - .map_err(|_| ConfigError::InvalidValue(var.to_string(), self.0)) + var.parse::() + .map_err(|err| ConfigError::InvalidValue { + name: self.0, + provided: var, + error: Box::new(err), + }) }) } } diff --git a/mirrord/config/src/experimental.rs b/mirrord/config/src/experimental.rs index d8239ef3095..8c50d191ab5 100644 --- a/mirrord/config/src/experimental.rs +++ b/mirrord/config/src/experimental.rs @@ -10,22 +10,29 @@ use crate::config::source::MirrordConfigSource; #[config(map_to = "ExperimentalFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct ExperimentalConfig { - /// ## _experimental_ tcp_ping4_mock {#fexperimental-tcp_ping4_mock} + /// ## _experimental_ tcp_ping4_mock {#experimental-tcp_ping4_mock} /// /// #[config(default = true)] pub tcp_ping4_mock: bool, - /// ## _experimental_ readlink {#fexperimental-readlink} + /// ## _experimental_ readlink {#experimental-readlink} /// /// Enables the `readlink` hook. #[config(default = false)] pub readlink: bool, + + /// # _experimental_ trust_any_certificate {#experimental-trust_any_certificate} + /// + /// Enables trusting any certificate on macOS, useful for + #[config(default = false)] + pub trust_any_certificate: bool, } impl CollectAnalytics for &ExperimentalConfig { fn collect_analytics(&self, analytics: &mut mirrord_analytics::Analytics) { analytics.add("tcp_ping4_mock", self.tcp_ping4_mock); analytics.add("readlink", self.readlink); + analytics.add("trust_any_certificate", self.trust_any_certificate); } } diff --git a/mirrord/config/src/feature/network.rs b/mirrord/config/src/feature/network.rs index 913b907262e..ab3ffc8d3db 100644 --- a/mirrord/config/src/feature/network.rs +++ b/mirrord/config/src/feature/network.rs @@ -1,13 +1,16 @@ +use dns::{DnsConfig, DnsFileConfig}; use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; use self::{incoming::*, outgoing::*}; use crate::{ - config::{from_env::FromEnv, source::MirrordConfigSource, ConfigContext, ConfigError}, + config::{ConfigContext, ConfigError}, util::MirrordToggleableConfig, }; +pub mod dns; +pub mod filter; pub mod incoming; pub mod outgoing; @@ -38,7 +41,12 @@ pub mod outgoing; /// "ignore_localhost": false, /// "unix_streams": "bear.+" /// }, -/// "dns": false +/// "dns": { +/// "enabled": true, +/// "filter": { +/// "local": ["1.1.1.0/24:1337", "1.1.5.0/24", "google.com"] +/// } +/// } /// } /// } /// } @@ -56,30 +64,15 @@ pub struct NetworkConfig { pub outgoing: OutgoingConfig, /// ### feature.network.dns {#feature-network-dns} - /// - /// Resolve DNS via the remote pod. - /// - /// Defaults to `true`. - /// - /// - Caveats: DNS resolving can be done in multiple ways, some frameworks will use - /// `getaddrinfo`, while others will create a connection on port `53` and perform a sort - /// of manual resolution. Just enabling the `dns` feature in mirrord might not be enough. - /// If you see an address resolution error, try enabling the [`fs`](#feature-fs) feature, - /// and setting `read_only: ["/etc/resolv.conf"]`. - #[config(env = "MIRRORD_REMOTE_DNS", default = true)] - pub dns: bool, + #[config(toggleable, nested)] + pub dns: DnsConfig, } impl MirrordToggleableConfig for NetworkFileConfig { fn disabled_config(context: &mut ConfigContext) -> Result { - let dns = FromEnv::new("MIRRORD_REMOTE_DNS") - .source_value(context) - .transpose()? - .unwrap_or(false); - Ok(NetworkConfig { incoming: IncomingFileConfig::disabled_config(context)?, - dns, + dns: DnsFileConfig::disabled_config(context)?, outgoing: OutgoingFileConfig::disabled_config(context)?, }) } @@ -89,7 +82,7 @@ impl CollectAnalytics for &NetworkConfig { fn collect_analytics(&self, analytics: &mut mirrord_analytics::Analytics) { analytics.add("incoming", &self.incoming); analytics.add("outgoing", &self.outgoing); - analytics.add("dns", self.dns); + analytics.add("dns", &self.dns); } } @@ -131,7 +124,7 @@ mod tests { .unwrap(); assert_eq!(env.incoming, incoming.1); - assert_eq!(env.dns, dns.1); + assert_eq!(env.dns.enabled, dns.1); }, ); } diff --git a/mirrord/config/src/feature/network/dns.rs b/mirrord/config/src/feature/network/dns.rs new file mode 100644 index 00000000000..22aa217b572 --- /dev/null +++ b/mirrord/config/src/feature/network/dns.rs @@ -0,0 +1,172 @@ +use std::ops::Deref; + +use mirrord_analytics::CollectAnalytics; +use mirrord_config_derive::MirrordConfig; +use schemars::JsonSchema; +use serde::Deserialize; + +use super::filter::AddressFilter; +use crate::{ + config::{from_env::FromEnv, source::MirrordConfigSource, ConfigContext, ConfigError}, + util::{MirrordToggleableConfig, VecOrSingle}, +}; + +/// List of addresses/ports/subnets that should be resolved through either the remote pod or local +/// app, depending how you set this up with either `remote` or `local`. +/// +/// You may use this option to specify when DNS resolution is done from the remote pod (which +/// is the default behavior when you enable remote DNS), or from the local app (default when +/// you have remote DNS disabled). +/// +/// Takes a list of values, such as: +/// +/// - Only queries for hostname `my-service-in-cluster` will go through the remote pod. +/// +/// ```json +/// { +/// "remote": ["my-service-in-cluster"] +/// } +/// ``` +/// +/// - Only queries for addresses in subnet `1.1.1.0/24` with service port `1337`` will go through +/// the remote pod. +/// +/// ```json +/// { +/// "remote": ["1.1.1.0/24:1337"] +/// } +/// ``` +/// +/// - Only queries for hostname `google.com` with service port `1337` or `7331` +/// will go through the remote pod. +/// +/// ```json +/// { +/// "remote": ["google.com:1337", "google.com:7331"] +/// } +/// ``` +/// +/// - Only queries for `localhost` with service port `1337` will go through the local app. +/// +/// ```json +/// { +/// "local": ["localhost:1337"] +/// } +/// ``` +/// +/// - Only queries with service port `1337` or `7331` will go through the local app. +/// +/// ```json +/// { +/// "local": [":1337", ":7331"] +/// } +/// ``` +/// +/// Valid values follow this pattern: `[name|address|subnet/mask][:port]`. +#[derive(Deserialize, PartialEq, Eq, Clone, Debug, JsonSchema)] +#[serde(deny_unknown_fields, rename_all = "lowercase")] +pub enum DnsFilterConfig { + /// DNS queries matching what is specified here will go through the remote pod, everything else + /// will go through local. + Remote(VecOrSingle), + + /// DNS queries matching what is specified here will go through the local app, everything else + /// will go through the remote pod. + Local(VecOrSingle), +} + +/// Resolve DNS via the remote pod. +/// +/// Defaults to `true`. +/// +/// Mind that: +/// - DNS resolving can be done in multiple ways. Some frameworks use +/// `getaddrinfo`/`gethostbyname` functions, while others communicate directly with the DNS server +/// at port `53` and perform a sort of manual resolution. Just enabling the `dns` feature in mirrord +/// might not be enough. If you see an address resolution error, try enabling the +/// [`fs`](#feature-fs) feature, and setting `read_only: ["/etc/resolv.conf"]`. +/// - DNS filter currently works only with frameworks that use `getaddrinfo`/`gethostbyname` +/// functions. +#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug)] +#[config(map_to = "DnsFileConfig", derive = "JsonSchema")] +#[cfg_attr(test, config(derive = "PartialEq, Eq"))] +pub struct DnsConfig { + #[config(env = "MIRRORD_REMOTE_DNS", default = true)] + pub enabled: bool, + + /// #### feature.network.dns.filter {#feature-network-dns-filter} + /// + /// Unstable: the precise syntax of this config is subject to change. + #[config(default, unstable)] + pub filter: Option, +} + +impl DnsConfig { + pub fn verify(&self, context: &mut ConfigContext) -> Result<(), ConfigError> { + let filters = match &self.filter { + Some(..) if !self.enabled => { + context.add_warning( + "Remote DNS resolution is disabled, provided DNS filter will be ignored" + .to_string(), + ); + return Ok(()); + } + None => return Ok(()), + Some(DnsFilterConfig::Local(filters)) if filters.is_empty() => { + context.add_warning( + "Local DNS filter is empty, all DNS resolution will be done remotely" + .to_string(), + ); + return Ok(()); + } + Some(DnsFilterConfig::Remote(filters)) if filters.is_empty() => { + context.add_warning( + "Remote DNS filter is empty, all DNS resolution will be done locally" + .to_string(), + ); + return Ok(()); + } + Some(DnsFilterConfig::Local(filters)) => filters.deref(), + Some(DnsFilterConfig::Remote(filters)) => filters.deref(), + }; + + for filter in filters { + let Err(error) = filter.parse::() else { + continue; + }; + + return Err(ConfigError::InvalidValue { + name: "feature.network.dns.filter", + provided: filter.to_string(), + error: Box::new(error), + }); + } + + Ok(()) + } +} + +impl MirrordToggleableConfig for DnsFileConfig { + fn disabled_config(context: &mut ConfigContext) -> Result { + Ok(DnsConfig { + enabled: FromEnv::new("MIRRORD_REMOTE_DNS") + .source_value(context) + .unwrap_or(Ok(false))?, + ..Default::default() + }) + } +} + +impl CollectAnalytics for &DnsConfig { + fn collect_analytics(&self, analytics: &mut mirrord_analytics::Analytics) { + analytics.add("enabled", self.enabled); + + if let Some(filter) = self.filter.as_ref() { + match filter { + DnsFilterConfig::Remote(value) => analytics.add("dns_filter_remote", value.len()), + + DnsFilterConfig::Local(value) => analytics.add("dns_filter_local", value.len()), + } + } + } +} diff --git a/mirrord/config/src/feature/network/filter.rs b/mirrord/config/src/feature/network/filter.rs new file mode 100644 index 00000000000..c09896076e1 --- /dev/null +++ b/mirrord/config/src/feature/network/filter.rs @@ -0,0 +1,459 @@ +use std::{ + net::{IpAddr, SocketAddr}, + num::ParseIntError, + str::FromStr, +}; + +use nom::{ + branch::alt, + bytes::complete::{tag, take_until}, + character::complete::{alphanumeric1, digit1}, + combinator::opt, + multi::many1, + sequence::{delimited, preceded, terminated}, + IResult, +}; +use thiserror::Error; + +/// The protocols we support in [`ProtocolAndAddressFilter`]. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ProtocolFilter { + #[default] + Any, + Tcp, + Udp, +} + +#[derive(Error, Debug)] +#[error("invalid protocol: {0}")] +pub struct ParseProtocolError(String); + +impl FromStr for ProtocolFilter { + type Err = ParseProtocolError; + + fn from_str(s: &str) -> Result { + let lowercase = s.to_lowercase(); + + match lowercase.as_str() { + "any" => Ok(Self::Any), + "tcp" => Ok(Self::Tcp), + "udp" => Ok(Self::Udp), + invalid => Err(ParseProtocolError(invalid.to_string())), + } + } +} + +/// +/// Parsed addresses can be one of these 3 variants. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum AddressFilter { + /// Only port was specified. + Port(u16), + + /// Just a plain old [`SocketAddr`], specified as `a.b.c.d:e`. + /// + /// We treat `0`s here as if it meant **any**, so `0.0.0.0` means we filter any IP, and `:0` + /// means any port. + Socket(SocketAddr), + + /// A named address, as we cannot resolve it here, specified as `name:a`. + /// + /// We can only resolve such names on the mirrord layer `connect` call, as we have to check if + /// the user enabled the DNS feature or not (and thus, resolve it through the remote pod, or + /// the local app). + Name(String, u16), + + /// Just a plain old subnet and a port, specified as `a.b.c.d/e:f`. + Subnet(ipnet::IpNet, u16), +} + +impl AddressFilter { + pub fn port(&self) -> u16 { + match self { + Self::Port(port) => *port, + Self::Name(_, port) => *port, + Self::Socket(socket) => socket.port(), + Self::Subnet(_, port) => *port, + } + } +} + +#[derive(Error, Debug)] +pub enum AddressFilterError { + #[error("parsing with nom failed: {0}")] + Nom(nom::Err>), + + #[error("parsing port number failed: {0}")] + ParsePort(ParseIntError), + + #[error("parsing left trailing value: {0}")] + TrailingValue(String), + + #[error("parsing subnet prefix length failed: {0}")] + ParseSubnetPrefixLength(ParseIntError), + + #[error("parsing subnet base IP address failed")] + ParseSubnetBaseAddress, + + #[error("invalid subnet: {0}")] + SubnetPrefixLen(#[from] ipnet::PrefixLenError), + + #[error("provided empty string")] + Empty, +} + +impl From>> for AddressFilterError { + fn from(value: nom::Err>) -> Self { + Self::Nom(value.to_owned()) + } +} + +impl FromStr for AddressFilter { + type Err = AddressFilterError; + + fn from_str(input: &str) -> Result { + // Perform the basic parsing. + let (rest, address) = address(input)?; + let (rest, subnet) = subnet(rest)?; + let (rest, port) = port(rest)?; + + if !rest.is_empty() { + return Err(Self::Err::TrailingValue(rest.to_string())); + } + + match (address, subnet, port) { + // Only port specified. + (None, None, Some(port)) => { + let port = port.parse::().map_err(AddressFilterError::ParsePort)?; + + Ok(Self::Port(port)) + } + + // Subnet specified. Address must be IP. + (Some(address), Some(subnet), port) => { + let as_ip = address + .parse::() + .map_err(|_| AddressFilterError::ParseSubnetBaseAddress)?; + let prefix_len = subnet + .parse::() + .map_err(AddressFilterError::ParseSubnetPrefixLength)?; + let ip_net = ipnet::IpNet::new(as_ip, prefix_len)?; + + let port = port + .map(u16::from_str) + .transpose() + .map_err(AddressFilterError::ParsePort)? + .unwrap_or(0); + + Ok(Self::Subnet(ip_net, port)) + } + + // Subnet not specified. Address can be a name or an IP. + (Some(address), None, _) => { + let port = port + .map(u16::from_str) + .transpose() + .map_err(AddressFilterError::ParsePort)? + .unwrap_or(0); + + let result = address + .parse::() + .map(|ip| Self::Socket(SocketAddr::new(ip, port))) + .unwrap_or(Self::Name(address, port)); + + Ok(result) + } + + // Subnet specified but address is missing, error. + (None, Some(_), _) => Err(AddressFilterError::ParseSubnetBaseAddress), + + // Nothing is specified, error. + (None, None, None) => Err(AddressFilterError::Empty), + } + } +} + +/// +/// The parsed filter with its [`ProtocolFilter`] and [`AddressFilter`]. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ProtocolAndAddressFilter { + /// Valid protocol types. + pub protocol: ProtocolFilter, + + /// Address|name|subnet we're going to filter. + pub address: AddressFilter, +} + +#[derive(Error, Debug)] +pub enum ProtocolAndAddressFilterError { + #[error(transparent)] + Address(#[from] AddressFilterError), + #[error(transparent)] + Protocol(#[from] ParseProtocolError), +} + +impl From>> for ProtocolAndAddressFilterError { + fn from(value: nom::Err>) -> Self { + Self::Address(value.into()) + } +} + +impl FromStr for ProtocolAndAddressFilter { + type Err = ProtocolAndAddressFilterError; + + fn from_str(input: &str) -> Result { + // Perform the basic parsing. + let (rest, protocol) = protocol(input)?; + let protocol = protocol.parse()?; + + let address = rest.parse().or_else(|error| match error { + AddressFilterError::Empty => Ok(AddressFilter::Port(0)), + other => Err(other), + })?; + + Ok(Self { protocol, address }) + } +} + +/// +/// +/// Parses `tcp://`, extracting the `tcp` part, and discarding the `://`. +fn protocol(input: &str) -> IResult<&str, &str> { + let (rest, protocol) = opt(terminated(take_until("://"), tag("://")))(input)?; + let protocol = protocol.unwrap_or("any"); + + Ok((rest, protocol)) +} + +/// +/// +/// We try to parse 3 different kinds of values here: +/// +/// 1. `name.with.dots`; +/// 2. `1.2.3.4.5.6`; +/// 3. `[dad:1337:fa57::0]` +/// +/// Where 1 and 2 are handled by `dotted_address`. +/// +/// The parser is not interested in only eating correct values here for hostnames, ip addresses, +/// etc., it just tries to get a good enough string that could be parsed by +/// `SocketAddr::parse`, or `IpNet::parse`. +fn address(input: &str) -> IResult<&str, Option> { + let ipv6 = many1(alt((alphanumeric1, tag(":")))); + let ipv6_host = delimited(tag("["), ipv6, tag("]")); + + let host_char = alt((alphanumeric1, tag("-"), tag("_"), tag("."))); + let dotted_address = many1(host_char); + + let (rest, address) = opt(alt((dotted_address, ipv6_host)))(input)?; + + let address = address.map(|addr| addr.concat()); + + Ok((rest, address)) +} + +/// +/// +/// Parses `/24`, extracting the `24` part, and discarding the `/`. +fn subnet(input: &str) -> IResult<&str, Option<&str>> { + let subnet_parser = preceded(tag("/"), digit1); + let (rest, subnet) = opt(subnet_parser)(input)?; + + Ok((rest, subnet)) +} + +/// +/// +/// Parses `:1337`, extracting the `1337` part, and discarding the `:`. +/// +/// Returns [`None`] if it doesn't parse anything. +fn port(input: &str) -> IResult<&str, Option<&str>> { + let port_parser = preceded(tag(":"), digit1); + let (rest, port) = opt(port_parser)(input)?; + + Ok((rest, port)) +} + +#[cfg(test)] +mod tests { + use ipnet::IpNet; + use rstest::{fixture, rstest}; + + use super::*; + + // Valid configs. + #[fixture] + fn full() -> &'static str { + "tcp://1.2.3.0/24:7777" + } + + #[fixture] + fn full_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Tcp, + address: AddressFilter::Subnet(IpNet::from_str("1.2.3.0/24").unwrap(), 7777), + } + } + + #[fixture] + fn ipv6() -> &'static str { + "tcp://[2800:3f0:4001:81e::2004]:7777" + } + + #[fixture] + fn ipv6_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Tcp, + address: AddressFilter::Socket( + SocketAddr::from_str("[2800:3f0:4001:81e::2004]:7777").unwrap(), + ), + } + } + + #[fixture] + fn protocol_only() -> &'static str { + "tcp://" + } + + #[fixture] + fn protocol_only_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Tcp, + address: AddressFilter::Port(0), + } + } + + #[fixture] + fn name() -> &'static str { + "tcp://google.com:7777" + } + + #[fixture] + fn name_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Tcp, + address: AddressFilter::Name("google.com".to_string(), 7777), + } + } + + #[fixture] + fn name_only() -> &'static str { + "rust-lang.org" + } + + #[fixture] + fn name_only_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Any, + address: AddressFilter::Name("rust-lang.org".to_string(), 0), + } + } + + #[fixture] + fn localhost() -> &'static str { + "localhost" + } + + #[fixture] + fn localhost_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Any, + address: AddressFilter::Name("localhost".to_string(), 0), + } + } + + #[fixture] + fn subnet_port() -> &'static str { + "1.2.3.0/24:7777" + } + + #[fixture] + fn subnet_port_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Any, + address: AddressFilter::Subnet(IpNet::from_str("1.2.3.0/24").unwrap(), 7777), + } + } + + #[fixture] + fn subnet_only() -> &'static str { + "1.2.3.0/24" + } + + #[fixture] + fn subnet_only_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Any, + address: AddressFilter::Subnet(IpNet::from_str("1.2.3.0/24").unwrap(), 0), + } + } + + #[fixture] + fn protocol_port() -> &'static str { + "udp://:7777" + } + + #[fixture] + fn protocol_port_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Udp, + address: AddressFilter::Port(7777), + } + } + + #[fixture] + fn port_only() -> &'static str { + ":7777" + } + + #[fixture] + fn port_only_converted() -> ProtocolAndAddressFilter { + ProtocolAndAddressFilter { + protocol: ProtocolFilter::Any, + address: AddressFilter::Port(7777), + } + } + + // Bad configs. + #[fixture] + fn name_with_subnet() -> &'static str { + "tcp://google.com/24:7777" + } + + #[fixture] + fn port_protocol() -> &'static str { + ":7777udp://" + } + + #[fixture] + fn fake_protocol() -> &'static str { + "meow://" + } + + #[rstest] + #[case(full(), full_converted())] + #[case(ipv6(), ipv6_converted())] + #[case(protocol_only(), protocol_only_converted())] + #[case(name(), name_converted())] + #[case(name_only(), name_only_converted())] + #[case(localhost(), localhost_converted())] + #[case(subnet_port(), subnet_port_converted())] + #[case(subnet_only(), subnet_only_converted())] + #[case(protocol_port(), protocol_port_converted())] + #[case(port_only(), port_only_converted())] + fn valid_filters(#[case] input: &'static str, #[case] converted: ProtocolAndAddressFilter) { + assert_eq!( + ProtocolAndAddressFilter::from_str(input).unwrap(), + converted + ); + } + + #[rstest] + #[case(name_with_subnet())] + #[case(port_protocol())] + #[case(fake_protocol())] + #[should_panic] + fn invalid_filters(#[case] input: &'static str) { + ProtocolAndAddressFilter::from_str(input).unwrap(); + } +} diff --git a/mirrord/config/src/feature/network/outgoing.rs b/mirrord/config/src/feature/network/outgoing.rs index 9892cb07e56..1a5790460b9 100644 --- a/mirrord/config/src/feature/network/outgoing.rs +++ b/mirrord/config/src/feature/network/outgoing.rs @@ -1,12 +1,11 @@ -use core::str::FromStr; -use std::net::SocketAddr; +use std::ops::Deref; use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; use serde::Deserialize; -use thiserror::Error; +use super::filter::ProtocolAndAddressFilter; use crate::{ config::{from_env::FromEnv, source::MirrordConfigSource, ConfigContext, ConfigError}, util::{MirrordToggleableConfig, VecOrSingle}, @@ -153,213 +152,6 @@ impl MirrordToggleableConfig for OutgoingFileConfig { } } -/// -/// Errors related to parsing an [`OutgoingFilter`]. -#[derive(Debug, Error)] -pub enum OutgoingFilterError { - #[error("Nom: failed parsing with {0}!")] - Nom2(nom::Err>), - - #[error("Subnet: Failed parsing with {0}!")] - Subnet(#[from] ipnet::AddrParseError), - - #[error("ParseInt: Failed converting string into `u16` with {0}!")] - ParseInt(#[from] std::num::ParseIntError), - - #[error("Failed parsing protocol value of {0}!")] - InvalidProtocol(String), - - #[error("Found trailing value after parsing {0}!")] - TrailingValue(String), -} - -impl From>> for OutgoingFilterError { - fn from(value: nom::Err>) -> Self { - Self::Nom2(value.to_owned()) - } -} - -/// -/// The protocols we support on [`OutgoingFilter`]. -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ProtocolFilter { - #[default] - Any, - Tcp, - Udp, -} - -impl FromStr for ProtocolFilter { - type Err = OutgoingFilterError; - - fn from_str(s: &str) -> Result { - let lowercase = s.to_lowercase(); - - match lowercase.as_str() { - "any" => Ok(Self::Any), - "tcp" => Ok(Self::Tcp), - "udp" => Ok(Self::Udp), - invalid => Err(OutgoingFilterError::InvalidProtocol(invalid.to_string())), - } - } -} - -/// -/// Parsed addresses can be one of these 3 variants. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum AddressFilter { - /// Just a plain old [`SocketAddr`], specified as `a.b.c.d:e`. - /// - /// We treat `0`s here as if it meant **any**, so `0.0.0.0` means we filter any IP, and `:0` - /// means any port. - Socket(SocketAddr), - - /// A named address, as we cannot resolve it here, specified as `name:a`. - /// - /// We can only resolve such names on the mirrord layer `connect` call, as we have to check if - /// the user enabled the DNS feature or not (and thus, resolve it through the remote pod, or - /// the local app). - Name((String, u16)), - - /// Just a plain old subnet and a port, specified as `a.b.c.d/e:f`. - Subnet((ipnet::IpNet, u16)), -} - -/// -/// The parsed filter with its [`ProtocolFilter`] and [`AddressFilter`]. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct OutgoingFilter { - /// Valid protocol types. - pub protocol: ProtocolFilter, - - /// Address|name|subnet we're going to filter. - pub address: AddressFilter, -} - -/// -/// It's dangerous to go alone! -/// Take [this](https://github.com/rust-bakery/nom/blob/main/doc/choosing_a_combinator.md). -/// -/// [`nom`] works better with `u8` slices, instead of `str`s. -mod parser { - use nom::{ - branch::alt, - bytes::complete::{tag, take_until}, - character::complete::{alphanumeric1, digit1}, - combinator::opt, - multi::many1, - sequence::{delimited, preceded, terminated}, - IResult, - }; - - /// - /// - /// Parses `tcp://`, extracting the `tcp` part, and discarding the `://`. - pub(super) fn protocol(input: &str) -> IResult<&str, &str> { - let (rest, protocol) = opt(terminated(take_until("://"), tag("://")))(input)?; - let protocol = protocol.unwrap_or("any"); - - Ok((rest, protocol)) - } - - /// - /// - /// We try to parse 3 different kinds of values here: - /// - /// 1. `name.with.dots`; - /// 2. `1.2.3.4.5.6`; - /// 3. `[dad:1337:fa57::0]` - /// - /// Where 1 and 2 are handled by `dotted_address`. - /// - /// The parser is not interested in only eating correct values here for hostnames, ip addresses, - /// etc., it just tries to get a good enough string that could be parsed by - /// `SocketAddr::parse`, or `IpNet::parse`. - /// - /// Returns `0.0.0.0` if it doesn't parse anything. - pub(super) fn address(input: &str) -> IResult<&str, String> { - let ipv6 = many1(alt((alphanumeric1, tag(":")))); - let ipv6_host = delimited(tag("["), ipv6, tag("]")); - - let host_char = alt((alphanumeric1, tag("-"), tag("_"), tag("."))); - let dotted_address = many1(host_char); - - let (rest, address) = opt(alt((dotted_address, ipv6_host)))(input)?; - - let address = address - .map(|addr| addr.concat()) - .unwrap_or(String::from("0.0.0.0")); - - Ok((rest, address)) - } - - /// - /// - /// Parses `/24`, extracting the `24` part, and discarding the `/`. - pub(super) fn subnet(input: &str) -> IResult<&str, Option<&str>> { - let subnet_parser = preceded(tag("/"), digit1); - let (rest, subnet) = opt(subnet_parser)(input)?; - - Ok((rest, subnet)) - } - - /// - /// - /// Parses `:1337`, extracting the `1337` part, and discarding the `:`. - /// - /// Returns `0` if it doesn't parse anything. - pub(super) fn port(input: &str) -> IResult<&str, &str> { - let port_parser = preceded(tag(":"), digit1); - let (rest, port) = opt(port_parser)(input)?; - - let port = port.unwrap_or("0"); - - Ok((rest, port)) - } -} - -impl FromStr for OutgoingFilter { - type Err = OutgoingFilterError; - - #[tracing::instrument(level = "trace", ret)] - fn from_str(input: &str) -> Result { - use crate::feature::network::outgoing::parser::*; - - // Perform the basic parsing. - let (rest, protocol) = protocol(input)?; - let (rest, address) = address(rest)?; - let (rest, subnet) = subnet(rest)?; - let (rest, port) = port(rest)?; - - // Stringify and convert to proper types. - let protocol = protocol.parse()?; - let port = port.parse::()?; - - let address = subnet - .map(|subnet| format!("{address}/{subnet}").parse::()) - .transpose()? - .map_or_else( - // Try to parse as an IPv4 address. - || { - format!("{address}:{port}") - .parse::() - // Try again as IPv6. - .or_else(|_| format!("[{address}]:{port}").parse()) - .map(AddressFilter::Socket) - // Neither IPv4 nor IPv6, it's probably a name. - .unwrap_or(AddressFilter::Name((address.to_string(), port))) - }, - |subnet| AddressFilter::Subnet((subnet, port)), - ); - - if rest.is_empty() { - Ok(Self { protocol, address }) - } else { - Err(OutgoingFilterError::TrailingValue(rest.to_string())) - } - } -} - impl CollectAnalytics for &OutgoingConfig { fn collect_analytics(&self, analytics: &mut mirrord_analytics::Analytics) { analytics.add("tcp", self.tcp); @@ -386,14 +178,37 @@ impl CollectAnalytics for &OutgoingConfig { } } +impl OutgoingConfig { + pub fn verify(&self, _: &mut ConfigContext) -> Result<(), ConfigError> { + let filters = match self.filter.as_ref() { + None => return Ok(()), + Some(OutgoingFilterConfig::Local(filters)) => filters.deref(), + Some(OutgoingFilterConfig::Remote(filters)) => filters.deref(), + }; + + for filter in filters { + let Err(error) = filter.parse::() else { + continue; + }; + + return Err(ConfigError::InvalidValue { + name: "feature.network.outgoing.filter", + provided: filter.to_string(), + error: Box::new(error), + }); + } + + Ok(()) + } +} + #[cfg(test)] mod tests { - use ipnet::IpNet; - use rstest::{fixture, rstest}; + use rstest::rstest; - use super::*; use crate::{ config::{ConfigContext, MirrordConfig}, + feature::network::OutgoingFileConfig, util::{testing::with_env_vars, ToggleableConfig}, }; @@ -452,177 +267,4 @@ mod tests { }, ); } - - // Valid configs. - #[fixture] - fn full() -> &'static str { - "tcp://1.2.3.0/24:7777" - } - - #[fixture] - fn full_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Tcp, - address: AddressFilter::Subnet((IpNet::from_str("1.2.3.0/24").unwrap(), 7777)), - } - } - - #[fixture] - fn ipv6() -> &'static str { - "tcp://[2800:3f0:4001:81e::2004]:7777" - } - - #[fixture] - fn ipv6_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Tcp, - address: AddressFilter::Socket( - SocketAddr::from_str("[2800:3f0:4001:81e::2004]:7777").unwrap(), - ), - } - } - - #[fixture] - fn protocol_only() -> &'static str { - "tcp://" - } - - #[fixture] - fn protocol_only_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Tcp, - address: AddressFilter::Socket(SocketAddr::from_str("0.0.0.0:0").unwrap()), - } - } - - #[fixture] - fn name() -> &'static str { - "tcp://google.com:7777" - } - - #[fixture] - fn name_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Tcp, - address: AddressFilter::Name(("google.com".to_string(), 7777)), - } - } - - #[fixture] - fn name_only() -> &'static str { - "rust-lang.org" - } - - #[fixture] - fn name_only_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Any, - address: AddressFilter::Name(("rust-lang.org".to_string(), 0)), - } - } - - #[fixture] - fn localhost() -> &'static str { - "localhost" - } - - #[fixture] - fn localhost_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Any, - address: AddressFilter::Name(("localhost".to_string(), 0)), - } - } - - #[fixture] - fn subnet_port() -> &'static str { - "1.2.3.0/24:7777" - } - - #[fixture] - fn subnet_port_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Any, - address: AddressFilter::Subnet((IpNet::from_str("1.2.3.0/24").unwrap(), 7777)), - } - } - - #[fixture] - fn subnet_only() -> &'static str { - "1.2.3.0/24" - } - - #[fixture] - fn subnet_only_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Any, - address: AddressFilter::Subnet((IpNet::from_str("1.2.3.0/24").unwrap(), 0)), - } - } - - #[fixture] - fn protocol_port() -> &'static str { - "udp://:7777" - } - - #[fixture] - fn protocol_port_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Udp, - address: AddressFilter::Socket(SocketAddr::from_str("0.0.0.0:7777").unwrap()), - } - } - - #[fixture] - fn port_only() -> &'static str { - ":7777" - } - - #[fixture] - fn port_only_converted() -> OutgoingFilter { - OutgoingFilter { - protocol: ProtocolFilter::Any, - address: AddressFilter::Socket(SocketAddr::from_str("0.0.0.0:7777").unwrap()), - } - } - - // Bad configs. - #[fixture] - fn name_with_subnet() -> &'static str { - "tcp://google.com/24:7777" - } - - #[fixture] - fn port_protocol() -> &'static str { - ":7777udp://" - } - - #[fixture] - fn fake_protocol() -> &'static str { - "meow://" - } - - #[rstest] - #[case(full(), full_converted())] - #[case(ipv6(), ipv6_converted())] - #[case(protocol_only(), protocol_only_converted())] - #[case(name(), name_converted())] - #[case(name_only(), name_only_converted())] - #[case(localhost(), localhost_converted())] - #[case(subnet_port(), subnet_port_converted())] - #[case(subnet_only(), subnet_only_converted())] - #[case(protocol_port(), protocol_port_converted())] - #[case(port_only(), port_only_converted())] - fn valid_filters(#[case] input: &'static str, #[case] converted: OutgoingFilter) { - assert_eq!(OutgoingFilter::from_str(input).unwrap(), converted); - } - - #[rstest] - #[case(name_with_subnet())] - #[case(port_protocol())] - #[case(fake_protocol())] - #[should_panic] - fn invalid_filters(#[case] input: &'static str) { - OutgoingFilter::from_str(input).unwrap(); - } } diff --git a/mirrord/config/src/lib.rs b/mirrord/config/src/lib.rs index 0bf89539903..cd1f7dc490f 100644 --- a/mirrord/config/src/lib.rs +++ b/mirrord/config/src/lib.rs @@ -23,6 +23,7 @@ use feature::network::outgoing::OutgoingFilterConfig; use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; +use target::Target; use tera::Tera; use tracing::warn; @@ -137,7 +138,12 @@ use crate::{ /// "ignore_localhost": false, /// "unix_streams": "bear.+" /// }, -/// "dns": false +/// "dns": { +/// "enabled": true, +/// "filter": { +/// "local": ["1.1.1.0/24:1337", "1.1.5.0/24", "google.com"] +/// } +/// } /// }, /// "copy_target": { /// "scale_down": false @@ -337,7 +343,7 @@ impl LayerConfig { if matches!( self.feature.network.outgoing.filter, Some(OutgoingFilterConfig::Remote(_)) - ) && !self.feature.network.dns + ) && !self.feature.network.dns.enabled { context.add_warning( "The mirrord outgoing traffic filter includes host names to be connected remotely, \ @@ -402,14 +408,7 @@ impl LayerConfig { .target .path .as_ref() - .map(|target| { - matches!( - target, - target::Target::Job(_) - | target::Target::CronJob(_) - | target::Target::StatefulSet(_) - ) - }) + .map(Target::requires_copy) .unwrap_or_default() { Err(ConfigError::TargetJobWithoutCopyTarget)? @@ -490,6 +489,9 @@ impl LayerConfig { )); } + self.feature.network.dns.verify(context)?; + self.feature.network.outgoing.verify(context)?; + Ok(()) } } @@ -751,7 +753,7 @@ mod tests { env: ToggleableConfig::Enabled(true).into(), fs: ToggleableConfig::Config(FsUserConfig::Simple(FsModeConfig::Write)).into(), network: Some(ToggleableConfig::Config(NetworkFileConfig { - dns: Some(false), + dns: Some(ToggleableConfig::Enabled(false)), incoming: Some(ToggleableConfig::Config(IncomingFileConfig::Advanced( Box::new(IncomingAdvancedFileConfig { mode: Some(IncomingMode::Mirror), diff --git a/mirrord/config/src/target.rs b/mirrord/config/src/target.rs index 560a7792b94..be58a7d0775 100644 --- a/mirrord/config/src/target.rs +++ b/mirrord/config/src/target.rs @@ -1,7 +1,5 @@ -use std::{ - fmt::{self}, - str::FromStr, -}; +use core::fmt; +use std::str::FromStr; use cron_job::CronJobTarget; use mirrord_analytics::CollectAnalytics; @@ -215,6 +213,7 @@ mirrord-layer failed to parse the provided target! /// - `job/{sample-job}`; /// - `cronjob/{sample-cronjob}`; /// - `statefulset/{sample-statefulset}`; +#[warn(clippy::wildcard_enum_match_arm)] #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Hash, Debug, JsonSchema)] #[serde(untagged, deny_unknown_fields)] pub enum Target { @@ -294,43 +293,61 @@ impl Target { } } } + + /// `true` if this [`Target`] is only supported when the copy target feature is enabled. + pub(super) fn requires_copy(&self) -> bool { + matches!( + self, + Target::Job(_) | Target::CronJob(_) | Target::StatefulSet(_) + ) + } } +/// Trait used to convert different aspects of a [`Target`] into a string. +/// +/// It's mainly implemented using the `impl_target_display` macro, except for [`Target`] +/// and `TargetHandle`, which manually implement this. pub trait TargetDisplay { - fn target_type(&self) -> &str; - - fn target_name(&self) -> &str; + /// The string version of a [`Target`]'s type, e.g. `Pod` -> `"Pod"`. + fn type_(&self) -> &str; - fn container_name(&self) -> Option<&String>; + /// The `name` of a [`Target`], e.g. `"pod-of-beans"`. + fn name(&self) -> &str; - fn fmt_display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}/{}{}", - self.target_type(), - self.target_name(), - self.container_name() - .map(|name| format!("/container/{name}")) - .unwrap_or_default() - ) - } + /// The optional name of a [`Target`]'s container, e.g. `"can-of-beans"`. + fn container(&self) -> Option<&String>; } +/// Implements the [`TargetDisplay`] and [`fmt::Display`] traits for a target type. macro_rules! impl_target_display { ($struct_name:ident, $target_type:ident) => { impl TargetDisplay for $struct_name { - fn target_type(&self) -> &str { + fn type_(&self) -> &str { stringify!($target_type) } - fn target_name(&self) -> &str { + fn name(&self) -> &str { self.$target_type.as_str() } - fn container_name(&self) -> Option<&String> { + fn container(&self) -> Option<&String> { self.container.as_ref() } } + + impl fmt::Display for $struct_name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}/{}{}", + self.type_(), + self.name(), + self.container() + .map(|name| format!("/container/{name}")) + .unwrap_or_default() + ) + } + } }; } @@ -345,50 +362,53 @@ impl fmt::Display for Target { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Target::Targetless => write!(f, "targetless"), - Target::Pod(target) => target.fmt_display(f), - Target::Deployment(target) => target.fmt_display(f), - Target::Rollout(target) => target.fmt_display(f), - Target::Job(target) => target.fmt_display(f), - Target::CronJob(target) => target.fmt_display(f), - Target::StatefulSet(target) => target.fmt_display(f), + Target::Pod(target) => target.fmt(f), + Target::Deployment(target) => target.fmt(f), + Target::Rollout(target) => target.fmt(f), + Target::Job(target) => target.fmt(f), + Target::CronJob(target) => target.fmt(f), + Target::StatefulSet(target) => target.fmt(f), } } } impl TargetDisplay for Target { - fn target_type(&self) -> &str { + #[tracing::instrument(level = "trace", ret)] + fn type_(&self) -> &str { match self { Target::Targetless => "targetless", - Target::Deployment(target) => target.target_type(), - Target::Pod(target) => target.target_type(), - Target::Rollout(target) => target.target_type(), - Target::Job(target) => target.target_type(), - Target::CronJob(target) => target.target_type(), - Target::StatefulSet(target) => target.target_type(), + Target::Deployment(target) => target.type_(), + Target::Pod(target) => target.type_(), + Target::Rollout(target) => target.type_(), + Target::Job(target) => target.type_(), + Target::CronJob(target) => target.type_(), + Target::StatefulSet(target) => target.type_(), } } - fn target_name(&self) -> &str { + #[tracing::instrument(level = "trace", ret)] + fn name(&self) -> &str { match self { Target::Targetless => "targetless", - Target::Deployment(target) => target.target_name(), - Target::Pod(target) => target.target_name(), - Target::Rollout(target) => target.target_name(), - Target::Job(target) => target.target_name(), - Target::CronJob(target) => target.target_name(), - Target::StatefulSet(target) => target.target_name(), + Target::Deployment(target) => target.name(), + Target::Pod(target) => target.name(), + Target::Rollout(target) => target.name(), + Target::Job(target) => target.name(), + Target::CronJob(target) => target.name(), + Target::StatefulSet(target) => target.name(), } } - fn container_name(&self) -> Option<&String> { + #[tracing::instrument(level = "trace", ret)] + fn container(&self) -> Option<&String> { match self { Target::Targetless => None, - Target::Deployment(target) => target.container_name(), - Target::Pod(target) => target.container_name(), - Target::Rollout(target) => target.container_name(), - Target::Job(target) => target.container_name(), - Target::CronJob(target) => target.container_name(), - Target::StatefulSet(target) => target.container_name(), + Target::Deployment(target) => target.container(), + Target::Pod(target) => target.container(), + Target::Rollout(target) => target.container(), + Target::Job(target) => target.container(), + Target::CronJob(target) => target.container(), + Target::StatefulSet(target) => target.container(), } } } diff --git a/mirrord/config/src/target/pod.rs b/mirrord/config/src/target/pod.rs index eb6d207deab..5161778c2d9 100644 --- a/mirrord/config/src/target/pod.rs +++ b/mirrord/config/src/target/pod.rs @@ -1,6 +1,3 @@ -use core::fmt; -use std::fmt::Display; - use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -18,20 +15,6 @@ pub struct PodTarget { pub container: Option, } -impl Display for PodTarget { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}{}", - self.container - .as_ref() - .map(|c| format!("{c}/")) - .unwrap_or_default(), - self.pod.clone() - ) - } -} - impl FromSplit for PodTarget { fn from_split(split: &mut std::str::Split) -> config::Result { let pod = split diff --git a/mirrord/intproxy/Cargo.toml b/mirrord/intproxy/Cargo.toml index b5deb0bf71b..eb234716534 100644 --- a/mirrord/intproxy/Cargo.toml +++ b/mirrord/intproxy/Cargo.toml @@ -26,6 +26,7 @@ mirrord-protocol = { path = "../protocol" } mirrord-intproxy-protocol = { path = "./protocol", features = ["codec-async"] } mirrord-analytics = { path = "../analytics" } +semver.workspace = true serde.workspace = true thiserror.workspace = true tokio.workspace = true @@ -35,8 +36,6 @@ hyper = { workspace = true, features = ["client", "http1", "http2"] } hyper-util.workspace = true http-body-util.workspace = true bytes.workspace = true +futures.workspace = true rand = "0.8" - -[dev-dependencies] -futures.workspace = true diff --git a/mirrord/intproxy/src/agent_conn.rs b/mirrord/intproxy/src/agent_conn.rs index 6cf52b90f8c..76b9e52aae4 100644 --- a/mirrord/intproxy/src/agent_conn.rs +++ b/mirrord/intproxy/src/agent_conn.rs @@ -12,7 +12,7 @@ use mirrord_kube::{ }, error::KubeApiError, }; -use mirrord_operator::client::{OperatorApi, OperatorApiError, OperatorSessionInformation}; +use mirrord_operator::client::{error::OperatorApiError, OperatorApi, OperatorSession}; use mirrord_protocol::{ClientMessage, DaemonMessage}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -47,7 +47,7 @@ pub enum AgentConnectionError { #[derive(Debug, Clone, Serialize, Deserialize)] pub enum AgentConnectInfo { /// Connect to the agent through the operator. - Operator(OperatorSessionInformation), + Operator(OperatorSession), /// Connect directly to the agent by name and port using k8s port forward. DirectKubernetes(AgentKubernetesConnectInfo), } @@ -74,12 +74,10 @@ impl AgentConnection { analytics: &mut R, ) -> Result { let (agent_tx, agent_rx) = match connect_info { - Some(AgentConnectInfo::Operator(operator_session_information)) => { - let session = OperatorApi::connect(config, operator_session_information, analytics) - .await - .map_err(AgentConnectionError::Operator)?; - - (session.tx, session.rx) + Some(AgentConnectInfo::Operator(session)) => { + let connection = + OperatorApi::connect_in_existing_session(config, session, analytics).await?; + (connection.tx, connection.rx) } Some(AgentConnectInfo::DirectKubernetes(connect_info)) => { diff --git a/mirrord/intproxy/src/lib.rs b/mirrord/intproxy/src/lib.rs index 99231497cd7..d1a2fd991c5 100644 --- a/mirrord/intproxy/src/lib.rs +++ b/mirrord/intproxy/src/lib.rs @@ -299,6 +299,11 @@ impl IntProxy { if CLIENT_READY_FOR_LOGS.matches(&protocol_version) { self.task_txs.agent.send(ClientMessage::ReadyForLogs).await; } + + self.task_txs + .incoming + .send(IncomingProxyMessage::AgentProtocolVersion(protocol_version)) + .await; } DaemonMessage::LogMessage(log) => match log.level { LogLevel::Error => tracing::error!("agent log: {}", log.message), diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 6a52a3ba8e8..d356e6a7fac 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -6,22 +6,28 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, }; +use futures::StreamExt; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, }; use mirrord_protocol::{ + body_chunks::BodyExt as _, tcp::{ - ChunkedRequest, DaemonTcp, HttpRequest, HttpRequestFallback, InternalHttpBodyFrame, - InternalHttpRequest, NewTcpConnection, StreamingBody, + ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, HttpRequest, + HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBodyFrame, + InternalHttpRequest, InternalHttpResponse, LayerTcpSteal, NewTcpConnection, + ReceiverStreamBody, StreamingBody, TcpData, }, - ConnectionId, RequestId, ResponseError, + ClientMessage, ConnectionId, RequestId, ResponseError, }; use thiserror::Error; use tokio::{ net::TcpSocket, sync::mpsc::{self, Sender}, }; +use tokio_stream::{StreamMap, StreamNotifyClose}; +use tracing::debug; use self::{ interceptor::{Interceptor, InterceptorError, MessageOut}, @@ -97,6 +103,8 @@ pub enum IncomingProxyMessage { LayerClosed(LayerClosed), AgentMirror(DaemonTcp), AgentSteal(DaemonTcp), + /// Agent responded to [`ClientMessage::SwitchProtocolVersion`]. + AgentProtocolVersion(semver::Version), } /// Handle for an [`Interceptor`]. @@ -159,6 +167,10 @@ pub struct IncomingProxy { metadata_store: MetadataStore, /// For managing streamed [`DaemonTcp::HttpRequestChunked`] request channels. request_body_txs: HashMap<(ConnectionId, RequestId), Sender>, + /// For managing streamed [`LayerTcpSteal::HttpResponseChunked`] response streams. + response_body_rxs: StreamMap<(ConnectionId, RequestId), StreamNotifyClose>, + /// Version of [`mirrord_protocol`] negotiated with the agent. + agent_protocol_version: Option, } impl IncomingProxy { @@ -226,7 +238,11 @@ impl IncomingProxy { let interceptor_socket = bind_similar(subscription.listening_on)?; let interceptor = self.background_tasks.register( - Interceptor::new(interceptor_socket, subscription.listening_on), + Interceptor::new( + interceptor_socket, + subscription.listening_on, + self.agent_protocol_version.clone(), + ), id, Self::CHANNEL_SIZE, ); @@ -253,7 +269,16 @@ impl IncomingProxy { self.interceptors .remove(&InterceptorId(close.connection_id)); self.request_body_txs - .retain(|(connection_id, _), _| *connection_id != close.connection_id) + .retain(|(connection_id, _), _| *connection_id != close.connection_id); + let keys: Vec<(ConnectionId, RequestId)> = self + .response_body_rxs + .keys() + .filter(|key| key.0 == close.connection_id) + .cloned() + .collect(); + for key in keys.iter() { + self.response_body_rxs.remove(key); + } } DaemonTcp::Data(data) => { if let Some(interceptor) = self.interceptors.get(&InterceptorId(data.connection_id)) @@ -365,7 +390,11 @@ impl IncomingProxy { ); let interceptor = self.background_tasks.register( - Interceptor::new(interceptor_socket, subscription.listening_on), + Interceptor::new( + interceptor_socket, + subscription.listening_on, + self.agent_protocol_version.clone(), + ), id, Self::CHANNEL_SIZE, ); @@ -418,6 +447,47 @@ impl BackgroundTask for IncomingProxy { async fn run(mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { loop { tokio::select! { + Some(((connection_id, request_id), stream_item)) = self.response_body_rxs.next() => match stream_item { + Some(Ok(frame)) => { + let int_frame = InternalHttpBodyFrame::from(frame); + let res = ChunkedResponse::Body(ChunkedHttpBody { + frames: vec![int_frame], + is_last: false, + connection_id, + request_id, + }); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + res, + ))) + .await; + }, + Some(Err(error)) => { + debug!(%error, "Error while reading streamed response body"); + let res = ChunkedResponse::Error(ChunkedHttpError {connection_id, request_id}); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + res, + ))) + .await; + self.response_body_rxs.remove(&(connection_id, request_id)); + }, + None => { + let res = ChunkedResponse::Body(ChunkedHttpBody { + frames: vec![], + is_last: true, + connection_id, + request_id, + }); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + res, + ))) + .await; + self.response_body_rxs.remove(&(connection_id, request_id)); + } + }, + msg = message_bus.recv() => match msg { None => { tracing::trace!("message bus closed, exiting"); @@ -439,6 +509,9 @@ impl BackgroundTask for IncomingProxy { } Some(IncomingProxyMessage::LayerClosed(msg)) => self.handle_layer_close(msg, message_bus).await, Some(IncomingProxyMessage::LayerForked(msg)) => self.handle_layer_fork(msg), + Some(IncomingProxyMessage::AgentProtocolVersion(version)) => { + self.agent_protocol_version.replace(version); + } }, Some(task_update) = self.background_tasks.next() => match task_update { @@ -456,10 +529,50 @@ impl BackgroundTask for IncomingProxy { }, (id, TaskUpdate::Message(msg)) => { - let msg = self.get_subscription(id).and_then(|s| s.wrap_response(msg, id.0)); - if let Some(msg) = msg { - message_bus.send(msg).await; - } + let Some(PortSubscription::Steal(_)) = self.get_subscription(id) else { + continue; + }; + let msg = match msg { + MessageOut::Raw(bytes) => { + ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { + connection_id: id.0, + bytes, + })) + }, + MessageOut::Http(HttpResponseFallback::Fallback(res)) => { + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(res)) + }, + MessageOut::Http(HttpResponseFallback::Framed(res)) => { + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(res)) + }, + MessageOut::Http(HttpResponseFallback::Streamed(mut res)) => { + let mut body = vec![]; + let key = (res.connection_id, res.request_id); + + match res.internal_response.body.next_frames(false).await { + Ok(frames) => { + frames.frames.into_iter().map(From::from).for_each(|frame| body.push(frame)); + }, + Err(error) => { + debug!(%error, "Error while receving streamed response frames"); + let res = ChunkedResponse::Error(ChunkedHttpError { connection_id: key.0, request_id: key.1 }); + message_bus.send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(res))).await; + continue; + }, + } + + self.response_body_rxs.insert(key, StreamNotifyClose::new(res.internal_response.body)); + + let internal_response = InternalHttpResponse { + status: res.internal_response.status, version: res.internal_response.version, headers: res.internal_response.headers, body + }; + let res = ChunkedResponse::Start(HttpResponse { + port: res.port , connection_id: res.connection_id, request_id: res.request_id, internal_response + }); + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(res)) + }, + }; + message_bus.send(msg).await; }, }, } diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index 6f30645e282..749c4387185 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -11,7 +11,8 @@ use bytes::BytesMut; use hyper::{upgrade::OnUpgrade, StatusCode, Version}; use hyper_util::rt::TokioIo; use mirrord_protocol::tcp::{ - HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBody, + HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBody, ReceiverStreamBody, + HTTP_CHUNKED_RESPONSE_VERSION, }; use thiserror::Error; use tokio::{ @@ -86,8 +87,12 @@ pub type InterceptorResult = core::result::Result /// When it receives [`MessageIn::Raw`], it starts acting as a simple proxy. /// When it received [`MessageIn::Http`], it starts acting as an HTTP gateway. pub struct Interceptor { + /// Socket that should be used to make the first connection (should already be bound). socket: TcpSocket, + /// Address of user app's listener. peer: SocketAddr, + /// Version of [`mirrord_protocol`] negotiated with the agent. + agent_protocol_version: Option, } impl Interceptor { @@ -97,8 +102,16 @@ impl Interceptor { /// # Note /// /// The socket can be replaced when retrying HTTP requests. - pub fn new(socket: TcpSocket, peer: SocketAddr) -> Self { - Self { socket, peer } + pub fn new( + socket: TcpSocket, + peer: SocketAddr, + agent_protocol_version: Option, + ) -> Self { + Self { + socket, + peer, + agent_protocol_version, + } } } @@ -139,6 +152,7 @@ impl BackgroundTask for Interceptor { let mut http_conn = HttpConnection { sender, peer: self.peer, + agent_protocol_version: self.agent_protocol_version.clone(), }; let (response, on_upgrade) = http_conn.send(request).await?; message_bus.send(MessageOut::Http(response)).await; @@ -176,11 +190,27 @@ struct HttpConnection { peer: SocketAddr, /// Handle to the HTTP connection between the [`Interceptor`] the server. sender: HttpSender, + /// Version of [`mirrord_protocol`] negotiated with the agent. + /// Determines which variant of [`LayerTcpSteal`](mirrord_protocol::tcp::LayerTcpSteal) + /// we use when sending HTTP responses. + agent_protocol_version: Option, } impl HttpConnection { + /// Returns whether the agent supports + /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked). + pub fn agent_supports_streaming_response(&self) -> bool { + self.agent_protocol_version + .as_ref() + .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) + .unwrap_or(false) + } + /// Handles the result of sending an HTTP request. /// Returns an [`HttpResponseFallback`] to be returned to the client or an [`InterceptorError`]. + /// + /// See [`HttpResponseFallback::response_from_request`] for notes on picking the correct + /// [`HttpResponseFallback`] variant. async fn handle_response( &self, request: HttpRequestFallback, @@ -209,6 +239,7 @@ impl HttpConnection { request, StatusCode::BAD_GATEWAY, &body_message, + self.agent_protocol_version.as_ref(), ), None, )) @@ -224,6 +255,7 @@ impl HttpConnection { request, StatusCode::BAD_GATEWAY, &body_message, + self.agent_protocol_version.as_ref(), ), None, )) @@ -257,9 +289,19 @@ impl HttpConnection { .await .map(HttpResponseFallback::Fallback) } + HttpRequestFallback::Streamed(..) + if self.agent_supports_streaming_response() => + { + HttpResponse::::from_hyper_response( + res, + self.peer.port(), + request.connection_id(), + request.request_id(), + ) + .await + .map(HttpResponseFallback::Streamed) + } HttpRequestFallback::Streamed(..) => { - // Returning `HttpResponseFallback::Framed` variant is safe - streaming - // requests require a strictly higher mirrord-protocol version HttpResponse::::from_hyper_response( res, self.peer.port(), @@ -285,6 +327,7 @@ impl HttpConnection { request, StatusCode::BAD_GATEWAY, "mirrord", + self.agent_protocol_version.as_ref(), ), None, ) @@ -437,10 +480,7 @@ impl RawConnection { #[cfg(test)] mod test { - use std::{ - convert::Infallible, - sync::{Arc, Mutex}, - }; + use std::sync::{Arc, Mutex}; use bytes::Bytes; use futures::future::FutureExt; @@ -567,7 +607,15 @@ mod test { let interceptor = { let socket = TcpSocket::new_v4().unwrap(); socket.bind("127.0.0.1:0".parse().unwrap()).unwrap(); - tasks.register(Interceptor::new(socket, local_destination), (), 8) + tasks.register( + Interceptor::new( + socket, + local_destination, + Some(mirrord_protocol::VERSION.clone()), + ), + (), + 8, + ) }; interceptor @@ -594,7 +642,7 @@ mod test { match update { TaskUpdate::Message(MessageOut::Http(res)) => { let res = res - .into_hyper::() + .into_hyper::() .expect("failed to convert into hyper response"); assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS); println!("{:?}", res.headers()); @@ -643,7 +691,11 @@ mod test { let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); let socket = TcpSocket::new_v4().unwrap(); socket.bind("127.0.0.1:0".parse().unwrap()).unwrap(); - let interceptor = Interceptor::new(socket, local_destination); + let interceptor = Interceptor::new( + socket, + local_destination, + Some(mirrord_protocol::VERSION.clone()), + ); let sender = tasks.register(interceptor, (), 8); let (tx, rx) = tokio::sync::mpsc::channel(12); diff --git a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs index 4c9e4cedb65..e928be69ace 100644 --- a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs +++ b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs @@ -2,12 +2,10 @@ use mirrord_intproxy_protocol::PortSubscription; use mirrord_protocol::{ - tcp::{HttpResponseFallback, LayerTcp, LayerTcpSteal, StealType, TcpData}, + tcp::{LayerTcp, LayerTcpSteal, StealType}, ClientMessage, ConnectionId, Port, }; -use super::interceptor::MessageOut; - /// Retrieves subscribed port from the given [`StealType`]. fn get_port(steal_type: &StealType) -> Port { match steal_type { @@ -31,10 +29,6 @@ pub trait PortSubscriptionExt { /// Returns an unsubscribe connection request to be sent to the agent. fn wrap_agent_unsubscribe_connection(&self, connection_id: ConnectionId) -> ClientMessage; - - /// Returns a message to be sent to the agent in response to data coming from an interceptor. - /// [`None`] means that the data should be discarded. - fn wrap_response(&self, res: MessageOut, connection_id: ConnectionId) -> Option; } impl PortSubscriptionExt for PortSubscription { @@ -74,26 +68,4 @@ impl PortSubscriptionExt for PortSubscription { } } } - - /// Always [`None`] for the `mirror` mode - data coming from the layer is discarded. - /// Corrent [`LayerTcpSteal`] variant for the `steal` mode. - fn wrap_response(&self, res: MessageOut, connection_id: ConnectionId) -> Option { - match self { - Self::Mirror(..) => None, - Self::Steal(..) => match res { - MessageOut::Raw(bytes) => { - Some(ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { - connection_id, - bytes, - }))) - } - MessageOut::Http(HttpResponseFallback::Fallback(res)) => { - Some(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(res))) - } - MessageOut::Http(HttpResponseFallback::Framed(res)) => Some( - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(res)), - ), - }, - } - } } diff --git a/mirrord/kube/src/api/container.rs b/mirrord/kube/src/api/container.rs index ad34966a270..289b4aee49a 100644 --- a/mirrord/kube/src/api/container.rs +++ b/mirrord/kube/src/api/container.rs @@ -39,10 +39,11 @@ pub struct ContainerParams { /// Value for [`AGENT_OPERATOR_CERT_ENV`](mirrord_protocol::AGENT_OPERATOR_CERT_ENV) set in /// the agent container. pub tls_cert: Option, + pub pod_ips: Option, } impl ContainerParams { - pub fn new() -> ContainerParams { + pub fn new(tls_cert: Option, pod_ips: Option) -> ContainerParams { let port: u16 = rand::thread_rng().gen_range(30000..=65535); let gid: u16 = rand::thread_rng().gen_range(3000..u16::MAX); @@ -57,17 +58,12 @@ impl ContainerParams { name, gid, port, - tls_cert: None, + tls_cert, + pod_ips, } } } -impl Default for ContainerParams { - fn default() -> Self { - Self::new() - } -} - pub trait ContainerVariant { type Update; diff --git a/mirrord/kube/src/api/container/job.rs b/mirrord/kube/src/api/container/job.rs index 96deb958c0d..c6ebc150cb3 100644 --- a/mirrord/kube/src/api/container/job.rs +++ b/mirrord/kube/src/api/container/job.rs @@ -245,6 +245,7 @@ mod test { port: 3000, gid: 13, tls_cert: None, + pod_ips: None, }; let update = JobVariant::new(&agent, ¶ms).as_update(); @@ -327,6 +328,7 @@ mod test { port: 3000, gid: 13, tls_cert: None, + pod_ips: None, }; let update = JobTargetedVariant::new( @@ -335,6 +337,7 @@ mod test { &RuntimeData { mesh: None, pod_name: "pod".to_string(), + pod_ips: vec![], pod_namespace: None, node_name: "foobaz".to_string(), container_id: "container".to_string(), diff --git a/mirrord/kube/src/api/container/util.rs b/mirrord/kube/src/api/container/util.rs index 80b4a27a6c6..77f917378ce 100644 --- a/mirrord/kube/src/api/container/util.rs +++ b/mirrord/kube/src/api/container/util.rs @@ -63,6 +63,10 @@ pub(super) fn agent_env(agent: &AgentConfig, params: &&ContainerParams) -> Vec Result { - let client = create_kube_api( + let client = create_kube_config( config.accept_invalid_certificates, config.kubeconfig.clone(), config.kube_context.clone(), ) - .await?; + .await? + .try_into()?; Ok(KubernetesAPI::new(client, config.agent.clone())) } @@ -189,8 +190,12 @@ impl KubernetesAPI { .into(), }; - let mut params = ContainerParams::new(); - params.tls_cert = tls_cert; + let pod_ips = runtime_data + .as_ref() + .filter(|runtime_data| !runtime_data.pod_ips.is_empty()) + .map(|runtime_data| runtime_data.pod_ips.join(",")); + + let params = ContainerParams::new(tls_cert, pod_ips); Ok((params, runtime_data)) } @@ -285,11 +290,11 @@ pub struct AgentKubernetesConnectInfo { pub agent_version: Option, } -pub async fn create_kube_api

( +pub async fn create_kube_config

( accept_invalid_certificates: bool, kubeconfig: Option

, kube_context: Option, -) -> Result +) -> Result where P: AsRef, { @@ -312,10 +317,11 @@ where Config::infer().await? }; config.accept_invalid_certs = accept_invalid_certificates; - Client::try_from(config).map_err(KubeApiError::from) + + Ok(config) } -#[tracing::instrument(level = "debug", skip(client))] +#[tracing::instrument(level = "trace", skip(client))] pub fn get_k8s_resource_api(client: &Client, namespace: Option<&str>) -> Api where K: kube::Resource, diff --git a/mirrord/kube/src/api/runtime.rs b/mirrord/kube/src/api/runtime.rs index f10edbb62ee..0d78cfccfea 100644 --- a/mirrord/kube/src/api/runtime.rs +++ b/mirrord/kube/src/api/runtime.rs @@ -65,6 +65,7 @@ impl Display for ContainerRuntime { #[derive(Debug)] pub struct RuntimeData { pub pod_name: String, + pub pod_ips: Vec, pub pod_namespace: Option, pub node_name: String, pub container_id: String, @@ -109,6 +110,16 @@ impl RuntimeData { .ok_or_else(|| KubeApiError::missing_field(pod, ".spec.nodeName"))? .to_owned(); + let pod_ips = pod + .status + .as_ref() + .and_then(|spec| spec.pod_ips.as_ref()) + .ok_or_else(|| KubeApiError::missing_field(pod, ".status.podIPs"))? + .iter() + .filter_map(|pod_ip| pod_ip.ip.as_ref()) + .cloned() + .collect(); + let container_statuses = pod .status .as_ref() @@ -155,6 +166,7 @@ impl RuntimeData { }; Ok(RuntimeData { + pod_ips, pod_name, pod_namespace: pod.metadata.namespace.clone(), node_name, @@ -184,7 +196,10 @@ impl RuntimeData { let mut pod_count = 0; let mut list_params = ListParams { - field_selector: Some(format!("spec.nodeName={}", self.node_name)), + field_selector: Some(format!( + "status.phase=Running,spec.nodeName={}", + self.node_name + )), ..Default::default() }; diff --git a/mirrord/layer/src/detour.rs b/mirrord/layer/src/detour.rs index a5e44093eb2..b4f12c1c0f9 100644 --- a/mirrord/layer/src/detour.rs +++ b/mirrord/layer/src/detour.rs @@ -206,6 +206,9 @@ pub(crate) enum Bypass { /// Hostname should be resolved locally. /// Currently this is the case only when the layer operates in the `trace only` mode. LocalHostname, + + /// DNS query should be done locally. + LocalDns, } impl Bypass { diff --git a/mirrord/layer/src/lib.rs b/mirrord/layer/src/lib.rs index 801a1d9755b..06e58cdfd57 100644 --- a/mirrord/layer/src/lib.rs +++ b/mirrord/layer/src/lib.rs @@ -112,6 +112,8 @@ mod macros; mod proxy_connection; mod setup; mod socket; +#[cfg(target_os = "macos")] +mod tls; #[cfg(all( any(target_arch = "x86_64", target_arch = "aarch64"), @@ -323,7 +325,7 @@ fn layer_start(mut config: LayerConfig) { // Disable all features that require the agent if trace_only { config.feature.fs.mode = FsModeConfig::Local; - config.feature.network.dns = false; + config.feature.network.dns.enabled = false; config.feature.network.incoming.mode = IncomingMode::Off; config.feature.network.outgoing.tcp = false; config.feature.network.outgoing.udp = false; @@ -341,11 +343,7 @@ fn layer_start(mut config: LayerConfig) { SETUP.set(state).unwrap(); let state = setup(); - enable_hooks( - state.fs_config().is_active(), - state.remote_dns_enabled(), - state.sip_binaries(), - ); + enable_hooks(state); let _detour_guard = DetourGuard::new(); tracing::info!("Initializing mirrord-layer!"); @@ -476,7 +474,12 @@ fn sip_only_layer_start(mut config: LayerConfig, patch_binaries: Vec) { /// `true`, see [`NetworkConfig`](mirrord_config::feature::network::NetworkConfig), and /// [`hooks::enable_socket_hooks`](socket::hooks::enable_socket_hooks). #[mirrord_layer_macro::instrument(level = "trace")] -fn enable_hooks(enabled_file_ops: bool, enabled_remote_dns: bool, patch_binaries: Vec) { +fn enable_hooks(state: &LayerSetup) { + let enabled_file_ops = state.fs_config().is_active(); + let enabled_remote_dns = state.remote_dns_enabled(); + #[cfg(target_os = "macos")] + let patch_binaries = state.sip_binaries(); + let mut hook_manager = HookManager::default(); unsafe { @@ -527,6 +530,11 @@ fn enable_hooks(enabled_file_ops: bool, enabled_remote_dns: bool, patch_binaries exec_utils::enable_execve_hook(&mut hook_manager, patch_binaries) }; + #[cfg(target_os = "macos")] + if state.experimental().trust_any_certificate { + unsafe { tls::enable_tls_hooks(&mut hook_manager) }; + } + if enabled_file_ops { unsafe { file::hooks::enable_file_hooks(&mut hook_manager) }; } diff --git a/mirrord/layer/src/setup.rs b/mirrord/layer/src/setup.rs index aeecc3e2ebe..eafca7b931e 100644 --- a/mirrord/layer/src/setup.rs +++ b/mirrord/layer/src/setup.rs @@ -20,7 +20,7 @@ use regex::RegexSet; use crate::{ debugger_ports::DebuggerPorts, file::{filter::FileFilter, mapper::FileRemapper}, - socket::OutgoingSelector, + socket::{dns_selector::DnsSelector, OutgoingSelector}, }; /// Complete layer setup. @@ -34,6 +34,7 @@ pub struct LayerSetup { debugger_ports: DebuggerPorts, remote_unix_streams: RegexSet, outgoing_selector: OutgoingSelector, + dns_selector: DnsSelector, proxy_address: SocketAddr, incoming_mode: IncomingMode, local_hostname: bool, @@ -59,8 +60,9 @@ impl LayerSetup { .expect("invalid unix stream regex set") .unwrap_or_default(); - let outgoing_selector: OutgoingSelector = - OutgoingSelector::new(&config.feature.network.outgoing); + let outgoing_selector = OutgoingSelector::new(&config.feature.network.outgoing); + + let dns_selector = DnsSelector::from(&config.feature.network.dns); let proxy_address = config .connect_tcp @@ -82,6 +84,7 @@ impl LayerSetup { debugger_ports, remote_unix_streams, outgoing_selector, + dns_selector, proxy_address, incoming_mode, local_hostname, @@ -119,7 +122,7 @@ impl LayerSetup { } pub fn remote_dns_enabled(&self) -> bool { - self.config.feature.network.dns + self.config.feature.network.dns.enabled } pub fn targetless(&self) -> bool { @@ -131,6 +134,7 @@ impl LayerSetup { .unwrap_or(true) } + #[cfg(target_os = "macos")] pub fn sip_binaries(&self) -> Vec { self.config .sip_binaries @@ -147,6 +151,10 @@ impl LayerSetup { &self.outgoing_selector } + pub fn dns_selector(&self) -> &DnsSelector { + &self.dns_selector + } + pub fn remote_unix_streams(&self) -> &RegexSet { &self.remote_unix_streams } diff --git a/mirrord/layer/src/socket.rs b/mirrord/layer/src/socket.rs index 56e0be718dc..a4a4e9af2ec 100644 --- a/mirrord/layer/src/socket.rs +++ b/mirrord/layer/src/socket.rs @@ -1,7 +1,7 @@ //! We implement each hook function in a safe function as much as possible, having the unsafe do the //! absolute minimum use std::{ - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}, + net::{SocketAddr, ToSocketAddrs}, os::unix::io::RawFd, str::FromStr, sync::{Arc, LazyLock}, @@ -10,8 +10,9 @@ use std::{ use dashmap::DashMap; use hashbrown::hash_set::HashSet; use libc::{c_int, sockaddr, socklen_t}; -use mirrord_config::feature::network::outgoing::{ - AddressFilter, OutgoingConfig, OutgoingFilter, OutgoingFilterConfig, ProtocolFilter, +use mirrord_config::feature::network::{ + filter::{AddressFilter, ProtocolAndAddressFilter, ProtocolFilter}, + outgoing::{OutgoingConfig, OutgoingFilterConfig}, }; use mirrord_intproxy_protocol::{NetProtocol, PortUnsubscribe}; use mirrord_protocol::{ @@ -27,6 +28,7 @@ use crate::{ socket::ops::{remote_getaddrinfo, REMOTE_DNS_REVERSE_MAPPING}, }; +pub(crate) mod dns_selector; pub(super) mod hooks; pub(crate) mod ops; @@ -188,16 +190,16 @@ enum ConnectionThrough { Remote(SocketAddr), } -/// Holds the [`OutgoingFilter`]s set up by the user. +/// Holds the [`ProtocolAndAddressFilter`]s set up by the user in the [`OutgoingFilterConfig`]. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub(crate) enum OutgoingSelector { #[default] Unfiltered, /// If the address from `connect` matches this, then we send the connection through the /// remote pod. - Remote(HashSet), + Remote(HashSet), /// If the address from `connect` matches this, then we send the connection from the local app. - Local(HashSet), + Local(HashSet), } impl OutgoingSelector { @@ -205,12 +207,14 @@ impl OutgoingSelector { filters: I, tcp_enabled: bool, udp_enabled: bool, - ) -> HashSet { + ) -> HashSet { filters - .map(|filter| OutgoingFilter::from_str(filter).expect("invalid outgoing filter")) + .map(|filter| { + ProtocolAndAddressFilter::from_str(filter).expect("invalid outgoing filter") + }) .collect::>() .into_iter() - .filter(|OutgoingFilter { protocol, .. }| match protocol { + .filter(|ProtocolAndAddressFilter { protocol, .. }| match protocol { ProtocolFilter::Any => tcp_enabled || udp_enabled, ProtocolFilter::Tcp => tcp_enabled, ProtocolFilter::Udp => udp_enabled, @@ -339,14 +343,14 @@ impl OutgoingSelector { } } -/// [`OutgoingFilter`] extension. -trait OutgoingFilterExt { +/// [`ProtocolAndAddressFilter`] extension. +trait ProtocolAndAddressFilterExt { /// Matches the outgoing connection request (given as [[`SocketAddr`], [`NetProtocol`]] pair) /// against this filter. /// /// # Note on DNS resolution /// - /// This method may require a DNS resolution (when [`OutgoingFilter::address`] is + /// This method may require a DNS resolution (when [`ProtocolAndAddressFilter::address`] is /// [`AddressFilter::Name`]). If remote DNS is disabled or `force_local_dns` /// flag is used, the method uses local resolution [`ToSocketAddrs`]. Otherwise, it uses /// remote resolution [`remote_getaddrinfo`]. @@ -358,7 +362,7 @@ trait OutgoingFilterExt { ) -> HookResult; } -impl OutgoingFilterExt for OutgoingFilter { +impl ProtocolAndAddressFilterExt for ProtocolAndAddressFilter { fn matches( &self, address: SocketAddr, @@ -371,17 +375,13 @@ impl OutgoingFilterExt for OutgoingFilter { return Ok(false); }; - let port = match &self.address { - AddressFilter::Name((_, port)) => *port, - AddressFilter::Socket(addr) => addr.port(), - AddressFilter::Subnet((_, port)) => *port, - }; + let port = self.address.port(); if port != 0 && port != address.port() { return Ok(false); } match &self.address { - AddressFilter::Name((name, port)) => { + AddressFilter::Name(name, port) => { let resolved_ips = if crate::setup().remote_dns_enabled() && !force_local_dns { match remote_getaddrinfo(name.to_string()) { Ok(res) => res.into_iter().map(|(_, ip)| ip).collect(), @@ -420,22 +420,18 @@ impl OutgoingFilterExt for OutgoingFilter { Ok(resolved_ips.into_iter().any(|ip| ip == address.ip())) } - AddressFilter::Socket(addr) - if addr.ip().is_unspecified() || addr.ip() == address.ip() => - { - Ok(true) + AddressFilter::Socket(addr) => { + Ok(addr.ip().is_unspecified() || addr.ip() == address.ip()) } - AddressFilter::Subnet((net, _)) if net.contains(&address.ip()) => Ok(true), - _ => Ok(false), + AddressFilter::Subnet(net, _) => Ok(net.contains(&address.ip())), + AddressFilter::Port(..) => Ok(true), } } } #[inline] fn is_ignored_port(addr: &SocketAddr) -> bool { - let (ip, port) = (addr.ip(), addr.port()); - let ignored_ip = ip == IpAddr::V4(Ipv4Addr::LOCALHOST) || ip == IpAddr::V6(Ipv6Addr::LOCALHOST); - port == 0 || ignored_ip && (port > 50000 && port < 60000) + addr.port() == 0 } /// Fill in the sockaddr structure for the given address. diff --git a/mirrord/layer/src/socket/dns_selector.rs b/mirrord/layer/src/socket/dns_selector.rs new file mode 100644 index 00000000000..e83dbba2d9f --- /dev/null +++ b/mirrord/layer/src/socket/dns_selector.rs @@ -0,0 +1,86 @@ +use std::{net::IpAddr, ops::Deref}; + +use mirrord_config::feature::network::{ + dns::{DnsConfig, DnsFilterConfig}, + filter::AddressFilter, +}; +use tracing::Level; + +use crate::detour::{Bypass, Detour}; + +/// Generated from [`DnsConfig`] provided in the [`LayerConfig`](mirrord_config::LayerConfig). +/// Decides whether DNS queries are done locally or remotely. +#[derive(Debug)] +pub struct DnsSelector { + /// Filters provided in the config. + filters: Vec, + /// Whether a query matching one of [`Self::filters`] should be done locally. + filter_is_local: bool, +} + +impl DnsSelector { + /// Bypasses queries that should be done locally. + #[tracing::instrument(level = Level::DEBUG, ret)] + pub fn check_query(&self, node: &str, port: u16) -> Detour<()> { + let matched = self + .filters + .iter() + .filter(|filter| { + let filter_port = filter.port(); + filter_port == 0 || filter_port == port + }) + .any(|filter| match filter { + AddressFilter::Port(..) => true, + AddressFilter::Name(filter_name, _) => filter_name == node, + AddressFilter::Socket(filter_socket) => { + filter_socket.ip().is_unspecified() + || Some(filter_socket.ip()) == node.parse().ok() + } + AddressFilter::Subnet(filter_subnet, _) => { + let Ok(ip) = node.parse::() else { + return false; + }; + + filter_subnet.contains(&ip) + } + }); + + if matched == self.filter_is_local { + Detour::Bypass(Bypass::LocalDns) + } else { + Detour::Success(()) + } + } +} + +impl From<&DnsConfig> for DnsSelector { + fn from(value: &DnsConfig) -> Self { + if !value.enabled { + return Self { + filters: Default::default(), + filter_is_local: false, + }; + } + + let (filters, filter_is_local) = match &value.filter { + Some(DnsFilterConfig::Local(filters)) => (Some(filters.deref()), true), + Some(DnsFilterConfig::Remote(filters)) => (Some(filters.deref()), false), + None => (None, true), + }; + + let filters = filters + .into_iter() + .flatten() + .map(|filter| { + filter + .parse::() + .expect("bad address filter, should be verified in the CLI") + }) + .collect(); + + Self { + filters, + filter_is_local, + } + } +} diff --git a/mirrord/layer/src/socket/ops.rs b/mirrord/layer/src/socket/ops.rs index 4a4449429b2..f4933a61843 100644 --- a/mirrord/layer/src/socket/ops.rs +++ b/mirrord/layer/src/socket/ops.rs @@ -873,6 +873,7 @@ pub(super) fn getaddrinfo( })? .into(); + // Convert `service` to port let service = rawish_service .map(CStr::to_str) .transpose() @@ -884,7 +885,10 @@ pub(super) fn getaddrinfo( Bypass::CStrConversion })? - .map(String::from); + .and_then(|service| service.parse::().ok()) + .unwrap_or(0); + + crate::setup().dns_selector().check_query(&node, service)?; let raw_hints = raw_hints .cloned() @@ -897,9 +901,6 @@ pub(super) fn getaddrinfo( .. } = raw_hints; - // Convert `service` into a port. - let service = service.map_or(0, |s| s.parse().unwrap_or_default()); - // Some apps (gRPC on Python) use `::` to listen on all interfaces, and usually that just means // resolve on unspecified. So we just return that in IpV4 because we don't support ipv6. let resolved_addr = if node == "::" { @@ -1003,6 +1004,8 @@ pub(super) fn gethostbyname(raw_name: Option<&CStr>) -> Detour<*mut hostent> { })? .into(); + crate::setup().dns_selector().check_query(&name, 0)?; + let hosts_and_ips = remote_getaddrinfo(name.clone())?; // We could `unwrap` here, as this would have failed on the previous conversion. diff --git a/mirrord/layer/src/tls.rs b/mirrord/layer/src/tls.rs new file mode 100644 index 00000000000..6b77c17c82f --- /dev/null +++ b/mirrord/layer/src/tls.rs @@ -0,0 +1,24 @@ +use libc::c_void; +use mirrord_layer_macro::hook_guard_fn; + +use crate::{hooks::HookManager, replace}; + +// https://developer.apple.com/documentation/security/2980705-sectrustevaluatewitherror +#[hook_guard_fn] +pub(crate) unsafe extern "C" fn sec_trust_evaluate_with_error_detour( + trust: *const c_void, + error: *const c_void, +) -> bool { + tracing::trace!("sec_trust_evaluate_with_error_detour called"); + true +} + +pub(crate) unsafe fn enable_tls_hooks(hook_manager: &mut HookManager) { + replace!( + hook_manager, + "SecTrustEvaluateWithError", + sec_trust_evaluate_with_error_detour, + FnSec_trust_evaluate_with_error, + FN_SEC_TRUST_EVALUATE_WITH_ERROR + ); +} diff --git a/mirrord/operator/src/client.rs b/mirrord/operator/src/client.rs index b3346edec59..51dcdfb85ce 100644 --- a/mirrord/operator/src/client.rs +++ b/mirrord/operator/src/client.rs @@ -1,557 +1,708 @@ -use std::{ - fmt::{self, Display}, - io, -}; +use std::fmt; -use base64::{engine::general_purpose, Engine as _}; +use base64::{engine::general_purpose, Engine}; use chrono::{DateTime, Utc}; -use futures::{SinkExt, StreamExt}; -use http::request::Request; +use conn_wrapper::ConnectionWrapper; +use error::{OperatorApiError, OperatorApiResult, OperatorOperation}; +use http::{request::Request, HeaderName, HeaderValue}; use kube::{ api::{ListParams, PostParams}, - Api, Client, Resource, + Api, Client, Config, Resource, }; use mirrord_analytics::{AnalyticsHash, AnalyticsOperatorProperties, Reporter}; use mirrord_auth::{ certificate::Certificate, credential_store::{CredentialStoreSync, UserIdentity}, credentials::LicenseValidity, - error::AuthenticationError, -}; -use mirrord_config::{ - feature::network::incoming::ConcurrentSteal, - target::{Target, TargetConfig}, - LayerConfig, -}; -use mirrord_kube::{ - api::kubernetes::{create_kube_api, get_k8s_resource_api}, - error::KubeApiError, }; +use mirrord_config::{feature::network::incoming::ConcurrentSteal, target::Target, LayerConfig}; +use mirrord_kube::{api::kubernetes::create_kube_config, error::KubeApiError}; use mirrord_progress::Progress; use mirrord_protocol::{ClientMessage, DaemonMessage}; use semver::Version; use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message}; -use tracing::{debug, error, info, warn}; - -use crate::crd::{ - CopyTargetCrd, CopyTargetSpec, MirrordOperatorCrd, OperatorFeatures, SessionCrd, TargetCrd, - OPERATOR_STATUS_NAME, +use tokio::sync::mpsc::{Receiver, Sender}; +use tracing::Level; + +use crate::{ + crd::{ + CopyTargetCrd, CopyTargetSpec, MirrordOperatorCrd, OperatorFeatures, TargetCrd, + OPERATOR_STATUS_NAME, + }, + types::{ + CLIENT_CERT_HEADER, CLIENT_HOSTNAME_HEADER, CLIENT_NAME_HEADER, MIRRORD_CLI_VERSION_HEADER, + SESSION_ID_HEADER, + }, }; -static CONNECTION_CHANNEL_SIZE: usize = 1000; +mod conn_wrapper; +mod discovery; +pub mod error; +mod upgrade; -pub use http::Error as HttpError; +/// State of client's [`Certificate`] the should be attached to some operator requests. +pub trait ClientCertificateState: fmt::Debug {} -/// Operations performed on the operator via [`kube`] API. +/// Represents a [`ClientCertificateState`] where we don't have the certificate. #[derive(Debug)] -pub enum OperatorOperation { - FindingOperator, - FindingTarget, - WebsocketConnection, - CopyingTarget, - GettingStatus, - SessionManagement, - ListingTargets, +pub struct NoClientCert { + /// [`Config::headers`] here contain some extra entries: + /// 1. [`CLIENT_HOSTNAME_HEADER`] (if available) + /// 2. [`CLIENT_NAME_HEADER`] (if available) + /// 3. [`MIRRORD_CLI_VERSION_HEADER`] + /// + /// Can be used to create a certified [`Client`] when the [`Certificate`] is available. + base_config: Config, } -impl Display for OperatorOperation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let as_str = match self { - Self::FindingOperator => "finding operator", - Self::FindingTarget => "finding target", - Self::WebsocketConnection => "creating a websocket connection", - Self::CopyingTarget => "copying target", - Self::GettingStatus => "getting status", - Self::SessionManagement => "session management", - Self::ListingTargets => "listing targets", - }; +impl ClientCertificateState for NoClientCert {} - f.write_str(as_str) - } +/// Represents a [`ClientCertificateState`] where have the certificate. +pub struct PreparedClientCert { + /// Prepared client certificate. + cert: Certificate, } -#[derive(Debug, Error)] -pub enum OperatorApiError { - #[error("failed to build a websocket connect request: {0}")] - ConnectRequestBuildError(HttpError), - - #[error("failed to create mirrord operator API: {0}")] - CreateApiError(KubeApiError), - - #[error("{operation} failed: {error}")] - KubeError { - error: kube::Error, - operation: OperatorOperation, - }, - - #[error("mirrord operator {operator_version} does not support feature {feature}")] - UnsupportedFeature { - feature: String, - operator_version: String, - }, - - #[error("{operation} failed with code {}: {}", status.code, status.reason)] - StatusFailure { - operation: OperatorOperation, - status: Box, - }, - - #[error("mirrord operator license expired")] - NoLicense, +impl fmt::Debug for PreparedClientCert { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PreparedClientCert") + .field("cert_public_key_data", &self.cert.public_key_data()) + .finish() + } } -type Result = std::result::Result; +impl ClientCertificateState for PreparedClientCert {} -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct OperatorSessionMetadata { - client_certificate: Option, - session_id: u64, - fingerprint: Option, - operator_features: Vec, - protocol_version: Option, - copy_pod_enabled: Option, +/// Represents a [`ClientCertificateState`] where we attempted to prepare the certificate and we may +/// have failed. +pub struct MaybeClientCert { + cert_result: Result, } -impl OperatorSessionMetadata { - fn new( - client_certificate: Option, - fingerprint: Option, - operator_features: Vec, - protocol_version: Option, - copy_pod_enabled: Option, - ) -> Self { - Self { - client_certificate, - session_id: rand::random(), - fingerprint, - operator_features, - protocol_version, - copy_pod_enabled, - } - } - - fn client_credentials(&self) -> io::Result> { - self.client_certificate - .as_ref() - .map(|cert| { - cert.encode_der() - .map(|bytes| general_purpose::STANDARD.encode(bytes)) - }) - .transpose() +impl fmt::Debug for MaybeClientCert { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MaybeClientCert") + .field("cert_result", &self.cert_result) + .finish() } +} - fn set_operator_properties(&self, analytics: &mut R) { - let client_hash = self - .client_certificate - .as_ref() - .map(|cert| cert.public_key_data()) - .as_deref() - .map(AnalyticsHash::from_bytes); +impl ClientCertificateState for MaybeClientCert {} + +/// Created operator session. Can be obtained from [`OperatorApi::connect_in_new_session`] and later +/// used in [`OperatorApi::connect_in_existing_session`]. +/// +/// # Note +/// +/// Contains enough information to enable connecting with target without fetching +/// [`MirrordOperatorCrd`] again. +#[derive(Clone, Serialize, Deserialize)] +pub struct OperatorSession { + /// Random session id, generated locally. + id: u64, + /// URL where websocket connection request should be sent. + connect_url: String, + /// Client certificate, should be included as header in the websocket connection request. + client_cert: Certificate, + /// Operator license fingerprint, right now only for setting [`Reporter`] properties. + operator_license_fingerprint: Option, + /// Version of [`mirrord_protocol`] used by the operator. + /// Used to create [`ConnectionWrapper`]. + operator_protocol_version: Option, +} - analytics.set_operator_properties(AnalyticsOperatorProperties { - client_hash, - license_hash: self.fingerprint.as_deref().map(AnalyticsHash::from_base64), - }); +impl fmt::Debug for OperatorSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OperatorSession") + .field("id", &self.id) + .field("connect_url", &self.connect_url) + .field("cert_public_key_data", &self.client_cert.public_key_data()) + .field( + "operator_license_fingerprint", + &self.operator_license_fingerprint, + ) + .field("operator_protocol_version", &self.operator_protocol_version) + .finish() } +} - fn proxy_feature_enabled(&self) -> bool { - self.operator_features.contains(&OperatorFeatures::ProxyApi) +/// Connection to an operator target. +pub struct OperatorSessionConnection { + /// Session of this connection. + pub session: OperatorSession, + /// Used to send [`ClientMessage`]s to the operator. + pub tx: Sender, + /// Used to receive [`DaemonMessage`]s from the operator. + pub rx: Receiver, +} + +impl fmt::Debug for OperatorSessionConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let tx_queued_messages = self.tx.max_capacity() - self.tx.capacity(); + let rx_queued_messages = self.rx.len(); + + f.debug_struct("OperatorSessionConnection") + .field("session", &self.session) + .field("tx_closed", &self.tx.is_closed()) + .field("tx_queued_messages", &tx_queued_messages) + .field("rx_closed", &self.rx.is_closed()) + .field("rx_queued_messages", &rx_queued_messages) + .finish() } } -#[derive(Serialize, Deserialize, Clone, Debug)] -pub enum OperatorSessionTarget { +/// Prepared target of an operator session. +#[derive(Debug)] +enum OperatorSessionTarget { + /// CRD of an immediate target validated and fetched from the operator. Raw(TargetCrd), + /// CRD of a copied target created by the operator. Copied(CopyTargetCrd), } -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct OperatorSessionInformation { - target: OperatorSessionTarget, - metadata: OperatorSessionMetadata, +impl OperatorSessionTarget { + /// Returns a connection url for the given [`OperatorSessionTarget`]. + /// This can be used to create a websocket connection with the operator. + fn connect_url( + &self, + use_proxy: bool, + concurrent_steal: ConcurrentSteal, + ) -> Result { + Ok(match (use_proxy, self) { + (true, OperatorSessionTarget::Raw(crd)) => { + let name = TargetCrd::urlfied_name(crd.spec.target.as_known()?); + let namespace = crd + .meta() + .namespace + .as_deref() + .expect("missing 'TargetCrd' namespace"); + let api_version = TargetCrd::api_version(&()); + let plural = TargetCrd::plural(&()); + + format!("/apis/{api_version}/proxy/namespaces/{namespace}/{plural}/{name}?on_concurrent_steal={concurrent_steal}&connect=true") + } + + (false, OperatorSessionTarget::Raw(crd)) => { + let name = TargetCrd::urlfied_name(crd.spec.target.as_known()?); + let namespace = crd + .meta() + .namespace + .as_deref() + .expect("missing 'TargetCrd' namespace"); + let url_path = TargetCrd::url_path(&(), Some(namespace)); + + format!("{url_path}/{name}?on_concurrent_steal={concurrent_steal}&connect=true") + } + (true, OperatorSessionTarget::Copied(crd)) => { + let name = crd + .meta() + .name + .as_deref() + .expect("missing 'CopyTargetCrd' name"); + let namespace = crd + .meta() + .namespace + .as_deref() + .expect("missing 'CopyTargetCrd' namespace"); + let api_version = CopyTargetCrd::api_version(&()); + let plural = CopyTargetCrd::plural(&()); + + format!( + "/apis/{api_version}/proxy/namespaces/{namespace}/{plural}/{name}?connect=true" + ) + } + (false, OperatorSessionTarget::Copied(crd)) => { + let name = crd + .meta() + .name + .as_deref() + .expect("missing 'CopyTargetCrd' name"); + let namespace = crd + .meta() + .namespace + .as_deref() + .expect("missing 'CopyTargetCrd' namespace"); + let url_path = CopyTargetCrd::url_path(&(), Some(namespace)); + + format!("{url_path}/{name}?connect=true") + } + }) + } } -pub struct OperatorApi { +/// Wrapper over mirrord operator API. +pub struct OperatorApi { + /// For making requests to kubernetes API server. client: Client, - target_api: Api, - copy_target_api: Api, - target_namespace: Option, - target_config: TargetConfig, - on_concurrent_steal: ConcurrentSteal, + /// Prepared client certificate. If present, [`Self::client`] sends [`CLIENT_CERT_HEADER`] with + /// each request. + client_cert: C, + /// Fetched operator resource. + operator: MirrordOperatorCrd, } -/// Connection to existing operator session. -pub struct OperatorSessionConnection { - /// For sending messages to the operator. - pub tx: Sender, - /// For receiving messages from the operator. - pub rx: Receiver, - /// Additional data about the session. - pub info: OperatorSessionInformation, +impl fmt::Debug for OperatorApi +where + C: ClientCertificateState, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OperatorApi") + .field("default_namespace", &self.client.default_namespace()) + .field("client_cert", &self.client_cert) + .field("operator_version", &self.operator.spec.operator_version) + .field( + "operator_protocol_version", + &self.operator.spec.protocol_version, + ) + .field( + "operator_license_fingerprint", + &self.operator.spec.license.fingerprint, + ) + .finish() + } } -/// Allows us to access the operator's [`SessionCrd`] [`Api`]. -pub async fn session_api(config: Option) -> Result> { - let kube_api: Client = create_kube_api(false, config, None) - .await - .map_err(OperatorApiError::CreateApiError)?; +impl OperatorApi { + /// Attempts to fetch the [`MirrordOperatorCrd`] resource and create an instance of this API. + /// In case of error response from the Kubernetes API server, executes an extra API discovery + /// step to confirm that the operator is not installed. + /// + /// If certain that the operator is not installed, returns [`None`]. + #[tracing::instrument(level = Level::TRACE, skip_all, err)] + pub async fn try_new( + config: &LayerConfig, + reporter: &mut R, + ) -> OperatorApiResult> + where + R: Reporter, + { + let base_config = Self::base_client_config(config).await?; + let client = Client::try_from(base_config.clone()) + .map_err(KubeApiError::from) + .map_err(OperatorApiError::CreateKubeClient)?; + + let operator: Result = + Api::all(client.clone()).get(OPERATOR_STATUS_NAME).await; + + let error = match operator { + Ok(operator) => { + reporter.set_operator_properties(AnalyticsOperatorProperties { + client_hash: None, + license_hash: operator + .spec + .license + .fingerprint + .as_deref() + .map(AnalyticsHash::from_base64), + }); + + return Ok(Some(Self { + client, + client_cert: NoClientCert { base_config }, + operator, + })); + } - Ok(Api::all(kube_api)) -} + Err(error @ kube::Error::Api(..)) => { + match discovery::operator_installed(&client).await { + Ok(false) | Err(..) => { + return Ok(None); + } + Ok(true) => error, + } + } -impl OperatorApi { - /// We allow copied pods to live only for 30 seconds before the internal proxy connects. - const COPIED_POD_IDLE_TTL: u32 = 30; + Err(error) => error, + }; - /// Checks used config against operator specification. - fn check_config(config: &LayerConfig, operator: &MirrordOperatorCrd) -> Result<()> { - if config.feature.copy_target.enabled && !operator.spec.copy_target_enabled.unwrap_or(false) - { - return Err(OperatorApiError::UnsupportedFeature { - feature: "copy target".into(), - operator_version: operator.spec.operator_version.clone(), + Err(OperatorApiError::KubeError { + error, + operation: OperatorOperation::FindingOperator, + }) + } + + /// Prepares client [`Certificate`] to be sent in all subsequent requests to the operator. + /// In case of failure, state of this API instance does not change. + #[tracing::instrument(level = Level::TRACE, skip(reporter))] + pub async fn prepare_client_cert(self, reporter: &mut R) -> OperatorApi + where + R: Reporter, + { + let previous_client = self.client.clone(); + + let result = try { + let certificate = self.get_client_certificate().await?; + + reporter.set_operator_properties(AnalyticsOperatorProperties { + client_hash: Some(AnalyticsHash::from_bytes(&certificate.public_key_data())), + license_hash: self + .operator + .spec + .license + .fingerprint + .as_deref() + .map(AnalyticsHash::from_base64), }); + + let header = Self::make_client_cert_header(&certificate)?; + + let mut config = self.client_cert.base_config; + config + .headers + .push((HeaderName::from_static(CLIENT_CERT_HEADER), header)); + let client = Client::try_from(config) + .map_err(KubeApiError::from) + .map_err(OperatorApiError::CreateKubeClient)?; + + (client, certificate) + }; + + match result { + Ok((new_client, cert)) => OperatorApi { + client: new_client, + client_cert: MaybeClientCert { + cert_result: Ok(cert), + }, + operator: self.operator, + }, + + Err(error) => OperatorApi { + client: previous_client, + client_cert: MaybeClientCert { + cert_result: Err(error), + }, + operator: self.operator, + }, } + } +} - Ok(()) +impl OperatorApi { + pub fn inspect_cert_error(&self, f: F) { + if let Err(e) = &self.client_cert.cert_result { + f(e); + } } - #[tracing::instrument(level = "trace", skip(api))] - pub async fn get_client_certificate( - api: &OperatorApi, - operator: &MirrordOperatorCrd, - ) -> Result, AuthenticationError> { - let Some(fingerprint) = operator.spec.license.fingerprint.clone() else { - return Ok(None); - }; + pub fn into_certified(self) -> OperatorApiResult> { + let cert = self.client_cert.cert_result?; - let subscription_id = operator.spec.license.subscription_id.clone(); + Ok(OperatorApi { + client: self.client, + client_cert: PreparedClientCert { cert }, + operator: self.operator, + }) + } +} - let mut credential_store = CredentialStoreSync::open().await?; - credential_store - .get_client_certificate::(&api.client, fingerprint, subscription_id) - .await - .map(Some) +impl OperatorApi +where + C: ClientCertificateState, +{ + /// Lists targets in the given namespace. + #[tracing::instrument(level = Level::TRACE, ret, err)] + pub async fn list_targets(&self, namespace: Option<&str>) -> OperatorApiResult> { + Api::namespaced( + self.client.clone(), + namespace.unwrap_or(self.client.default_namespace()), + ) + .list(&ListParams::default()) + .await + .map_err(|error| OperatorApiError::KubeError { + error, + operation: OperatorOperation::ListingTargets, + }) + .map(|list| list.items) } - /// Creates new [`OperatorSessionConnection`] based on the given [`LayerConfig`]. - /// Keep in mind that some failures here won't stop mirrord from hooking into the process - /// and working, it'll just work without the operator. - /// - /// For a fuller documentation, see the docs in `operator/service/src/main.rs::listen`. - /// - /// - `copy_target`: When this feature is enabled, `target` validation is done in the operator. - #[tracing::instrument(level = "trace", skip_all)] - pub async fn create_session( - config: &LayerConfig, - progress: &P, - analytics: &mut R, - ) -> Result + pub fn check_license_validity

(&self, progress: &P) -> OperatorApiResult<()> where - P: Progress + Send + Sync, + P: Progress, { - let operator_api = OperatorApi::new(config).await?; - - let operator = operator_api.fetch_operator().await?; - - // Warns the user if their license is close to expiring or fallback to OSS if expired let Some(days_until_expiration) = - DateTime::from_naive_date(operator.spec.license.expire_at).days_until_expiration() + DateTime::from_naive_date(self.operator.spec.license.expire_at).days_until_expiration() else { - let no_license_message = "No valid license found for mirrord for Teams, falling back to OSS usage. Visit https://app.metalbear.co to purchase or renew your license."; + let no_license_message = "No valid license found for mirrord for Teams. Visit https://app.metalbear.co to purchase or renew your license"; progress.warning(no_license_message); - warn!(no_license_message); + tracing::warn!(no_license_message); return Err(OperatorApiError::NoLicense); }; let expires_soon = days_until_expiration <= as LicenseValidity>::CLOSE_TO_EXPIRATION_DAYS; - let is_trial = operator.spec.license.name.contains("(Trial)"); + let is_trial = self.operator.spec.license.name.contains("(Trial)"); if is_trial && expires_soon { - let expiring_soon = (days_until_expiration > 0) - .then(|| { - format!( - "soon, in {days_until_expiration} day{}", - if days_until_expiration > 1 { "s" } else { "" } - ) - }) - .unwrap_or_else(|| "today".to_string()); - - let expiring_message = format!("Operator license will expire {expiring_soon}!",); - - progress.warning(&expiring_message); - warn!(expiring_message); + let expiring_soon = if days_until_expiration > 0 { + format!( + "soon, in {days_until_expiration} day{}", + if days_until_expiration > 1 { "s" } else { "" } + ) + } else { + "today".to_string() + }; + let message = format!("Operator license will expire {expiring_soon}!",); + progress.warning(&message); } else if is_trial { - let good_validity_message = + let message = format!("Operator license is valid for {days_until_expiration} more days."); - - progress.info(&good_validity_message); - info!(good_validity_message); + progress.info(&message); } - Self::check_config(config, &operator)?; - - let client_certificate = Self::get_client_certificate(&operator_api, &operator) - .await - .ok() - .flatten(); - let metadata = OperatorSessionMetadata::new( - client_certificate, - operator.spec.license.fingerprint, - operator.spec.features.unwrap_or_default(), - operator - .spec - .protocol_version - .and_then(|str_version| str_version.parse().ok()), - operator.spec.copy_target_enabled, - ); - - metadata.set_operator_properties(analytics); + Ok(()) + } - let mut version_progress = progress.subtask("comparing versions"); - let operator_version = Version::parse(&operator.spec.operator_version) - .expect("failed to parse operator version from operator crd"); // TODO: Remove expect + pub fn check_operator_version

(&self, progress: &P) -> bool + where + P: Progress, + { + match Version::parse(&self.operator.spec.operator_version) { + Ok(operator_version) => { + let mirrord_version = Version::parse(env!("CARGO_PKG_VERSION")).unwrap(); - let mirrord_version = Version::parse(env!("CARGO_PKG_VERSION")).unwrap(); - if operator_version > mirrord_version { - // we make two sub tasks since it looks best this way - version_progress.warning( - &format!( - "Your mirrord plugin/CLI version {} does not match the operator version {}. This can lead to unforeseen issues.", + if operator_version > mirrord_version { + let message = format!( + "mirrord binary version {} does not match the operator version {}. Consider updating your mirrord binary.", mirrord_version, - operator_version)); - version_progress.success(None); - version_progress = progress.subtask("comparing versions"); - version_progress.warning( - "Consider updating your mirrord plugin/CLI to match the operator version.", - ); - } - version_progress.success(None); - - let target_to_connect = if config.feature.copy_target.enabled { - // We do not validate the `target` here, it's up to the operator. - let mut copy_progress = progress.subtask("copying target"); - let copied = operator_api - .copy_target( - &metadata, - config.target.path.clone().unwrap_or(Target::Targetless), - config.feature.copy_target.scale_down, - ) - .await?; - copy_progress.success(None); - - OperatorSessionTarget::Copied(copied) - } else { - let raw_target = operator_api.fetch_target().await?; - OperatorSessionTarget::Raw(raw_target) - }; - - let session_info = OperatorSessionInformation { - target: target_to_connect, - metadata, - }; - let connection = operator_api.connect_target(session_info).await?; + operator_version + ); + progress.warning(&message); + false + } else { + true + } + } - Ok(connection) + Err(error) => { + tracing::debug!(%error, "failed to parse operator version"); + progress.warning("Failed to parse operator version."); + false + } + } } - /// Connects to exisiting operator session based on the given [`LayerConfig`] and - /// [`OperatorSessionInformation`]. - pub async fn connect( - config: &LayerConfig, - session_information: OperatorSessionInformation, - analytics: &mut R, - ) -> Result { - session_information - .metadata - .set_operator_properties(analytics); - - let operator_api = OperatorApi::new(config).await?; - operator_api.connect_target(session_information).await + /// Returns a reference to the operator resource fetched from the cluster. + pub fn operator(&self) -> &MirrordOperatorCrd { + &self.operator } - pub async fn new(config: &LayerConfig) -> Result { - let target_config = config.target.clone(); - let on_concurrent_steal = config.feature.network.incoming.on_concurrent_steal; + /// Returns a reference to the [`Client`] used by this instance. + pub fn client(&self) -> &Client { + &self.client + } - let client = create_kube_api( - config.accept_invalid_certificates, - config.kubeconfig.clone(), - config.kube_context.clone(), + /// Creates a base [`Config`] for creating kube [`Client`]s. + /// Adds extra headers that we send to the operator with each request: + /// 1. [`MIRRORD_CLI_VERSION_HEADER`] + /// 2. [`CLIENT_NAME_HEADER`] + /// 3. [`CLIENT_HOSTNAME_HEADER`] + async fn base_client_config(layer_config: &LayerConfig) -> OperatorApiResult { + let mut client_config = create_kube_config( + layer_config.accept_invalid_certificates, + layer_config.kubeconfig.clone(), + layer_config.kube_context.clone(), ) .await - .map_err(OperatorApiError::CreateApiError)?; + .map_err(KubeApiError::from) + .map_err(OperatorApiError::CreateKubeClient)?; - let target_namespace = if target_config.path.is_some() { - target_config.namespace.clone() - } else { - // When targetless, pass agent namespace to operator so that it knows where to create - // the agent (the operator does not get the agent config). - config.agent.namespace.clone() - }; + client_config.headers.push(( + HeaderName::from_static(MIRRORD_CLI_VERSION_HEADER), + HeaderValue::from_static(env!("CARGO_PKG_VERSION")), + )); - let target_api: Api = get_k8s_resource_api(&client, target_namespace.as_deref()); - let copy_target_api: Api = - get_k8s_resource_api(&client, target_namespace.as_deref()); + let UserIdentity { name, hostname } = UserIdentity::load(); - Ok(OperatorApi { - client, - target_api, - copy_target_api, - target_namespace, - target_config, - on_concurrent_steal, - }) + let headers = [ + (CLIENT_NAME_HEADER, name), + (CLIENT_HOSTNAME_HEADER, hostname), + ]; + for (name, raw_value) in headers { + let Some(raw_value) = raw_value else { + continue; + }; + + // Replace non-ascii (not supported in headers) chars and trim. + let cleaned = raw_value + .replace(|c: char| !c.is_ascii(), "") + .trim() + .to_string(); + let value = HeaderValue::from_str(&cleaned); + match value { + Ok(value) => client_config + .headers + .push((HeaderName::from_static(name), value)), + Err(error) => { + tracing::debug!(%error, %name, raw_value = raw_value, cleaned, "Invalid header value"); + } + } + } + + Ok(client_config) } - #[tracing::instrument(level = "trace", skip(self), ret)] - async fn fetch_operator(&self) -> Result { - let api: Api = Api::all(self.client.clone()); - api.get(OPERATOR_STATUS_NAME) - .await - .map_err(|error| OperatorApiError::KubeError { - error, - operation: OperatorOperation::FindingOperator, - }) + /// If `copy_target` feature is enabled in the given [`LayerConfig`], checks that the operator + /// supports it. + fn check_copy_target_feature_support(&self, config: &LayerConfig) -> OperatorApiResult<()> { + let client_wants_copy = config.feature.copy_target.enabled; + let operator_supports_copy = self.operator.spec.copy_target_enabled.unwrap_or(false); + if client_wants_copy && !operator_supports_copy { + return Err(OperatorApiError::UnsupportedFeature { + feature: "copy target".into(), + operator_version: self.operator.spec.operator_version.clone(), + }); + } + + Ok(()) } - /// See `operator/controller/src/target.rs::TargetProvider::get_resource`. - #[tracing::instrument(level = "trace", fields(self.target_config), skip(self))] - async fn fetch_target(&self) -> Result { - let target_name = TargetCrd::target_name_by_config(&self.target_config); - self.target_api - .get(&target_name) + /// Retrieves client [`Certificate`] from local credential store or requests one from the + /// operator. + #[tracing::instrument(level = Level::TRACE, err)] + async fn get_client_certificate(&self) -> Result { + let Some(fingerprint) = self.operator.spec.license.fingerprint.clone() else { + return Err(OperatorApiError::ClientCertError( + "license fingerprint is missing from the mirrord operator resource".to_string(), + )); + }; + + let subscription_id = self.operator.spec.license.subscription_id.clone(); + + let mut credential_store = CredentialStoreSync::open().await.map_err(|error| { + OperatorApiError::ClientCertError(format!( + "failed to access local credential store: {error}" + )) + })?; + + credential_store + .get_client_certificate::( + &self.client, + fingerprint, + subscription_id, + ) .await - .map_err(|error| OperatorApiError::KubeError { - error, - operation: OperatorOperation::FindingTarget, + .map_err(|error| { + OperatorApiError::ClientCertError(format!( + "failed to get client cerfificate: {error}" + )) }) } - /// Returns a namespace of the target. - fn namespace(&self) -> &str { - self.target_namespace - .as_deref() - .unwrap_or_else(|| self.client.default_namespace()) + /// Transforms the given client [`Certificate`] into a [`HeaderValue`]. + fn make_client_cert_header(certificate: &Certificate) -> Result { + let as_der = certificate.encode_der().map_err(|error| { + OperatorApiError::ClientCertError(format!( + "failed to encode client certificate: {error}" + )) + })?; + let as_base64 = general_purpose::STANDARD.encode(as_der); + HeaderValue::try_from(as_base64) + .map_err(|error| OperatorApiError::ClientCertError(error.to_string())) } - /// Returns a connection url for the given [`OperatorSessionInformation`]. - /// This can be used to create a websocket connection with the operator. - #[tracing::instrument(level = "debug", skip(self), ret)] - fn connect_url(&self, session: &OperatorSessionInformation) -> String { - match (session.metadata.proxy_feature_enabled(), &session.target) { - (true, OperatorSessionTarget::Raw(target)) => { - let dt = &(); - let namespace = self.namespace(); - let api_version = TargetCrd::api_version(dt); - let plural = TargetCrd::plural(dt); - - format!( - "/apis/{api_version}/proxy/namespaces/{namespace}/{plural}/{}?on_concurrent_steal={}&connect=true", - target.name(), - self.on_concurrent_steal, - ) - } - (false, OperatorSessionTarget::Raw(target)) => { - format!( - "{}/{}?on_concurrent_steal={}&connect=true", - self.target_api.resource_url(), - target.name(), - self.on_concurrent_steal, - ) - } - (true, OperatorSessionTarget::Copied(target)) => { - let dt = &(); - let namespace = self.namespace(); - let api_version = CopyTargetCrd::api_version(dt); - let plural = CopyTargetCrd::plural(dt); + /// Returns a namespace of the target based on the given [`LayerConfig`] and default namespace + /// of [`Client`] used by this instance. + fn target_namespace<'a>(&'a self, config: &'a LayerConfig) -> &'a str { + let namespace_opt = if config.target.path.is_some() { + // Not a targetless run, we use target's namespace. + config.target.namespace.as_deref() + } else { + // A targetless run, we use the namespace where the agent should live. + config.agent.namespace.as_deref() + }; - format!( - "/apis/{api_version}/proxy/namespaces/{namespace}/{plural}/{}?connect=true", - target - .meta() - .name - .as_ref() - .expect("missing 'copytarget' name"), - ) - } - (false, OperatorSessionTarget::Copied(target)) => { - format!( - "{}/{}?connect=true", - self.copy_target_api.resource_url(), - target - .meta() - .name - .as_ref() - .expect("missing 'copytarget' name"), - ) - } - } + namespace_opt.unwrap_or(self.client.default_namespace()) } +} - /// Create websocket connection to operator. - #[tracing::instrument(level = "trace", skip(self))] - async fn connect_target( +impl OperatorApi { + /// We allow copied pods to live only for 30 seconds before the internal proxy connects. + const COPIED_POD_IDLE_TTL: u32 = 30; + + /// Starts a new operator session and connects to the target. + /// Returned [`OperatorSessionConnection::session`] can be later used to create another + /// connection in the same session with [`OperatorApi::connect_in_existing_session`]. + #[tracing::instrument( + level = Level::TRACE, + skip(config, progress), + fields( + target_config = ?config.target, + copy_target_config = ?config.feature.copy_target, + on_concurrent_steal = ?config.feature.network.incoming.on_concurrent_steal, + ), + ret, + err + )] + pub async fn connect_in_new_session

( &self, - session_info: OperatorSessionInformation, - ) -> Result { - let UserIdentity { name, hostname } = UserIdentity::load(); + config: &LayerConfig, + progress: &P, + ) -> OperatorApiResult + where + P: Progress, + { + self.check_copy_target_feature_support(config)?; - let request = { - let mut builder = Request::builder() - .uri(self.connect_url(&session_info)) - .header("x-session-id", session_info.metadata.session_id.to_string()); - - // Replace non-ascii (not supported in headers) chars and trim headers. - if let Some(name) = name { - builder = builder.header( - "x-client-name", - name.replace(|c: char| !c.is_ascii(), "").trim(), - ); - }; + let target = if config.feature.copy_target.enabled { + let mut copy_subtask = progress.subtask("copying target"); - if let Some(hostname) = hostname { - builder = builder.header( - "x-client-hostname", - hostname.replace(|c: char| !c.is_ascii(), "").trim(), - ); - }; + // We do not validate the `target` here, it's up to the operator. + let target = config.target.path.clone().unwrap_or(Target::Targetless); + let scale_down = config.feature.copy_target.scale_down; + let namespace = self.target_namespace(config); + let copied = self.copy_target(target, scale_down, namespace).await?; - match session_info.metadata.client_credentials() { - Ok(Some(credentials)) => { - builder = builder.header("x-client-der", credentials); - } - Ok(None) => {} - Err(err) => { - debug!("CredentialStore error: {err}"); - } - } + copy_subtask.success(Some("target copied")); - builder - .body(vec![]) - .map_err(OperatorApiError::ConnectRequestBuildError)? - }; + OperatorSessionTarget::Copied(copied) + } else { + let mut fetch_subtask = progress.subtask("fetching target"); - let connection = upgrade::connect_ws(&self.client, request) - .await - .map_err(|error| OperatorApiError::KubeError { - error, - operation: OperatorOperation::WebsocketConnection, - })?; + let target_name = + TargetCrd::urlfied_name(config.target.path.as_ref().unwrap_or(&Target::Targetless)); + let raw_target = Api::namespaced(self.client.clone(), self.target_namespace(config)) + .get(&target_name) + .await + .map_err(|error| OperatorApiError::KubeError { + error, + operation: OperatorOperation::FindingTarget, + })?; - let (tx, rx) = - ConnectionWrapper::wrap(connection, session_info.metadata.protocol_version.clone()); + fetch_subtask.success(Some("target fetched")); - Ok(OperatorSessionConnection { - tx, - rx, - info: session_info, - }) + OperatorSessionTarget::Raw(raw_target) + }; + let use_proxy_api = self + .operator + .spec + .features + .as_ref() + .map(|features| features.contains(&OperatorFeatures::ProxyApi)) + .unwrap_or(false); + let connect_url = target.connect_url( + use_proxy_api, + config.feature.network.incoming.on_concurrent_steal, + )?; + + let session = OperatorSession { + id: rand::random(), + connect_url, + client_cert: self.client_cert.cert.clone(), + operator_license_fingerprint: self.operator.spec.license.fingerprint.clone(), + operator_protocol_version: self + .operator + .spec + .protocol_version + .as_ref() + .and_then(|version| version.parse().ok()), + }; + + let mut connection_subtask = progress.subtask("connecting to the target"); + let (tx, rx) = Self::connect_target(&self.client, &session).await?; + connection_subtask.success(Some("connected to the target")); + + Ok(OperatorSessionConnection { session, tx, rx }) } /// Creates a new [`CopyTargetCrd`] resource using the operator. @@ -561,14 +712,14 @@ impl OperatorApi { /// /// `copy_target` feature is not available for all target types. /// Target type compatibility is checked by the operator. - #[tracing::instrument(level = "trace", skip(self))] + #[tracing::instrument(level = "trace", err)] async fn copy_target( &self, - session_metadata: &OperatorSessionMetadata, target: Target, scale_down: bool, - ) -> Result { - let name = TargetCrd::target_name(&target); + namespace: &str, + ) -> OperatorApiResult { + let name = TargetCrd::urlfied_name(&target); let requested = CopyTargetCrd::new( &name, @@ -579,7 +730,7 @@ impl OperatorApi { }, ); - self.copy_target_api + Api::namespaced(self.client.clone(), namespace) .create(&PostParams::default(), &requested) .await .map_err(|error| OperatorApiError::KubeError { @@ -588,303 +739,63 @@ impl OperatorApi { }) } - /// List targets using the operator - #[tracing::instrument(level = "trace", ret)] - pub async fn list_targets(config: &LayerConfig) -> Result> { - let client = create_kube_api( - config.accept_invalid_certificates, - config.kubeconfig.clone(), - config.kube_context.clone(), - ) - .await - .map_err(OperatorApiError::CreateApiError)?; - - let target_api: Api = - get_k8s_resource_api(&client, config.target.namespace.as_deref()); - target_api - .list(&ListParams::default()) - .await - .map_err(|error| OperatorApiError::KubeError { - error, - operation: OperatorOperation::ListingTargets, - }) - .map(|list| list.items) - } -} - -#[derive(Error, Debug)] -enum ConnectionWrapperError { - #[error(transparent)] - DecodeError(#[from] bincode::error::DecodeError), - #[error(transparent)] - EncodeError(#[from] bincode::error::EncodeError), - #[error(transparent)] - WsError(#[from] TungsteniteError), - #[error("invalid message: {0:?}")] - InvalidMessage(Message), - #[error("message channel is closed")] - ChannelClosed, -} - -pub struct ConnectionWrapper { - connection: T, - client_rx: Receiver, - daemon_tx: Sender, - protocol_version: Option, -} - -impl ConnectionWrapper -where - for<'stream> T: StreamExt> - + SinkExt - + Send - + Unpin - + 'stream, -{ - fn wrap( - connection: T, - protocol_version: Option, - ) -> (Sender, Receiver) { - let (client_tx, client_rx) = mpsc::channel(CONNECTION_CHANNEL_SIZE); - let (daemon_tx, daemon_rx) = mpsc::channel(CONNECTION_CHANNEL_SIZE); - - let connection_wrapper = ConnectionWrapper { - protocol_version, - connection, - client_rx, - daemon_tx, - }; - - tokio::spawn(async move { - if let Err(err) = connection_wrapper.start().await { - error!("{err:?}") - } + /// Connects to the target, reusing the given [`OperatorSession`]. + #[tracing::instrument(level = Level::TRACE, skip(layer_config, reporter), ret, err)] + pub async fn connect_in_existing_session( + layer_config: &LayerConfig, + session: OperatorSession, + reporter: &mut R, + ) -> OperatorApiResult + where + R: Reporter, + { + reporter.set_operator_properties(AnalyticsOperatorProperties { + client_hash: Some(AnalyticsHash::from_bytes( + session.client_cert.public_key_data().as_ref(), + )), + license_hash: session + .operator_license_fingerprint + .as_ref() + .map(|fingerprint| AnalyticsHash::from_base64(fingerprint)), }); - (client_tx, daemon_rx) - } - - async fn handle_client_message( - &mut self, - client_message: ClientMessage, - ) -> Result<(), ConnectionWrapperError> { - let payload = bincode::encode_to_vec(client_message, bincode::config::standard())?; - - self.connection.send(payload.into()).await?; - - Ok(()) - } - - async fn handle_daemon_message( - &mut self, - daemon_message: Result, - ) -> Result<(), ConnectionWrapperError> { - match daemon_message? { - Message::Binary(payload) => { - let (daemon_message, _) = bincode::decode_from_slice::( - &payload, - bincode::config::standard(), - )?; - - self.daemon_tx - .send(daemon_message) - .await - .map_err(|_| ConnectionWrapperError::ChannelClosed) - } - message => Err(ConnectionWrapperError::InvalidMessage(message)), - } - } - - async fn start(mut self) -> Result<(), ConnectionWrapperError> { - loop { - tokio::select! { - client_message = self.client_rx.recv() => { - match client_message { - Some(ClientMessage::SwitchProtocolVersion(version)) => { - if let Some(operator_protocol_version) = self.protocol_version.as_ref() { - self.handle_client_message(ClientMessage::SwitchProtocolVersion(operator_protocol_version.min(&version).clone())).await?; - } else { - self.daemon_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - "1.2.1".parse().expect("Bad static version"), - )) - .await - .map_err(|_| ConnectionWrapperError::ChannelClosed)?; - } - } - Some(client_message) => self.handle_client_message(client_message).await?, - None => break, - } - } - daemon_message = self.connection.next() => { - match daemon_message { - Some(daemon_message) => self.handle_daemon_message(daemon_message).await?, - None => break, - } - } - } - } - - let _ = self.connection.send(Message::Close(None)).await; - - Ok(()) - } -} - -mod upgrade { - //! Code copied from [`kube::client`] and adjusted. - //! - //! Just like original [`Client::connect`] function, [`connect_ws`] creates a - //! WebSockets connection. However, original function swallows - //! [`ErrorResponse`] sent by the operator and returns flat - //! [`UpgradeConnectionError`]. [`connect_ws`] attempts to - //! recover the [`ErrorResponse`] - if operator response code is not - //! [`StatusCode::SWITCHING_PROTOCOLS`], it tries to read - //! response body and deserialize it. - - use base64::Engine; - use http::{HeaderValue, Request, Response, StatusCode}; - use http_body_util::BodyExt; - use hyper_util::rt::TokioIo; - use kube::{ - client::{Body, UpgradeConnectionError}, - core::ErrorResponse, - Client, Error, Result, - }; - use tokio_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; - - const WS_PROTOCOL: &str = "v4.channel.k8s.io"; - - // Verify upgrade response according to RFC6455. - // Based on `tungstenite` and added subprotocol verification. - async fn verify_response(res: Response, key: &HeaderValue) -> Result> { - let status = res.status(); - - if status != StatusCode::SWITCHING_PROTOCOLS { - if status.is_client_error() || status.is_server_error() { - let error_response = res - .into_body() - .collect() - .await - .ok() - .map(|body| body.to_bytes()) - .and_then(|body_bytes| { - serde_json::from_slice::(&body_bytes).ok() - }); - - if let Some(error_response) = error_response { - return Err(Error::Api(error_response)); - } - } - - return Err(Error::UpgradeConnection( - UpgradeConnectionError::ProtocolSwitch(status), - )); - } - - let headers = res.headers(); - if !headers - .get(http::header::UPGRADE) - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("websocket")) - .unwrap_or(false) - { - return Err(Error::UpgradeConnection( - UpgradeConnectionError::MissingUpgradeWebSocketHeader, - )); - } + let mut config = Self::base_client_config(layer_config).await?; + let cert_header = Self::make_client_cert_header(&session.client_cert)?; + config + .headers + .push((HeaderName::from_static(CLIENT_CERT_HEADER), cert_header)); - if !headers - .get(http::header::CONNECTION) - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("Upgrade")) - .unwrap_or(false) - { - return Err(Error::UpgradeConnection( - UpgradeConnectionError::MissingConnectionUpgradeHeader, - )); - } + let client = Client::try_from(config) + .map_err(KubeApiError::from) + .map_err(OperatorApiError::CreateKubeClient)?; - let accept_key = tokio_tungstenite::tungstenite::handshake::derive_accept_key(key.as_ref()); - if !headers - .get(http::header::SEC_WEBSOCKET_ACCEPT) - .map(|h| h == &accept_key) - .unwrap_or(false) - { - return Err(Error::UpgradeConnection( - UpgradeConnectionError::SecWebSocketAcceptKeyMismatch, - )); - } + let (tx, rx) = Self::connect_target(&client, &session).await?; - // Make sure that the server returned the correct subprotocol. - if !headers - .get(http::header::SEC_WEBSOCKET_PROTOCOL) - .map(|h| h == WS_PROTOCOL) - .unwrap_or(false) - { - return Err(Error::UpgradeConnection( - UpgradeConnectionError::SecWebSocketProtocolMismatch, - )); - } - - Ok(res) + Ok(OperatorSessionConnection { tx, rx, session }) } - /// Generate a random key for the `Sec-WebSocket-Key` header. - /// This must be nonce consisting of a randomly selected 16-byte value in base64. - fn sec_websocket_key() -> HeaderValue { - let random: [u8; 16] = rand::random(); - base64::engine::general_purpose::STANDARD - .encode(random) - .parse() - .expect("should be valid") - } - - pub async fn connect_ws( + /// Creates websocket connection to the operator target. + #[tracing::instrument(level = Level::TRACE, skip(client), err)] + async fn connect_target( client: &Client, - request: Request>, - ) -> kube::Result>> { - let (mut parts, body) = request.into_parts(); - parts.headers.insert( - http::header::CONNECTION, - HeaderValue::from_static("Upgrade"), - ); - parts - .headers - .insert(http::header::UPGRADE, HeaderValue::from_static("websocket")); - parts.headers.insert( - http::header::SEC_WEBSOCKET_VERSION, - HeaderValue::from_static("13"), - ); - let key = sec_websocket_key(); - parts - .headers - .insert(http::header::SEC_WEBSOCKET_KEY, key.clone()); - // Use the binary subprotocol v4, to get JSON `Status` object in `error` channel (3). - // There's no official documentation about this protocol, but it's described in - // [`k8s.io/apiserver/pkg/util/wsstream/conn.go`](https://git.io/JLQED). - // There's a comment about v4 and `Status` object in - // [`kublet/cri/streaming/remotecommand/httpstream.go`](https://git.io/JLQEh). - parts.headers.insert( - http::header::SEC_WEBSOCKET_PROTOCOL, - HeaderValue::from_static(WS_PROTOCOL), - ); - - let res = client - .send(Request::from_parts(parts, Body::from(body))) - .await?; - let res = verify_response(res, &key).await?; - match hyper::upgrade::on(res).await { - Ok(upgraded) => { - Ok( - WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Client, None) - .await, - ) - } + session: &OperatorSession, + ) -> OperatorApiResult<(Sender, Receiver)> { + let request = Request::builder() + .uri(&session.connect_url) + .header(SESSION_ID_HEADER, session.id.to_string()) + .body(vec![]) + .map_err(OperatorApiError::ConnectRequestBuildError)?; + + let connection = upgrade::connect_ws(client, request) + .await + .map_err(|error| OperatorApiError::KubeError { + error, + operation: OperatorOperation::WebsocketConnection, + })?; - Err(e) => Err(Error::UpgradeConnection( - UpgradeConnectionError::GetPendingUpgrade(e), - )), - } + Ok(ConnectionWrapper::wrap( + connection, + session.operator_protocol_version.clone(), + )) } } diff --git a/mirrord/operator/src/client/conn_wrapper.rs b/mirrord/operator/src/client/conn_wrapper.rs new file mode 100644 index 00000000000..cd3208c5b1b --- /dev/null +++ b/mirrord/operator/src/client/conn_wrapper.rs @@ -0,0 +1,126 @@ +use futures::{Sink, SinkExt, Stream, StreamExt}; +use mirrord_protocol::{ClientMessage, DaemonMessage}; +use thiserror::Error; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio_tungstenite::tungstenite::{self, Message}; + +#[derive(Error, Debug)] +enum ConnectionWrapperError { + #[error(transparent)] + DecodeError(#[from] bincode::error::DecodeError), + #[error(transparent)] + EncodeError(#[from] bincode::error::EncodeError), + #[error(transparent)] + WsError(#[from] tungstenite::Error), + #[error("invalid message: {0:?}")] + InvalidMessage(Message), + #[error("message channel is closed")] + ChannelClosed, +} + +pub struct ConnectionWrapper { + connection: T, + client_rx: Receiver, + daemon_tx: Sender, + protocol_version: Option, +} + +impl ConnectionWrapper +where + for<'stream> T: Stream> + + Sink + + Send + + Unpin + + 'stream, +{ + const CONNECTION_CHANNEL_SIZE: usize = 1000; + + pub fn wrap( + connection: T, + protocol_version: Option, + ) -> (Sender, Receiver) { + let (client_tx, client_rx) = mpsc::channel(Self::CONNECTION_CHANNEL_SIZE); + let (daemon_tx, daemon_rx) = mpsc::channel(Self::CONNECTION_CHANNEL_SIZE); + + let connection_wrapper = ConnectionWrapper { + protocol_version, + connection, + client_rx, + daemon_tx, + }; + + tokio::spawn(async move { + match connection_wrapper.start().await { + Ok(()) | Err(ConnectionWrapperError::ChannelClosed) => {} + Err(error) => tracing::error!(%error, "Operator connection failed"), + } + }); + + (client_tx, daemon_rx) + } + + async fn handle_client_message( + &mut self, + client_message: ClientMessage, + ) -> Result<(), ConnectionWrapperError> { + let payload = bincode::encode_to_vec(client_message, bincode::config::standard())?; + + self.connection.send(payload.into()).await?; + + Ok(()) + } + + async fn handle_daemon_message( + &mut self, + daemon_message: Result, + ) -> Result<(), ConnectionWrapperError> { + match daemon_message? { + Message::Binary(payload) => { + let (daemon_message, _) = bincode::decode_from_slice::( + &payload, + bincode::config::standard(), + )?; + + self.daemon_tx + .send(daemon_message) + .await + .map_err(|_| ConnectionWrapperError::ChannelClosed) + } + message => Err(ConnectionWrapperError::InvalidMessage(message)), + } + } + + async fn start(mut self) -> Result<(), ConnectionWrapperError> { + loop { + tokio::select! { + client_message = self.client_rx.recv() => { + match client_message { + Some(ClientMessage::SwitchProtocolVersion(version)) => { + if let Some(operator_protocol_version) = self.protocol_version.as_ref() { + self.handle_client_message(ClientMessage::SwitchProtocolVersion(operator_protocol_version.min(&version).clone())).await?; + } else { + self.daemon_tx + .send(DaemonMessage::SwitchProtocolVersionResponse( + "1.2.1".parse().expect("Bad static version"), + )) + .await + .map_err(|_| ConnectionWrapperError::ChannelClosed)?; + } + } + Some(client_message) => self.handle_client_message(client_message).await?, + None => break, + } + } + + daemon_message = self.connection.next() => match daemon_message { + Some(daemon_message) => self.handle_daemon_message(daemon_message).await?, + None => break, + }, + } + } + + let _ = self.connection.send(Message::Close(None)).await; + + Ok(()) + } +} diff --git a/mirrord/operator/src/client/discovery.rs b/mirrord/operator/src/client/discovery.rs new file mode 100644 index 00000000000..374d93ac984 --- /dev/null +++ b/mirrord/operator/src/client/discovery.rs @@ -0,0 +1,18 @@ +use kube::{api::GroupVersionKind, discovery, Client, Resource}; + +use crate::crd::MirrordOperatorCrd; + +#[tracing::instrument(level = "trace", skip_all, ret, err)] +pub async fn operator_installed(client: &Client) -> kube::Result { + let gvk = GroupVersionKind { + group: MirrordOperatorCrd::group(&()).into_owned(), + version: MirrordOperatorCrd::version(&()).into_owned(), + kind: MirrordOperatorCrd::kind(&()).into_owned(), + }; + + match discovery::oneshot::pinned_kind(client, &gvk).await { + Ok(..) => Ok(true), + Err(kube::Error::Api(response)) if response.code == 404 => Ok(false), + Err(error) => Err(error), + } +} diff --git a/mirrord/operator/src/client/error.rs b/mirrord/operator/src/client/error.rs new file mode 100644 index 00000000000..83e46eeebaf --- /dev/null +++ b/mirrord/operator/src/client/error.rs @@ -0,0 +1,73 @@ +use std::fmt; + +pub use http::Error as HttpError; +use mirrord_kube::error::KubeApiError; +use thiserror::Error; + +use crate::crd::kube_target::UnknownTargetType; + +/// Operations performed on the operator via [`kube`] API. +#[derive(Debug)] +pub enum OperatorOperation { + FindingOperator, + FindingTarget, + WebsocketConnection, + CopyingTarget, + GettingStatus, + SessionManagement, + ListingTargets, +} + +impl fmt::Display for OperatorOperation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let as_str = match self { + Self::FindingOperator => "finding operator", + Self::FindingTarget => "finding target", + Self::WebsocketConnection => "creating a websocket connection", + Self::CopyingTarget => "copying target", + Self::GettingStatus => "getting status", + Self::SessionManagement => "session management", + Self::ListingTargets => "listing targets", + }; + + f.write_str(as_str) + } +} + +#[derive(Debug, Error)] +pub enum OperatorApiError { + #[error("failed to build a websocket connect request: {0}")] + ConnectRequestBuildError(HttpError), + + #[error("failed to create Kubernetes client: {0}")] + CreateKubeClient(KubeApiError), + + #[error("{operation} failed: {error}")] + KubeError { + error: kube::Error, + operation: OperatorOperation, + }, + + #[error("mirrord operator {operator_version} does not support feature {feature}")] + UnsupportedFeature { + feature: String, + operator_version: String, + }, + + #[error("{operation} failed with code {}: {}", status.code, status.reason)] + StatusFailure { + operation: OperatorOperation, + status: Box, + }, + + #[error("mirrord operator license expired")] + NoLicense, + + #[error("failed to prepare client certificate: {0}")] + ClientCertError(String), + + #[error("mirrord operator returned a target of unknown type: {}", .0 .0)] + FetchedUnknownTargetType(#[from] UnknownTargetType), +} + +pub type OperatorApiResult = Result; diff --git a/mirrord/operator/src/client/upgrade.rs b/mirrord/operator/src/client/upgrade.rs new file mode 100644 index 00000000000..b75d71f12c2 --- /dev/null +++ b/mirrord/operator/src/client/upgrade.rs @@ -0,0 +1,150 @@ +//! Code copied from [`kube::client`] and adjusted. +//! +//! Just like original [`Client::connect`] function, [`connect_ws`] creates a +//! WebSockets connection. However, original function swallows +//! [`ErrorResponse`] sent by the operator and returns flat +//! [`UpgradeConnectionError`]. [`connect_ws`] attempts to +//! recover the [`ErrorResponse`] - if operator response code is not +//! [`StatusCode::SWITCHING_PROTOCOLS`], it tries to read +//! response body and deserialize it. + +use base64::Engine; +use http::{HeaderValue, Request, Response, StatusCode}; +use http_body_util::BodyExt; +use hyper_util::rt::TokioIo; +use kube::{ + client::{Body, UpgradeConnectionError}, + core::ErrorResponse, + Client, Error, Result, +}; +use tokio_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; + +const WS_PROTOCOL: &str = "v4.channel.k8s.io"; + +// Verify upgrade response according to RFC6455. +// Based on `tungstenite` and added subprotocol verification. +async fn verify_response(res: Response, key: &HeaderValue) -> Result> { + let status = res.status(); + + if status != StatusCode::SWITCHING_PROTOCOLS { + if status.is_client_error() || status.is_server_error() { + let error_response = res + .into_body() + .collect() + .await + .ok() + .map(|body| body.to_bytes()) + .and_then(|body_bytes| serde_json::from_slice::(&body_bytes).ok()); + + if let Some(error_response) = error_response { + return Err(Error::Api(error_response)); + } + } + + return Err(Error::UpgradeConnection( + UpgradeConnectionError::ProtocolSwitch(status), + )); + } + + let headers = res.headers(); + if !headers + .get(http::header::UPGRADE) + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + { + return Err(Error::UpgradeConnection( + UpgradeConnectionError::MissingUpgradeWebSocketHeader, + )); + } + + if !headers + .get(http::header::CONNECTION) + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) + { + return Err(Error::UpgradeConnection( + UpgradeConnectionError::MissingConnectionUpgradeHeader, + )); + } + + let accept_key = tokio_tungstenite::tungstenite::handshake::derive_accept_key(key.as_ref()); + if !headers + .get(http::header::SEC_WEBSOCKET_ACCEPT) + .map(|h| h == &accept_key) + .unwrap_or(false) + { + return Err(Error::UpgradeConnection( + UpgradeConnectionError::SecWebSocketAcceptKeyMismatch, + )); + } + + // Make sure that the server returned the correct subprotocol. + if !headers + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .map(|h| h == WS_PROTOCOL) + .unwrap_or(false) + { + return Err(Error::UpgradeConnection( + UpgradeConnectionError::SecWebSocketProtocolMismatch, + )); + } + + Ok(res) +} + +/// Generate a random key for the `Sec-WebSocket-Key` header. +/// This must be nonce consisting of a randomly selected 16-byte value in base64. +fn sec_websocket_key() -> HeaderValue { + let random: [u8; 16] = rand::random(); + base64::engine::general_purpose::STANDARD + .encode(random) + .parse() + .expect("should be valid") +} + +pub async fn connect_ws( + client: &Client, + request: Request>, +) -> kube::Result>> { + let (mut parts, body) = request.into_parts(); + parts.headers.insert( + http::header::CONNECTION, + HeaderValue::from_static("Upgrade"), + ); + parts + .headers + .insert(http::header::UPGRADE, HeaderValue::from_static("websocket")); + parts.headers.insert( + http::header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + ); + let key = sec_websocket_key(); + parts + .headers + .insert(http::header::SEC_WEBSOCKET_KEY, key.clone()); + // Use the binary subprotocol v4, to get JSON `Status` object in `error` channel (3). + // There's no official documentation about this protocol, but it's described in + // [`k8s.io/apiserver/pkg/util/wsstream/conn.go`](https://git.io/JLQED). + // There's a comment about v4 and `Status` object in + // [`kublet/cri/streaming/remotecommand/httpstream.go`](https://git.io/JLQEh). + parts.headers.insert( + http::header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static(WS_PROTOCOL), + ); + + let res = client + .send(Request::from_parts(parts, Body::from(body))) + .await?; + let res = verify_response(res, &key).await?; + match hyper::upgrade::on(res).await { + Ok(upgraded) => { + Ok(WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Client, None).await) + } + + Err(e) => Err(Error::UpgradeConnection( + UpgradeConnectionError::GetPendingUpgrade(e), + )), + } +} diff --git a/mirrord/operator/src/crd.rs b/mirrord/operator/src/crd.rs index 86984ae2b8a..b61fb5a0e38 100644 --- a/mirrord/operator/src/crd.rs +++ b/mirrord/operator/src/crd.rs @@ -1,4 +1,5 @@ use kube::CustomResource; +use kube_target::{KubeTarget, UnknownTargetType}; use mirrord_config::target::{Target, TargetConfig}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -6,6 +7,7 @@ use serde::{Deserialize, Serialize}; use self::label_selector::LabelSelector; use crate::types::LicenseInfoOwned; +pub mod kube_target; pub mod label_selector; pub const TARGETLESS_TARGET_NAME: &str = "targetless"; @@ -19,16 +21,20 @@ pub const TARGETLESS_TARGET_NAME: &str = "targetless"; namespaced )] pub struct TargetSpec { - /// None when targetless. - pub target: Option, + /// The kubernetes resource to target. + pub target: KubeTarget, } impl TargetCrd { - /// Creates target name in format of target_type.target_name.[container.container_name] + /// Creates a target name in format of `target_type.target_name.[container.container_name]` /// for example: - /// deploy.nginx - /// deploy.nginx.container.nginx - pub fn target_name(target: &Target) -> String { + /// + /// - `DeploymentTarget { deployment: "nginx", container: None }` -> `deploy.nginx`; + /// - `DeploymentTarget { deployment: "nginx", container: Some("pyrex") }` -> + /// `deploy.nginx.container.pyrex`; + /// + /// It's used to connect to a resource through the operator. + pub fn urlfied_name(target: &Target) -> String { let (type_name, target, container) = match target { Target::Deployment(target) => ("deploy", &target.deployment, &target.container), Target::Pod(target) => ("pod", &target.pod, &target.container), @@ -38,6 +44,7 @@ impl TargetCrd { Target::StatefulSet(target) => ("statefulset", &target.stateful_set, &target.container), Target::Targetless => return TARGETLESS_TARGET_NAME.to_string(), }; + if let Some(container) = container { format!("{}.{}.container.{}", type_name, target, container) } else { @@ -51,33 +58,21 @@ impl TargetCrd { target_config .path .as_ref() - .map_or_else(|| TARGETLESS_TARGET_NAME.to_string(), Self::target_name) - } - - pub fn name(&self) -> String { - self.spec - .target - .as_ref() - .map(Self::target_name) - .unwrap_or(TARGETLESS_TARGET_NAME.to_string()) + .map_or_else(|| TARGETLESS_TARGET_NAME.to_string(), Self::urlfied_name) } } -impl From for TargetConfig { - fn from(crd: TargetCrd) -> Self { - TargetConfig { - path: crd.spec.target, +impl TryFrom for TargetConfig { + type Error = UnknownTargetType; + + fn try_from(crd: TargetCrd) -> Result { + Ok(TargetConfig { + path: Some(Target::try_from(crd.spec.target)?), namespace: crd.metadata.namespace, - } + }) } } -#[derive(Clone, Debug, Default, Deserialize, Serialize, JsonSchema)] -pub struct TargetPortLock { - pub target_hash: String, - pub port: u16, -} - pub static OPERATOR_STATUS_NAME: &str = "operator"; #[derive(CustomResource, Clone, Debug, Deserialize, Serialize, JsonSchema)] diff --git a/mirrord/operator/src/crd/kube_target.rs b/mirrord/operator/src/crd/kube_target.rs new file mode 100644 index 00000000000..8d9c88b841c --- /dev/null +++ b/mirrord/operator/src/crd/kube_target.rs @@ -0,0 +1,168 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::Target; + +#[derive(Error, Debug)] +#[error("unknown target type: {0}")] +pub struct UnknownTargetType(pub String); + +/// Holds either a kubernetes target that we know about, (de)serializing it into a +/// [`Target`], or a target we do not know about. +/// +/// It's main purpose is to provide forward compatibility with targets between the operator +/// and mirrord, so when we add new targets over there, they can reported back through +/// `mirrord ls` (or other ways of listing targets). +/// +/// You should avoid passing this type around, instead try to get the `Known` variant +/// out, and potentially throw an error if it's an `Unknown` target. If you feel compelled +/// to write methods for this type, think again, you probaly don't want to do that. +/// +/// ## Why not an `Option` +/// +/// Due to how we used to treat a `None` `Option` as meaning [`Target::Targetless`], +/// we can't just change it to `None` meaning _unknown_, so this type is basically acting +/// as a custom `Option` for this purpose. +/// +/// ## `serde` implementation +/// +/// [`Deserialize`] is _manually-ish_ implemented to handle the `Unknown` variant. +/// +/// [`Deserialize`] happens in two steps: +/// 1. deserialize the type as a [`serde_json::Value`], where an error here means an +/// an actual deserialization issue; +/// 2. convert the [`serde_json::Value`] into a [`Target`], turning an error into +/// [`KubeTarget::Unknown`]. +#[derive(Serialize, Clone, Debug, JsonSchema)] +#[serde(untagged)] +pub enum KubeTarget { + /// A target that we know of in both mirrord and the operator. + #[serde(serialize_with = "Target::serialize")] + Known(Target), + + /// A target that has been added in the operator, but the current version of mirrord + /// doesn't know about. + /// + /// Should be ignored in most cases. + #[serde(skip_serializing)] + Unknown(String), +} + +impl KubeTarget { + pub fn as_known(&self) -> Result<&Target, UnknownTargetType> { + match self { + KubeTarget::Known(target) => Ok(target), + KubeTarget::Unknown(unknown) => Err(UnknownTargetType(unknown.clone())), + } + } +} + +impl TryFrom for Target { + type Error = UnknownTargetType; + + fn try_from(kube_target: KubeTarget) -> Result { + match kube_target { + KubeTarget::Known(target) => Ok(target), + KubeTarget::Unknown(unknown) => Err(UnknownTargetType(unknown)), + } + } +} + +impl From for KubeTarget { + fn from(target: Target) -> Self { + Self::Known(target) + } +} + +impl core::fmt::Display for KubeTarget { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + KubeTarget::Known(target) => target.fmt(f), + KubeTarget::Unknown(unknown) => write!(f, "{}", unknown), + } + } +} + +impl<'de> Deserialize<'de> for KubeTarget { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let deserialized = serde_json::Value::deserialize(deserializer)?; + let maybe_unknown = deserialized.to_string(); + + let target = serde_json::from_value::(deserialized); + match target { + Ok(target) => Ok(KubeTarget::Known(target)), + Err(_) => Ok(KubeTarget::Unknown(maybe_unknown)), + } + } +} + +#[cfg(test)] +mod tests { + use kube::CustomResource; + use mirrord_config::target::Target; + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + + use crate::crd::{kube_target::KubeTarget, TargetSpec}; + + #[derive(CustomResource, Clone, Debug, Deserialize, Serialize, JsonSchema)] + #[kube( + group = "operator.metalbear.co", + version = "v1", + kind = "Target", + root = "LegacyTargetCrd", + namespaced + )] + struct LegacyTargetSpec { + target: Option, + } + + #[test] + fn none_into_kube_target() { + let legacy = serde_json::to_string_pretty(&LegacyTargetSpec { target: None }).unwrap(); + serde_json::from_str::(&legacy).expect("Deserialization from old to new!"); + } + + #[test] + fn some_into_kube_target() { + let legacy = serde_json::to_string_pretty(&LegacyTargetSpec { + target: Some(Target::Targetless), + }) + .unwrap(); + serde_json::from_str::(&legacy).expect("Deserialization from old to new!"); + } + + #[test] + fn kube_target_unknown() { + let new = serde_json::from_str::(&r#"{"target": "Bolesław the Great"}"#) + .expect("Deserialization of unknown!"); + + assert!(matches!( + new, + TargetSpec { + target: KubeTarget::Unknown(_) + } + )) + } + + #[test] + fn kube_target_to_legacy() { + let new = serde_json::to_string_pretty(&TargetSpec { + target: KubeTarget::Known(Target::Targetless), + }) + .unwrap(); + + serde_json::from_str::(&new).expect("Deserialization from new to old!"); + } + + #[test] + #[should_panic] + fn bonkers_kube_target_fails() { + serde_json::from_str::(&r#"{"king": "Sigismund II"}"#) + .expect("Kings are not deserializible!"); + } +} diff --git a/mirrord/operator/src/lib.rs b/mirrord/operator/src/lib.rs index 6029966697e..4832aa50496 100644 --- a/mirrord/operator/src/lib.rs +++ b/mirrord/operator/src/lib.rs @@ -1,5 +1,6 @@ #![feature(let_chains)] #![feature(lazy_cell)] +#![feature(try_blocks)] #![warn(clippy::indexing_slicing)] #[cfg(feature = "client")] diff --git a/mirrord/operator/src/setup.rs b/mirrord/operator/src/setup.rs index ba8940844ae..27ada24a593 100644 --- a/mirrord/operator/src/setup.rs +++ b/mirrord/operator/src/setup.rs @@ -8,7 +8,9 @@ use k8s_openapi::{ Probe, ResourceRequirements, Secret, SecretVolumeSource, SecurityContext, Service, ServiceAccount, ServicePort, ServiceSpec, Volume, VolumeMount, }, - rbac::v1::{ClusterRole, ClusterRoleBinding, PolicyRule, RoleRef, Subject}, + rbac::v1::{ + ClusterRole, ClusterRoleBinding, PolicyRule, Role, RoleBinding, RoleRef, Subject, + }, }, apiextensions_apiserver::pkg::apis::apiextensions::v1::CustomResourceDefinition, apimachinery::pkg::{ @@ -29,6 +31,7 @@ static OPERATOR_NAME: &str = "mirrord-operator"; static OPERATOR_PORT: i32 = 3000; static OPERATOR_ROLE_NAME: &str = "mirrord-operator"; static OPERATOR_ROLE_BINDING_NAME: &str = "mirrord-operator"; +static OPERATOR_CLIENT_CA_ROLE_NAME: &str = "mirrord-operator-apiserver-authentication"; static OPERATOR_CLUSTER_USER_ROLE_NAME: &str = "mirrord-operator-user"; static OPERATOR_LICENSE_SECRET_NAME: &str = "mirrord-operator-license"; static OPERATOR_LICENSE_SECRET_FILE_NAME: &str = "license.pem"; @@ -95,6 +98,8 @@ pub struct Operator { service: OperatorService, service_account: OperatorServiceAccount, user_cluster_role: OperatorClusterUserRole, + client_ca_role: OperatorClientCaRole, + client_ca_role_binding: OperatorClientCaRoleBinding, } impl Operator { @@ -118,6 +123,10 @@ impl Operator { let role_binding = OperatorRoleBinding::new(&role, &service_account); let user_cluster_role = OperatorClusterUserRole::new(); + let client_ca_role = OperatorClientCaRole::new(); + let client_ca_role_binding = + OperatorClientCaRoleBinding::new(&client_ca_role, &service_account); + let deployment = OperatorDeployment::new( &namespace, &service_account, @@ -140,6 +149,8 @@ impl Operator { service, service_account, user_cluster_role, + client_ca_role, + client_ca_role_binding, } } } @@ -162,9 +173,15 @@ impl OperatorSetup for Operator { writer.write_all(b"---\n")?; self.user_cluster_role.to_writer(&mut writer)?; + writer.write_all(b"---\n")?; + self.client_ca_role.to_writer(&mut writer)?; + writer.write_all(b"---\n")?; self.role_binding.to_writer(&mut writer)?; + writer.write_all(b"---\n")?; + self.client_ca_role_binding.to_writer(&mut writer)?; + writer.write_all(b"---\n")?; self.deployment.to_writer(&mut writer)?; @@ -653,6 +670,63 @@ impl Default for OperatorClusterUserRole { } } +#[derive(Debug)] +pub struct OperatorClientCaRole(Role); + +impl OperatorClientCaRole { + pub fn new() -> Self { + let role = Role { + metadata: ObjectMeta { + name: Some(OPERATOR_CLIENT_CA_ROLE_NAME.to_owned()), + namespace: Some("kube-system".to_owned()), + ..Default::default() + }, + rules: Some(vec![PolicyRule { + api_groups: Some(vec!["".to_owned()]), + resources: Some(vec!["configmaps".to_owned()]), + verbs: vec!["get".to_owned()], + resource_names: Some(vec!["extension-apiserver-authentication".to_owned()]), + ..Default::default() + }]), + }; + + OperatorClientCaRole(role) + } + + fn as_role_ref(&self) -> RoleRef { + RoleRef { + api_group: "rbac.authorization.k8s.io".to_owned(), + kind: "Role".to_owned(), + name: self.0.metadata.name.clone().unwrap_or_default(), + } + } +} + +impl Default for OperatorClientCaRole { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +pub struct OperatorClientCaRoleBinding(RoleBinding); + +impl OperatorClientCaRoleBinding { + pub fn new(role: &OperatorClientCaRole, sa: &OperatorServiceAccount) -> Self { + let role = RoleBinding { + metadata: ObjectMeta { + name: Some(OPERATOR_CLIENT_CA_ROLE_NAME.to_owned()), + namespace: role.0.metadata.namespace.clone(), + ..Default::default() + }, + role_ref: role.as_role_ref(), + subjects: Some(vec![sa.as_subject()]), + }; + + OperatorClientCaRoleBinding(role) + } +} + impl OperatorSetup for CustomResourceDefinition { fn to_writer(&self, writer: W) -> Result<()> { serde_yaml::to_writer(writer, &self).map_err(SetupWriteError::from) @@ -668,5 +742,7 @@ writer_impl![ OperatorLicenseSecret, OperatorService, OperatorApiService, - OperatorClusterUserRole + OperatorClusterUserRole, + OperatorClientCaRole, + OperatorClientCaRoleBinding ]; diff --git a/mirrord/operator/src/types.rs b/mirrord/operator/src/types.rs index c21c1ab8282..78107221307 100644 --- a/mirrord/operator/src/types.rs +++ b/mirrord/operator/src/types.rs @@ -12,3 +12,26 @@ pub struct LicenseInfoOwned { /// Subscription id encoded in the operator license extension. pub subscription_id: Option, } + +/// Name of HTTP header containing CLI version. +/// Sent with each request to the mirrord operator. +pub const MIRRORD_CLI_VERSION_HEADER: &str = "x-mirrord-cli-version"; + +/// Name of HTTP header containing client certificate. +/// Sent with each request to the mirrord operator (if available) except: +/// 1. Initial GET on the operator resource +/// 2. User certificate request +/// Required for making the target connection request. +pub const CLIENT_CERT_HEADER: &str = "x-client-der"; + +/// Name of HTTP header containing client hostname. +/// Sent with each request to the mirrord operator (if available). +pub const CLIENT_HOSTNAME_HEADER: &str = "x-client-hostname"; + +/// Name of HTTP header containing client name. +/// Sent with each request to the mirrord operator (if available). +pub const CLIENT_NAME_HEADER: &str = "x-client-name"; + +/// Name of HTTP header containing operator session id. +/// Sent with target connection request. +pub const SESSION_ID_HEADER: &str = "x-session-id"; diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index 43c44b00006..1a1482277e2 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mirrord-protocol" -version = "1.7.0" +version = "1.8.1" authors.workspace = true description.workspace = true documentation.workspace = true diff --git a/mirrord/agent/src/steal/http/body_chunks.rs b/mirrord/protocol/src/body_chunks.rs similarity index 77% rename from mirrord/agent/src/steal/http/body_chunks.rs rename to mirrord/protocol/src/body_chunks.rs index 83dfe3a1758..e9e2a6cc073 100644 --- a/mirrord/agent/src/steal/http/body_chunks.rs +++ b/mirrord/protocol/src/body_chunks.rs @@ -5,14 +5,17 @@ use std::{ }; use bytes::Bytes; -use hyper::body::{Body, Frame, Incoming}; +use hyper::body::{Body, Frame}; -pub trait IncomingExt { - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_>; +pub trait BodyExt { + fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B>; } -impl IncomingExt for Incoming { - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_> { +impl BodyExt for B +where + B: Body, +{ + fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B> { FramesFut { body: self, no_wait, @@ -20,12 +23,15 @@ impl IncomingExt for Incoming { } } -pub struct FramesFut<'a> { - body: &'a mut Incoming, +pub struct FramesFut<'a, B> { + body: &'a mut B, no_wait: bool, } -impl<'a> Future for FramesFut<'a> { +impl<'a, B> Future for FramesFut<'a, B> +where + B: Body + Unpin, +{ type Output = hyper::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/mirrord/protocol/src/lib.rs b/mirrord/protocol/src/lib.rs index c01e88ad2bb..afe503d5d2b 100644 --- a/mirrord/protocol/src/lib.rs +++ b/mirrord/protocol/src/lib.rs @@ -3,6 +3,7 @@ #![feature(lazy_cell)] #![warn(clippy::indexing_slicing)] +pub mod body_chunks; pub mod codec; pub mod dns; pub mod error; diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index 170d3d2dbf2..351e285e2bc 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -11,7 +11,7 @@ use std::{ use bincode::{Decode, Encode}; use bytes::Bytes; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; use hyper::{ body::{Body, Frame, Incoming}, http, @@ -22,9 +22,10 @@ use mirrord_macros::protocol_break; use semver::VersionReq; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Receiver; +use tokio_stream::wrappers::ReceiverStream; use tracing::error; -use crate::{ConnectionId, Port, RemoteResult, RequestId}; +use crate::{body_chunks::BodyExt as _, ConnectionId, Port, RemoteResult, RequestId}; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct NewTcpConnection { @@ -81,13 +82,13 @@ pub enum DaemonTcp { #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum ChunkedRequest { Start(HttpRequest>), - Body(ChunkedRequestBody), - Error(ChunkedRequestError), + Body(ChunkedHttpBody), + Error(ChunkedHttpError), } /// Contents of a chunked message body frame from server. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] -pub struct ChunkedRequestBody { +pub struct ChunkedHttpBody { #[bincode(with_serde)] pub frames: Vec, pub is_last: bool, @@ -106,7 +107,7 @@ impl From for Frame { /// An error occurred while processing chunked data from server. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] -pub struct ChunkedRequestError { +pub struct ChunkedHttpError { pub connection_id: ConnectionId, pub request_id: RequestId, } @@ -191,6 +192,14 @@ pub enum LayerTcpSteal { Data(TcpData), HttpResponse(HttpResponse>), HttpResponseFramed(HttpResponse), + HttpResponseChunked(ChunkedResponse), +} + +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub enum ChunkedResponse { + Start(HttpResponse>), + Body(ChunkedHttpBody), + Error(ChunkedHttpError), } /// (De-)Serializable HTTP request. @@ -388,18 +397,22 @@ impl HttpRequestFallback { } } -/// Minimal mirrord-protocol version that allows [`DaemonTcp::HttpRequestFramed`] instead of -/// [`DaemonTcp::HttpRequest`]. +/// Minimal mirrord-protocol version that allows [`DaemonTcp::HttpRequestFramed`] and +/// [`LayerTcpSteal::HttpResponseFramed`]. pub static HTTP_FRAMED_VERSION: LazyLock = LazyLock::new(|| ">=1.3.0".parse().expect("Bad Identifier")); -/// Minimal mirrord-protocol version that allows [`DaemonTcp::HttpRequestChunked`] instead of -/// [`DaemonTcp::HttpRequest`]. -pub static HTTP_CHUNKED_VERSION: LazyLock = +/// Minimal mirrord-protocol version that allows [`DaemonTcp::HttpRequestChunked`]. +pub static HTTP_CHUNKED_REQUEST_VERSION: LazyLock = LazyLock::new(|| ">=1.7.0".parse().expect("Bad Identifier")); +/// Minimal mirrord-protocol version that allows [`LayerTcpSteal::HttpResponseChunked`]. +pub static HTTP_CHUNKED_RESPONSE_VERSION: LazyLock = + LazyLock::new(|| ">=1.8.1".parse().expect("Bad Identifier")); + /// Minimal mirrord-protocol version that allows [`DaemonTcp::Data`] to be sent in the same -/// connection as [`DaemonTcp::HttpRequestFramed`] and [`DaemonTcp::HttpRequest`]. +/// connection as +/// [`DaemonTcp::HttpRequestChunked`]/[`DaemonTcp::HttpRequestFramed`]/[`DaemonTcp::HttpRequest`]. pub static HTTP_FILTERED_UPGRADE_VERSION: LazyLock = LazyLock::new(|| ">=1.5.0".parse().expect("Bad Identifier")); @@ -429,15 +442,15 @@ impl HttpRequest { #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] pub struct InternalHttpResponse { #[serde(with = "http_serde::status_code")] - status: StatusCode, + pub status: StatusCode, #[serde(with = "http_serde::version")] - version: Version, + pub version: Version, #[serde(with = "http_serde::header_map")] - headers: HeaderMap, + pub headers: HeaderMap, - body: Body, + pub body: Body, } impl InternalHttpResponse { @@ -536,10 +549,13 @@ impl fmt::Debug for InternalHttpBodyFrame { } } +pub type ReceiverStreamBody = StreamBody>>>; + #[derive(Debug)] pub enum HttpResponseFallback { Framed(HttpResponse), Fallback(HttpResponse>), + Streamed(HttpResponse), } impl HttpResponseFallback { @@ -547,6 +563,7 @@ impl HttpResponseFallback { match self { HttpResponseFallback::Framed(req) => req.connection_id, HttpResponseFallback::Fallback(req) => req.connection_id, + HttpResponseFallback::Streamed(req) => req.connection_id, } } @@ -554,28 +571,70 @@ impl HttpResponseFallback { match self { HttpResponseFallback::Framed(req) => req.request_id, HttpResponseFallback::Fallback(req) => req.request_id, + HttpResponseFallback::Streamed(req) => req.request_id, } } - pub fn into_hyper(self) -> Result>, http::Error> { + pub fn into_hyper(self) -> Result>, http::Error> + where + E: From, + { match self { HttpResponseFallback::Framed(req) => req.internal_response.try_into(), HttpResponseFallback::Fallback(req) => req.internal_response.try_into(), + HttpResponseFallback::Streamed(req) => req.internal_response.try_into(), } } + /// Produces an [`HttpResponseFallback`] to the given [`HttpRequestFallback`]. + /// + /// # Note on picking response variant + /// + /// Variant of returned [`HttpResponseFallback`] is picked based on the variant of given + /// [`HttpRequestFallback`] and agent protocol version. We need to consider both due + /// to: + /// 1. Old agent versions always responding with client's `mirrord_protocol` version to + /// [`ClientMessage::SwitchProtocolVersion`](super::ClientMessage::SwitchProtocolVersion), + /// 2. [`LayerTcpSteal::HttpResponseChunked`] being introduced after + /// [`DaemonTcp::HttpRequestChunked`]. pub fn response_from_request( request: HttpRequestFallback, status: StatusCode, message: &str, + agent_protocol_version: Option<&semver::Version>, ) -> Self { + let agent_supports_streaming_response = agent_protocol_version + .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) + .unwrap_or(false); + match request { + // We received `DaemonTcp::HttpRequestFramed` from the agent, + // so we know it supports `LayerTcpSteal::HttpResponseFramed` (both were introduced in + // the same `mirrord_protocol` version). HttpRequestFallback::Framed(request) => HttpResponseFallback::Framed( HttpResponse::::response_from_request(request, status, message), ), + + // We received `DaemonTcp::HttpRequest` from the agent, so we assume it only supports + // `LayerTcpSteal::HttpResponse`. HttpRequestFallback::Fallback(request) => HttpResponseFallback::Fallback( HttpResponse::>::response_from_request(request, status, message), ), + + // We received `DaemonTcp::HttpRequestChunked` and the agent supports + // `LayerTcpSteal::HttpResponseChunked`. + HttpRequestFallback::Streamed(request) if agent_supports_streaming_response => { + HttpResponseFallback::Streamed( + HttpResponse::::response_from_request( + request, status, message, + ), + ) + } + + // We received `DaemonTcp::HttpRequestChunked` from the agent, + // but the agent does not support `LayerTcpSteal::HttpResponseChunked`. + // However, it must support the older `LayerTcpSteal::HttpResponseFramed` + // variant (was introduced before `DaemonTcp::HttpRequestChunked`). HttpRequestFallback::Streamed(request) => HttpResponseFallback::Framed( HttpResponse::::response_from_request(request, status, message), ), @@ -585,10 +644,7 @@ impl HttpResponseFallback { #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] #[bincode(bounds = "for<'de> Body: Serialize + Deserialize<'de>")] -pub struct HttpResponse -where - for<'de> Body: Serialize + Deserialize<'de>, -{ +pub struct HttpResponse { /// This is used to make sure the response is sent in its turn, after responses to all earlier /// requests were already sent. pub port: Port, @@ -789,6 +845,88 @@ impl HttpResponse> { } } +impl HttpResponse { + pub async fn from_hyper_response( + response: Response, + port: Port, + connection_id: ConnectionId, + request_id: RequestId, + ) -> Result, hyper::Error> { + let ( + Parts { + status, + version, + headers, + .. + }, + mut body, + ) = response.into_parts(); + + let frames = body.next_frames(true).await?; + let (tx, rx) = tokio::sync::mpsc::channel(frames.frames.len().max(12)); + for frame in frames.frames { + tx.try_send(Ok(frame)) + .expect("Channel is open, capacity sufficient") + } + if !frames.is_last { + tokio::spawn(async move { + while let Some(frame) = body.frame().await { + if tx.send(frame).await.is_err() { + return; + } + } + }); + }; + + let body = StreamBody::new(ReceiverStream::from(rx)); + + let internal_response = InternalHttpResponse { + status, + headers, + version, + body, + }; + + Ok(HttpResponse { + request_id, + port, + connection_id, + internal_response, + }) + } + + pub fn response_from_request( + request: HttpRequest, + status: StatusCode, + message: &str, + ) -> Self { + let HttpRequest { + internal_request: InternalHttpRequest { version, .. }, + connection_id, + request_id, + port, + } = request; + + let (tx, rx) = tokio::sync::mpsc::channel(1); + let frame = Frame::data(Bytes::copy_from_slice(message.as_bytes())); + tx.try_send(Ok(frame)) + .expect("channel is open, capacity is sufficient"); + let body = StreamBody::new(ReceiverStream::new(rx)); + + Self { + port, + connection_id, + request_id, + internal_response: InternalHttpResponse { + status, + version, + headers: Default::default(), + body, + }, + } + } +} + impl TryFrom> for Response> { type Error = http::Error; @@ -830,3 +968,26 @@ impl TryFrom>> for Response> { )) } } + +impl TryFrom> for Response> +where + E: From, +{ + type Error = http::Error; + + fn try_from(value: InternalHttpResponse) -> Result { + let InternalHttpResponse { + status, + version, + headers, + body, + } = value; + + let mut builder = Response::builder().status(status).version(version); + if let Some(h) = builder.headers_mut() { + *h = headers; + } + + builder.body(BoxBody::new(body.map_err(|e| e.into()))) + } +}