diff --git a/CHANGELOG.md b/CHANGELOG.md index bda379d047d..b701ad68919 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,30 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang +## [3.131.0](https://github.com/metalbear-co/mirrord/tree/3.131.0) - 2025-01-27 + + +### Added + +- `statfs` support + [#statfs](https://github.com/metalbear-co/mirrord/issues/statfs) +- Support for in-cluster DNS resolution of IPv6 addresses. + [#2958](https://github.com/metalbear-co/mirrord/issues/2958) +- Prometheus metrics to the mirrord-agent. + [#2975](https://github.com/metalbear-co/mirrord/issues/2975) +- Kubernetes Service as a new type of mirrord target (requires mirrord + operator). + + +### Fixed + +- Misleading doc for `.target.namespace` config. + [#3009](https://github.com/metalbear-co/mirrord/issues/3009) +- Agent now correctly clears incoming port subscriptions of disconnected + clients. +- mirrord no longer uses the default `{"operator": "Exists"}` tolerations when + spawning targetless agent pods. + ## [3.130.0](https://github.com/metalbear-co/mirrord/tree/3.130.0) - 2025-01-21 diff --git a/Cargo.lock b/Cargo.lock index e0339e18df0..d230c6c84f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -584,9 +584,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.14" +version = "1.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f40e82e858e02445402906e454a73e244c7f501fcae198977585946c48e8697" +checksum = "dc47e70fc35d054c8fcd296d47a61711f043ac80534a10b4f741904f81e73a90" dependencies = [ "aws-credential-types", "aws-runtime", @@ -676,9 +676,9 @@ dependencies = [ [[package]] name = "aws-sdk-sqs" -version = "1.55.0" +version = "1.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2db64ffe78706b344b7c9b620f96c3c0655745e006b87bad20f424562656a0dd" +checksum = "ca6b2f438a99c189b89279ca88ed05c2bcdcfcacd7d78a821b8650166040b9b4" dependencies = [ "aws-credential-types", "aws-runtime", @@ -698,9 +698,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.55.0" +version = "1.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33993c0b054f4251ff2946941b56c26b582677303eeca34087594eb901ece022" +checksum = "12e057fdcb8842de9b83592a70f5b4da0ee10bc0ad278247da1425a742a444d7" dependencies = [ "aws-credential-types", "aws-runtime", @@ -720,9 +720,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.56.0" +version = "1.57.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bd3ceba74a584337a8f3839c818f14f1a2288bfd24235120ff22d7e17a0dd54" +checksum = "a120ade4a44691b3c5c2ff2fa61b14ed331fdc218397f61ab48d66593012ae2a" dependencies = [ "aws-credential-types", "aws-runtime", @@ -742,9 +742,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.56.0" +version = "1.57.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07835598e52dd354368429cb2abf447ce523ea446d0a533a63cb42cd0d2d9280" +checksum = "115fd4fb663817ed595a5ee4f1649d7aacd861d47462323cb37576ce89271b93" dependencies = [ "aws-credential-types", "aws-runtime", @@ -937,6 +937,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", + "axum-macros", "base64 0.22.1", "bytes", "futures-util", @@ -987,6 +988,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "backoff" version = "0.4.0" @@ -1657,9 +1669,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -2401,7 +2413,7 @@ dependencies = [ [[package]] name = "fileops" -version = "3.130.0" +version = "3.131.0" dependencies = [ "libc", ] @@ -3519,7 +3531,7 @@ checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "issue1317" -version = "3.130.0" +version = "3.131.0" dependencies = [ "actix-web", "env_logger 0.11.6", @@ -3529,7 +3541,7 @@ dependencies = [ [[package]] name = "issue1776" -version = "3.130.0" +version = "3.131.0" dependencies = [ "errno 0.3.10", "libc", @@ -3538,7 +3550,7 @@ dependencies = [ [[package]] name = "issue1776portnot53" -version = "3.130.0" +version = "3.131.0" dependencies = [ "libc", "socket2", @@ -3546,14 +3558,14 @@ dependencies = [ [[package]] name = "issue1899" -version = "3.130.0" +version = "3.131.0" dependencies = [ "libc", ] [[package]] name = "issue2001" -version = "3.130.0" +version = "3.131.0" dependencies = [ "libc", ] @@ -3874,7 +3886,7 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "listen_ports" -version = "3.130.0" +version = "3.131.0" [[package]] name = "litemap" @@ -4115,7 +4127,7 @@ dependencies = [ [[package]] name = "mirrord" -version = "3.130.0" +version = "3.131.0" dependencies = [ "actix-codec", "clap", @@ -4171,10 +4183,11 @@ dependencies = [ [[package]] name = "mirrord-agent" -version = "3.130.0" +version = "3.131.0" dependencies = [ "actix-codec", "async-trait", + "axum", "bollard", "bytes", "clap", @@ -4199,10 +4212,12 @@ dependencies = [ "nix 0.29.0", "oci-spec", "pnet", - "procfs", + "procfs 0.17.0", + "prometheus", "rand", "rawsocket", "rcgen", + "reqwest 0.12.12", "rstest", "rustls 0.23.21", "semver 1.0.25", @@ -4226,7 +4241,7 @@ dependencies = [ [[package]] name = "mirrord-analytics" -version = "3.130.0" +version = "3.131.0" dependencies = [ "assert-json-diff", "base64 0.22.1", @@ -4240,7 +4255,7 @@ dependencies = [ [[package]] name = "mirrord-auth" -version = "3.130.0" +version = "3.131.0" dependencies = [ "bcder", "chrono", @@ -4261,7 +4276,7 @@ dependencies = [ [[package]] name = "mirrord-config" -version = "3.130.0" +version = "3.131.0" dependencies = [ "base64 0.22.1", "bimap", @@ -4285,7 +4300,7 @@ dependencies = [ [[package]] name = "mirrord-config-derive" -version = "3.130.0" +version = "3.131.0" dependencies = [ "proc-macro2", "proc-macro2-diagnostics", @@ -4295,7 +4310,7 @@ dependencies = [ [[package]] name = "mirrord-console" -version = "3.130.0" +version = "3.131.0" dependencies = [ "bincode", "drain", @@ -4311,7 +4326,7 @@ dependencies = [ [[package]] name = "mirrord-intproxy" -version = "3.130.0" +version = "3.131.0" dependencies = [ "bytes", "exponential-backoff", @@ -4339,7 +4354,7 @@ dependencies = [ [[package]] name = "mirrord-intproxy-protocol" -version = "3.130.0" +version = "3.131.0" dependencies = [ "bincode", "mirrord-protocol", @@ -4349,7 +4364,7 @@ dependencies = [ [[package]] name = "mirrord-kube" -version = "3.130.0" +version = "3.131.0" dependencies = [ "actix-codec", "async-stream", @@ -4373,7 +4388,7 @@ dependencies = [ [[package]] name = "mirrord-layer" -version = "3.130.0" +version = "3.131.0" dependencies = [ "actix-codec", "base64 0.22.1", @@ -4416,7 +4431,7 @@ dependencies = [ [[package]] name = "mirrord-layer-macro" -version = "3.130.0" +version = "3.131.0" dependencies = [ "proc-macro2", "quote", @@ -4425,7 +4440,7 @@ dependencies = [ [[package]] name = "mirrord-macros" -version = "3.130.0" +version = "3.131.0" dependencies = [ "proc-macro2", "proc-macro2-diagnostics", @@ -4435,7 +4450,7 @@ dependencies = [ [[package]] name = "mirrord-operator" -version = "3.130.0" +version = "3.131.0" dependencies = [ "base64 0.22.1", "bincode", @@ -4468,7 +4483,7 @@ dependencies = [ [[package]] name = "mirrord-progress" -version = "3.130.0" +version = "3.131.0" dependencies = [ "enum_dispatch", "indicatif", @@ -4478,7 +4493,7 @@ dependencies = [ [[package]] name = "mirrord-protocol" -version = "1.15.1" +version = "1.16.1" dependencies = [ "actix-codec", "bincode", @@ -4502,7 +4517,7 @@ dependencies = [ [[package]] name = "mirrord-sip" -version = "3.130.0" +version = "3.131.0" dependencies = [ "apple-codesign", "object 0.36.7", @@ -4515,7 +4530,7 @@ dependencies = [ [[package]] name = "mirrord-vpn" -version = "3.130.0" +version = "3.131.0" dependencies = [ "futures", "ipnet", @@ -4842,9 +4857,9 @@ checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "option-ext" @@ -4863,7 +4878,7 @@ dependencies = [ [[package]] name = "outgoing" -version = "3.130.0" +version = "3.131.0" [[package]] name = "outref" @@ -5409,6 +5424,19 @@ dependencies = [ "yansi", ] +[[package]] +name = "procfs" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "731e0d9356b0c25f16f33b5be79b1c57b562f141ebfcdb0ad8ac2c13a24293b4" +dependencies = [ + "bitflags 2.8.0", + "hex", + "lazy_static", + "procfs-core 0.16.0", + "rustix", +] + [[package]] name = "procfs" version = "0.17.0" @@ -5419,10 +5447,20 @@ dependencies = [ "chrono", "flate2", "hex", - "procfs-core", + "procfs-core 0.17.0", "rustix", ] +[[package]] +name = "procfs-core" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29" +dependencies = [ + "bitflags 2.8.0", + "hex", +] + [[package]] name = "procfs-core" version = "0.17.0" @@ -5434,6 +5472,23 @@ dependencies = [ "hex", ] +[[package]] +name = "prometheus" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1" +dependencies = [ + "cfg-if", + "fnv", + "lazy_static", + "libc", + "memchr", + "parking_lot", + "procfs 0.16.0", + "protobuf", + "thiserror 1.0.69", +] + [[package]] name = "prost" version = "0.13.4" @@ -5486,6 +5541,12 @@ dependencies = [ "prost", ] +[[package]] +name = "protobuf" +version = "2.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" + [[package]] name = "quick-error" version = "1.2.3" @@ -5924,14 +5985,14 @@ dependencies = [ [[package]] name = "rust-bypassed-unix-socket" -version = "3.130.0" +version = "3.131.0" dependencies = [ "tokio", ] [[package]] name = "rust-e2e-fileops" -version = "3.130.0" +version = "3.131.0" dependencies = [ "libc", ] @@ -5947,7 +6008,7 @@ dependencies = [ [[package]] name = "rust-unix-socket-client" -version = "3.130.0" +version = "3.131.0" dependencies = [ "tokio", ] @@ -6108,9 +6169,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" dependencies = [ "web-time", ] @@ -7589,9 +7650,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "11cd88e12b17c6494200a9c1b683a04fcac9573ed74cd1b62aeb2727c5592243" [[package]] name = "unicode-linebreak" diff --git a/Cargo.toml b/Cargo.toml index 9f01963dc6a..4288306f5ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ resolver = "2" # latest commits on rustls suppress certificate verification [workspace.package] -version = "3.130.0" +version = "3.131.0" edition = "2021" license = "MIT" readme = "README.md" diff --git a/changelog.d/2958.added.md b/changelog.d/2958.added.md deleted file mode 100644 index af12472d466..00000000000 --- a/changelog.d/2958.added.md +++ /dev/null @@ -1 +0,0 @@ -Support for in-cluster DNS resolution of IPv6 addresses. diff --git a/mirrord-schema.json b/mirrord-schema.json index 5beadff3c58..0cebfb20ce9 100644 --- a/mirrord-schema.json +++ b/mirrord-schema.json @@ -1,7 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "LayerFileConfig", - "description": "mirrord allows for a high degree of customization when it comes to which features you want to enable, and how they should function.\n\nAll of the configuration fields have a default value, so a minimal configuration would be no configuration at all.\n\nThe configuration supports templating using the [Tera](https://keats.github.io/tera/docs/) template engine. Currently we don't provide additional values to the context, if you have anything you want us to provide please let us know.\n\nTo use a configuration file in the CLI, use the `-f ` flag. Or if using VSCode Extension or JetBrains plugin, simply create a `.mirrord/mirrord.json` file or use the UI.\n\nTo help you get started, here are examples of a basic configuration file, and a complete configuration file containing all fields.\n\n### Basic `config.json` {#root-basic}\n\n```json { \"target\": \"pod/bear-pod\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Basic `config.json` with templating {#root-basic-templating}\n\n```json { \"target\": \"{{ get_env(name=\"TARGET\", default=\"pod/fallback\") }}\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Complete `config.json` {#root-complete}\n\nDon't use this example as a starting point, it's just here to show you all the available options. ```json { \"accept_invalid_certificates\": false, \"skip_processes\": \"ide-debugger\", \"target\": { \"path\": \"pod/bear-pod\", \"namespace\": \"default\" }, \"connect_tcp\": null, \"agent\": { \"log_level\": \"info\", \"json_log\": false, \"labels\": { \"user\": \"meow\" }, \"annotations\": { \"cats.io/inject\": \"enabled\" }, \"namespace\": \"default\", \"image\": \"ghcr.io/metalbear-co/mirrord:latest\", \"image_pull_policy\": \"IfNotPresent\", \"image_pull_secrets\": [ { \"secret-key\": \"secret\" } ], \"ttl\": 30, \"ephemeral\": false, \"communication_timeout\": 30, \"startup_timeout\": 360, \"network_interface\": \"eth0\", \"flush_connections\": true }, \"feature\": { \"env\": { \"include\": \"DATABASE_USER;PUBLIC_ENV\", \"exclude\": \"DATABASE_PASSWORD;SECRET_ENV\", \"override\": { \"DATABASE_CONNECTION\": \"db://localhost:7777/my-db\", \"LOCAL_BEAR\": \"panda\" }, \"mapping\": { \".+_TIMEOUT\": \"1000\" } }, \"fs\": { \"mode\": \"write\", \"read_write\": \".+\\\\.json\" , \"read_only\": [ \".+\\\\.yaml\", \".+important-file\\\\.txt\" ], \"local\": [ \".+\\\\.js\", \".+\\\\.mjs\" ] }, \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": { \"enabled\": true, \"filter\": { \"local\": [\"1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\"] } } }, \"copy_target\": { \"scale_down\": false } }, \"operator\": true, \"kubeconfig\": \"~/.kube/config\", \"sip_binaries\": \"bash\", \"telemetry\": true, \"kube_context\": \"my-cluster\" } ```\n\n# Options {#root-options}", + "description": "mirrord allows for a high degree of customization when it comes to which features you want to enable, and how they should function.\n\nAll of the configuration fields have a default value, so a minimal configuration would be no configuration at all.\n\nThe configuration supports templating using the [Tera](https://keats.github.io/tera/docs/) template engine. Currently we don't provide additional values to the context, if you have anything you want us to provide please let us know.\n\nTo use a configuration file in the CLI, use the `-f ` flag. Or if using VSCode Extension or JetBrains plugin, simply create a `.mirrord/mirrord.json` file or use the UI.\n\nTo help you get started, here are examples of a basic configuration file, and a complete configuration file containing all fields.\n\n### Basic `config.json` {#root-basic}\n\n```json { \"target\": \"pod/bear-pod\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Basic `config.json` with templating {#root-basic-templating}\n\n```json { \"target\": \"{{ get_env(name=\"TARGET\", default=\"pod/fallback\") }}\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Complete `config.json` {#root-complete}\n\nDon't use this example as a starting point, it's just here to show you all the available options. ```json { \"accept_invalid_certificates\": false, \"skip_processes\": \"ide-debugger\", \"target\": { \"path\": \"pod/bear-pod\", \"namespace\": \"default\" }, \"connect_tcp\": null, \"agent\": { \"log_level\": \"info\", \"json_log\": false, \"labels\": { \"user\": \"meow\" }, \"annotations\": { \"cats.io/inject\": \"enabled\" }, \"namespace\": \"default\", \"image\": \"ghcr.io/metalbear-co/mirrord:latest\", \"image_pull_policy\": \"IfNotPresent\", \"image_pull_secrets\": [ { \"secret-key\": \"secret\" } ], \"ttl\": 30, \"ephemeral\": false, \"communication_timeout\": 30, \"startup_timeout\": 360, \"network_interface\": \"eth0\", \"flush_connections\": true, \"metrics\": \"0.0.0.0:9000\", }, \"feature\": { \"env\": { \"include\": \"DATABASE_USER;PUBLIC_ENV\", \"exclude\": \"DATABASE_PASSWORD;SECRET_ENV\", \"override\": { \"DATABASE_CONNECTION\": \"db://localhost:7777/my-db\", \"LOCAL_BEAR\": \"panda\" }, \"mapping\": { \".+_TIMEOUT\": \"1000\" } }, \"fs\": { \"mode\": \"write\", \"read_write\": \".+\\\\.json\" , \"read_only\": [ \".+\\\\.yaml\", \".+important-file\\\\.txt\" ], \"local\": [ \".+\\\\.js\", \".+\\\\.mjs\" ] }, \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": { \"enabled\": true, \"filter\": { \"local\": [\"1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\"] } } }, \"copy_target\": { \"scale_down\": false } }, \"operator\": true, \"kubeconfig\": \"~/.kube/config\", \"sip_binaries\": \"bash\", \"telemetry\": true, \"kube_context\": \"my-cluster\" } ```\n\n# Options {#root-options}", "type": "object", "properties": { "accept_invalid_certificates": { @@ -255,7 +255,7 @@ "properties": { "annotations": { "title": "agent.annotations {#agent-annotations}", - "description": "Allows setting up custom annotations for the agent Job and Pod.\n\n```json { \"annotations\": { \"cats.io/inject\": \"enabled\" } } ```", + "description": "Allows setting up custom annotations for the agent Job and Pod.\n\n```json { \"annotations\": { \"cats.io/inject\": \"enabled\" \"prometheus.io/scrape\": \"true\", \"prometheus.io/port\": \"9000\" } } ```", "type": [ "object", "null" @@ -378,6 +378,14 @@ "null" ] }, + "metrics": { + "title": "agent.metrics {#agent-metrics}", + "description": "Enables prometheus metrics for the agent pod.\n\nYou might need to add annotations to the agent pod depending on how prometheus is configured to scrape for metrics.\n\n```json { \"metrics\": \"0.0.0.0:9000\" } ```", + "type": [ + "string", + "null" + ] + }, "namespace": { "title": "agent.namespace {#agent-namespace}", "description": "Namespace where the agent shall live. Note: Doesn't work with ephemeral containers. Defaults to the current kubernetes namespace.", @@ -453,7 +461,7 @@ }, "tolerations": { "title": "agent.tolerations {#agent-tolerations}", - "description": "Set pod tolerations. (not with ephemeral agents) Default is ```json [ { \"operator\": \"Exists\" } ] ```\n\nSet to an empty array to have no tolerations at all", + "description": "Set pod tolerations. (not with ephemeral agents).\n\nDefaults to `operator: Exists`.\n\n```json [ { \"key\": \"meow\", \"operator\": \"Exists\", \"effect\": \"NoSchedule\" } ] ```\n\nSet to an empty array to have no tolerations at all", "type": [ "array", "null" @@ -1681,6 +1689,24 @@ }, "additionalProperties": false }, + "ServiceTarget": { + "type": "object", + "required": [ + "service" + ], + "properties": { + "container": { + "type": [ + "string", + "null" + ] + }, + "service": { + "type": "string" + } + }, + "additionalProperties": false + }, "SplitQueuesConfig": { "description": "```json { \"feature\": { \"split_queues\": { \"first-queue\": { \"queue_type\": \"SQS\", \"message_filter\": { \"wows\": \"so wows\", \"coolz\": \"^very\" } }, \"second-queue\": { \"queue_type\": \"SQS\", \"message_filter\": { \"who\": \"you$\" } }, \"third-queue\": { \"queue_type\": \"Kafka\", \"message_filter\": { \"who\": \"you$\" } }, \"fourth-queue\": { \"queue_type\": \"Kafka\", \"message_filter\": { \"wows\": \"so wows\", \"coolz\": \"^very\" } }, } } } ```", "type": "object", @@ -1707,10 +1733,10 @@ "additionalProperties": false }, "Target": { - "description": " ## path\n\nSpecifies the running pod (or deployment) to mirror.\n\nSupports: - `pod/{sample-pod}`; - `deployment/{sample-deployment}`; - `container/{sample-container}`; - `containername/{sample-container}`. - `job/{sample-job}`; - `cronjob/{sample-cronjob}`; - `statefulset/{sample-statefulset}`;", + "description": " ## path\n\nSpecifies the running pod (or deployment) to mirror.\n\nSupports: - `targetless` - `pod/{pod-name}[/container/{container-name}]`; - `deployment/{deployment-name}[/container/{container-name}]`; - `rollout/{rollout-name}[/container/{container-name}]`; - `job/{job-name}[/container/{container-name}]`; - `cronjob/{cronjob-name}[/container/{container-name}]`; - `statefulset/{statefulset-name}[/container/{container-name}]`; - `service/{service-name}[/container/{container-name}]`;", "anyOf": [ { - "description": " Mirror a deployment.", + "description": " [Deployment](https://kubernetes.io/docs/concepts/workloads/controllers/deployment/).", "allOf": [ { "$ref": "#/definitions/DeploymentTarget" @@ -1718,7 +1744,7 @@ ] }, { - "description": " Mirror a pod.", + "description": " [Pod](https://kubernetes.io/docs/concepts/workloads/pods/).", "allOf": [ { "$ref": "#/definitions/PodTarget" @@ -1726,7 +1752,7 @@ ] }, { - "description": " Mirror a rollout.", + "description": " [Argo Rollout](https://argoproj.github.io/argo-rollouts/#how-does-it-work).", "allOf": [ { "$ref": "#/definitions/RolloutTarget" @@ -1734,7 +1760,7 @@ ] }, { - "description": " Mirror a Job.\n\nOnly supported when `copy_target` is enabled.", + "description": " [Job](https://kubernetes.io/docs/concepts/workloads/controllers/job/).\n\nOnly supported when `copy_target` is enabled.", "allOf": [ { "$ref": "#/definitions/JobTarget" @@ -1742,7 +1768,7 @@ ] }, { - "description": " Targets a [CronJob](https://kubernetes.io/docs/concepts/workloads/controllers/cron-jobs/).\n\nOnly supported when `copy_target` is enabled.", + "description": " [CronJob](https://kubernetes.io/docs/concepts/workloads/controllers/cron-jobs/).\n\nOnly supported when `copy_target` is enabled.", "allOf": [ { "$ref": "#/definitions/CronJobTarget" @@ -1750,13 +1776,21 @@ ] }, { - "description": " Targets a [StatefulSet](https://kubernetes.io/docs/concepts/workloads/controllers/statefulset/).\n\nOnly supported when `copy_target` is enabled.", + "description": " [StatefulSet](https://kubernetes.io/docs/concepts/workloads/controllers/statefulset/).", "allOf": [ { "$ref": "#/definitions/StatefulSetTarget" } ] }, + { + "description": " [Service](https://kubernetes.io/docs/concepts/services-networking/service/).", + "allOf": [ + { + "$ref": "#/definitions/ServiceTarget" + } + ] + }, { "description": " Spawn a new pod.", "type": "null" diff --git a/mirrord/agent/Cargo.toml b/mirrord/agent/Cargo.toml index cdba788acf3..07757b7900e 100644 --- a/mirrord/agent/Cargo.toml +++ b/mirrord/agent/Cargo.toml @@ -69,6 +69,8 @@ x509-parser = "0.16" rustls.workspace = true envy = "0.4" socket2.workspace = true +prometheus = { version = "0.13", features = ["process"] } +axum = { version = "0.7", features = ["macros"] } iptables = { git = "https://github.com/metalbear-co/rust-iptables.git", rev = "e66c7332e361df3c61a194f08eefe3f40763d624" } rawsocket = { git = "https://github.com/metalbear-co/rawsocket.git" } procfs = "0.17.0" @@ -78,3 +80,4 @@ rstest.workspace = true mockall = "0.13" test_bin = "0.4" rcgen.workspace = true +reqwest.workspace = true diff --git a/mirrord/agent/README.md b/mirrord/agent/README.md index bf077b5fdcf..d7456ead64c 100644 --- a/mirrord/agent/README.md +++ b/mirrord/agent/README.md @@ -6,3 +6,198 @@ Agent part of [mirrord](https://github.com/metalbear-co/mirrord) responsible for mirrord-agent is written in Rust for safety, low memory consumption and performance. mirrord-agent is distributed as a container image (currently only x86) that is published on [GitHub Packages publicly](https://github.com/metalbear-co/mirrord-agent/pkgs/container/mirrord-agent). + +## Enabling prometheus metrics + +To start the metrics server, you'll need to add this config to your `mirrord.json`: + +```json +{ + "agent": { + "metrics": "0.0.0.0:9000", + "annotations": { + "prometheus.io/scrape": "true", + "prometheus.io/port": "9000" + } +} +``` + +Remember to change the `port` in both `metrics` and `annotations`, they have to match, +otherwise prometheus will try to scrape on `port: 80` or other commonly used ports. + +### Installing prometheus + +Run `kubectl apply -f {file-name}.yaml` on these sequences of `yaml` files and you should +get prometheus running in your cluster. You can access the dashboard from your browser at +`http://{cluster-ip}:30909`, if you're using minikube it might be +`http://192.168.49.2:30909`. + +You'll get prometheus running under the `monitoring` namespace, but it'll be able to look +into resources from all namespaces. The config in `configmap.yaml` sets prometheus to look +at pods only, if you want to use it to scrape other stuff, check +[this example](https://github.com/prometheus/prometheus/blob/main/documentation/examples/prometheus-kubernetes.yml). + +1. `create-namespace.yaml` + +```yaml +apiVersion: v1 +kind: Namespace +metadata: + name: monitoring +``` + +2. `cluster-role.yaml` + +```yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: prometheus +rules: +- apiGroups: [""] + resources: + - nodes + - services + - endpoints + - pods + verbs: ["get", "list", "watch"] +- apiGroups: + - extensions + resources: + - ingresses + verbs: ["get", "list", "watch"] +``` + +3. `service-account.yaml` + +```yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: prometheus + namespace: monitoring +``` + +4. `cluster-role-binding.yaml` + +```yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: prometheus +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: prometheus +subjects: +- kind: ServiceAccount + name: prometheus + namespace: monitoring +``` + +5. `configmap.yaml` + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: prometheus-config + namespace: monitoring +data: + prometheus.yml: | + global: + keep_dropped_targets: 100 + + scrape_configs: + - job_name: "kubernetes-pods" + + kubernetes_sd_configs: + - role: pod + + relabel_configs: + - source_labels: [__address__, __meta_kubernetes_pod_annotation_prometheus_io_port] + action: replace + regex: ([^:]+)(?::\d+)?;(\d+) + replacement: $1:$2 + target_label: __address__ + - action: labelmap + regex: __meta_kubernetes_pod_label_(.+) + - source_labels: [__meta_kubernetes_namespace] + action: replace + target_label: namespace + - source_labels: [__meta_kubernetes_pod_name] + action: replace + target_label: pod +``` + +- If you make any changes to the 5-configmap.yaml file, remember to `kubectl apply` it + **before** restarting the `prometheus` deployment. + +6. `deployment.yaml` + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: prometheus + namespace: monitoring + labels: + app: prometheus +spec: + replicas: 1 + strategy: + rollingUpdate: + maxSurge: 1 + maxUnavailable: 1 + type: RollingUpdate + selector: + matchLabels: + app: prometheus + template: + metadata: + labels: + app: prometheus + annotations: + prometheus.io/scrape: "true" + prometheus.io/port: "9090" + spec: + serviceAccountName: prometheus + containers: + - name: prometheus + image: prom/prometheus + args: + - '--config.file=/etc/prometheus/prometheus.yml' + ports: + - name: web + containerPort: 9090 + volumeMounts: + - name: prometheus-config-volume + mountPath: /etc/prometheus + restartPolicy: Always + volumes: + - name: prometheus-config-volume + configMap: + defaultMode: 420 + name: prometheus-config +``` + +7. `service.yaml` + +```yaml +apiVersion: v1 +kind: Service +metadata: + name: prometheus-service + namespace: monitoring + annotations: + prometheus.io/scrape: 'true' + prometheus.io/port: '9090' +spec: + selector: + app: prometheus + type: NodePort + ports: + - port: 8080 + targetPort: 9090 + nodePort: 30909 +``` diff --git a/mirrord/agent/src/cli.rs b/mirrord/agent/src/cli.rs index 07c8aa97503..bbcf23f1816 100644 --- a/mirrord/agent/src/cli.rs +++ b/mirrord/agent/src/cli.rs @@ -1,8 +1,11 @@ #![deny(missing_docs)] +use std::net::SocketAddr; + use clap::{Parser, Subcommand}; use mirrord_protocol::{ - MeshVendor, AGENT_IPV6_ENV, AGENT_NETWORK_INTERFACE_ENV, AGENT_OPERATOR_CERT_ENV, + MeshVendor, AGENT_IPV6_ENV, AGENT_METRICS_ENV, AGENT_NETWORK_INTERFACE_ENV, + AGENT_OPERATOR_CERT_ENV, }; const DEFAULT_RUNTIME: &str = "containerd"; @@ -28,6 +31,10 @@ pub struct Args { #[arg(short = 'i', long, env = AGENT_NETWORK_INTERFACE_ENV)] pub network_interface: Option, + /// Controls whether metrics are enabled, and the address to set up the metrics server. + #[arg(long, env = AGENT_METRICS_ENV)] + pub metrics: Option, + /// Return an error after accepting the first client connection, in order to test agent error /// cleanup. /// diff --git a/mirrord/agent/src/client_connection.rs b/mirrord/agent/src/client_connection.rs index 8181e4baabd..7b484cc25da 100644 --- a/mirrord/agent/src/client_connection.rs +++ b/mirrord/agent/src/client_connection.rs @@ -208,7 +208,7 @@ enum ConnectionFramed { #[cfg(test)] mod test { - use std::sync::Arc; + use std::sync::{Arc, Once}; use futures::StreamExt; use mirrord_protocol::ClientCodec; @@ -220,10 +220,19 @@ mod test { use super::*; + static CRYPTO_PROVIDER: Once = Once::new(); + /// Verifies that [`AgentTlsConnector`] correctly accepts a /// connection from a server using the provided certificate. #[tokio::test] async fn agent_tls_connector_valid_cert() { + CRYPTO_PROVIDER.call_once(|| { + rustls::crypto::CryptoProvider::install_default( + rustls::crypto::aws_lc_rs::default_provider(), + ) + .expect("Failed to install crypto provider") + }); + let cert = rcgen::generate_simple_self_signed(vec!["operator".to_string()]).unwrap(); let cert_bytes = cert.cert.der(); let key_bytes = cert.key_pair.serialize_der(); @@ -269,6 +278,13 @@ mod test { /// connection from a server using some other certificate. #[tokio::test] async fn agent_tls_connector_invalid_cert() { + CRYPTO_PROVIDER.call_once(|| { + rustls::crypto::CryptoProvider::install_default( + rustls::crypto::aws_lc_rs::default_provider(), + ) + .expect("Failed to install crypto provider") + }); + let server_cert = rcgen::generate_simple_self_signed(vec!["operator".to_string()]).unwrap(); let cert_bytes = server_cert.cert.der(); let key_bytes = server_cert.key_pair.serialize_der(); diff --git a/mirrord/agent/src/container_handle.rs b/mirrord/agent/src/container_handle.rs index 6e8ba78173d..dd6755e766d 100644 --- a/mirrord/agent/src/container_handle.rs +++ b/mirrord/agent/src/container_handle.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - error::Result, + error::AgentResult, runtime::{Container, ContainerInfo, ContainerRuntime}, }; @@ -22,7 +22,7 @@ pub(crate) struct ContainerHandle(Arc); impl ContainerHandle { /// Retrieve info about the container and initialize this struct. #[tracing::instrument(level = "trace")] - pub(crate) async fn new(container: Container) -> Result { + pub(crate) async fn new(container: Container) -> AgentResult { let ContainerInfo { pid, env: raw_env } = container.get_info().await?; let inner = Inner { pid, raw_env }; diff --git a/mirrord/agent/src/dns.rs b/mirrord/agent/src/dns.rs index 3240856275a..b92487594e0 100644 --- a/mirrord/agent/src/dns.rs +++ b/mirrord/agent/src/dns.rs @@ -16,10 +16,7 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::Level; -use crate::{ - error::{AgentError, Result}, - watched_task::TaskStatus, -}; +use crate::{error::AgentResult, metrics::DNS_REQUEST_COUNT, watched_task::TaskStatus}; #[derive(Debug)] pub(crate) enum ClientGetAddrInfoRequest { @@ -167,6 +164,9 @@ impl DnsWorker { let etc_path = self.etc_path.clone(); let timeout = self.timeout; let attempts = self.attempts; + + DNS_REQUEST_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let support_ipv6 = self.support_ipv6; let lookup_future = async move { let result = Self::do_lookup( @@ -181,15 +181,13 @@ impl DnsWorker { if let Err(result) = message.response_tx.send(result) { tracing::error!(?result, "Failed to send query response"); } + DNS_REQUEST_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); }; tokio::spawn(lookup_future); } - pub(crate) async fn run( - mut self, - cancellation_token: CancellationToken, - ) -> Result<(), AgentError> { + pub(crate) async fn run(mut self, cancellation_token: CancellationToken) -> AgentResult<()> { loop { tokio::select! { _ = cancellation_token.cancelled() => break Ok(()), @@ -225,7 +223,7 @@ impl DnsApi { pub(crate) async fn make_request( &mut self, request: ClientGetAddrInfoRequest, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { let (response_tx, response_rx) = oneshot::channel(); let command = DnsCommand { @@ -244,7 +242,7 @@ impl DnsApi { /// Returns the result of the oldest outstanding DNS request issued with this struct (see /// [`Self::make_request`]). #[tracing::instrument(level = Level::TRACE, skip(self), ret, err)] - pub(crate) async fn recv(&mut self) -> Result { + pub(crate) async fn recv(&mut self) -> AgentResult { let Some(response) = self.responses.next().await else { return future::pending().await; }; diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 407bf27c33f..ac9157897a0 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -12,6 +12,7 @@ use std::{ use client_connection::AgentTlsConnector; use dns::{ClientGetAddrInfoRequest, DnsCommand, DnsWorker}; use futures::TryFutureExt; +use metrics::{start_metrics, CLIENT_COUNT}; use mirrord_protocol::{ClientMessage, DaemonMessage, GetEnvVarsRequest, LogMessage}; use sniffer::tcp_capture::RawSocketTcpCapture; use tokio::{ @@ -32,7 +33,7 @@ use crate::{ client_connection::ClientConnection, container_handle::ContainerHandle, dns::DnsApi, - error::{AgentError, Result}, + error::{AgentError, AgentResult}, file::FileManager, outgoing::{TcpOutgoingApi, UdpOutgoingApi}, runtime::get_container, @@ -72,7 +73,7 @@ struct State { impl State { /// Return [`Err`] if container runtime operations failed. - pub async fn new(args: &Args) -> Result { + pub async fn new(args: &Args) -> AgentResult { let tls_connector = args .operator_tls_cert_pem .clone() @@ -205,6 +206,12 @@ struct ClientConnectionHandler { ready_for_logs: bool, } +impl Drop for ClientConnectionHandler { + fn drop(&mut self) { + CLIENT_COUNT.fetch_sub(1, Ordering::Relaxed); + } +} + impl ClientConnectionHandler { /// Initializes [`ClientConnectionHandler`]. pub async fn new( @@ -212,7 +219,7 @@ impl ClientConnectionHandler { mut connection: ClientConnection, bg_tasks: BackgroundTasks, state: State, - ) -> Result { + ) -> AgentResult { let pid = state.container_pid(); let file_manager = FileManager::new(pid.or_else(|| state.ephemeral.then_some(1))); @@ -238,6 +245,8 @@ impl ClientConnectionHandler { ready_for_logs: false, }; + CLIENT_COUNT.fetch_add(1, Ordering::Relaxed); + Ok(client_handler) } @@ -273,7 +282,7 @@ impl ClientConnectionHandler { id: ClientId, task: BackgroundTask, connection: &mut ClientConnection, - ) -> Result> { + ) -> AgentResult> { if let BackgroundTask::Running(stealer_status, stealer_sender) = task { match TcpStealerApi::new( id, @@ -313,7 +322,7 @@ impl ClientConnectionHandler { /// /// Breaks upon receiver/sender drop. #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, cancellation_token: CancellationToken) -> Result<()> { + async fn start(mut self, cancellation_token: CancellationToken) -> AgentResult<()> { let error = loop { select! { message = self.connection.receive() => { @@ -364,7 +373,7 @@ impl ClientConnectionHandler { Ok(message) => self.respond(DaemonMessage::TcpOutgoing(message)).await?, Err(e) => break e, }, - message = self.udp_outgoing_api.daemon_message() => match message { + message = self.udp_outgoing_api.recv_from_task() => match message { Ok(message) => self.respond(DaemonMessage::UdpOutgoing(message)).await?, Err(e) => break e, }, @@ -389,7 +398,7 @@ impl ClientConnectionHandler { /// Sends a [`DaemonMessage`] response to the connected client (`mirrord-layer`). #[tracing::instrument(level = "trace", skip(self))] - async fn respond(&mut self, response: DaemonMessage) -> Result<()> { + async fn respond(&mut self, response: DaemonMessage) -> AgentResult<()> { self.connection.send(response).await.map_err(Into::into) } @@ -397,7 +406,7 @@ impl ClientConnectionHandler { /// /// Returns `false` if the client disconnected. #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::DEBUG))] - async fn handle_client_message(&mut self, message: ClientMessage) -> Result { + async fn handle_client_message(&mut self, message: ClientMessage) -> AgentResult { match message { ClientMessage::FileRequest(req) => { if let Some(response) = self.file_manager.handle_message(req)? { @@ -415,7 +424,7 @@ impl ClientConnectionHandler { self.tcp_outgoing_api.send_to_task(layer_message).await? } ClientMessage::UdpOutgoing(layer_message) => { - self.udp_outgoing_api.layer_message(layer_message).await? + self.udp_outgoing_api.send_to_task(layer_message).await? } ClientMessage::GetEnvVarsRequest(GetEnvVarsRequest { env_vars_filter, @@ -495,8 +504,8 @@ impl ClientConnectionHandler { } /// Initializes the agent's [`State`], channels, threads, and runs [`ClientConnectionHandler`]s. -#[tracing::instrument(level = "trace", ret)] -async fn start_agent(args: Args) -> Result<()> { +#[tracing::instrument(level = Level::TRACE, ret, err)] +async fn start_agent(args: Args) -> AgentResult<()> { trace!("start_agent -> Starting agent with args: {args:?}"); // listen for client connections @@ -534,6 +543,18 @@ async fn start_agent(args: Args) -> Result<()> { // To make sure that background tasks are cancelled when we exit early from this function. let cancel_guard = cancellation_token.clone().drop_guard(); + if let Some(metrics_address) = args.metrics { + let cancellation_token = cancellation_token.clone(); + tokio::spawn(async move { + start_metrics(metrics_address, cancellation_token.clone()) + .await + .inspect_err(|fail| { + tracing::error!(?fail, "Failed starting metrics server!"); + cancellation_token.cancel(); + }) + }); + } + let (sniffer_command_tx, sniffer_command_rx) = mpsc::channel::(1000); let (stealer_command_tx, stealer_command_rx) = mpsc::channel::(1000); let (dns_command_tx, dns_command_rx) = mpsc::channel::(1000); @@ -755,7 +776,7 @@ async fn start_agent(args: Args) -> Result<()> { Ok(()) } -async fn clear_iptable_chain() -> Result<()> { +async fn clear_iptable_chain() -> AgentResult<()> { let ipt = new_iptables(); SafeIpTables::load(IPTablesWrapper::from(ipt), false) @@ -766,7 +787,7 @@ async fn clear_iptable_chain() -> Result<()> { Ok(()) } -async fn run_child_agent() -> Result<()> { +async fn run_child_agent() -> AgentResult<()> { let command_args = std::env::args().collect::>(); let (command, args) = command_args .split_first() @@ -790,7 +811,7 @@ async fn run_child_agent() -> Result<()> { /// /// Captures SIGTERM signals sent by Kubernetes when the pod is gracefully deleted. /// When a signal is captured, the child process is killed and the iptables are cleaned. -async fn start_iptable_guard(args: Args) -> Result<()> { +async fn start_iptable_guard(args: Args) -> AgentResult<()> { debug!("start_iptable_guard -> Initializing iptable-guard."); let state = State::new(&args).await?; @@ -827,7 +848,18 @@ async fn start_iptable_guard(args: Args) -> Result<()> { result } -pub async fn main() -> Result<()> { +/// The agent is somewhat started twice, first with [`start_iptable_guard`], and then the +/// proper agent with [`start_agent`]. +/// +/// ## Things to keep in mind due to the double initialization +/// +/// Since the _second_ agent gets spawned as a child of the _first_, they share resources, +/// like the `namespace`, which means: +/// +/// 1. If you try to `bind` a socket to some address before [`start_agent`], it'll actually be bound +/// **twice**, which incurs an error (address already in use). You could get around this by +/// `bind`ing on `0.0.0.0:0`, but this is most likely **not** what you want. +pub async fn main() -> AgentResult<()> { rustls::crypto::CryptoProvider::install_default(rustls::crypto::aws_lc_rs::default_provider()) .expect("Failed to install crypto provider"); diff --git a/mirrord/agent/src/env.rs b/mirrord/agent/src/env.rs index 26fa4681431..5a349709f2d 100644 --- a/mirrord/agent/src/env.rs +++ b/mirrord/agent/src/env.rs @@ -7,7 +7,7 @@ use mirrord_protocol::RemoteResult; use tokio::io::AsyncReadExt; use wildmatch::WildMatch; -use crate::error::Result; +use crate::error::AgentResult; struct EnvFilter { include: Vec, @@ -97,7 +97,7 @@ pub(crate) fn parse_raw_env<'a, S: AsRef + 'a + ?Sized, T: IntoIterator>() } -pub(crate) async fn get_proc_environ(path: PathBuf) -> Result> { +pub(crate) async fn get_proc_environ(path: PathBuf) -> AgentResult> { let mut environ_file = tokio::fs::File::open(path).await?; let mut raw_env_vars = String::with_capacity(8192); diff --git a/mirrord/agent/src/error.rs b/mirrord/agent/src/error.rs index d9ae7cb8b9d..88b811e590b 100644 --- a/mirrord/agent/src/error.rs +++ b/mirrord/agent/src/error.rs @@ -96,4 +96,4 @@ impl From> for AgentError { } } -pub(crate) type Result = std::result::Result; +pub(crate) type AgentResult = std::result::Result; diff --git a/mirrord/agent/src/file.rs b/mirrord/agent/src/file.rs index 0bc30afb151..571b2ad9d3c 100644 --- a/mirrord/agent/src/file.rs +++ b/mirrord/agent/src/file.rs @@ -18,7 +18,7 @@ use mirrord_protocol::{file::*, FileRequest, FileResponse, RemoteResult, Respons use nix::unistd::UnlinkatFlags; use tracing::{error, trace, Level}; -use crate::error::Result; +use crate::{error::AgentResult, metrics::OPEN_FD_COUNT}; #[derive(Debug)] pub enum RemoteFile { @@ -76,15 +76,11 @@ pub(crate) struct FileManager { fds_iter: RangeInclusive, } -impl Default for FileManager { - fn default() -> Self { - Self { - root_path: Default::default(), - open_files: Default::default(), - dir_streams: Default::default(), - getdents_streams: Default::default(), - fds_iter: (0..=u64::MAX), - } +impl Drop for FileManager { + fn drop(&mut self) { + let descriptors = + self.open_files.len() + self.dir_streams.len() + self.getdents_streams.len(); + OPEN_FD_COUNT.fetch_sub(descriptors as i64, std::sync::atomic::Ordering::Relaxed); } } @@ -152,7 +148,10 @@ pub fn resolve_path + std::fmt::Debug, R: AsRef + std::fmt: impl FileManager { /// Executes the request and returns the response. #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::DEBUG))] - pub fn handle_message(&mut self, request: FileRequest) -> Result> { + pub(crate) fn handle_message( + &mut self, + request: FileRequest, + ) -> AgentResult> { Ok(match request { FileRequest::Open(OpenFileRequest { path, open_options }) => { // TODO: maybe not agent error on this? @@ -206,10 +205,7 @@ impl FileManager { let write_result = self.write_limited(remote_fd, start_from, write_bytes); Some(FileResponse::WriteLimited(write_result)) } - FileRequest::Close(CloseFileRequest { fd }) => { - self.close(fd); - None - } + FileRequest::Close(CloseFileRequest { fd }) => self.close(fd), FileRequest::Access(AccessFileRequest { pathname, mode }) => { let pathname = pathname .strip_prefix("/") @@ -227,8 +223,12 @@ impl FileManager { Some(FileResponse::Xstat(xstat_result)) } FileRequest::XstatFs(XstatFsRequest { fd }) => { - let xstat_result = self.xstatfs(fd); - Some(FileResponse::XstatFs(xstat_result)) + let xstatfs_result = self.xstatfs(fd); + Some(FileResponse::XstatFs(xstatfs_result)) + } + FileRequest::StatFs(StatFsRequest { path }) => { + let statfs_result = self.statfs(path); + Some(FileResponse::XstatFs(statfs_result)) } // dir operations @@ -244,10 +244,7 @@ impl FileManager { let read_dir_result = self.read_dir_batch(remote_fd, amount); Some(FileResponse::ReadDirBatch(read_dir_result)) } - FileRequest::CloseDir(CloseDirRequest { remote_fd }) => { - self.close_dir(remote_fd); - None - } + FileRequest::CloseDir(CloseDirRequest { remote_fd }) => self.close_dir(remote_fd), FileRequest::GetDEnts64(GetDEnts64Request { remote_fd, buffer_size, @@ -280,10 +277,13 @@ impl FileManager { pub fn new(pid: Option) -> Self { let root_path = get_root_path_from_optional_pid(pid); trace!("Agent root path >> {root_path:?}"); + Self { - open_files: HashMap::new(), root_path, - ..Default::default() + open_files: Default::default(), + dir_streams: Default::default(), + getdents_streams: Default::default(), + fds_iter: (0..=u64::MAX), } } @@ -309,7 +309,9 @@ impl FileManager { RemoteFile::File(file) }; - self.open_files.insert(fd, remote_file); + if self.open_files.insert(fd, remote_file).is_none() { + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } Ok(OpenFileResponse { fd }) } @@ -343,7 +345,9 @@ impl FileManager { RemoteFile::File(file) }; - self.open_files.insert(fd, remote_file); + if self.open_files.insert(fd, remote_file).is_none() { + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } Ok(OpenFileResponse { fd }) } else { @@ -636,20 +640,36 @@ impl FileManager { }) } - pub(crate) fn close(&mut self, fd: u64) { - trace!("FileManager::close -> fd {:#?}", fd,); - + /// Always returns `None`, since we don't return any [`FileResponse`] back to mirrord + /// on `close` of an fd. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn close(&mut self, fd: u64) -> Option { if self.open_files.remove(&fd).is_none() { - error!("FileManager::close -> fd {:#?} not found", fd); + error!(fd, "fd not found!"); + } else { + OPEN_FD_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); } - } - pub(crate) fn close_dir(&mut self, fd: u64) { - trace!("FileManager::close_dir -> fd {:#?}", fd,); + None + } - if self.dir_streams.remove(&fd).is_none() && self.getdents_streams.remove(&fd).is_none() { + /// Always returns `None`, since we don't return any [`FileResponse`] back to mirrord + /// on `close_dir` of an fd. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn close_dir(&mut self, fd: u64) -> Option { + let closed_dir_stream = self.dir_streams.remove(&fd); + let closed_getdents_stream = self.getdents_streams.remove(&fd); + + if closed_dir_stream.is_some() && closed_getdents_stream.is_some() { + // Closed `dirstream` and `dentsstream` + OPEN_FD_COUNT.fetch_sub(2, std::sync::atomic::Ordering::Relaxed); + } else if closed_dir_stream.is_some() || closed_getdents_stream.is_some() { + OPEN_FD_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } else { error!("FileManager::close_dir -> fd {:#?} not found", fd); } + + None } pub(crate) fn access( @@ -754,6 +774,18 @@ impl FileManager { } #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn statfs(&mut self, path: PathBuf) -> RemoteResult { + let path = resolve_path(path, &self.root_path)?; + + let statfs = nix::sys::statfs::statfs(&path) + .map_err(|err| std::io::Error::from_raw_os_error(err as i32))?; + + Ok(XstatFsResponse { + metadata: statfs.into(), + }) + } + + #[tracing::instrument(level = Level::TRACE, skip(self), err(level = Level::DEBUG))] pub(crate) fn fdopen_dir(&mut self, fd: u64) -> RemoteResult { let path = match self .open_files @@ -770,7 +802,10 @@ impl FileManager { .ok_or_else(|| ResponseError::IdsExhausted("fdopen_dir".to_string()))?; let dir_stream = path.read_dir()?.enumerate(); - self.dir_streams.insert(fd, dir_stream); + + if self.dir_streams.insert(fd, dir_stream).is_none() { + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } Ok(OpenDirResponse { fd }) } @@ -819,7 +854,7 @@ impl FileManager { /// The possible remote errors are: /// [`ResponseError::NotFound`] if there is not such fd here. /// [`ResponseError::NotDirectory`] if the fd points to a file with a non-directory file type. - #[tracing::instrument(level = "trace", skip(self))] + #[tracing::instrument(level = Level::TRACE, skip(self))] pub(crate) fn get_or_create_getdents64_stream( &mut self, fd: u64, @@ -832,6 +867,7 @@ impl FileManager { let current_and_parent = Self::get_current_and_parent_entries(dir); let stream = GetDEnts64Stream::new(dir.read_dir()?, current_and_parent).peekable(); + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); Ok(e.insert(stream)) } }, diff --git a/mirrord/agent/src/main.rs b/mirrord/agent/src/main.rs index 305ec50e0ed..e9f3e107907 100644 --- a/mirrord/agent/src/main.rs +++ b/mirrord/agent/src/main.rs @@ -22,6 +22,7 @@ mod env; mod error; mod file; mod http; +mod metrics; mod namespace; mod outgoing; mod runtime; @@ -31,7 +32,8 @@ mod util; mod vpn; mod watched_task; +#[cfg(target_os = "linux")] #[tokio::main(flavor = "current_thread")] -async fn main() -> crate::error::Result<()> { +async fn main() -> crate::error::AgentResult<()> { crate::entrypoint::main().await } diff --git a/mirrord/agent/src/metrics.rs b/mirrord/agent/src/metrics.rs new file mode 100644 index 00000000000..59fefbc2305 --- /dev/null +++ b/mirrord/agent/src/metrics.rs @@ -0,0 +1,366 @@ +use std::{ + net::SocketAddr, + sync::{atomic::AtomicI64, Arc}, +}; + +use axum::{extract::State, routing::get, Router}; +use http::StatusCode; +use prometheus::{proto::MetricFamily, IntGauge, Registry}; +use tokio::net::TcpListener; +use tokio_util::sync::CancellationToken; +use tracing::Level; + +use crate::error::AgentError; + +/// Incremented whenever we get a new client in `ClientConnectionHandler`, and decremented +/// when this client is dropped. +pub(crate) static CLIENT_COUNT: AtomicI64 = AtomicI64::new(0); + +/// Incremented whenever we handle a new `DnsCommand`, and decremented after the result of +/// `do_lookup` has been sent back through the response channel. +pub(crate) static DNS_REQUEST_COUNT: AtomicI64 = AtomicI64::new(0); + +/// Incremented and decremented in _open-ish_/_close-ish_ file operations in `FileManager`, +/// Also gets decremented when `FileManager` is dropped. +pub(crate) static OPEN_FD_COUNT: AtomicI64 = AtomicI64::new(0); + +/// Follows the amount of subscribed ports in `update_packet_filter`. We don't really +/// increment/decrement this one, and mostly `set` it to the latest amount of ports, zeroing it when +/// the `TcpConnectionSniffer` gets dropped. +pub(crate) static MIRROR_PORT_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static MIRROR_CONNECTION_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_FILTERED_PORT_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_UNFILTERED_PORT_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_FILTERED_CONNECTION_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static HTTP_REQUEST_IN_PROGRESS_COUNT: AtomicI64 = AtomicI64::new(0); + +pub(crate) static TCP_OUTGOING_CONNECTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static UDP_OUTGOING_CONNECTION: AtomicI64 = AtomicI64::new(0); + +/// The state with all the metrics [`IntGauge`]s and the prometheus [`Registry`] where we keep them. +/// +/// **Do not** modify the gauges directly! +/// +/// Instead rely on [`Metrics::gather_metrics`], as we actually use a bunch of [`AtomicI64`]s to +/// keep track of the values, they are the ones being (de|in)cremented. These gauges are just set +/// when it's time to send them via [`get_metrics`]. +#[derive(Debug)] +struct Metrics { + registry: Registry, + client_count: IntGauge, + dns_request_count: IntGauge, + open_fd_count: IntGauge, + mirror_port_subscription: IntGauge, + mirror_connection_subscription: IntGauge, + steal_filtered_port_subscription: IntGauge, + steal_unfiltered_port_subscription: IntGauge, + steal_filtered_connection_subscription: IntGauge, + steal_unfiltered_connection_subscription: IntGauge, + http_request_in_progress_count: IntGauge, + tcp_outgoing_connection: IntGauge, + udp_outgoing_connection: IntGauge, +} + +impl Metrics { + /// Creates a [`Registry`] to ... register our [`IntGauge`]s. + fn new() -> Self { + use prometheus::Opts; + + let registry = Registry::new(); + + let client_count = { + let opts = Opts::new( + "mirrord_agent_client_count", + "amount of connected clients to this mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let dns_request_count = { + let opts = Opts::new( + "mirrord_agent_dns_request_count", + "amount of in-progress dns requests in the mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let open_fd_count = { + let opts = Opts::new( + "mirrord_agent_open_fd_count", + "amount of open file descriptors in mirrord-agent file manager", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let mirror_port_subscription = { + let opts = Opts::new( + "mirrord_agent_mirror_port_subscription_count", + "amount of mirror port subscriptions in mirror-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let mirror_connection_subscription = { + let opts = Opts::new( + "mirrord_agent_mirror_connection_subscription_count", + "amount of connections in mirror mode in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_filtered_port_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_filtered_port_subscription_count", + "amount of filtered steal port subscriptions in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_unfiltered_port_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_unfiltered_port_subscription_count", + "amount of unfiltered steal port subscriptions in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_filtered_connection_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_connection_subscription_count", + "amount of filtered connections in steal mode in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_unfiltered_connection_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_unfiltered_connection_subscription_count", + "amount of unfiltered connections in steal mode in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let http_request_in_progress_count = { + let opts = Opts::new( + "mirrord_agent_http_request_in_progress_count", + "amount of in-progress http requests in the mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let tcp_outgoing_connection = { + let opts = Opts::new( + "mirrord_agent_tcp_outgoing_connection_count", + "amount of tcp outgoing connections in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let udp_outgoing_connection = { + let opts = Opts::new( + "mirrord_agent_udp_outgoing_connection_count", + "amount of udp outgoing connections in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + registry + .register(Box::new(client_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(dns_request_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(open_fd_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(mirror_port_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(mirror_connection_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_filtered_port_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_unfiltered_port_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_filtered_connection_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_unfiltered_connection_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(http_request_in_progress_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(tcp_outgoing_connection.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(udp_outgoing_connection.clone())) + .expect("Register must be valid at initialization!"); + + Self { + registry, + client_count, + dns_request_count, + open_fd_count, + mirror_port_subscription, + mirror_connection_subscription, + steal_filtered_port_subscription, + steal_unfiltered_port_subscription, + steal_filtered_connection_subscription, + steal_unfiltered_connection_subscription, + http_request_in_progress_count, + tcp_outgoing_connection, + udp_outgoing_connection, + } + } + + /// Calls [`IntGauge::set`] on every [`IntGauge`] of `Self`, setting it to the value of + /// the corresponding [`AtomicI64`] global (the uppercase named version of the gauge). + /// + /// Returns the list of [`MetricFamily`] registered in our [`Metrics::registry`], ready to be + /// encoded and sent to prometheus. + fn gather_metrics(&self) -> Vec { + use std::sync::atomic::Ordering; + + let Self { + registry, + client_count, + dns_request_count, + open_fd_count, + mirror_port_subscription, + mirror_connection_subscription, + steal_filtered_port_subscription, + steal_unfiltered_port_subscription, + steal_filtered_connection_subscription, + steal_unfiltered_connection_subscription, + http_request_in_progress_count, + tcp_outgoing_connection, + udp_outgoing_connection, + } = self; + + client_count.set(CLIENT_COUNT.load(Ordering::Relaxed)); + dns_request_count.set(DNS_REQUEST_COUNT.load(Ordering::Relaxed)); + open_fd_count.set(OPEN_FD_COUNT.load(Ordering::Relaxed)); + mirror_port_subscription.set(MIRROR_PORT_SUBSCRIPTION.load(Ordering::Relaxed)); + mirror_connection_subscription.set(MIRROR_CONNECTION_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_filtered_port_subscription + .set(STEAL_FILTERED_PORT_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_unfiltered_port_subscription + .set(STEAL_UNFILTERED_PORT_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_filtered_connection_subscription + .set(STEAL_FILTERED_CONNECTION_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_unfiltered_connection_subscription + .set(STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.load(Ordering::Relaxed)); + http_request_in_progress_count.set(HTTP_REQUEST_IN_PROGRESS_COUNT.load(Ordering::Relaxed)); + tcp_outgoing_connection.set(TCP_OUTGOING_CONNECTION.load(Ordering::Relaxed)); + udp_outgoing_connection.set(UDP_OUTGOING_CONNECTION.load(Ordering::Relaxed)); + + registry.gather() + } +} + +/// `GET /metrics` +/// +/// Prepares all the metrics with [`Metrics::gather_metrics`], and responds to the prometheus +/// request. +#[tracing::instrument(level = Level::TRACE, ret)] +async fn get_metrics(State(state): State>) -> (StatusCode, String) { + use prometheus::TextEncoder; + + let metric_families = state.gather_metrics(); + match TextEncoder.encode_to_string(&metric_families) { + Ok(response) => (StatusCode::OK, response), + Err(fail) => { + tracing::error!(?fail, "Failed GET /metrics"); + (StatusCode::INTERNAL_SERVER_ERROR, fail.to_string()) + } + } +} + +/// Starts the mirrord-agent prometheus metrics service. +/// +/// You can get the metrics from `GET address/metrics`. +/// +/// - `address`: comes from a mirrord-agent config. +#[tracing::instrument(level = Level::TRACE, skip_all, ret ,err)] +pub(crate) async fn start_metrics( + address: SocketAddr, + cancellation_token: CancellationToken, +) -> Result<(), axum::BoxError> { + let metrics_state = Arc::new(Metrics::new()); + + let app = Router::new() + .route("/metrics", get(get_metrics)) + .with_state(metrics_state); + + let listener = TcpListener::bind(address) + .await + .map_err(AgentError::from) + .inspect_err(|fail| { + tracing::error!(?fail, "Failed to bind TCP socket for metrics server") + })?; + + let cancel_on_error = cancellation_token.clone(); + axum::serve(listener, app) + .with_graceful_shutdown(async move { cancellation_token.cancelled().await }) + .await + .inspect_err(|fail| { + tracing::error!(%fail, "Could not start agent metrics server!"); + cancel_on_error.cancel(); + })?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::{sync::atomic::Ordering, time::Duration}; + + use tokio_util::sync::CancellationToken; + + use super::OPEN_FD_COUNT; + use crate::metrics::start_metrics; + + #[tokio::test] + async fn test_metrics() { + let metrics_address = "127.0.0.1:9000".parse().unwrap(); + let cancellation_token = CancellationToken::new(); + + let metrics_cancellation = cancellation_token.child_token(); + tokio::spawn(async move { + start_metrics(metrics_address, metrics_cancellation) + .await + .unwrap() + }); + + OPEN_FD_COUNT.fetch_add(1, Ordering::Relaxed); + + // Give the server some time to start. + tokio::time::sleep(Duration::from_secs(1)).await; + + let get_all_metrics = reqwest::get("http://127.0.0.1:9000/metrics") + .await + .unwrap() + .error_for_status() + .unwrap() + .text() + .await + .unwrap(); + + assert!(get_all_metrics.contains("mirrord_agent_open_fd_count 1")); + + cancellation_token.drop_guard(); + } +} diff --git a/mirrord/agent/src/outgoing.rs b/mirrord/agent/src/outgoing.rs index 13e3a9e1e06..96a063d7a05 100644 --- a/mirrord/agent/src/outgoing.rs +++ b/mirrord/agent/src/outgoing.rs @@ -18,7 +18,8 @@ use tokio_util::io::ReaderStream; use tracing::Level; use crate::{ - error::Result, + error::AgentResult, + metrics::TCP_OUTGOING_CONNECTION, util::run_thread_in_namespace, watched_task::{TaskStatus, WatchedTask}, }; @@ -81,7 +82,7 @@ impl TcpOutgoingApi { /// Sends the [`LayerTcpOutgoing`] message to the background task. #[tracing::instrument(level = Level::TRACE, skip(self), err)] - pub(crate) async fn send_to_task(&mut self, message: LayerTcpOutgoing) -> Result<()> { + pub(crate) async fn send_to_task(&mut self, message: LayerTcpOutgoing) -> AgentResult<()> { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { @@ -91,7 +92,7 @@ impl TcpOutgoingApi { /// Receives a [`DaemonTcpOutgoing`] message from the background task. #[tracing::instrument(level = Level::TRACE, skip(self), err)] - pub(crate) async fn recv_from_task(&mut self) -> Result { + pub(crate) async fn recv_from_task(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), None => Err(self.task_status.unwrap_err().await), @@ -112,6 +113,13 @@ struct TcpOutgoingTask { daemon_tx: Sender, } +impl Drop for TcpOutgoingTask { + fn drop(&mut self) { + let connections = self.readers.keys().chain(self.writers.keys()).count(); + TCP_OUTGOING_CONNECTION.fetch_sub(connections as i64, std::sync::atomic::Ordering::Relaxed); + } +} + impl fmt::Debug for TcpOutgoingTask { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TcpOutgoingTask") @@ -152,7 +160,7 @@ impl TcpOutgoingTask { /// Runs this task as long as the channels connecting it with [`TcpOutgoingApi`] are open. /// This routine never fails and returns [`Result`] only due to [`WatchedTask`] constraints. #[tracing::instrument(level = Level::TRACE, skip(self))] - async fn run(mut self) -> Result<()> { + async fn run(mut self) -> AgentResult<()> { loop { let channel_closed = select! { biased; @@ -216,6 +224,7 @@ impl TcpOutgoingTask { self.readers.remove(&connection_id); self.writers.remove(&connection_id); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); let daemon_message = DaemonTcpOutgoing::Close(connection_id); self.daemon_tx.send(daemon_message).await?; @@ -246,6 +255,8 @@ impl TcpOutgoingTask { "Layer connection is shut down as well, sending close message.", ); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.daemon_tx .send(DaemonTcpOutgoing::Close(connection_id)) .await?; @@ -287,6 +298,7 @@ impl TcpOutgoingTask { connection_id, ReaderStream::with_capacity(read_half, Self::READ_BUFFER_SIZE), ); + TCP_OUTGOING_CONNECTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); Ok(DaemonConnect { connection_id, @@ -299,9 +311,12 @@ impl TcpOutgoingTask { result = ?daemon_connect, "Connection attempt finished.", ); + self.daemon_tx .send(DaemonTcpOutgoing::Connect(daemon_connect)) - .await + .await?; + + Ok(()) } // This message handles two cases: @@ -341,9 +356,14 @@ impl TcpOutgoingTask { connection_id, "Peer connection is shut down as well, sending close message to the client.", ); + TCP_OUTGOING_CONNECTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.daemon_tx .send(DaemonTcpOutgoing::Close(connection_id)) - .await + .await?; + + Ok(()) } } @@ -352,6 +372,7 @@ impl TcpOutgoingTask { Err(error) => { self.writers.remove(&connection_id); self.readers.remove(&connection_id); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); tracing::trace!( connection_id, @@ -360,7 +381,9 @@ impl TcpOutgoingTask { ); self.daemon_tx .send(DaemonTcpOutgoing::Close(connection_id)) - .await + .await?; + + Ok(()) } } } @@ -370,6 +393,7 @@ impl TcpOutgoingTask { LayerTcpOutgoing::Close(LayerClose { connection_id }) => { self.writers.remove(&connection_id); self.readers.remove(&connection_id); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); Ok(()) } diff --git a/mirrord/agent/src/outgoing/udp.rs b/mirrord/agent/src/outgoing/udp.rs index b6baa5e537e..0dab137a92b 100644 --- a/mirrord/agent/src/outgoing/udp.rs +++ b/mirrord/agent/src/outgoing/udp.rs @@ -1,10 +1,11 @@ +use core::fmt; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, thread, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::{ prelude::*, stream::{SplitSink, SplitStream}, @@ -15,21 +16,262 @@ use mirrord_protocol::{ }; use streammap_ext::StreamMap; use tokio::{ + io, net::UdpSocket, select, - sync::mpsc::{self, Receiver, Sender}, + sync::mpsc::{self, error::SendError, Receiver, Sender}, }; use tokio_util::{codec::BytesCodec, udp::UdpFramed}; -use tracing::{debug, trace, warn}; +use tracing::Level; use crate::{ - error::Result, + error::AgentResult, + metrics::UDP_OUTGOING_CONNECTION, util::run_thread_in_namespace, watched_task::{TaskStatus, WatchedTask}, }; -type Layer = LayerUdpOutgoing; -type Daemon = DaemonUdpOutgoing; +/// Task that handles [`LayerUdpOutgoing`] and [`DaemonUdpOutgoing`] messages. +/// +/// We start these tasks from the [`UdpOutgoingApi`] as a [`WatchedTask`]. +struct UdpOutgoingTask { + next_connection_id: ConnectionId, + /// Writing halves of peer connections made on layer's requests. + #[allow(clippy::type_complexity)] + writers: HashMap< + ConnectionId, + ( + SplitSink, (BytesMut, SocketAddr)>, + SocketAddr, + ), + >, + /// Reading halves of peer connections made on layer's requests. + readers: StreamMap>>, + /// Optional pid of agent's target. Used in `SocketStream::connect`. + pid: Option, + layer_rx: Receiver, + daemon_tx: Sender, +} + +impl Drop for UdpOutgoingTask { + fn drop(&mut self) { + let connections = self.readers.keys().chain(self.writers.keys()).count(); + UDP_OUTGOING_CONNECTION.fetch_sub(connections as i64, std::sync::atomic::Ordering::Relaxed); + } +} + +impl fmt::Debug for UdpOutgoingTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UdpOutgoingTask") + .field("next_connection_id", &self.next_connection_id) + .field("writers", &self.writers.len()) + .field("readers", &self.readers.len()) + .field("pid", &self.pid) + .finish() + } +} + +impl UdpOutgoingTask { + fn new( + pid: Option, + layer_rx: Receiver, + daemon_tx: Sender, + ) -> Self { + Self { + next_connection_id: 0, + writers: Default::default(), + readers: Default::default(), + pid, + layer_rx, + daemon_tx, + } + } + + /// Runs this task as long as the channels connecting it with [`UdpOutgoingApi`] are open. + /// This routine never fails and returns [`AgentResult`] only due to [`WatchedTask`] + /// constraints. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(super) async fn run(mut self) -> AgentResult<()> { + loop { + let channel_closed = select! { + biased; + + message = self.layer_rx.recv() => match message { + // We have a message from the layer to be handled. + Some(message) => { + self.handle_layer_msg(message).await.is_err() + }, + // Our channel with the layer is closed, this task is no longer needed. + None => true, + }, + + // We have data coming from one of our peers. + Some((connection_id, remote_read)) = self.readers.next() => { + self.handle_connection_read(connection_id, remote_read.transpose().map(|remote| remote.map(|(read, _)| read.into()))).await.is_err() + }, + }; + + if channel_closed { + tracing::trace!("Client channel closed, exiting"); + break Ok(()); + } + } + } + + /// Returns [`Err`] only when the client has disconnected. + #[tracing::instrument( + level = Level::TRACE, + skip(read), + fields(read = ?read.as_ref().map(|data| data.as_ref().map(Bytes::len).unwrap_or_default())) + err(level = Level::TRACE) + )] + async fn handle_connection_read( + &mut self, + connection_id: ConnectionId, + read: io::Result>, + ) -> Result<(), SendError> { + match read { + Ok(Some(read)) => { + let message = DaemonUdpOutgoing::Read(Ok(DaemonRead { + connection_id, + bytes: read.to_vec(), + })); + + self.daemon_tx.send(message).await? + } + // An error occurred when reading from a peer connection. + // We remove both io halves and inform the layer that the connection is closed. + // We remove the reader, because otherwise the `StreamMap` will produce an extra `None` + // item from the related stream. + Err(error) => { + tracing::trace!( + ?error, + connection_id, + "Reading from peer connection failed, sending close message.", + ); + + self.readers.remove(&connection_id); + self.writers.remove(&connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + let daemon_message = DaemonUdpOutgoing::Close(connection_id); + self.daemon_tx.send(daemon_message).await?; + } + Ok(None) => { + self.writers.remove(&connection_id); + self.readers.remove(&connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + let daemon_message = DaemonUdpOutgoing::Close(connection_id); + self.daemon_tx.send(daemon_message).await?; + } + } + + Ok(()) + } + + /// Returns [`Err`] only when the client has disconnected. + #[allow(clippy::type_complexity)] + #[tracing::instrument(level = Level::TRACE, ret)] + async fn handle_layer_msg( + &mut self, + message: LayerUdpOutgoing, + ) -> Result<(), SendError> { + match message { + // [user] -> [layer] -> [agent] -> [layer] + // `user` is asking us to connect to some remote host. + LayerUdpOutgoing::Connect(LayerConnect { remote_address }) => { + let daemon_connect = + connect(remote_address.clone()) + .await + .and_then(|mirror_socket| { + let connection_id = self.next_connection_id; + self.next_connection_id += 1; + + let peer_address = mirror_socket.peer_addr()?; + let local_address = mirror_socket.local_addr()?; + let local_address = SocketAddress::Ip(local_address); + + let framed = UdpFramed::new(mirror_socket, BytesCodec::new()); + + let (sink, stream): ( + SplitSink, (BytesMut, SocketAddr)>, + SplitStream>, + ) = framed.split(); + + self.writers.insert(connection_id, (sink, peer_address)); + self.readers.insert(connection_id, stream); + UDP_OUTGOING_CONNECTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + Ok(DaemonConnect { + connection_id, + remote_address, + local_address, + }) + }); + + tracing::trace!( + result = ?daemon_connect, + "Connection attempt finished.", + ); + + self.daemon_tx + .send(DaemonUdpOutgoing::Connect(daemon_connect)) + .await?; + + Ok(()) + } + // [user] -> [layer] -> [agent] -> [remote] + // `user` wrote some message to the remote host. + LayerUdpOutgoing::Write(LayerWrite { + connection_id, + bytes, + }) => { + let write_result = match self + .writers + .get_mut(&connection_id) + .ok_or(ResponseError::NotFound(connection_id)) + { + Ok((mirror, remote_address)) => mirror + .send((BytesMut::from(bytes.as_slice()), *remote_address)) + .await + .map_err(ResponseError::from), + Err(fail) => Err(fail), + }; + + match write_result { + Ok(()) => Ok(()), + Err(error) => { + self.writers.remove(&connection_id); + self.readers.remove(&connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + tracing::trace!( + connection_id, + ?error, + "Failed to handle layer write, sending close message to the client.", + ); + + let daemon_message = DaemonUdpOutgoing::Close(connection_id); + self.daemon_tx.send(daemon_message).await?; + + Ok(()) + } + } + } + // [layer] -> [agent] + // `layer` closed their interceptor stream. + LayerUdpOutgoing::Close(LayerClose { ref connection_id }) => { + self.writers.remove(connection_id); + self.readers.remove(connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + Ok(()) + } + } + } +} /// Handles (briefly) the `UdpOutgoingRequest` and `UdpOutgoingResponse` messages, mostly the /// passing of these messages to the `interceptor_task` thread. @@ -41,10 +283,10 @@ pub(crate) struct UdpOutgoingApi { task_status: TaskStatus, /// Sends the `Layer` message to the `interceptor_task`. - layer_tx: Sender, + layer_tx: Sender, /// Reads the `Daemon` message from the `interceptor_task`. - daemon_rx: Receiver, + daemon_rx: Receiver, } /// Performs an [`UdpSocket::connect`] that handles 3 situations: @@ -55,8 +297,9 @@ pub(crate) struct UdpOutgoingApi { /// read access to `/etc/resolv.conf`, otherwise they'll be getting a mismatched connection; /// 3. User is trying to use `sendto` and `recvfrom`, we use the same hack as in DNS to fake a /// connection. -#[tracing::instrument(level = "trace", ret)] -async fn connect(remote_address: SocketAddr) -> Result { +#[tracing::instrument(level = Level::TRACE, ret, err(level = Level::DEBUG))] +async fn connect(remote_address: SocketAddress) -> Result { + let remote_address = remote_address.try_into()?; let mirror_address = match remote_address { std::net::SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), std::net::SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), @@ -75,8 +318,10 @@ impl UdpOutgoingApi { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); - let watched_task = - WatchedTask::new(Self::TASK_NAME, Self::interceptor_task(layer_rx, daemon_tx)); + let watched_task = WatchedTask::new( + Self::TASK_NAME, + UdpOutgoingTask::new(pid, layer_rx, daemon_tx).run(), + ); let task_status = watched_task.status(); let task = run_thread_in_namespace( @@ -94,150 +339,9 @@ impl UdpOutgoingApi { } } - /// The [`UdpOutgoingApi`] task. - /// - /// Receives [`LayerUdpOutgoing`] messages and replies with [`DaemonUdpOutgoing`]. - #[allow(clippy::type_complexity)] - async fn interceptor_task( - mut layer_rx: Receiver, - daemon_tx: Sender, - ) -> Result<()> { - let mut connection_ids = 0..=ConnectionId::MAX; - - // TODO: Right now we're manually keeping these 2 maps in sync (aviram suggested using - // `Weak` for `writers`). - let mut writers: HashMap< - ConnectionId, - ( - SplitSink, (BytesMut, SocketAddr)>, - SocketAddr, - ), - > = HashMap::default(); - - let mut readers: StreamMap>> = - StreamMap::default(); - - loop { - select! { - biased; - - // [layer] -> [agent] - Some(layer_message) = layer_rx.recv() => { - trace!("udp: interceptor_task -> layer_message {:?}", layer_message); - match layer_message { - // [user] -> [layer] -> [agent] -> [layer] - // `user` is asking us to connect to some remote host. - LayerUdpOutgoing::Connect(LayerConnect { remote_address }) => { - let daemon_connect = connect(remote_address.clone().try_into()?) - .await - .and_then(|mirror_socket| { - let connection_id = connection_ids - .next() - .ok_or_else(|| ResponseError::IdsExhausted("connect".into()))?; - - debug!("interceptor_task -> mirror_socket {:#?}", mirror_socket); - let peer_address = mirror_socket.peer_addr()?; - let local_address = mirror_socket.local_addr()?; - let local_address = SocketAddress::Ip(local_address); - let framed = UdpFramed::new(mirror_socket, BytesCodec::new()); - debug!("interceptor_task -> framed {:#?}", framed); - let (sink, stream): ( - SplitSink, (BytesMut, SocketAddr)>, - SplitStream>, - ) = framed.split(); - - writers.insert(connection_id, (sink, peer_address)); - readers.insert(connection_id, stream); - - Ok(DaemonConnect { - connection_id, - remote_address, - local_address - }) - }); - - let daemon_message = DaemonUdpOutgoing::Connect(daemon_connect); - debug!("interceptor_task -> daemon_message {:#?}", daemon_message); - daemon_tx.send(daemon_message).await? - } - // [user] -> [layer] -> [agent] -> [remote] - // `user` wrote some message to the remote host. - LayerUdpOutgoing::Write(LayerWrite { - connection_id, - bytes, - }) => { - let daemon_write = match writers - .get_mut(&connection_id) - .ok_or(ResponseError::NotFound(connection_id)) - { - Ok((mirror, remote_address)) => mirror - .send((BytesMut::from(bytes.as_slice()), *remote_address)) - .await - .map_err(ResponseError::from), - Err(fail) => Err(fail), - }; - - if let Err(fail) = daemon_write { - warn!("LayerUdpOutgoing::Write -> Failed with {:#?}", fail); - writers.remove(&connection_id); - readers.remove(&connection_id); - - let daemon_message = DaemonUdpOutgoing::Close(connection_id); - daemon_tx.send(daemon_message).await? - } - } - // [layer] -> [agent] - // `layer` closed their interceptor stream. - LayerUdpOutgoing::Close(LayerClose { ref connection_id }) => { - writers.remove(connection_id); - readers.remove(connection_id); - } - } - } - - // [remote] -> [agent] -> [layer] -> [user] - // Read the data from one of the connected remote hosts, and forward the result back - // to the `user`. - Some((connection_id, remote_read)) = readers.next() => { - trace!("interceptor_task -> read connection_id {:#?}", connection_id); - - match remote_read { - Some(read) => { - let daemon_read = read - .map_err(ResponseError::from) - .map(|(bytes, _)| DaemonRead { connection_id, bytes: bytes.to_vec() }); - - let daemon_message = DaemonUdpOutgoing::Read(daemon_read); - daemon_tx.send(daemon_message).await? - } - None => { - trace!("interceptor_task -> close connection {:#?}", connection_id); - writers.remove(&connection_id); - readers.remove(&connection_id); - - let daemon_message = DaemonUdpOutgoing::Close(connection_id); - daemon_tx.send(daemon_message).await? - } - } - } - else => { - // We have no more data coming from any of the remote hosts. - warn!("interceptor_task -> no messages left"); - break; - } - } - } - - Ok(()) - } - /// Sends a `UdpOutgoingRequest` to the `interceptor_task`. - pub(crate) async fn layer_message(&mut self, message: LayerUdpOutgoing) -> Result<()> { - trace!( - "UdpOutgoingApi::layer_message -> layer_message {:#?}", - message - ); - + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + pub(crate) async fn send_to_task(&mut self, message: LayerUdpOutgoing) -> AgentResult<()> { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { @@ -246,7 +350,7 @@ impl UdpOutgoingApi { } /// Receives a `UdpOutgoingResponse` from the `interceptor_task`. - pub(crate) async fn daemon_message(&mut self) -> Result { + pub(crate) async fn recv_from_task(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), None => Err(self.task_status.unwrap_err().await), diff --git a/mirrord/agent/src/sniffer.rs b/mirrord/agent/src/sniffer.rs index b1c232eae6d..94c40e0ce67 100644 --- a/mirrord/agent/src/sniffer.rs +++ b/mirrord/agent/src/sniffer.rs @@ -24,8 +24,9 @@ use self::{ tcp_capture::RawSocketTcpCapture, }; use crate::{ - error::AgentError, + error::AgentResult, http::HttpVersion, + metrics::{MIRROR_CONNECTION_SUBSCRIPTION, MIRROR_PORT_SUBSCRIPTION}, util::{ChannelClosedFuture, ClientId, Subscriptions}, }; @@ -138,7 +139,14 @@ pub(crate) struct TcpConnectionSniffer { sessions: TCPSessionMap, client_txs: HashMap>, - clients_closed: FuturesUnordered>, + clients_closed: FuturesUnordered, +} + +impl Drop for TcpConnectionSniffer { + fn drop(&mut self) { + MIRROR_PORT_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + MIRROR_CONNECTION_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + } } impl fmt::Debug for TcpConnectionSniffer { @@ -163,7 +171,7 @@ impl TcpConnectionSniffer { command_rx: Receiver, network_interface: Option, is_mesh: bool, - ) -> Result { + ) -> AgentResult { let tcp_capture = RawSocketTcpCapture::new(network_interface, is_mesh).await?; Ok(Self { @@ -190,7 +198,7 @@ where /// Runs the sniffer loop, capturing packets. #[tracing::instrument(level = Level::DEBUG, skip(cancel_token), err)] - pub async fn start(mut self, cancel_token: CancellationToken) -> Result<(), AgentError> { + pub async fn start(mut self, cancel_token: CancellationToken) -> AgentResult<()> { loop { select! { command = self.command_rx.recv() => { @@ -232,7 +240,7 @@ where /// Removes the client with `client_id`, and also unsubscribes its port. /// Adjusts BPF filter if needed. #[tracing::instrument(level = Level::TRACE, err)] - fn handle_client_closed(&mut self, client_id: ClientId) -> Result<(), AgentError> { + fn handle_client_closed(&mut self, client_id: ClientId) -> AgentResult<()> { self.client_txs.remove(&client_id); if self.port_subscriptions.remove_client(client_id) { @@ -245,8 +253,9 @@ where /// Updates BPF filter used by [`Self::tcp_capture`] to match state of /// [`Self::port_subscriptions`]. #[tracing::instrument(level = Level::TRACE, err)] - fn update_packet_filter(&mut self) -> Result<(), AgentError> { + fn update_packet_filter(&mut self) -> AgentResult<()> { let ports = self.port_subscriptions.get_subscribed_topics(); + MIRROR_PORT_SUBSCRIPTION.store(ports.len() as i64, std::sync::atomic::Ordering::Relaxed); let filter = if ports.is_empty() { tracing::trace!("No ports subscribed, setting dummy bpf"); @@ -261,7 +270,7 @@ where } #[tracing::instrument(level = Level::TRACE, err)] - fn handle_command(&mut self, command: SnifferCommand) -> Result<(), AgentError> { + fn handle_command(&mut self, command: SnifferCommand) -> AgentResult<()> { match command { SnifferCommand { client_id, @@ -325,7 +334,7 @@ where &mut self, identifier: TcpSessionIdentifier, tcp_packet: TcpPacketData, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { let data_tx = match self.sessions.entry(identifier) { Entry::Occupied(e) => e, Entry::Vacant(e) => { @@ -394,6 +403,7 @@ where } } + MIRROR_CONNECTION_SUBSCRIPTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); e.insert_entry(data_tx) } }; @@ -422,7 +432,7 @@ mod test { atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, + time::{Duration, Instant}, }; use api::TcpSnifferApi; @@ -430,6 +440,7 @@ mod test { tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, ConnectionId, LogLevel, }; + use rstest::rstest; use tcp_capture::test::TcpPacketsChannel; use tokio::sync::mpsc; @@ -448,6 +459,7 @@ mod test { async fn get_api(&mut self) -> TcpSnifferApi { let client_id = self.next_client_id; self.next_client_id += 1; + TcpSnifferApi::new(client_id, self.command_tx.clone(), self.task_status.clone()) .await .unwrap() @@ -845,4 +857,45 @@ mod test { }), ); } + + /// Verifies that [`TcpConnectionSniffer`] reacts to [`TcpSnifferApi`] being dropped + /// and clears the packet filter. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn cleanup_on_client_closed() { + let mut setup = TestSnifferSetup::new(); + + let mut api = setup.get_api().await; + + api.handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + assert_eq!( + api.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + assert_eq!(setup.times_filter_changed(), 1); + + std::mem::drop(api); + let dropped_at = Instant::now(); + + loop { + match setup.times_filter_changed() { + 1 => { + println!( + "filter still not changed {}ms after client closed", + dropped_at.elapsed().as_millis() + ); + tokio::time::sleep(Duration::from_millis(20)).await; + } + + 2 => { + break; + } + + other => panic!("unexpected times filter changed {other}"), + } + } + } } diff --git a/mirrord/agent/src/sniffer/api.rs b/mirrord/agent/src/sniffer/api.rs index 31ec4107f97..08874e93124 100644 --- a/mirrord/agent/src/sniffer/api.rs +++ b/mirrord/agent/src/sniffer/api.rs @@ -14,12 +14,17 @@ use tokio_stream::{ StreamMap, StreamNotifyClose, }; -use super::messages::{SniffedConnection, SnifferCommand, SnifferCommandInner}; +use super::{ + messages::{SniffedConnection, SnifferCommand, SnifferCommandInner}, + AgentResult, +}; use crate::{error::AgentError, util::ClientId, watched_task::TaskStatus}; /// Interface used by clients to interact with the /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). Multiple instances of this struct operate /// on a single sniffer instance. +/// +/// Enabled by the `mirror` feature for incoming traffic. pub(crate) struct TcpSnifferApi { /// Id of the client using this struct. client_id: ClientId, @@ -55,7 +60,7 @@ impl TcpSnifferApi { client_id: ClientId, sniffer_sender: Sender, mut task_status: TaskStatus, - ) -> Result { + ) -> AgentResult { let (sender, receiver) = mpsc::channel(Self::CONNECTION_CHANNEL_SIZE); let command = SnifferCommand { @@ -79,7 +84,7 @@ impl TcpSnifferApi { /// Send the given command to the connected /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). - async fn send_command(&mut self, command: SnifferCommandInner) -> Result<(), AgentError> { + async fn send_command(&mut self, command: SnifferCommandInner) -> AgentResult<()> { let command = SnifferCommand { client_id: self.client_id, command, @@ -94,7 +99,7 @@ impl TcpSnifferApi { /// Return the next message from the connected /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). - pub async fn recv(&mut self) -> Result<(DaemonTcp, Option), AgentError> { + pub async fn recv(&mut self) -> AgentResult<(DaemonTcp, Option)> { tokio::select! { conn = self.receiver.recv() => match conn { Some(conn) => { @@ -158,27 +163,26 @@ impl TcpSnifferApi { } } - /// Tansform the given message into a [`SnifferCommand`] and pass it to the connected + /// Tansforms a [`LayerTcp`] message into a [`SnifferCommand`] and passes it to the connected /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). - pub async fn handle_client_message(&mut self, message: LayerTcp) -> Result<(), AgentError> { + pub async fn handle_client_message(&mut self, message: LayerTcp) -> AgentResult<()> { match message { LayerTcp::PortSubscribe(port) => { let (tx, rx) = oneshot::channel(); self.send_command(SnifferCommandInner::Subscribe(port, tx)) .await?; self.subscriptions_in_progress.push(rx); - Ok(()) } LayerTcp::PortUnsubscribe(port) => { self.send_command(SnifferCommandInner::UnsubscribePort(port)) - .await + .await?; + Ok(()) } LayerTcp::ConnectionUnsubscribe(connection_id) => { self.connections.remove(&connection_id); - Ok(()) } } diff --git a/mirrord/agent/src/sniffer/tcp_capture.rs b/mirrord/agent/src/sniffer/tcp_capture.rs index 1d8031d08b3..dc8fb2bba04 100644 --- a/mirrord/agent/src/sniffer/tcp_capture.rs +++ b/mirrord/agent/src/sniffer/tcp_capture.rs @@ -12,7 +12,7 @@ use rawsocket::{filter::SocketFilterProgram, RawCapture}; use tokio::net::UdpSocket; use tracing::Level; -use super::{TcpPacketData, TcpSessionIdentifier}; +use super::{AgentResult, TcpPacketData, TcpSessionIdentifier}; use crate::error::AgentError; /// Trait for structs that are able to sniff incoming Ethernet packets and filter TCP packets. @@ -36,7 +36,7 @@ impl RawSocketTcpCapture { /// /// Returned instance initially uses a BPF filter that drops every packet. #[tracing::instrument(level = Level::DEBUG, err)] - pub async fn new(network_interface: Option, is_mesh: bool) -> Result { + pub async fn new(network_interface: Option, is_mesh: bool) -> AgentResult { // Priority is whatever the user set as an option to mirrord, then we check if we're in a // mesh to use `lo` interface, otherwise we try to get the appropriate interface. let interface = match network_interface.or_else(|| is_mesh.then(|| "lo".to_string())) { diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index 15d2f265ba7..2fc5733f8fa 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -8,11 +8,11 @@ use mirrord_protocol::{ }; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio_stream::wrappers::ReceiverStream; +use tracing::Level; use super::{http::ReceiverStreamBody, *}; use crate::{ - error::{AgentError, Result}, - util::ClientId, + error::AgentResult, metrics::HTTP_REQUEST_IN_PROGRESS_COUNT, util::ClientId, watched_task::TaskStatus, }; @@ -50,17 +50,23 @@ pub(crate) struct TcpStealerApi { response_body_txs: HashMap<(ConnectionId, RequestId), ResponseBodyTx>, } +impl Drop for TcpStealerApi { + fn drop(&mut self) { + HTTP_REQUEST_IN_PROGRESS_COUNT.store(0, std::sync::atomic::Ordering::Relaxed); + } +} + impl TcpStealerApi { /// Initializes a [`TcpStealerApi`] and sends a message to [`TcpConnectionStealer`] signaling /// that we have a new client. - #[tracing::instrument(level = "trace")] + #[tracing::instrument(level = Level::TRACE, err)] pub(crate) async fn new( client_id: ClientId, command_tx: Sender, task_status: TaskStatus, channel_size: usize, protocol_version: semver::Version, - ) -> Result { + ) -> AgentResult { let (daemon_tx, daemon_rx) = mpsc::channel(channel_size); command_tx @@ -80,7 +86,7 @@ impl TcpStealerApi { } /// Send `command` to stealer, with the client id of the client that is using this API instance. - async fn send_command(&mut self, command: Command) -> Result<()> { + async fn send_command(&mut self, command: Command) -> AgentResult<()> { let command = StealerCommand { client_id: self.client_id, command, @@ -98,12 +104,16 @@ impl TcpStealerApi { /// /// Called in the `ClientConnectionHandler`. #[tracing::instrument(level = "trace", skip(self))] - pub(crate) async fn recv(&mut self) -> Result { + pub(crate) async fn recv(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => { if let DaemonTcp::Close(close) = &msg { self.response_body_txs .retain(|(key_id, _), _| *key_id != close.connection_id); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); } Ok(msg) } @@ -115,7 +125,7 @@ impl TcpStealerApi { /// agent, to an internal stealer command [`Command::PortSubscribe`]. /// /// The actual handling of this message is done in [`TcpConnectionStealer`]. - pub(crate) async fn port_subscribe(&mut self, port_steal: StealType) -> Result<(), AgentError> { + pub(crate) async fn port_subscribe(&mut self, port_steal: StealType) -> AgentResult<()> { self.send_command(Command::PortSubscribe(port_steal)).await } @@ -123,7 +133,7 @@ impl TcpStealerApi { /// agent, to an internal stealer command [`Command::PortUnsubscribe`]. /// /// The actual handling of this message is done in [`TcpConnectionStealer`]. - pub(crate) async fn port_unsubscribe(&mut self, port: Port) -> Result<(), AgentError> { + pub(crate) async fn port_unsubscribe(&mut self, port: Port) -> AgentResult<()> { self.send_command(Command::PortUnsubscribe(port)).await } @@ -134,7 +144,7 @@ impl TcpStealerApi { pub(crate) async fn connection_unsubscribe( &mut self, connection_id: ConnectionId, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { self.send_command(Command::ConnectionUnsubscribe(connection_id)) .await } @@ -143,7 +153,7 @@ impl TcpStealerApi { /// agent, to an internal stealer command [`Command::ResponseData`]. /// /// The actual handling of this message is done in [`TcpConnectionStealer`]. - pub(crate) async fn client_data(&mut self, tcp_data: TcpData) -> Result<(), AgentError> { + pub(crate) async fn client_data(&mut self, tcp_data: TcpData) -> AgentResult<()> { self.send_command(Command::ResponseData(tcp_data)).await } @@ -154,24 +164,32 @@ impl TcpStealerApi { pub(crate) async fn http_response( &mut self, response: HttpResponseFallback, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { self.send_command(Command::HttpResponse(response)).await } pub(crate) async fn switch_protocol_version( &mut self, version: semver::Version, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { self.send_command(Command::SwitchProtocolVersion(version)) .await } - pub(crate) async fn handle_client_message(&mut self, message: LayerTcpSteal) -> Result<()> { + pub(crate) async fn handle_client_message( + &mut self, + message: LayerTcpSteal, + ) -> AgentResult<()> { match message { LayerTcpSteal::PortSubscribe(port_steal) => self.port_subscribe(port_steal).await, LayerTcpSteal::ConnectionUnsubscribe(connection_id) => { self.response_body_txs .retain(|(key_id, _), _| *key_id != connection_id); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); + self.connection_unsubscribe(connection_id).await } LayerTcpSteal::PortUnsubscribe(port) => self.port_unsubscribe(port).await, @@ -202,6 +220,10 @@ impl TcpStealerApi { let key = (response.connection_id, response.request_id); self.response_body_txs.insert(key, tx.clone()); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); self.http_response(HttpResponseFallback::Streamed(http_response)) .await?; @@ -209,6 +231,10 @@ impl TcpStealerApi { for frame in response.internal_response.body { if let Err(err) = tx.send(Ok(frame.into())).await { self.response_body_txs.remove(&key); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); tracing::trace!(?err, "error while sending streaming response frame"); } } @@ -231,12 +257,20 @@ impl TcpStealerApi { } if send_err || body.is_last { self.response_body_txs.remove(key); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); }; Ok(()) } ChunkedResponse::Error(err) => { self.response_body_txs .remove(&(err.connection_id, err.request_id)); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); tracing::trace!(?err, "ChunkedResponse error received"); Ok(()) } diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 5e4b6b1219a..f6b4a9f2b7b 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -28,11 +28,12 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, }; use tokio_util::sync::CancellationToken; -use tracing::warn; +use tracing::{warn, Level}; -use super::http::HttpResponseFallback; +use super::{http::HttpResponseFallback, subscriptions::PortRedirector}; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, + metrics::HTTP_REQUEST_IN_PROGRESS_COUNT, steal::{ connections::{ ConnectionMessageIn, ConnectionMessageOut, StolenConnection, StolenConnections, @@ -55,6 +56,22 @@ struct MatchedHttpRequest { } impl MatchedHttpRequest { + fn new( + connection_id: ConnectionId, + port: Port, + request_id: RequestId, + request: Request, + ) -> Self { + HTTP_REQUEST_IN_PROGRESS_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + Self { + connection_id, + port, + request_id, + request, + } + } + async fn into_serializable(self) -> Result, hyper::Error> { let ( Parts { @@ -181,7 +198,7 @@ impl Client { let frames = frames .into_iter() .map(InternalHttpBodyFrame::try_from) - .filter_map(Result::ok) + .filter_map(AgentResult::ok) .collect(); let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Start(HttpRequest { @@ -210,7 +227,7 @@ impl Client { let frames = frames .into_iter() .map(InternalHttpBodyFrame::try_from) - .filter_map(Result::ok) + .filter_map(AgentResult::ok) .collect(); let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Body( ChunkedHttpBody { @@ -273,9 +290,11 @@ struct TcpStealerConfig { /// Meant to be run (see [`TcpConnectionStealer::start`]) in a separate thread while the agent /// lives. When handling port subscription requests, this struct manipulates iptables, so it should /// run in the same network namespace as the agent's target. -pub(crate) struct TcpConnectionStealer { +/// +/// Enabled by the `steal` feature for incoming traffic. +pub(crate) struct TcpConnectionStealer { /// For managing active subscriptions and port redirections. - port_subscriptions: PortSubscriptions, + port_subscriptions: PortSubscriptions, /// For receiving commands. /// The other end of this channel belongs to [`TcpStealerApi`](super::api::TcpStealerApi). @@ -285,7 +304,7 @@ pub(crate) struct TcpConnectionStealer { clients: HashMap, /// [`Future`](std::future::Future)s that resolve when stealer clients close. - clients_closed: FuturesUnordered>, + clients_closed: FuturesUnordered, /// Set of active connections stolen by [`Self::port_subscriptions`]. connections: StolenConnections, @@ -294,39 +313,53 @@ pub(crate) struct TcpConnectionStealer { support_ipv6: bool, } -impl TcpConnectionStealer { +impl TcpConnectionStealer { pub const TASK_NAME: &'static str = "Stealer"; /// Initializes a new [`TcpConnectionStealer`], but doesn't start the actual work. /// You need to call [`TcpConnectionStealer::start`] to do so. - #[tracing::instrument(level = "trace")] + #[tracing::instrument(level = Level::TRACE, err)] pub(crate) async fn new( command_rx: Receiver, support_ipv6: bool, - ) -> Result { + ) -> AgentResult { let config = envy::prefixed("MIRRORD_AGENT_") .from_env::() .unwrap_or_default(); - let port_subscriptions = { - let redirector = IpTablesRedirector::new( - config.stealer_flush_connections, - config.pod_ips, - support_ipv6, - ) - .await?; + let redirector = IpTablesRedirector::new( + config.stealer_flush_connections, + config.pod_ips, + support_ipv6, + ) + .await?; - PortSubscriptions::new(redirector, 4) - }; + Ok(Self::with_redirector(command_rx, support_ipv6, redirector)) + } +} - Ok(Self { - port_subscriptions, +impl TcpConnectionStealer +where + Redirector: PortRedirector, + Redirector::Error: std::error::Error + Into, + AgentError: From, +{ + /// Creates a new stealer. + /// + /// Given [`PortRedirector`] will be used to capture incoming connections. + pub(crate) fn with_redirector( + command_rx: Receiver, + support_ipv6: bool, + redirector: Redirector, + ) -> Self { + Self { + port_subscriptions: PortSubscriptions::new(redirector, 4), command_rx, clients: HashMap::with_capacity(8), clients_closed: Default::default(), connections: StolenConnections::with_capacity(8), support_ipv6, - }) + } } /// Runs the tcp traffic stealer loop. @@ -341,10 +374,7 @@ impl TcpConnectionStealer { /// /// 4. Handling the cancellation of the whole stealer thread (given `cancellation_token`). #[tracing::instrument(level = "trace", skip(self))] - pub(crate) async fn start( - mut self, - cancellation_token: CancellationToken, - ) -> Result<(), AgentError> { + pub(crate) async fn start(mut self, cancellation_token: CancellationToken) -> AgentResult<()> { loop { tokio::select! { command = self.command_rx.recv() => { @@ -362,10 +392,12 @@ impl TcpConnectionStealer { }, accept = self.port_subscriptions.next_connection() => match accept { - Ok((stream, peer)) => self.incoming_connection(stream, peer).await?, + Ok((stream, peer)) => { + self.incoming_connection(stream, peer).await?; + } Err(error) => { tracing::error!(?error, "Failed to accept a stolen connection"); - break Err(error); + break Err(error.into()); } }, @@ -380,7 +412,11 @@ impl TcpConnectionStealer { /// Handles a new remote connection that was stolen by [`Self::port_subscriptions`]. #[tracing::instrument(level = "trace", skip(self))] - async fn incoming_connection(&mut self, stream: TcpStream, peer: SocketAddr) -> Result<()> { + async fn incoming_connection( + &mut self, + stream: TcpStream, + peer: SocketAddr, + ) -> AgentResult<()> { let mut real_address = orig_dst::orig_dst_addr(&stream)?; let localhost = if self.support_ipv6 && real_address.is_ipv6() { IpAddr::V6(Ipv6Addr::LOCALHOST) @@ -416,10 +452,7 @@ impl TcpConnectionStealer { /// Handles an update from one of the connections in [`Self::connections`]. #[tracing::instrument(level = "trace", skip(self))] - async fn handle_connection_update( - &mut self, - update: ConnectionMessageOut, - ) -> Result<(), AgentError> { + async fn handle_connection_update(&mut self, update: ConnectionMessageOut) -> AgentResult<()> { match update { ConnectionMessageOut::Closed { connection_id, @@ -526,12 +559,7 @@ impl TcpConnectionStealer { return Ok(()); } - let matched_request = MatchedHttpRequest { - connection_id, - request, - request_id: id, - port, - }; + let matched_request = MatchedHttpRequest::new(connection_id, port, id, request); if !client.send_request_async(matched_request) { self.connections @@ -550,11 +578,18 @@ impl TcpConnectionStealer { Ok(()) } - /// Helper function to handle [`Command::PortSubscribe`] messages. + /// Helper function to handle [`Command::PortSubscribe`] messages for the `TcpStealer`. /// - /// Inserts a subscription into [`Self::port_subscriptions`]. - #[tracing::instrument(level = "trace", skip(self))] - async fn port_subscribe(&mut self, client_id: ClientId, port_steal: StealType) -> Result<()> { + /// Checks if [`StealType`] is a valid [`HttpFilter`], then inserts a subscription into + /// [`Self::port_subscriptions`]. + /// + /// - Returns: `true` if this is an HTTP filtered subscription. + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + async fn port_subscribe( + &mut self, + client_id: ClientId, + port_steal: StealType, + ) -> AgentResult { let spec = match port_steal { StealType::All(port) => Ok((port, None)), StealType::FilteredHttp(port, filter) => Regex::new(&format!("(?i){filter}")) @@ -565,6 +600,11 @@ impl TcpConnectionStealer { .map_err(|err| BadHttpFilterExRegex(filter, err.to_string())), }; + let filtered = spec + .as_ref() + .map(|(_, filter)| filter.is_some()) + .unwrap_or_default(); + let res = match spec { Ok((port, filter)) => self.port_subscriptions.add(client_id, port, filter).await?, Err(e) => Err(e.into()), @@ -573,18 +613,18 @@ impl TcpConnectionStealer { let client = self.clients.get(&client_id).expect("client not found"); let _ = client.tx.send(DaemonTcp::SubscribeResult(res)).await; - Ok(()) + Ok(filtered) } /// Removes the client with `client_id` from our list of clients (layers), and also removes /// their subscriptions from [`Self::port_subscriptions`] and all their open /// connections. #[tracing::instrument(level = "trace", skip(self))] - async fn close_client(&mut self, client_id: ClientId) -> Result<(), AgentError> { + async fn close_client(&mut self, client_id: ClientId) -> AgentResult<()> { self.port_subscriptions.remove_all(client_id).await?; let client = self.clients.remove(&client_id).expect("client not found"); - for connection in client.subscribed_connections.into_iter() { + for connection in client.subscribed_connections { self.connections .send(connection, ConnectionMessageIn::Unsubscribed { client_id }) .await; @@ -612,12 +652,14 @@ impl TcpConnectionStealer { } /// Handles [`Command`]s that were received by [`TcpConnectionStealer::command_rx`]. - #[tracing::instrument(level = "trace", skip(self))] - async fn handle_command(&mut self, command: StealerCommand) -> Result<(), AgentError> { + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + async fn handle_command(&mut self, command: StealerCommand) -> AgentResult<()> { let StealerCommand { client_id, command } = command; match command { Command::NewClient(daemon_tx, protocol_version) => { + self.clients_closed + .push(ChannelClosedFuture::new(daemon_tx.clone(), client_id)); self.clients.insert( client_id, Client { @@ -644,7 +686,7 @@ impl TcpConnectionStealer { } Command::PortSubscribe(port_steal) => { - self.port_subscribe(client_id, port_steal).await? + self.port_subscribe(client_id, port_steal).await?; } Command::PortUnsubscribe(port) => { @@ -682,7 +724,7 @@ impl TcpConnectionStealer { #[cfg(test)] mod test { - use std::net::SocketAddr; + use std::{net::SocketAddr, time::Duration}; use bytes::Bytes; use futures::{future::BoxFuture, FutureExt}; @@ -693,18 +735,75 @@ mod test { service::Service, }; use hyper_util::rt::TokioIo; - use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame}; + use mirrord_protocol::{ + tcp::{ChunkedRequest, DaemonTcp, Filter, HttpFilter, InternalHttpBodyFrame, StealType}, + Port, + }; use rstest::rstest; use tokio::{ net::{TcpListener, TcpStream}, sync::{ mpsc::{self, Receiver, Sender}, - oneshot, + oneshot, watch, }, }; use tokio_stream::wrappers::ReceiverStream; + use tokio_util::sync::CancellationToken; + + use super::AgentError; + use crate::{ + steal::{ + connection::{Client, MatchedHttpRequest}, + subscriptions::PortRedirector, + TcpConnectionStealer, TcpStealerApi, + }, + watched_task::TaskStatus, + }; + + /// Notification about a requested redirection operation. + /// + /// Produced by [`NotifyingRedirector`]. + #[derive(Debug, PartialEq, Eq)] + enum RedirectNotification { + Added(Port), + Removed(Port), + Cleanup, + } + + /// Test [`PortRedirector`] that never fails and notifies about requested operations using an + /// [`mpsc::channel`]. + struct NotifyingRedirector(Sender); + + #[async_trait::async_trait] + impl PortRedirector for NotifyingRedirector { + type Error = AgentError; + + async fn add_redirection(&mut self, port: Port) -> Result<(), Self::Error> { + self.0 + .send(RedirectNotification::Added(port)) + .await + .unwrap(); + Ok(()) + } + + async fn remove_redirection(&mut self, port: Port) -> Result<(), Self::Error> { + self.0 + .send(RedirectNotification::Removed(port)) + .await + .unwrap(); + Ok(()) + } + + async fn cleanup(&mut self) -> Result<(), Self::Error> { + self.0.send(RedirectNotification::Cleanup).await.unwrap(); + Ok(()) + } + + async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error> { + std::future::pending().await + } + } - use crate::steal::connection::{Client, MatchedHttpRequest}; async fn prepare_dummy_service() -> ( SocketAddr, Receiver<(Request, oneshot::Sender>>)>, @@ -881,4 +980,51 @@ mod test { let _ = response_tx.send(Response::new(Empty::default())); } + + /// Verifies that [`TcpConnectionStealer`] removes client's port subscriptions + /// when client's [`TcpStealerApi`] is dropped. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn cleanup_on_client_closed() { + let (command_tx, command_rx) = mpsc::channel(8); + let (redirect_tx, mut redirect_rx) = mpsc::channel(2); + let stealer = TcpConnectionStealer::with_redirector( + command_rx, + false, + NotifyingRedirector(redirect_tx), + ); + + tokio::spawn(stealer.start(CancellationToken::new())); + + let (_dummy_tx, dummy_rx) = watch::channel(None); + let task_status = TaskStatus::dummy(TcpConnectionStealer::TASK_NAME, dummy_rx); + let mut api = TcpStealerApi::new( + 0, + command_tx.clone(), + task_status, + 8, + mirrord_protocol::VERSION.clone(), + ) + .await + .unwrap(); + + api.port_subscribe(StealType::FilteredHttpEx( + 80, + HttpFilter::Header(Filter::new("user: test".into()).unwrap()), + )) + .await + .unwrap(); + + let response = api.recv().await.unwrap(); + assert_eq!(response, DaemonTcp::SubscribeResult(Ok(80))); + + let notification = redirect_rx.recv().await.unwrap(); + assert_eq!(notification, RedirectNotification::Added(80)); + + std::mem::drop(api); + + let notification = redirect_rx.recv().await.unwrap(); + assert_eq!(notification, RedirectNotification::Removed(80)); + } } diff --git a/mirrord/agent/src/steal/connections.rs b/mirrord/agent/src/steal/connections.rs index bc47eb80e4b..8e969b0a868 100644 --- a/mirrord/agent/src/steal/connections.rs +++ b/mirrord/agent/src/steal/connections.rs @@ -11,10 +11,14 @@ use tokio::{ sync::mpsc::{self, error::SendError, Receiver, Sender}, task::JoinSet, }; +use tracing::Level; use self::{filtered::DynamicBody, unfiltered::UnfilteredStealTask}; use super::{http::DefaultReversibleStream, subscriptions::PortSubscription}; -use crate::{http::HttpVersion, steal::connections::filtered::FilteredStealTask, util::ClientId}; +use crate::{ + http::HttpVersion, metrics::STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION, + steal::connections::filtered::FilteredStealTask, util::ClientId, +}; mod filtered; mod unfiltered; @@ -287,7 +291,7 @@ impl StolenConnections { /// Adds the given [`StolenConnection`] to this set. Spawns a new [`tokio::task`] that will /// manage it. - #[tracing::instrument(level = "trace", name = "manage_stolen_connection", skip(self))] + #[tracing::instrument(level = Level::TRACE, name = "manage_stolen_connection", skip(self))] pub fn manage(&mut self, connection: StolenConnection) { let connection_id = self.next_connection_id; self.next_connection_id += 1; @@ -458,13 +462,9 @@ impl ConnectionTask { }) .await?; - let task = UnfilteredStealTask { - connection_id: self.connection_id, - client_id, - stream: self.connection.stream, - }; - - task.run(self.tx, &mut self.rx).await + UnfilteredStealTask::new(self.connection_id, client_id, self.connection.stream) + .run(self.tx, &mut self.rx) + .await } PortSubscription::Filtered(filters) => { diff --git a/mirrord/agent/src/steal/connections/filtered.rs b/mirrord/agent/src/steal/connections/filtered.rs index b30e48a5757..ecc9f0064ad 100644 --- a/mirrord/agent/src/steal/connections/filtered.rs +++ b/mirrord/agent/src/steal/connections/filtered.rs @@ -1,5 +1,6 @@ use std::{ - collections::HashMap, future::Future, marker::PhantomData, net::SocketAddr, pin::Pin, sync::Arc, + collections::HashMap, future::Future, marker::PhantomData, net::SocketAddr, ops::Not, pin::Pin, + sync::Arc, }; use bytes::Bytes; @@ -28,9 +29,13 @@ use tokio::{ use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::Level; -use super::{ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError}; +use super::{ + ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError, + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION, +}; use crate::{ http::HttpVersion, + metrics::STEAL_FILTERED_CONNECTION_SUBSCRIPTION, steal::{connections::unfiltered::UnfilteredStealTask, http::HttpFilter}, util::ClientId, }; @@ -368,6 +373,18 @@ pub struct FilteredStealTask { /// For safely downcasting the IO stream after an HTTP upgrade. See [`Upgraded::downcast`]. _io_type: PhantomData T>, + + /// Helps us figuring out if we should update some metrics in the `Drop` implementation. + metrics_updated: bool, +} + +impl Drop for FilteredStealTask { + fn drop(&mut self) { + if self.metrics_updated.not() { + STEAL_FILTERED_CONNECTION_SUBSCRIPTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + } } impl FilteredStealTask @@ -443,6 +460,8 @@ where } }; + STEAL_FILTERED_CONNECTION_SUBSCRIPTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Self { connection_id, original_destination, @@ -453,6 +472,7 @@ where blocked_requests: Default::default(), next_request_id: Default::default(), _io_type: Default::default(), + metrics_updated: false, } } @@ -638,6 +658,8 @@ where queued_raw_data.remove(&client_id); self.subscribed.insert(client_id, false); self.blocked_requests.retain(|key, _| key.0 != client_id); + + STEAL_FILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); }, }, @@ -646,7 +668,10 @@ where // No more requests from the `FilteringService`. // HTTP connection is closed and possibly upgraded. - None => break, + None => { + STEAL_FILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + break + } } } } @@ -788,15 +813,18 @@ where ) -> Result<(), ConnectionTaskError> { let res = self.run_until_http_ends(tx.clone(), rx).await; + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.metrics_updated = true; + let res = match res { Ok(data) => self.run_after_http_ends(data, tx.clone(), rx).await, Err(e) => Err(e), }; - for (client_id, subscribed) in self.subscribed { - if subscribed { + for (client_id, subscribed) in self.subscribed.iter() { + if *subscribed { tx.send(ConnectionMessageOut::Closed { - client_id, + client_id: *client_id, connection_id: self.connection_id, }) .await?; diff --git a/mirrord/agent/src/steal/connections/unfiltered.rs b/mirrord/agent/src/steal/connections/unfiltered.rs index 5b6676094c3..ec54691315e 100644 --- a/mirrord/agent/src/steal/connections/unfiltered.rs +++ b/mirrord/agent/src/steal/connections/unfiltered.rs @@ -7,7 +7,10 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, }; -use super::{ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError}; +use super::{ + ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError, + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION, +}; use crate::util::ClientId; /// Manages an unfiltered stolen connection. @@ -19,7 +22,23 @@ pub struct UnfilteredStealTask { pub stream: T, } +impl Drop for UnfilteredStealTask { + fn drop(&mut self) { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } +} + impl UnfilteredStealTask { + pub(crate) fn new(connection_id: ConnectionId, client_id: ClientId, stream: T) -> Self { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + Self { + connection_id, + client_id, + stream, + } + } + /// Runs this task until the managed connection is closed. /// /// # Note @@ -40,6 +59,8 @@ impl UnfilteredStealTask { read = self.stream.read_buf(&mut buf), if !reading_closed => match read { Ok(..) => { if buf.is_empty() { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tracing::trace!( client_id = self.client_id, connection_id = self.connection_id, @@ -63,6 +84,8 @@ impl UnfilteredStealTask { Err(e) if e.kind() == ErrorKind::WouldBlock => {} Err(e) => { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tx.send(ConnectionMessageOut::Closed { client_id: self.client_id, connection_id: self.connection_id @@ -85,6 +108,8 @@ impl UnfilteredStealTask { ConnectionMessageIn::Raw { data, .. } => { let res = if data.is_empty() { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tracing::trace!( client_id = self.client_id, connection_id = self.connection_id, @@ -97,6 +122,8 @@ impl UnfilteredStealTask { }; if let Err(e) = res { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tx.send(ConnectionMessageOut::Closed { client_id: self.client_id, connection_id: self.connection_id @@ -115,6 +142,8 @@ impl UnfilteredStealTask { }, ConnectionMessageIn::Unsubscribed { .. } => { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + return Ok(()); } } diff --git a/mirrord/agent/src/steal/ip_tables.rs b/mirrord/agent/src/steal/ip_tables.rs index 68bddb6a406..9c175220767 100644 --- a/mirrord/agent/src/steal/ip_tables.rs +++ b/mirrord/agent/src/steal/ip_tables.rs @@ -9,7 +9,7 @@ use rand::distributions::{Alphanumeric, DistString}; use tracing::warn; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, steal::ip_tables::{ flush_connections::FlushConnections, mesh::{istio::AmbientRedirect, MeshRedirect, MeshVendorExt}, @@ -84,13 +84,13 @@ pub(crate) trait IPTables { where Self: Sized; - fn create_chain(&self, name: &str) -> Result<()>; - fn remove_chain(&self, name: &str) -> Result<()>; + fn create_chain(&self, name: &str) -> AgentResult<()>; + fn remove_chain(&self, name: &str) -> AgentResult<()>; - fn add_rule(&self, chain: &str, rule: &str) -> Result<()>; - fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> Result<()>; - fn list_rules(&self, chain: &str) -> Result>; - fn remove_rule(&self, chain: &str, rule: &str) -> Result<()>; + fn add_rule(&self, chain: &str, rule: &str) -> AgentResult<()>; + fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> AgentResult<()>; + fn list_rules(&self, chain: &str) -> AgentResult>; + fn remove_rule(&self, chain: &str, rule: &str) -> AgentResult<()>; } #[derive(Clone)] @@ -152,8 +152,13 @@ impl IPTables for IPTablesWrapper { } } - #[tracing::instrument(level = tracing::Level::TRACE, skip(self), ret, fields(table_name=%self.table_name))] - fn create_chain(&self, name: &str) -> Result<()> { + #[tracing::instrument( + level = tracing::Level::TRACE, + skip(self), + ret, + fields(table_name=%self.table_name + ))] + fn create_chain(&self, name: &str) -> AgentResult<()> { self.tables .new_chain(self.table_name, name) .map_err(|e| AgentError::IPTablesError(e.to_string()))?; @@ -165,7 +170,7 @@ impl IPTables for IPTablesWrapper { } #[tracing::instrument(level = "trace")] - fn remove_chain(&self, name: &str) -> Result<()> { + fn remove_chain(&self, name: &str) -> AgentResult<()> { self.tables .flush_chain(self.table_name, name) .map_err(|e| AgentError::IPTablesError(e.to_string()))?; @@ -177,28 +182,28 @@ impl IPTables for IPTablesWrapper { } #[tracing::instrument(level = "trace", ret)] - fn add_rule(&self, chain: &str, rule: &str) -> Result<()> { + fn add_rule(&self, chain: &str, rule: &str) -> AgentResult<()> { self.tables .append(self.table_name, chain, rule) .map_err(|e| AgentError::IPTablesError(e.to_string())) } #[tracing::instrument(level = "trace", ret)] - fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> Result<()> { + fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> AgentResult<()> { self.tables .insert(self.table_name, chain, rule, index) .map_err(|e| AgentError::IPTablesError(e.to_string())) } #[tracing::instrument(level = "trace")] - fn list_rules(&self, chain: &str) -> Result> { + fn list_rules(&self, chain: &str) -> AgentResult> { self.tables .list(self.table_name, chain) .map_err(|e| AgentError::IPTablesError(e.to_string())) } #[tracing::instrument(level = "trace")] - fn remove_rule(&self, chain: &str, rule: &str) -> Result<()> { + fn remove_rule(&self, chain: &str, rule: &str) -> AgentResult<()> { self.tables .delete(self.table_name, chain, rule) .map_err(|e| AgentError::IPTablesError(e.to_string())) @@ -233,7 +238,7 @@ where flush_connections: bool, pod_ips: Option<&str>, ipv6: bool, - ) -> Result { + ) -> AgentResult { let ipt = Arc::new(ipt); let mut redirect = if let Some(vendor) = MeshVendor::detect(ipt.as_ref())? { @@ -265,7 +270,7 @@ where Ok(Self { redirect }) } - pub(crate) async fn load(ipt: IPT, flush_connections: bool) -> Result { + pub(crate) async fn load(ipt: IPT, flush_connections: bool) -> AgentResult { let ipt = Arc::new(ipt); let mut redirect = if let Some(vendor) = MeshVendor::detect(ipt.as_ref())? { @@ -299,7 +304,7 @@ where &self, redirected_port: Port, target_port: Port, - ) -> Result<()> { + ) -> AgentResult<()> { self.redirect .add_redirect(redirected_port, target_port) .await @@ -314,13 +319,13 @@ where &self, redirected_port: Port, target_port: Port, - ) -> Result<()> { + ) -> AgentResult<()> { self.redirect .remove_redirect(redirected_port, target_port) .await } - pub(crate) async fn cleanup(&self) -> Result<()> { + pub(crate) async fn cleanup(&self) -> AgentResult<()> { self.redirect.unmount_entrypoint().await } } diff --git a/mirrord/agent/src/steal/ip_tables/chain.rs b/mirrord/agent/src/steal/ip_tables/chain.rs index c5bc6d65404..c1c34715c85 100644 --- a/mirrord/agent/src/steal/ip_tables/chain.rs +++ b/mirrord/agent/src/steal/ip_tables/chain.rs @@ -4,7 +4,7 @@ use std::sync::{ }; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, steal::ip_tables::IPTables, }; @@ -19,7 +19,7 @@ impl IPTableChain where IPT: IPTables, { - pub fn create(inner: Arc, chain_name: String) -> Result { + pub fn create(inner: Arc, chain_name: String) -> AgentResult { inner.create_chain(&chain_name)?; // Start with 1 because the chain will allways have atleast `-A ` as a rule @@ -32,7 +32,7 @@ where }) } - pub fn load(inner: Arc, chain_name: String) -> Result { + pub fn load(inner: Arc, chain_name: String) -> AgentResult { let existing_rules = inner.list_rules(&chain_name)?.len(); if existing_rules == 0 { @@ -59,7 +59,7 @@ where &self.inner } - pub fn add_rule(&self, rule: &str) -> Result { + pub fn add_rule(&self, rule: &str) -> AgentResult { self.inner .insert_rule( &self.chain_name, @@ -72,7 +72,7 @@ where }) } - pub fn remove_rule(&self, rule: &str) -> Result<()> { + pub fn remove_rule(&self, rule: &str) -> AgentResult<()> { self.inner.remove_rule(&self.chain_name, rule)?; self.chain_size.fetch_sub(1, Ordering::Relaxed); diff --git a/mirrord/agent/src/steal/ip_tables/flush_connections.rs b/mirrord/agent/src/steal/ip_tables/flush_connections.rs index 6675a40651f..c0f19c20b8d 100644 --- a/mirrord/agent/src/steal/ip_tables/flush_connections.rs +++ b/mirrord/agent/src/steal/ip_tables/flush_connections.rs @@ -13,7 +13,7 @@ use tokio::process::Command; use tracing::warn; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{chain::IPTableChain, redirect::Redirect, IPTables, IPTABLE_INPUT}, }; @@ -33,7 +33,7 @@ where const ENTRYPOINT: &'static str = "INPUT"; #[tracing::instrument(level = "trace", skip(ipt, inner))] - pub fn create(ipt: Arc, inner: Box) -> Result { + pub fn create(ipt: Arc, inner: Box) -> AgentResult { let managed = IPTableChain::create(ipt.with_table("filter").into(), IPTABLE_INPUT.to_string())?; @@ -48,7 +48,7 @@ where } #[tracing::instrument(level = "trace", skip(ipt, inner))] - pub fn load(ipt: Arc, inner: Box) -> Result { + pub fn load(ipt: Arc, inner: Box) -> AgentResult { let managed = IPTableChain::load(ipt.with_table("filter").into(), IPTABLE_INPUT.to_string())?; @@ -63,7 +63,7 @@ where T: Redirect + Send + Sync, { #[tracing::instrument(level = "trace", skip(self), ret)] - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.inner.mount_entrypoint().await?; self.managed.inner().add_rule( @@ -75,7 +75,7 @@ where } #[tracing::instrument(level = "trace", skip(self), ret)] - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.inner.unmount_entrypoint().await?; self.managed.inner().remove_rule( @@ -87,7 +87,7 @@ where } #[tracing::instrument(level = "trace", skip(self), ret)] - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.inner .add_redirect(redirected_port, target_port) .await?; @@ -115,7 +115,7 @@ where } #[tracing::instrument(level = "trace", skip(self), ret)] - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.inner .remove_redirect(redirected_port, target_port) .await?; diff --git a/mirrord/agent/src/steal/ip_tables/mesh.rs b/mirrord/agent/src/steal/ip_tables/mesh.rs index 88fdff5d0b1..1a3e5acbe62 100644 --- a/mirrord/agent/src/steal/ip_tables/mesh.rs +++ b/mirrord/agent/src/steal/ip_tables/mesh.rs @@ -5,7 +5,7 @@ use fancy_regex::Regex; use mirrord_protocol::{MeshVendor, Port}; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{ output::OutputRedirect, prerouting::PreroutingRedirect, redirect::Redirect, IPTables, IPTABLE_MESH, @@ -29,7 +29,7 @@ impl MeshRedirect where IPT: IPTables, { - pub fn create(ipt: Arc, vendor: MeshVendor, pod_ips: Option<&str>) -> Result { + pub fn create(ipt: Arc, vendor: MeshVendor, pod_ips: Option<&str>) -> AgentResult { let prerouting = PreroutingRedirect::create(ipt.clone())?; for port in Self::get_skip_ports(&ipt, &vendor)? { @@ -45,7 +45,7 @@ where }) } - pub fn load(ipt: Arc, vendor: MeshVendor) -> Result { + pub fn load(ipt: Arc, vendor: MeshVendor) -> AgentResult { let prerouting = PreroutingRedirect::load(ipt.clone())?; let output = OutputRedirect::load(ipt, IPTABLE_MESH.to_string())?; @@ -56,7 +56,7 @@ where }) } - fn get_skip_ports(ipt: &IPT, vendor: &MeshVendor) -> Result> { + fn get_skip_ports(ipt: &IPT, vendor: &MeshVendor) -> AgentResult> { let chain_name = vendor.input_chain(); let lookup_regex = if let Some(regex) = vendor.skip_ports_regex() { regex @@ -86,21 +86,21 @@ impl Redirect for MeshRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.prerouting.mount_entrypoint().await?; self.output.mount_entrypoint().await?; Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.prerouting.unmount_entrypoint().await?; self.output.unmount_entrypoint().await?; Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { if self.vendor != MeshVendor::IstioCni { self.prerouting .add_redirect(redirected_port, target_port) @@ -113,7 +113,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { if self.vendor != MeshVendor::IstioCni { self.prerouting .remove_redirect(redirected_port, target_port) @@ -129,13 +129,13 @@ where /// Extends the [`MeshVendor`] type with methods that are only relevant for the agent. pub(super) trait MeshVendorExt: Sized { - fn detect(ipt: &IPT) -> Result>; + fn detect(ipt: &IPT) -> AgentResult>; fn input_chain(&self) -> &str; fn skip_ports_regex(&self) -> Option<&Regex>; } impl MeshVendorExt for MeshVendor { - fn detect(ipt: &IPT) -> Result> { + fn detect(ipt: &IPT) -> AgentResult> { if let Ok(val) = std::env::var("MIRRORD_AGENT_ISTIO_CNI") && val.to_lowercase() == "true" { diff --git a/mirrord/agent/src/steal/ip_tables/mesh/istio.rs b/mirrord/agent/src/steal/ip_tables/mesh/istio.rs index cd3d4b06fa9..01e513a6bf9 100644 --- a/mirrord/agent/src/steal/ip_tables/mesh/istio.rs +++ b/mirrord/agent/src/steal/ip_tables/mesh/istio.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use mirrord_protocol::Port; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{ output::OutputRedirect, prerouting::PreroutingRedirect, redirect::Redirect, IPTables, IPTABLE_IPV4_ROUTE_LOCALNET_ORIGINAL, IPTABLE_MESH, @@ -20,14 +20,14 @@ impl AmbientRedirect where IPT: IPTables, { - pub fn create(ipt: Arc, pod_ips: Option<&str>) -> Result { + pub fn create(ipt: Arc, pod_ips: Option<&str>) -> AgentResult { let prerouting = PreroutingRedirect::create(ipt.clone())?; let output = OutputRedirect::create(ipt, IPTABLE_MESH.to_string(), pod_ips)?; Ok(AmbientRedirect { prerouting, output }) } - pub fn load(ipt: Arc) -> Result { + pub fn load(ipt: Arc) -> AgentResult { let prerouting = PreroutingRedirect::load(ipt.clone())?; let output = OutputRedirect::load(ipt, IPTABLE_MESH.to_string())?; @@ -40,7 +40,7 @@ impl Redirect for AmbientRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { tokio::fs::write("/proc/sys/net/ipv4/conf/all/route_localnet", "1".as_bytes()).await?; self.prerouting.mount_entrypoint().await?; @@ -49,7 +49,7 @@ where Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.prerouting.unmount_entrypoint().await?; self.output.unmount_entrypoint().await?; @@ -62,7 +62,7 @@ where Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .add_redirect(redirected_port, target_port) .await?; @@ -73,7 +73,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .remove_redirect(redirected_port, target_port) .await?; diff --git a/mirrord/agent/src/steal/ip_tables/output.rs b/mirrord/agent/src/steal/ip_tables/output.rs index 2286469c00c..9eebad0c9ae 100644 --- a/mirrord/agent/src/steal/ip_tables/output.rs +++ b/mirrord/agent/src/steal/ip_tables/output.rs @@ -6,7 +6,7 @@ use nix::unistd::getgid; use tracing::warn; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{chain::IPTableChain, IPTables, Redirect}, }; @@ -20,8 +20,8 @@ where { const ENTRYPOINT: &'static str = "OUTPUT"; - #[tracing::instrument(skip(ipt), level = tracing::Level::TRACE)] - pub fn create(ipt: Arc, chain_name: String, pod_ips: Option<&str>) -> Result { + #[tracing::instrument(level = tracing::Level::TRACE, skip(ipt), err)] + pub fn create(ipt: Arc, chain_name: String, pod_ips: Option<&str>) -> AgentResult { let managed = IPTableChain::create(ipt, chain_name.clone()).inspect_err( |e| tracing::error!(%e, "Could not create iptables chain \"{chain_name}\"."), )?; @@ -42,7 +42,7 @@ where Ok(OutputRedirect { managed }) } - pub fn load(ipt: Arc, chain_name: String) -> Result { + pub fn load(ipt: Arc, chain_name: String) -> AgentResult { let managed = IPTableChain::load(ipt, chain_name)?; Ok(OutputRedirect { managed }) @@ -56,7 +56,7 @@ impl Redirect for OutputRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { if USE_INSERT { self.managed.inner().insert_rule( Self::ENTRYPOINT, @@ -73,7 +73,7 @@ where Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.managed.inner().remove_rule( Self::ENTRYPOINT, &format!("-j {}", self.managed.chain_name()), @@ -82,7 +82,7 @@ where Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!( "-o lo -m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}" ); @@ -92,7 +92,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!( "-o lo -m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}" ); diff --git a/mirrord/agent/src/steal/ip_tables/prerouting.rs b/mirrord/agent/src/steal/ip_tables/prerouting.rs index 486b0ca1b51..29d5de06103 100644 --- a/mirrord/agent/src/steal/ip_tables/prerouting.rs +++ b/mirrord/agent/src/steal/ip_tables/prerouting.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use mirrord_protocol::Port; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{chain::IPTableChain, IPTables, Redirect, IPTABLE_PREROUTING}, }; @@ -18,13 +18,13 @@ where { const ENTRYPOINT: &'static str = "PREROUTING"; - pub fn create(ipt: Arc) -> Result { + pub fn create(ipt: Arc) -> AgentResult { let managed = IPTableChain::create(ipt, IPTABLE_PREROUTING.to_string())?; Ok(PreroutingRedirect { managed }) } - pub fn load(ipt: Arc) -> Result { + pub fn load(ipt: Arc) -> AgentResult { let managed = IPTableChain::load(ipt, IPTABLE_PREROUTING.to_string())?; Ok(PreroutingRedirect { managed }) @@ -36,7 +36,7 @@ impl Redirect for PreroutingRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.managed.inner().add_rule( Self::ENTRYPOINT, &format!("-j {}", self.managed.chain_name()), @@ -45,7 +45,7 @@ where Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.managed.inner().remove_rule( Self::ENTRYPOINT, &format!("-j {}", self.managed.chain_name()), @@ -54,7 +54,7 @@ where Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!("-m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}"); @@ -63,7 +63,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!("-m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}"); diff --git a/mirrord/agent/src/steal/ip_tables/redirect.rs b/mirrord/agent/src/steal/ip_tables/redirect.rs index d18aeb1d7ea..fe52d90fc1e 100644 --- a/mirrord/agent/src/steal/ip_tables/redirect.rs +++ b/mirrord/agent/src/steal/ip_tables/redirect.rs @@ -2,17 +2,17 @@ use async_trait::async_trait; use enum_dispatch::enum_dispatch; use mirrord_protocol::Port; -use crate::error::Result; +use crate::error::AgentResult; #[async_trait] #[enum_dispatch] pub(crate) trait Redirect { - async fn mount_entrypoint(&self) -> Result<()>; + async fn mount_entrypoint(&self) -> AgentResult<()>; - async fn unmount_entrypoint(&self) -> Result<()>; + async fn unmount_entrypoint(&self) -> AgentResult<()>; /// Create port redirection - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()>; + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()>; /// Remove port redirection - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()>; + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()>; } diff --git a/mirrord/agent/src/steal/ip_tables/standard.rs b/mirrord/agent/src/steal/ip_tables/standard.rs index 3302b05c02e..47b9bf0c0af 100644 --- a/mirrord/agent/src/steal/ip_tables/standard.rs +++ b/mirrord/agent/src/steal/ip_tables/standard.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use mirrord_protocol::Port; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{ output::OutputRedirect, prerouting::PreroutingRedirect, IPTables, Redirect, IPTABLE_STANDARD, @@ -20,14 +20,14 @@ impl StandardRedirect where IPT: IPTables, { - pub fn create(ipt: Arc, pod_ips: Option<&str>) -> Result { + pub fn create(ipt: Arc, pod_ips: Option<&str>) -> AgentResult { let prerouting = PreroutingRedirect::create(ipt.clone())?; let output = OutputRedirect::create(ipt, IPTABLE_STANDARD.to_string(), pod_ips)?; Ok(StandardRedirect { prerouting, output }) } - pub fn load(ipt: Arc) -> Result { + pub fn load(ipt: Arc) -> AgentResult { let prerouting = PreroutingRedirect::load(ipt.clone())?; let output = OutputRedirect::load(ipt, IPTABLE_STANDARD.to_string())?; @@ -42,21 +42,21 @@ impl Redirect for StandardRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.prerouting.mount_entrypoint().await?; self.output.mount_entrypoint().await?; Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.prerouting.unmount_entrypoint().await?; self.output.unmount_entrypoint().await?; Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .add_redirect(redirected_port, target_port) .await?; @@ -67,7 +67,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .remove_redirect(redirected_port, target_port) .await?; diff --git a/mirrord/agent/src/steal/subscriptions.rs b/mirrord/agent/src/steal/subscriptions.rs index 0468719bc9c..901ecd725ef 100644 --- a/mirrord/agent/src/steal/subscriptions.rs +++ b/mirrord/agent/src/steal/subscriptions.rs @@ -16,7 +16,11 @@ use super::{ http::HttpFilter, ip_tables::{new_ip6tables, new_iptables, IPTablesWrapper, SafeIpTables}, }; -use crate::{error::AgentError, util::ClientId}; +use crate::{ + error::{AgentError, AgentResult}, + metrics::{STEAL_FILTERED_PORT_SUBSCRIPTION, STEAL_UNFILTERED_PORT_SUBSCRIPTION}, + util::ClientId, +}; /// For stealing incoming TCP connections. #[async_trait::async_trait] @@ -149,7 +153,7 @@ impl IpTablesRedirector { flush_connections: bool, pod_ips: Option, support_ipv6: bool, - ) -> Result { + ) -> AgentResult { let (pod_ips4, pod_ips6) = pod_ips.map_or_else( || (None, None), |ips| { @@ -310,6 +314,13 @@ pub struct PortSubscriptions { subscriptions: HashMap, } +impl Drop for PortSubscriptions { + fn drop(&mut self) { + STEAL_FILTERED_PORT_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + STEAL_UNFILTERED_PORT_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + } +} + impl PortSubscriptions { /// Create an empty instance of this struct. /// @@ -351,7 +362,16 @@ impl PortSubscriptions { ) -> Result, R::Error> { let add_redirect = match self.subscriptions.entry(port) { Entry::Occupied(mut e) => { + let filtered = filter.is_some(); if e.get_mut().try_extend(client_id, filter) { + if filtered { + STEAL_FILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } else { + STEAL_UNFILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + Ok(false) } else { Err(ResponseError::PortAlreadyStolen(port)) @@ -359,6 +379,14 @@ impl PortSubscriptions { } Entry::Vacant(e) => { + if filter.is_some() { + STEAL_FILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } else { + STEAL_UNFILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + e.insert(PortSubscription::new(client_id, filter)); Ok(true) } @@ -395,11 +423,17 @@ impl PortSubscriptions { let remove_redirect = match e.get_mut() { PortSubscription::Unfiltered(subscribed_client) if *subscribed_client == client_id => { e.remove(); + STEAL_UNFILTERED_PORT_SUBSCRIPTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + true } PortSubscription::Unfiltered(..) => false, PortSubscription::Filtered(filters) => { - filters.remove(&client_id); + if filters.remove(&client_id).is_some() { + STEAL_FILTERED_PORT_SUBSCRIPTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } if filters.is_empty() { e.remove(); diff --git a/mirrord/agent/src/util.rs b/mirrord/agent/src/util.rs index 9dcbc6cd892..c5a002979e9 100644 --- a/mirrord/agent/src/util.rs +++ b/mirrord/agent/src/util.rs @@ -8,11 +8,12 @@ use std::{ thread::JoinHandle, }; +use futures::{future::BoxFuture, FutureExt}; use tokio::sync::mpsc; use tracing::error; use crate::{ - error::AgentError, + error::AgentResult, namespace::{set_namespace, NamespaceType}, }; @@ -151,7 +152,7 @@ where /// Many of the agent's TCP/UDP connections require that they're made from the `pid`'s namespace to /// work. #[tracing::instrument(level = "trace")] -pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> Result<(), AgentError> { +pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> AgentResult<()> { if let Some(pid) = pid { Ok(set_namespace(pid, NamespaceType::Net).inspect_err(|fail| { error!("Failed setting pid {pid:#?} namespace {namespace:#?} with {fail:#?}") @@ -162,27 +163,25 @@ pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> Result<(), A } /// [`Future`] that resolves to [`ClientId`] when the client drops their [`mpsc::Receiver`]. -pub(crate) struct ChannelClosedFuture { - tx: mpsc::Sender, - client_id: ClientId, -} +pub(crate) struct ChannelClosedFuture(BoxFuture<'static, ClientId>); + +impl ChannelClosedFuture { + pub(crate) fn new(tx: mpsc::Sender, client_id: ClientId) -> Self { + let future = async move { + tx.closed().await; + client_id + } + .boxed(); -impl ChannelClosedFuture { - pub(crate) fn new(tx: mpsc::Sender, client_id: ClientId) -> Self { - Self { tx, client_id } + Self(future) } } -impl Future for ChannelClosedFuture { +impl Future for ChannelClosedFuture { type Output = ClientId; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let client_id = self.client_id; - - let future = std::pin::pin!(self.get_mut().tx.closed()); - std::task::ready!(future.poll(cx)); - - Poll::Ready(client_id) + self.get_mut().0.as_mut().poll(cx) } } @@ -264,3 +263,52 @@ mod subscription_tests { assert_eq!(subscriptions.get_subscribed_topics(), Vec::::new()); } } + +#[cfg(test)] +mod channel_closed_tests { + use std::time::Duration; + + use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; + use rstest::rstest; + + use super::*; + + /// Verifies that [`ChannelClosedFuture`] resolves when the related [`mpsc::Receiver`] is + /// dropped. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn channel_closed_resolves() { + let (tx, rx) = mpsc::channel::<()>(1); + let future = ChannelClosedFuture::new(tx, 0); + std::mem::drop(rx); + assert_eq!(future.await, 0); + } + + /// Verifies that [`ChannelClosedFuture`] works fine when used in [`FuturesUnordered`]. + /// + /// The future used to hold the [`mpsc::Sender`] and call poll [`mpsc::Sender::closed`] in it's + /// [`Future::poll`] implementation. This worked fine when the future was used in a simple way + /// ([`channel_closed_resolves`] test was passing). + /// + /// However, [`FuturesUnordered::next`] was hanging forever due to [`mpsc::Sender::closed`] + /// implementation details. + /// + /// New implementation of [`ChannelClosedFuture`] uses a [`BoxFuture`] internally, which works + /// fine. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn channel_closed_works_in_futures_unordered() { + let mut unordered: FuturesUnordered = FuturesUnordered::new(); + + let (tx, rx) = mpsc::channel::<()>(1); + let future = ChannelClosedFuture::new(tx, 0); + + unordered.push(future); + + assert!(unordered.next().now_or_never().is_none()); + std::mem::drop(rx); + assert_eq!(unordered.next().await.unwrap(), 0); + } +} diff --git a/mirrord/agent/src/vpn.rs b/mirrord/agent/src/vpn.rs index dd8c3a5133f..d7d30d5ca6f 100644 --- a/mirrord/agent/src/vpn.rs +++ b/mirrord/agent/src/vpn.rs @@ -17,7 +17,7 @@ use tokio::{ }; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, util::run_thread_in_namespace, watched_task::{TaskStatus, WatchedTask}, }; @@ -75,7 +75,7 @@ impl VpnApi { /// Sends the [`ClientVpn`] message to the background task. #[tracing::instrument(level = "trace", skip(self))] - pub(crate) async fn layer_message(&mut self, message: ClientVpn) -> Result<()> { + pub(crate) async fn layer_message(&mut self, message: ClientVpn) -> AgentResult<()> { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { @@ -84,7 +84,7 @@ impl VpnApi { } /// Receives a [`ServerVpn`] message from the background task. - pub(crate) async fn daemon_message(&mut self) -> Result { + pub(crate) async fn daemon_message(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), None => Err(self.task_status.unwrap_err().await), @@ -121,7 +121,7 @@ impl AsyncRawSocket { } } -async fn create_raw_socket() -> Result { +async fn create_raw_socket() -> AgentResult { let index = nix::net::if_::if_nametoindex("eth0") .map_err(|err| AgentError::VpnError(err.to_string()))?; @@ -139,7 +139,7 @@ async fn create_raw_socket() -> Result { } #[tracing::instrument(level = "debug", ret)] -async fn resolve_interface() -> Result<(IpAddr, IpAddr, IpAddr)> { +async fn resolve_interface() -> AgentResult<(IpAddr, IpAddr, IpAddr)> { // Connect to a remote address so we can later get the default network interface. let temporary_socket = UdpSocket::bind("0.0.0.0:0").await?; temporary_socket.connect("8.8.8.8:53").await?; @@ -209,7 +209,7 @@ impl fmt::Debug for VpnTask { } } -fn interface_index_to_sock_addr(index: i32) -> Result { +fn interface_index_to_sock_addr(index: i32) -> AgentResult { let mut addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; let len = std::mem::size_of::() as libc::socklen_t; let macs = procfs::net::arp().map_err(|err| AgentError::VpnError(err.to_string()))?; @@ -245,7 +245,7 @@ impl VpnTask { } #[allow(clippy::indexing_slicing)] - async fn run(mut self) -> Result<()> { + async fn run(mut self) -> AgentResult<()> { // so host won't respond with RST to our packets. // TODO: need to do it for UDP as well to avoid ICMP unreachable. let output = std::process::Command::new("iptables") @@ -318,7 +318,7 @@ impl VpnTask { &mut self, message: ClientVpn, network_configuration: &NetworkConfiguration, - ) -> Result<()> { + ) -> AgentResult<()> { match message { // We make connection to the requested address, split the stream into halves with // `io::split`, and put them into respective maps. diff --git a/mirrord/agent/src/watched_task.rs b/mirrord/agent/src/watched_task.rs index 0212f279163..2e7370b262c 100644 --- a/mirrord/agent/src/watched_task.rs +++ b/mirrord/agent/src/watched_task.rs @@ -2,7 +2,7 @@ use std::future::Future; use tokio::sync::watch::{self, Receiver, Sender}; -use crate::error::AgentError; +use crate::error::{AgentError, AgentResult}; /// A shared clonable view on a background task's status. #[derive(Debug, Clone)] @@ -83,7 +83,7 @@ impl WatchedTask { impl WatchedTask where - T: Future>, + T: Future>, { /// Execute the wrapped task. /// Store its result in the inner [`TaskStatus`]. @@ -94,9 +94,21 @@ where } #[cfg(test)] -mod test { +pub(crate) mod test { use super::*; + impl TaskStatus { + pub fn dummy( + task_name: &'static str, + result_rx: Receiver>>, + ) -> Self { + Self { + task_name, + result_rx, + } + } + } + #[tokio::test] async fn simple_successful() { let task = WatchedTask::new("task", async move { Ok(()) }); diff --git a/mirrord/cli/src/verify_config.rs b/mirrord/cli/src/verify_config.rs index 56d719dacbc..4628118360b 100644 --- a/mirrord/cli/src/verify_config.rs +++ b/mirrord/cli/src/verify_config.rs @@ -8,7 +8,8 @@ use mirrord_config::{ feature::FeatureConfig, target::{ cron_job::CronJobTarget, deployment::DeploymentTarget, job::JobTarget, pod::PodTarget, - rollout::RolloutTarget, stateful_set::StatefulSetTarget, Target, TargetConfig, + rollout::RolloutTarget, service::ServiceTarget, stateful_set::StatefulSetTarget, Target, + TargetConfig, }, }; use serde::Serialize; @@ -43,6 +44,9 @@ enum VerifiedTarget { #[serde(untagged)] StatefulSet(StatefulSetTarget), + + #[serde(untagged)] + Service(ServiceTarget), } impl From for VerifiedTarget { @@ -54,6 +58,7 @@ impl From for VerifiedTarget { Target::Job(target) => Self::Job(target), Target::CronJob(target) => Self::CronJob(target), Target::StatefulSet(target) => Self::StatefulSet(target), + Target::Service(target) => Self::Service(target), Target::Targetless => Self::Targetless, } } @@ -69,6 +74,7 @@ impl From for TargetType { VerifiedTarget::Job(_) => TargetType::Job, VerifiedTarget::CronJob(_) => TargetType::CronJob, VerifiedTarget::StatefulSet(_) => TargetType::StatefulSet, + VerifiedTarget::Service(_) => TargetType::Service, } } } @@ -99,6 +105,7 @@ enum TargetType { Job, CronJob, StatefulSet, + Service, } impl core::fmt::Display for TargetType { @@ -111,6 +118,7 @@ impl core::fmt::Display for TargetType { TargetType::Job => "job", TargetType::CronJob => "cronjob", TargetType::StatefulSet => "statefulset", + TargetType::Service => "service", }; f.write_str(stringifed) @@ -127,6 +135,7 @@ impl TargetType { Self::Job, Self::CronJob, Self::StatefulSet, + Self::Service, ] .into_iter() } @@ -136,6 +145,7 @@ impl TargetType { Self::Targetless | Self::Rollout => !config.copy_target.enabled, Self::Pod => !(config.copy_target.enabled && config.copy_target.scale_down), Self::Job | Self::CronJob => config.copy_target.enabled, + Self::Service => !config.copy_target.enabled, Self::Deployment | Self::StatefulSet => true, } } diff --git a/mirrord/config/configuration.md b/mirrord/config/configuration.md index 8e8b9ea6aee..20d3dbc0e0e 100644 --- a/mirrord/config/configuration.md +++ b/mirrord/config/configuration.md @@ -68,7 +68,8 @@ configuration file containing all fields. "communication_timeout": 30, "startup_timeout": 360, "network_interface": "eth0", - "flush_connections": true + "flush_connections": true, + "metrics": "0.0.0.0:9000", }, "feature": { "env": { @@ -166,7 +167,11 @@ Allows setting up custom annotations for the agent Job and Pod. ```json { - "annotations": { "cats.io/inject": "enabled" } + "annotations": { + "cats.io/inject": "enabled" + "prometheus.io/scrape": "true", + "prometheus.io/port": "9000" + } } ``` @@ -299,6 +304,19 @@ with `RUST_LOG`. } ``` +### agent.metrics {#agent-metrics} + +Enables prometheus metrics for the agent pod. + +You might need to add annotations to the agent pod depending on how prometheus is +configured to scrape for metrics. + +```json +{ + "metrics": "0.0.0.0:9000" +} +``` + ### agent.namespace {#agent-namespace} Namespace where the agent shall live. @@ -376,12 +394,14 @@ Defaults to `60`. ### agent.tolerations {#agent-tolerations} -Set pod tolerations. (not with ephemeral agents) -Default is +Set pod tolerations. (not with ephemeral agents). + +Defaults to `operator: Exists`. + ```json [ { - "operator": "Exists" + "key": "meow", "operator": "Exists", "effect": "NoSchedule" } ] ``` @@ -1543,13 +1563,23 @@ Accepts a single value, or multiple values separated by `;`. ## target {#root-target} -Specifies the target and namespace to mirror, see [`path`](#target-path) for a list of -accepted values for the `target` option. +Specifies the target and namespace to target. The simplified configuration supports: -- `pod/{sample-pod}/[container]/{sample-container}`; -- `deployment/{sample-deployment}/[container]/{sample-container}`; +- `targetless` +- `pod/{pod-name}[/container/{container-name}]`; +- `deployment/{deployment-name}[/container/{container-name}]`; +- `rollout/{rollout-name}[/container/{container-name}]`; +- `job/{job-name}[/container/{container-name}]`; +- `cronjob/{cronjob-name}[/container/{container-name}]`; +- `statefulset/{statefulset-name}[/container/{container-name}]`; +- `service/{service-name}[/container/{container-name}]`; + +Please note that: + +- `job`, `cronjob`, `statefulset` and `service` targets require the mirrord Operator +- `job` and `cronjob` targets require the [`copy_target`](#feature-copy_target) feature Shortened setup: @@ -1559,38 +1589,63 @@ Shortened setup: } ``` +The setup above will result in a session targeting the `bear-pod` Kubernetes pod +in the user's default namespace. A target container will be chosen by mirrord. + +Shortened setup with target container: + +```json +{ + "target": "pod/bear-pod/container/bear-pod-container" +} +``` + +The setup above will result in a session targeting the `bear-pod-container` container +in the `bear-pod` Kubernetes pod in the user's default namespace. + Complete setup: ```json { "target": { "path": { - "pod": "bear-pod" + "pod": "bear-pod", + "container": "bear-pod-container" }, - "namespace": "default" + "namespace": "bear-pod-namespace" } } ``` +The setup above will result in a session targeting the `bear-pod-container` container +in the `bear-pod` Kubernetes pod in the `bear-pod-namespace` namespace. + ### target.namespace {#target-namespace} Namespace where the target lives. -Defaults to `"default"`. +Defaults to the Kubernetes user's default namespace (defined in Kubernetes context). ### target.path {#target-path} -Specifies the running pod (or deployment) to mirror. +Specifies the Kubernetes resource to target. -Note: Deployment level steal/mirroring is available only in mirrord for Teams -If you use it without it, it will choose a random pod replica to work with. +Note: targeting services and whole workloads is available only in mirrord for Teams. +If you target a workload without the mirrord Operator, it will choose a random pod replica +to work with. Supports: -- `pod/{sample-pod}`; -- `deployment/{sample-deployment}`; -- `container/{sample-container}`; -- `containername/{sample-container}`. -- `job/{sample-job}` (only when [`copy_target`](#feature-copy_target) is enabled). +- `targetless` +- `pod/{pod-name}[/container/{container-name}]`; +- `deployment/{deployment-name}[/container/{container-name}]`; +- `rollout/{rollout-name}[/container/{container-name}]`; +- `job/{job-name}[/container/{container-name}]`; (requires mirrord Operator and the + [`copy_target`](#feature-copy_target) feature) +- `cronjob/{cronjob-name}[/container/{container-name}]`; (requires mirrord Operator and the + [`copy_target`](#feature-copy_target) feature) +- `statefulset/{statefulset-name}[/container/{container-name}]`; (requires mirrord + Operator) +- `service/{service-name}[/container/{container-name}]`; (requires mirrord Operator) ## telemetry {#root-telemetry} Controls whether or not mirrord sends telemetry data to MetalBear cloud. diff --git a/mirrord/config/src/agent.rs b/mirrord/config/src/agent.rs index 3dff5adfefc..caf1d131a25 100644 --- a/mirrord/config/src/agent.rs +++ b/mirrord/config/src/agent.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt, path::Path}; +use std::{collections::HashMap, fmt, net::SocketAddr, path::Path}; use k8s_openapi::api::core::v1::{ResourceRequirements, Toleration}; use mirrord_analytics::CollectAnalytics; @@ -240,12 +240,14 @@ pub struct AgentConfig { /// ### agent.tolerations {#agent-tolerations} /// - /// Set pod tolerations. (not with ephemeral agents) - /// Default is + /// Set pod tolerations. (not with ephemeral agents). + /// + /// Defaults to `operator: Exists`. + /// /// ```json /// [ /// { - /// "operator": "Exists" + /// "key": "meow", "operator": "Exists", "effect": "NoSchedule" /// } /// ] /// ``` @@ -322,7 +324,11 @@ pub struct AgentConfig { /// /// ```json /// { - /// "annotations": { "cats.io/inject": "enabled" } + /// "annotations": { + /// "cats.io/inject": "enabled" + /// "prometheus.io/scrape": "true", + /// "prometheus.io/port": "9000" + /// } /// } /// ``` pub annotations: Option>, @@ -350,6 +356,20 @@ pub struct AgentConfig { /// ``` pub service_account: Option, + /// ### agent.metrics {#agent-metrics} + /// + /// Enables prometheus metrics for the agent pod. + /// + /// You might need to add annotations to the agent pod depending on how prometheus is + /// configured to scrape for metrics. + /// + /// ```json + /// { + /// "metrics": "0.0.0.0:9000" + /// } + /// ``` + pub metrics: Option, + /// /// Create an agent that returns an error after accepting the first client. For testing /// purposes. Only supported with job agents (not with ephemeral agents). diff --git a/mirrord/config/src/lib.rs b/mirrord/config/src/lib.rs index d3f8fae7bc6..d0384d834b5 100644 --- a/mirrord/config/src/lib.rs +++ b/mirrord/config/src/lib.rs @@ -119,7 +119,8 @@ pub static MIRRORD_RESOLVED_CONFIG_ENV: &str = "MIRRORD_RESOLVED_CONFIG"; /// "communication_timeout": 30, /// "startup_timeout": 360, /// "network_interface": "eth0", -/// "flush_connections": true +/// "flush_connections": true, +/// "metrics": "0.0.0.0:9000", /// }, /// "feature": { /// "env": { @@ -539,6 +540,14 @@ impl LayerConfig { )); } + if matches!(self.target.path, Some(Target::Service(..))) { + return Err(ConfigError::Conflict( + "The copy target feature is not yet supported with service targets, \ + please either disable this option or specify an exact workload covered by this service." + .into() + )); + } + if !self.feature.network.incoming.is_steal() { context.add_warning( "Using copy target feature without steal mode \ diff --git a/mirrord/config/src/target.rs b/mirrord/config/src/target.rs index e262ee4cafc..d62e0ae1053 100644 --- a/mirrord/config/src/target.rs +++ b/mirrord/config/src/target.rs @@ -5,9 +5,11 @@ use cron_job::CronJobTarget; use mirrord_analytics::CollectAnalytics; use schemars::{gen::SchemaGenerator, schema::SchemaObject, JsonSchema}; use serde::{Deserialize, Serialize}; -use stateful_set::StatefulSetTarget; -use self::{deployment::DeploymentTarget, job::JobTarget, pod::PodTarget, rollout::RolloutTarget}; +use self::{ + deployment::DeploymentTarget, job::JobTarget, pod::PodTarget, rollout::RolloutTarget, + service::ServiceTarget, stateful_set::StatefulSetTarget, +}; use crate::{ config::{ from_env::{FromEnv, FromEnvWithError}, @@ -22,6 +24,7 @@ pub mod deployment; pub mod job; pub mod pod; pub mod rollout; +pub mod service; pub mod stateful_set; #[derive(Deserialize, PartialEq, Eq, Clone, Debug, JsonSchema)] @@ -65,19 +68,23 @@ fn make_simple_target_custom_schema(gen: &mut SchemaGenerator) -> schemars::sche schema.into() } -// - Only path is `Some` -> use current namespace. -// - Only namespace is `Some` -> this should only happen in `mirrord ls`. In `mirrord exec` -// namespace without a path does not mean anything and therefore should be prevented by returning -// an error. The error is not returned when parsing the configuration because it's not an error -// for `mirrord ls`. -// - Both are `None` -> targetless. -/// Specifies the target and namespace to mirror, see [`path`](#target-path) for a list of -/// accepted values for the `target` option. +/// Specifies the target and namespace to target. /// /// The simplified configuration supports: /// -/// - `pod/{sample-pod}/[container]/{sample-container}`; -/// - `deployment/{sample-deployment}/[container]/{sample-container}`; +/// - `targetless` +/// - `pod/{pod-name}[/container/{container-name}]`; +/// - `deployment/{deployment-name}[/container/{container-name}]`; +/// - `rollout/{rollout-name}[/container/{container-name}]`; +/// - `job/{job-name}[/container/{container-name}]`; +/// - `cronjob/{cronjob-name}[/container/{container-name}]`; +/// - `statefulset/{statefulset-name}[/container/{container-name}]`; +/// - `service/{service-name}[/container/{container-name}]`; +/// +/// Please note that: +/// +/// - `job`, `cronjob`, `statefulset` and `service` targets require the mirrord Operator +/// - `job` and `cronjob` targets require the [`copy_target`](#feature-copy_target) feature /// /// Shortened setup: /// @@ -87,34 +94,59 @@ fn make_simple_target_custom_schema(gen: &mut SchemaGenerator) -> schemars::sche /// } /// ``` /// +/// The setup above will result in a session targeting the `bear-pod` Kubernetes pod +/// in the user's default namespace. A target container will be chosen by mirrord. +/// +/// Shortened setup with target container: +/// +/// ```json +/// { +/// "target": "pod/bear-pod/container/bear-pod-container" +/// } +/// ``` +/// +/// The setup above will result in a session targeting the `bear-pod-container` container +/// in the `bear-pod` Kubernetes pod in the user's default namespace. +/// /// Complete setup: /// /// ```json /// { /// "target": { /// "path": { -/// "pod": "bear-pod" +/// "pod": "bear-pod", +/// "container": "bear-pod-container" /// }, -/// "namespace": "default" +/// "namespace": "bear-pod-namespace" /// } /// } /// ``` +/// +/// The setup above will result in a session targeting the `bear-pod-container` container +/// in the `bear-pod` Kubernetes pod in the `bear-pod-namespace` namespace. #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Hash, Debug)] #[serde(deny_unknown_fields)] pub struct TargetConfig { /// ### target.path {#target-path} /// - /// Specifies the running pod (or deployment) to mirror. + /// Specifies the Kubernetes resource to target. /// - /// Note: Deployment level steal/mirroring is available only in mirrord for Teams - /// If you use it without it, it will choose a random pod replica to work with. + /// Note: targeting services and whole workloads is available only in mirrord for Teams. + /// If you target a workload without the mirrord Operator, it will choose a random pod replica + /// to work with. /// /// Supports: - /// - `pod/{sample-pod}`; - /// - `deployment/{sample-deployment}`; - /// - `container/{sample-container}`; - /// - `containername/{sample-container}`. - /// - `job/{sample-job}` (only when [`copy_target`](#feature-copy_target) is enabled). + /// - `targetless` + /// - `pod/{pod-name}[/container/{container-name}]`; + /// - `deployment/{deployment-name}[/container/{container-name}]`; + /// - `rollout/{rollout-name}[/container/{container-name}]`; + /// - `job/{job-name}[/container/{container-name}]`; (requires mirrord Operator and the + /// [`copy_target`](#feature-copy_target) feature) + /// - `cronjob/{cronjob-name}[/container/{container-name}]`; (requires mirrord Operator and the + /// [`copy_target`](#feature-copy_target) feature) + /// - `statefulset/{statefulset-name}[/container/{container-name}]`; (requires mirrord + /// Operator) + /// - `service/{service-name}[/container/{container-name}]`; (requires mirrord Operator) #[serde(skip_serializing_if = "Option::is_none")] pub path: Option, @@ -122,7 +154,7 @@ pub struct TargetConfig { /// /// Namespace where the target lives. /// - /// Defaults to `"default"`. + /// Defaults to the Kubernetes user's default namespace (defined in Kubernetes context). #[serde(skip_serializing_if = "Option::is_none")] pub namespace: Option, } @@ -181,16 +213,18 @@ const FAIL_PARSE_DEPLOYMENT_OR_POD: &str = r#" mirrord-layer failed to parse the provided target! - Valid format: - >> deployment/[/container/container-name] - >> deploy/[/container/container-name] - >> pod/[/container/container-name] - >> job/[/container/container-name] - >> cronjob/[/container/container-name] - >> statefulset/[/container/container-name] + >> `targetless` + >> `pod/{pod-name}[/container/{container-name}]`; + >> `deployment/{deployment-name}[/container/{container-name}]`; + >> `rollout/{rollout-name}[/container/{container-name}]`; + >> `job/{job-name}[/container/{container-name}]`; + >> `cronjob/{cronjob-name}[/container/{container-name}]`; + >> `statefulset/{statefulset-name}[/container/{container-name}]`; + >> `service/{service-name}[/container/{container-name}]`; - Note: - >> specifying container name is optional, defaults to the first container in the provided pod/deployment target. - >> specifying the pod name is optional, defaults to the first pod in case the target is a deployment. + >> specifying container name is optional, defaults to a container chosen by mirrord + >> targeting a workload without the mirrord Operator results in a session targeting a random pod replica - Suggestions: >> check for typos in the provided target. @@ -204,49 +238,50 @@ mirrord-layer failed to parse the provided target! /// Specifies the running pod (or deployment) to mirror. /// /// Supports: -/// - `pod/{sample-pod}`; -/// - `deployment/{sample-deployment}`; -/// - `container/{sample-container}`; -/// - `containername/{sample-container}`. -/// - `job/{sample-job}`; -/// - `cronjob/{sample-cronjob}`; -/// - `statefulset/{sample-statefulset}`; +/// - `targetless` +/// - `pod/{pod-name}[/container/{container-name}]`; +/// - `deployment/{deployment-name}[/container/{container-name}]`; +/// - `rollout/{rollout-name}[/container/{container-name}]`; +/// - `job/{job-name}[/container/{container-name}]`; +/// - `cronjob/{cronjob-name}[/container/{container-name}]`; +/// - `statefulset/{statefulset-name}[/container/{container-name}]`; +/// - `service/{service-name}[/container/{container-name}]`; #[warn(clippy::wildcard_enum_match_arm)] #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Hash, Debug, JsonSchema)] #[serde(untagged, deny_unknown_fields)] pub enum Target { /// - /// Mirror a deployment. + /// [Deployment](https://kubernetes.io/docs/concepts/workloads/controllers/deployment/). Deployment(deployment::DeploymentTarget), /// - /// Mirror a pod. + /// [Pod](https://kubernetes.io/docs/concepts/workloads/pods/). Pod(pod::PodTarget), /// - /// Mirror a rollout. + /// [Argo Rollout](https://argoproj.github.io/argo-rollouts/#how-does-it-work). Rollout(rollout::RolloutTarget), /// - /// Mirror a Job. + /// [Job](https://kubernetes.io/docs/concepts/workloads/controllers/job/). /// /// Only supported when `copy_target` is enabled. Job(job::JobTarget), /// - /// Targets a /// [CronJob](https://kubernetes.io/docs/concepts/workloads/controllers/cron-jobs/). /// /// Only supported when `copy_target` is enabled. CronJob(cron_job::CronJobTarget), /// - /// Targets a /// [StatefulSet](https://kubernetes.io/docs/concepts/workloads/controllers/statefulset/). - /// - /// Only supported when `copy_target` is enabled. StatefulSet(stateful_set::StatefulSetTarget), + /// + /// [Service](https://kubernetes.io/docs/concepts/services-networking/service/). + Service(service::ServiceTarget), + /// /// Spawn a new pod. Targetless, @@ -269,6 +304,7 @@ impl FromStr for Target { Some("job") => job::JobTarget::from_split(&mut split).map(Target::Job), Some("cronjob") => cron_job::CronJobTarget::from_split(&mut split).map(Target::CronJob), Some("statefulset") => stateful_set::StatefulSetTarget::from_split(&mut split).map(Target::StatefulSet), + Some("service") => service::ServiceTarget::from_split(&mut split).map(Target::Service), _ => Err(ConfigError::InvalidTarget(format!( "Provided target: {target} is unsupported. Did you remember to add a prefix, e.g. pod/{target}? \n{FAIL_PARSE_DEPLOYMENT_OR_POD}", ))), @@ -286,6 +322,7 @@ impl Target { Target::Job(target) => target.job.clone(), Target::CronJob(target) => target.cron_job.clone(), Target::StatefulSet(target) => target.stateful_set.clone(), + Target::Service(target) => target.service.clone(), Target::Targetless => { unreachable!("this shouldn't happen - called from operator on a flow where it's not targetless.") } @@ -301,7 +338,7 @@ impl Target { pub(super) fn requires_operator(&self) -> bool { matches!( self, - Target::Job(_) | Target::CronJob(_) | Target::StatefulSet(_) + Target::Job(_) | Target::CronJob(_) | Target::StatefulSet(_) | Target::Service(_) ) } } @@ -361,6 +398,7 @@ impl_target_display!(RolloutTarget, rollout, "rollout"); impl_target_display!(JobTarget, job, "job"); impl_target_display!(CronJobTarget, cron_job, "cronjob"); impl_target_display!(StatefulSetTarget, stateful_set, "statefulset"); +impl_target_display!(ServiceTarget, service, "service"); impl fmt::Display for Target { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -372,6 +410,7 @@ impl fmt::Display for Target { Target::Job(target) => target.fmt(f), Target::CronJob(target) => target.fmt(f), Target::StatefulSet(target) => target.fmt(f), + Target::Service(target) => target.fmt(f), } } } @@ -386,6 +425,7 @@ impl TargetDisplay for Target { Target::Job(target) => target.type_(), Target::CronJob(target) => target.type_(), Target::StatefulSet(target) => target.type_(), + Target::Service(target) => target.type_(), } } @@ -398,6 +438,7 @@ impl TargetDisplay for Target { Target::Job(target) => target.name(), Target::CronJob(target) => target.name(), Target::StatefulSet(target) => target.name(), + Target::Service(target) => target.name(), } } @@ -410,6 +451,7 @@ impl TargetDisplay for Target { Target::Job(target) => target.container(), Target::CronJob(target) => target.container(), Target::StatefulSet(target) => target.container(), + Target::Service(target) => target.container(), } } } @@ -426,6 +468,7 @@ bitflags::bitflags! { const JOB = 32; const CRON_JOB = 64; const STATEFUL_SET = 128; + const SERVICE = 256; } } @@ -473,6 +516,12 @@ impl CollectAnalytics for &TargetConfig { flags |= TargetAnalyticFlags::CONTAINER; } } + Target::Service(target) => { + flags |= TargetAnalyticFlags::SERVICE; + if target.container.is_some() { + flags |= TargetAnalyticFlags::CONTAINER; + } + } Target::Targetless => { // Targetless is essentially 0, so no need to set any flags. } diff --git a/mirrord/config/src/target/service.rs b/mirrord/config/src/target/service.rs new file mode 100644 index 00000000000..8d41e0053fe --- /dev/null +++ b/mirrord/config/src/target/service.rs @@ -0,0 +1,36 @@ +use std::str::Split; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::{FromSplit, FAIL_PARSE_DEPLOYMENT_OR_POD}; +use crate::config::{ConfigError, Result}; + +#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Hash, Debug, JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct ServiceTarget { + pub service: String, + pub container: Option, +} + +impl FromSplit for ServiceTarget { + fn from_split(split: &mut Split) -> Result { + let service = split + .next() + .ok_or_else(|| ConfigError::InvalidTarget(FAIL_PARSE_DEPLOYMENT_OR_POD.to_string()))?; + + match (split.next(), split.next()) { + (Some("container"), Some(container)) => Ok(Self { + service: service.to_string(), + container: Some(container.to_string()), + }), + (None, None) => Ok(Self { + service: service.to_string(), + container: None, + }), + _ => Err(ConfigError::InvalidTarget( + FAIL_PARSE_DEPLOYMENT_OR_POD.to_string(), + )), + } + } +} diff --git a/mirrord/intproxy/protocol/src/lib.rs b/mirrord/intproxy/protocol/src/lib.rs index e51fcdf773e..7648f3d6cf6 100644 --- a/mirrord/intproxy/protocol/src/lib.rs +++ b/mirrord/intproxy/protocol/src/lib.rs @@ -387,6 +387,13 @@ impl_request!( res_path = ProxyToLayerMessage::File => FileResponse::XstatFs, ); +impl_request!( + req = StatFsRequest, + res = RemoteResult, + req_path = LayerToProxyMessage::File => FileRequest::StatFs, + res_path = ProxyToLayerMessage::File => FileResponse::XstatFs, +); + impl_request!( req = FdOpenDirRequest, res = RemoteResult, diff --git a/mirrord/intproxy/src/proxies/files.rs b/mirrord/intproxy/src/proxies/files.rs index 517c743ce12..55c1b2f98f0 100644 --- a/mirrord/intproxy/src/proxies/files.rs +++ b/mirrord/intproxy/src/proxies/files.rs @@ -6,7 +6,7 @@ use mirrord_protocol::{ file::{ CloseDirRequest, CloseFileRequest, DirEntryInternal, ReadDirBatchRequest, ReadDirResponse, ReadFileResponse, ReadLimitedFileRequest, SeekFromInternal, MKDIR_VERSION, - READDIR_BATCH_VERSION, READLINK_VERSION, RMDIR_VERSION, + READDIR_BATCH_VERSION, READLINK_VERSION, RMDIR_VERSION, STATFS_VERSION, }, ClientMessage, DaemonMessage, ErrorKindInternal, FileRequest, FileResponse, RemoteIOError, ResponseError, @@ -259,21 +259,27 @@ impl FilesProxy { match request { FileRequest::ReadLink(..) - if protocol_version.is_some_and(|version| !READLINK_VERSION.matches(version)) => + if protocol_version.is_none_or(|version| !READLINK_VERSION.matches(version)) => { Err(FileResponse::ReadLink(Err(ResponseError::NotImplemented))) } FileRequest::MakeDir(..) | FileRequest::MakeDirAt(..) - if protocol_version.is_some_and(|version| !MKDIR_VERSION.matches(version)) => + if protocol_version.is_none_or(|version| !MKDIR_VERSION.matches(version)) => { Err(FileResponse::MakeDir(Err(ResponseError::NotImplemented))) } FileRequest::RemoveDir(..) | FileRequest::Unlink(..) | FileRequest::UnlinkAt(..) if protocol_version - .is_some_and(|version: &Version| !RMDIR_VERSION.matches(version)) => + .is_none_or(|version: &Version| !RMDIR_VERSION.matches(version)) => { Err(FileResponse::RemoveDir(Err(ResponseError::NotImplemented))) } + FileRequest::StatFs(..) + if protocol_version + .is_none_or(|version: &Version| !STATFS_VERSION.matches(version)) => + { + Err(FileResponse::XstatFs(Err(ResponseError::NotImplemented))) + } _ => Ok(()), } } diff --git a/mirrord/kube/src/api/container/job.rs b/mirrord/kube/src/api/container/job.rs index d9958e6620b..7c3247a091f 100644 --- a/mirrord/kube/src/api/container/job.rs +++ b/mirrord/kube/src/api/container/job.rs @@ -285,7 +285,6 @@ mod test { "restartPolicy": "Never", "imagePullSecrets": agent.image_pull_secrets, "nodeSelector": {}, - "tolerations": *DEFAULT_TOLERATIONS, "serviceAccountName": agent.service_account, "containers": [ { diff --git a/mirrord/kube/src/api/container/pod.rs b/mirrord/kube/src/api/container/pod.rs index f8461e8a002..984a3fd8afa 100644 --- a/mirrord/kube/src/api/container/pod.rs +++ b/mirrord/kube/src/api/container/pod.rs @@ -20,6 +20,7 @@ use crate::api::{ runtime::RuntimeData, }; +/// The `targetless` agent variant is created by this, see its [`PodVariant::as_update`]. pub struct PodVariant<'c> { agent: &'c AgentConfig, command_line: Vec, @@ -67,8 +68,6 @@ impl ContainerVariant for PodVariant<'_> { .. } = self; - let tolerations = agent.tolerations.as_ref().unwrap_or(&DEFAULT_TOLERATIONS); - let resources = agent.resources.clone().unwrap_or_else(|| { serde_json::from_value(serde_json::json!({ "requests": @@ -124,7 +123,7 @@ impl ContainerVariant for PodVariant<'_> { spec: Some(PodSpec { restart_policy: Some("Never".to_string()), image_pull_secrets, - tolerations: Some(tolerations.clone()), + tolerations: agent.tolerations.clone(), node_selector: Some(node_selector), service_account_name: agent.service_account.clone(), containers: vec![Container { @@ -148,6 +147,10 @@ impl ContainerVariant for PodVariant<'_> { } } +/// The `targeted` agent variant is created by this. +/// +/// It builds on top of [`PodVariant`], merging spec, etc from there. See +/// [`PodTargetedVariant::as_update`]. pub struct PodTargetedVariant<'c> { inner: PodVariant<'c>, runtime_data: &'c RuntimeData, @@ -195,6 +198,8 @@ impl ContainerVariant for PodTargetedVariant<'_> { let agent = self.agent_config(); let params = self.params(); + let tolerations = agent.tolerations.as_ref().unwrap_or(&DEFAULT_TOLERATIONS); + let env = self.runtime_data.mesh.map(|mesh_vendor| { let mut env = vec![EnvVar { name: "MIRRORD_AGENT_IN_SERVICE_MESH".into(), @@ -214,6 +219,7 @@ impl ContainerVariant for PodTargetedVariant<'_> { let update = Pod { spec: Some(PodSpec { restart_policy: Some("Never".to_string()), + tolerations: Some(tolerations.clone()), host_pid: Some(true), node_name: Some(runtime_data.node_name.clone()), volumes: Some(vec![ diff --git a/mirrord/kube/src/api/container/util.rs b/mirrord/kube/src/api/container/util.rs index 23fd752181b..d40949d268c 100644 --- a/mirrord/kube/src/api/container/util.rs +++ b/mirrord/kube/src/api/container/util.rs @@ -4,7 +4,9 @@ use futures::{AsyncBufReadExt, TryStreamExt}; use k8s_openapi::api::core::v1::{EnvVar, Pod, Toleration}; use kube::{api::LogParams, Api}; use mirrord_config::agent::{AgentConfig, LinuxCapability}; -use mirrord_protocol::{AGENT_IPV6_ENV, AGENT_NETWORK_INTERFACE_ENV, AGENT_OPERATOR_CERT_ENV}; +use mirrord_protocol::{ + AGENT_IPV6_ENV, AGENT_METRICS_ENV, AGENT_NETWORK_INTERFACE_ENV, AGENT_OPERATOR_CERT_ENV, +}; use regex::Regex; use tracing::warn; @@ -59,7 +61,9 @@ pub(super) fn agent_env(agent: &AgentConfig, params: &&ContainerParams) -> Vec Vec NodeCheck { let node_api: Api = Api::all(client.clone()); let pod_api: Api = Api::all(client.clone()); @@ -271,20 +274,61 @@ where } pub trait RuntimeDataProvider { - #[allow(async_fn_in_trait)] - async fn runtime_data(&self, client: &Client, namespace: Option<&str>) -> Result; + fn runtime_data( + &self, + client: &Client, + namespace: Option<&str>, + ) -> impl Future>; } +/// Trait for resources that abstract a set of pods +/// defined by a label selector. +/// +/// Implementors are provided with an implementation of [`RuntimeDataProvider`]. +/// When resolving [`RuntimeData`], the set of pods is fetched and [`RuntimeData`] is extracted from +/// the first pod on the list. If the set is empty, resolution fails. pub trait RuntimeDataFromLabels { type Resource: Resource + Clone + DeserializeOwned + fmt::Debug; - #[allow(async_fn_in_trait)] - async fn get_selector_match_labels( + fn get_selector_match_labels(resource: &Self::Resource) -> Result>; + + /// Returns a list of pods matching the selector of the given `resource`. + fn get_pods( resource: &Self::Resource, - ) -> Result>; + client: &Client, + ) -> impl Future>> { + async { + let api: Api<::Resource> = + get_k8s_resource_api(client, resource.meta().namespace.as_deref()); + let name = resource + .meta() + .name + .as_deref() + .ok_or_else(|| KubeApiError::missing_field(resource, ".metadata.name"))?; + let resource = api.get(name).await?; + + let labels = Self::get_selector_match_labels(&resource)?; + + let formatted_labels = labels + .iter() + .map(|(key, value)| format!("{key}={value}")) + .collect::>() + .join(","); + let list_params = ListParams { + label_selector: Some(formatted_labels), + ..Default::default() + }; + + let pod_api: Api = + get_k8s_resource_api(client, resource.meta().namespace.as_deref()); + let pods = pod_api.list(&list_params).await?; + + Ok(pods.items) + } + } fn name(&self) -> Cow; @@ -299,37 +343,22 @@ where let api: Api<::Resource> = get_k8s_resource_api(client, namespace); let resource = api.get(&self.name()).await?; + let pods = Self::get_pods(&resource, client).await?; - let labels = Self::get_selector_match_labels(&resource).await?; - - let formatted_labels = labels - .iter() - .map(|(key, value)| format!("{key}={value}")) - .collect::>() - .join(","); - let list_params = ListParams { - label_selector: Some(formatted_labels), - ..Default::default() - }; - - let pod_api: Api = get_k8s_resource_api(client, namespace); - let pods = pod_api.list(&list_params).await?; - - if pods.items.is_empty() { + if pods.is_empty() { return Err(KubeApiError::invalid_state( &resource, - "no pods matching labels found", + "no pods matching the labels were found", )); } - pods.items - .iter() + pods.iter() .filter_map(|pod| RuntimeData::from_pod(pod, self.container()).ok()) .next() .ok_or_else(|| { KubeApiError::invalid_state( &resource, - "no pod matching labels is ready to be targeted", + "no pod matching the labels is ready to be targeted", ) }) } @@ -344,6 +373,7 @@ impl RuntimeDataProvider for Target { Target::Job(target) => target.runtime_data(client, namespace).await, Target::CronJob(target) => target.runtime_data(client, namespace).await, Target::StatefulSet(target) => target.runtime_data(client, namespace).await, + Target::Service(target) => target.runtime_data(client, namespace).await, Target::Targetless => Err(KubeApiError::MissingRuntimeData), } } @@ -358,6 +388,7 @@ impl RuntimeDataProvider for ResolvedTarget { Self::Job(target) => target.runtime_data(client, namespace).await, Self::CronJob(target) => target.runtime_data(client, namespace).await, Self::StatefulSet(target) => target.runtime_data(client, namespace).await, + Self::Service(target) => target.runtime_data(client, namespace).await, Self::Targetless(_) => Err(KubeApiError::MissingRuntimeData), } } @@ -365,7 +396,9 @@ impl RuntimeDataProvider for ResolvedTarget { #[cfg(test)] mod tests { - use mirrord_config::target::{deployment::DeploymentTarget, job::JobTarget, pod::PodTarget}; + use mirrord_config::target::{ + deployment::DeploymentTarget, job::JobTarget, pod::PodTarget, service::ServiceTarget, + }; use rstest::rstest; use super::*; @@ -378,6 +411,8 @@ mod tests { #[case("deployment/nginx-deployment/container/container-name", Target::Deployment(DeploymentTarget {deployment: "nginx-deployment".to_string(), container: Some("container-name".to_string())}))] #[case("job/foo", Target::Job(JobTarget { job: "foo".to_string(), container: None }))] #[case("job/foo/container/baz", Target::Job(JobTarget { job: "foo".to_string(), container: Some("baz".to_string()) }))] + #[case("service/foo", Target::Service(ServiceTarget { service: "foo".into(), container: None }))] + #[case("service/foo/container/baz", Target::Service(ServiceTarget { service: "foo".into(), container: Some("baz".into()) }))] fn target_parses(#[case] target: &str, #[case] expected: Target) { let target = target.parse::().unwrap(); assert_eq!(target, expected) diff --git a/mirrord/kube/src/api/runtime/cron_job.rs b/mirrord/kube/src/api/runtime/cron_job.rs index 31b1f09f1b2..0c4dc16e3c9 100644 --- a/mirrord/kube/src/api/runtime/cron_job.rs +++ b/mirrord/kube/src/api/runtime/cron_job.rs @@ -17,9 +17,7 @@ impl RuntimeDataFromLabels for CronJobTarget { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/kube/src/api/runtime/deployment.rs b/mirrord/kube/src/api/runtime/deployment.rs index 11c3383651a..5dab483cad5 100644 --- a/mirrord/kube/src/api/runtime/deployment.rs +++ b/mirrord/kube/src/api/runtime/deployment.rs @@ -17,9 +17,7 @@ impl RuntimeDataFromLabels for DeploymentTarget { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/kube/src/api/runtime/job.rs b/mirrord/kube/src/api/runtime/job.rs index 2bd06af576e..9d75e6afb88 100644 --- a/mirrord/kube/src/api/runtime/job.rs +++ b/mirrord/kube/src/api/runtime/job.rs @@ -17,9 +17,7 @@ impl RuntimeDataFromLabels for JobTarget { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/kube/src/api/runtime/rollout.rs b/mirrord/kube/src/api/runtime/rollout.rs index ee556117dba..d75963ce974 100644 --- a/mirrord/kube/src/api/runtime/rollout.rs +++ b/mirrord/kube/src/api/runtime/rollout.rs @@ -20,9 +20,7 @@ impl RuntimeDataFromLabels for RolloutTarget { } /// Digs into `resource` to return its `.spec.selector.matchLabels`. - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .clone() diff --git a/mirrord/kube/src/api/runtime/service.rs b/mirrord/kube/src/api/runtime/service.rs new file mode 100644 index 00000000000..4b18cdcf648 --- /dev/null +++ b/mirrord/kube/src/api/runtime/service.rs @@ -0,0 +1,27 @@ +use std::{borrow::Cow, collections::BTreeMap}; + +use k8s_openapi::api::core::v1::Service; +use mirrord_config::target::service::ServiceTarget; + +use super::RuntimeDataFromLabels; +use crate::error::{KubeApiError, Result}; + +impl RuntimeDataFromLabels for ServiceTarget { + type Resource = Service; + + fn name(&self) -> Cow { + Cow::from(&self.service) + } + + fn container(&self) -> Option<&str> { + self.container.as_deref() + } + + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { + resource + .spec + .as_ref() + .and_then(|spec| spec.selector.clone()) + .ok_or_else(|| KubeApiError::missing_field(resource, ".spec.selector")) + } +} diff --git a/mirrord/kube/src/api/runtime/stateful_set.rs b/mirrord/kube/src/api/runtime/stateful_set.rs index eae8846af5a..5bb70dbf09c 100644 --- a/mirrord/kube/src/api/runtime/stateful_set.rs +++ b/mirrord/kube/src/api/runtime/stateful_set.rs @@ -17,9 +17,7 @@ impl RuntimeDataFromLabels for StatefulSetTarget { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/kube/src/resolved.rs b/mirrord/kube/src/resolved.rs index 943198fdf9c..70ba34d7535 100644 --- a/mirrord/kube/src/resolved.rs +++ b/mirrord/kube/src/resolved.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use k8s_openapi::api::{ apps::v1::{Deployment, StatefulSet}, batch::v1::{CronJob, Job}, - core::v1::Pod, + core::v1::{Pod, Service}, }; use kube::{Client, Resource, ResourceExt}; use mirrord_config::{feature::network::incoming::ConcurrentSteal, target::Target}; @@ -20,6 +20,7 @@ pub mod deployment; pub mod job; pub mod pod; pub mod rollout; +pub mod service; pub mod stateful_set; /// Helper struct for resolving user-provided [`Target`] to Kubernetes resources. @@ -30,7 +31,7 @@ pub mod stateful_set; /// 1. A generic implementation with helper methods for getting strings such as names, types and so /// on; /// 2. `CHECKED = false` that may be used to build the struct, and to call -/// `assert_valid_mirrord_target` (along with the generic methods); +/// [`ResolvedTarget::assert_valid_mirrord_target`] (along with the generic methods); /// 3. `CHECKED = true` which is how we get a connection url for the target; #[derive(Debug, Clone)] pub enum ResolvedTarget { @@ -39,13 +40,17 @@ pub enum ResolvedTarget { Job(ResolvedResource), CronJob(ResolvedResource), StatefulSet(ResolvedResource), + Service(ResolvedResource), /// [`Pod`] is a special case, in that it does not implement [`RuntimeDataFromLabels`], /// and instead we implement a `runtime_data` method directly in its /// [`ResolvedResource`] impl. Pod(ResolvedResource), - /// Holds the `namespace` for this target. - Targetless(String), + + Targetless( + /// Agent pod's namespace. + String, + ), } /// A kubernetes [`Resource`], and container pair to be used based on the target we @@ -84,6 +89,9 @@ impl ResolvedTarget { ResolvedTarget::StatefulSet(ResolvedResource { resource, .. }) => { resource.metadata.name.as_deref() } + ResolvedTarget::Service(ResolvedResource { resource, .. }) => { + resource.metadata.name.as_deref() + } ResolvedTarget::Targetless(_) => None, } } @@ -96,6 +104,7 @@ impl ResolvedTarget { ResolvedTarget::Job(ResolvedResource { resource, .. }) => resource.name_any(), ResolvedTarget::CronJob(ResolvedResource { resource, .. }) => resource.name_any(), ResolvedTarget::StatefulSet(ResolvedResource { resource, .. }) => resource.name_any(), + ResolvedTarget::Service(ResolvedResource { resource, .. }) => resource.name_any(), ResolvedTarget::Targetless(..) => "targetless".to_string(), } } @@ -120,6 +129,9 @@ impl ResolvedTarget { ResolvedTarget::StatefulSet(ResolvedResource { resource, .. }) => { resource.metadata.namespace.as_deref() } + ResolvedTarget::Service(ResolvedResource { resource, .. }) => { + resource.metadata.namespace.as_deref() + } ResolvedTarget::Targetless(namespace) => Some(namespace), } } @@ -137,6 +149,7 @@ impl ResolvedTarget { ResolvedTarget::StatefulSet(ResolvedResource { resource, .. }) => { resource.metadata.labels } + ResolvedTarget::Service(ResolvedResource { resource, .. }) => resource.metadata.labels, ResolvedTarget::Targetless(_) => None, } } @@ -149,22 +162,11 @@ impl ResolvedTarget { ResolvedTarget::Job(_) => "job", ResolvedTarget::CronJob(_) => "cronjob", ResolvedTarget::StatefulSet(_) => "statefulset", + ResolvedTarget::Service(_) => "service", ResolvedTarget::Targetless(_) => "targetless", } } - pub fn get_container(&self) -> Option<&str> { - match self { - ResolvedTarget::Deployment(ResolvedResource { container, .. }) - | ResolvedTarget::Rollout(ResolvedResource { container, .. }) - | ResolvedTarget::Job(ResolvedResource { container, .. }) - | ResolvedTarget::CronJob(ResolvedResource { container, .. }) - | ResolvedTarget::StatefulSet(ResolvedResource { container, .. }) - | ResolvedTarget::Pod(ResolvedResource { container, .. }) => container.as_deref(), - ResolvedTarget::Targetless(..) => None, - } - } - /// Convenient way of getting the container from this target. pub fn container(&self) -> Option<&str> { match self { @@ -173,6 +175,7 @@ impl ResolvedTarget { | ResolvedTarget::Job(ResolvedResource { container, .. }) | ResolvedTarget::CronJob(ResolvedResource { container, .. }) | ResolvedTarget::StatefulSet(ResolvedResource { container, .. }) + | ResolvedTarget::Service(ResolvedResource { container, .. }) | ResolvedTarget::Pod(ResolvedResource { container, .. }) => container.as_deref(), ResolvedTarget::Targetless(..) => None, } @@ -190,45 +193,6 @@ impl ResolvedTarget { false } } - - /// Returns the number of containers for this [`ResolvedTarget`], defaulting to 1. - pub fn containers_status(&self) -> usize { - match self { - ResolvedTarget::Deployment(ResolvedResource { resource, .. }) => resource - .spec - .as_ref() - .and_then(|spec| spec.template.spec.as_ref()) - .map(|pod_spec| pod_spec.containers.len()), - ResolvedTarget::Rollout(ResolvedResource { resource, .. }) => resource - .spec - .as_ref() - .and_then(|spec| spec.template.as_ref()) - .and_then(|pod_template| pod_template.spec.as_ref()) - .map(|pod_spec| pod_spec.containers.len()), - ResolvedTarget::StatefulSet(ResolvedResource { resource, .. }) => resource - .spec - .as_ref() - .and_then(|spec| spec.template.spec.as_ref()) - .map(|pod_spec| pod_spec.containers.len()), - ResolvedTarget::CronJob(ResolvedResource { resource, .. }) => resource - .spec - .as_ref() - .and_then(|spec| spec.job_template.spec.as_ref()) - .and_then(|job_spec| job_spec.template.spec.as_ref()) - .map(|pod_spec| pod_spec.containers.len()), - ResolvedTarget::Job(ResolvedResource { resource, .. }) => resource - .spec - .as_ref() - .and_then(|spec| spec.template.spec.as_ref()) - .map(|pod_spec| pod_spec.containers.len()), - ResolvedTarget::Pod(ResolvedResource { resource, .. }) => resource - .spec - .as_ref() - .map(|pod_spec| pod_spec.containers.len()), - ResolvedTarget::Targetless(..) => Some(1), - } - .unwrap_or(1) - } } impl ResolvedTarget { @@ -295,6 +259,15 @@ impl ResolvedTarget { container: target.container.clone(), }) }), + Target::Service(target) => get_k8s_resource_api::(client, namespace) + .get(&target.service) + .await + .map(|resource| { + ResolvedTarget::Service(ResolvedResource { + resource, + container: target.container.clone(), + }) + }), Target::Targetless => Ok(ResolvedTarget::Targetless( namespace.unwrap_or("default").to_string(), )), @@ -303,13 +276,20 @@ impl ResolvedTarget { Ok(target) } - /// Check if the target can be used as a mirrord target. + /// Checks if the target can be used via the mirrord Operator. + /// + /// This is implemented in the CLI only to improve the UX (skip roundtrip to the operator). /// - /// 1. [`ResolvedTarget::Deployment`] or [`ResolvedTarget::Rollout`] - has available replicas - /// and the target container, if specified, is found in the spec + /// Performs only basic checks: + /// 1. [`ResolvedTarget::Deployment`], [`ResolvedTarget::Rollout`], + /// [`ResolvedTarget::StatefulSet`] - has available replicas and the target container, if + /// specified, is found in the spec /// 2. [`ResolvedTarget::Pod`] - passes target-readiness check, see [`RuntimeData::from_pod`]. - /// 3. [`ResolvedTarget::Job`] - error, as this is `copy_target` exclusive - /// 4. [`ResolvedTarget::Targetless`] - no check + /// 3. [`ResolvedTarget::Job`] and [`ResolvedTarget::CronJob`] - error, as this is `copy_target` + /// exclusive + /// 4. [`ResolvedTarget::Targetless`] - no check (not applicable) + /// 5. [`ResolvedTarget::Service`] - has available replicas and the target container, if + /// specified, is found in at least one of them #[tracing::instrument(level = Level::DEBUG, skip(client), ret, err)] pub async fn assert_valid_mirrord_target( self, @@ -355,6 +335,7 @@ impl ResolvedTarget { container, })) } + ResolvedTarget::Pod(ResolvedResource { resource, container, @@ -404,9 +385,11 @@ impl ResolvedTarget { ResolvedTarget::Job(..) => { return Err(KubeApiError::requires_copy::()); } + ResolvedTarget::CronJob(..) => { return Err(KubeApiError::requires_copy::()); } + ResolvedTarget::StatefulSet(ResolvedResource { resource, container, @@ -447,6 +430,39 @@ impl ResolvedTarget { })) } + ResolvedTarget::Service(ResolvedResource { + resource, + container, + }) => { + let pods = ResolvedResource::::get_pods(&resource, client).await?; + + if pods.is_empty() { + return Err(KubeApiError::invalid_state( + &resource, + "no pods matching the labels were found", + )); + } + + if let Some(container) = &container { + let exists_in_a_pod = pods + .iter() + .flat_map(|pod| pod.spec.as_ref()) + .flat_map(|spec| &spec.containers) + .any(|found_container| found_container.name == *container); + if !exists_in_a_pod { + return Err(KubeApiError::invalid_state( + &resource, + format_args!("none of the pods that match the labels contain the target container `{container}`" + ))); + } + } + + Ok(ResolvedTarget::Service(ResolvedResource { + resource, + container, + })) + } + ResolvedTarget::Targetless(namespace) => { // no check needed here Ok(ResolvedTarget::Targetless(namespace)) diff --git a/mirrord/kube/src/resolved/cron_job.rs b/mirrord/kube/src/resolved/cron_job.rs index 9fcf4fd5027..60ecabc862e 100644 --- a/mirrord/kube/src/resolved/cron_job.rs +++ b/mirrord/kube/src/resolved/cron_job.rs @@ -24,9 +24,7 @@ impl RuntimeDataFromLabels for ResolvedResource { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/kube/src/resolved/deployment.rs b/mirrord/kube/src/resolved/deployment.rs index af06682c4f2..f85334a3d71 100644 --- a/mirrord/kube/src/resolved/deployment.rs +++ b/mirrord/kube/src/resolved/deployment.rs @@ -21,7 +21,7 @@ impl RuntimeDataFromLabels for ResolvedResource { self.container.as_deref() } - async fn get_selector_match_labels( + fn get_selector_match_labels( resource: &Self::Resource, ) -> Result, KubeApiError> { resource diff --git a/mirrord/kube/src/resolved/job.rs b/mirrord/kube/src/resolved/job.rs index 039d3afec30..28555431d0c 100644 --- a/mirrord/kube/src/resolved/job.rs +++ b/mirrord/kube/src/resolved/job.rs @@ -24,9 +24,7 @@ impl RuntimeDataFromLabels for ResolvedResource { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/kube/src/resolved/rollout.rs b/mirrord/kube/src/resolved/rollout.rs index 4f7ab615c8f..5411b02d5a3 100644 --- a/mirrord/kube/src/resolved/rollout.rs +++ b/mirrord/kube/src/resolved/rollout.rs @@ -22,9 +22,7 @@ impl RuntimeDataFromLabels for ResolvedResource { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/kube/src/resolved/service.rs b/mirrord/kube/src/resolved/service.rs new file mode 100644 index 00000000000..1e7427de74f --- /dev/null +++ b/mirrord/kube/src/resolved/service.rs @@ -0,0 +1,31 @@ +use std::{borrow::Cow, collections::BTreeMap}; + +use k8s_openapi::api::core::v1::Service; + +use super::{ResolvedResource, RuntimeDataFromLabels}; +use crate::error::{KubeApiError, Result}; + +impl RuntimeDataFromLabels for ResolvedResource { + type Resource = Service; + + fn name(&self) -> Cow { + self.resource + .metadata + .name + .as_ref() + .map(Cow::from) + .unwrap_or_default() + } + + fn container(&self) -> Option<&str> { + self.container.as_deref() + } + + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { + resource + .spec + .as_ref() + .and_then(|spec| spec.selector.clone()) + .ok_or_else(|| KubeApiError::missing_field(resource, ".spec.selector")) + } +} diff --git a/mirrord/kube/src/resolved/stateful_set.rs b/mirrord/kube/src/resolved/stateful_set.rs index ccc0edeb7a1..aebe74ba317 100644 --- a/mirrord/kube/src/resolved/stateful_set.rs +++ b/mirrord/kube/src/resolved/stateful_set.rs @@ -21,9 +21,7 @@ impl RuntimeDataFromLabels for ResolvedResource { self.container.as_deref() } - async fn get_selector_match_labels( - resource: &Self::Resource, - ) -> Result> { + fn get_selector_match_labels(resource: &Self::Resource) -> Result> { resource .spec .as_ref() diff --git a/mirrord/layer/src/file/hooks.rs b/mirrord/layer/src/file/hooks.rs index 7c46165b37a..3fcfc3c1280 100644 --- a/mirrord/layer/src/file/hooks.rs +++ b/mirrord/layer/src/file/hooks.rs @@ -904,8 +904,9 @@ unsafe extern "C" fn fstatat_detour( }) } +/// Hook for `libc::fstatfs`. #[hook_guard_fn] -unsafe extern "C" fn fstatfs_detour(fd: c_int, out_stat: *mut statfs) -> c_int { +pub(crate) unsafe extern "C" fn fstatfs_detour(fd: c_int, out_stat: *mut statfs) -> c_int { if out_stat.is_null() { return HookError::BadPointer.into(); } @@ -919,6 +920,25 @@ unsafe extern "C" fn fstatfs_detour(fd: c_int, out_stat: *mut statfs) -> c_int { .unwrap_or_bypass_with(|_| FN_FSTATFS(fd, out_stat)) } +/// Hook for `libc::statfs`. +#[hook_guard_fn] +pub(crate) unsafe extern "C" fn statfs_detour( + raw_path: *const c_char, + out_stat: *mut statfs, +) -> c_int { + if out_stat.is_null() { + return HookError::BadPointer.into(); + } + + crate::file::ops::statfs(raw_path.checked_into()) + .map(|res| { + let res = res.metadata; + fill_statfs(out_stat, &res); + 0 + }) + .unwrap_or_bypass_with(|_| FN_STATFS(raw_path, out_stat)) +} + unsafe fn realpath_logic( source_path: *const c_char, output_path: *mut c_char, @@ -1333,6 +1353,8 @@ pub(crate) unsafe fn enable_file_hooks(hook_manager: &mut HookManager) { FnFstatfs, FN_FSTATFS ); + replace!(hook_manager, "statfs", statfs_detour, FnStatfs, FN_STATFS); + replace!( hook_manager, "fdopendir", @@ -1415,6 +1437,13 @@ pub(crate) unsafe fn enable_file_hooks(hook_manager: &mut HookManager) { FnFstatfs, FN_FSTATFS ); + replace!( + hook_manager, + "statfs$INODE64", + statfs_detour, + FnStatfs, + FN_STATFS + ); replace!( hook_manager, "fdopendir$INODE64", diff --git a/mirrord/layer/src/file/ops.rs b/mirrord/layer/src/file/ops.rs index bc6bb63670c..c47f7cad0d4 100644 --- a/mirrord/layer/src/file/ops.rs +++ b/mirrord/layer/src/file/ops.rs @@ -9,7 +9,8 @@ use mirrord_protocol::{ file::{ MakeDirAtRequest, MakeDirRequest, OpenFileRequest, OpenFileResponse, OpenOptionsInternal, ReadFileResponse, ReadLinkFileRequest, ReadLinkFileResponse, RemoveDirRequest, - SeekFileResponse, UnlinkAtRequest, WriteFileResponse, XstatFsResponse, XstatResponse, + SeekFileResponse, StatFsRequest, UnlinkAtRequest, WriteFileResponse, XstatFsResponse, + XstatResponse, }, ResponseError, }; @@ -738,6 +739,16 @@ pub(crate) fn xstatfs(fd: RawFd) -> Detour { Detour::Success(response) } +#[mirrord_layer_macro::instrument(level = "trace")] +pub(crate) fn statfs(path: Detour) -> Detour { + let path = path?; + let lstatfs = StatFsRequest { path }; + + let response = common::make_proxy_request_with_response(lstatfs)??; + + Detour::Success(response) +} + #[cfg(target_os = "linux")] #[mirrord_layer_macro::instrument(level = "trace")] pub(crate) fn getdents64(fd: RawFd, buffer_size: u64) -> Detour { diff --git a/mirrord/layer/src/go/linux_x64.rs b/mirrord/layer/src/go/linux_x64.rs index 18b36700cbe..622a24383d7 100644 --- a/mirrord/layer/src/go/linux_x64.rs +++ b/mirrord/layer/src/go/linux_x64.rs @@ -340,6 +340,8 @@ unsafe extern "C" fn c_abi_syscall_handler( faccessat_detour(param1 as _, param2 as _, param3 as _, 0) as i64 } libc::SYS_fstat => fstat_detour(param1 as _, param2 as _) as i64, + libc::SYS_statfs => statfs_detour(param1 as _, param2 as _) as i64, + libc::SYS_fstatfs => fstatfs_detour(param1 as _, param2 as _) as i64, libc::SYS_getdents64 => getdents64_detour(param1 as _, param2 as _, param3 as _) as i64, #[cfg(all(target_os = "linux", not(target_arch = "aarch64")))] libc::SYS_mkdir => mkdir_detour(param1 as _, param2 as _) as i64, diff --git a/mirrord/layer/src/go/mod.rs b/mirrord/layer/src/go/mod.rs index df810bbdcf9..6a28c3ebfd9 100644 --- a/mirrord/layer/src/go/mod.rs +++ b/mirrord/layer/src/go/mod.rs @@ -101,6 +101,8 @@ unsafe extern "C" fn c_abi_syscall6_handler( .into() } libc::SYS_fstat => fstat_detour(param1 as _, param2 as _) as i64, + libc::SYS_statfs => statfs_detour(param1 as _, param2 as _) as i64, + libc::SYS_fstatfs => fstatfs_detour(param1 as _, param2 as _) as i64, libc::SYS_fsync => fsync_detour(param1 as _) as i64, libc::SYS_fdatasync => fsync_detour(param1 as _) as i64, libc::SYS_openat => { diff --git a/mirrord/layer/tests/apps/fileops/go/main.go b/mirrord/layer/tests/apps/fileops/go/main.go index 5973e013d5e..69db336fb3e 100644 --- a/mirrord/layer/tests/apps/fileops/go/main.go +++ b/mirrord/layer/tests/apps/fileops/go/main.go @@ -7,10 +7,21 @@ import ( func main() { tempFile := "/tmp/test_file.txt" - syscall.Open(tempFile, syscall.O_CREAT|syscall.O_WRONLY, 0644) + fd, _ := syscall.Open(tempFile, syscall.O_CREAT|syscall.O_WRONLY, 0644) var stat syscall.Stat_t err := syscall.Stat(tempFile, &stat) if err != nil { panic(err) } + + var statfs syscall.Statfs_t + err = syscall.Statfs(tempFile, &statfs) + if err != nil { + panic(err) + } + + err = syscall.Fstatfs(fd, &statfs) + if err != nil { + panic(err) + } } diff --git a/mirrord/layer/tests/apps/statfs_fstatfs/statfs_fstatfs.c b/mirrord/layer/tests/apps/statfs_fstatfs/statfs_fstatfs.c new file mode 100644 index 00000000000..6c474cce4e3 --- /dev/null +++ b/mirrord/layer/tests/apps/statfs_fstatfs/statfs_fstatfs.c @@ -0,0 +1,52 @@ +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) && defined(__MACH__) +#include +#include +#else +#include +#endif + +/// Test `statfs / fstatfs`. +/// +/// Gets information about a mounted filesystem +/// +int main() +{ + char *tmp_test_path = "/statfs_fstatfs_test_path"; + mkdir(tmp_test_path, 0777); + + // statfs + struct statfs statfs_buf; + if (statfs(tmp_test_path, &statfs_buf) == -1) + { + perror("statfs failed"); + return EXIT_FAILURE; + } + + // fstatfs + int fd = open(tmp_test_path, O_RDONLY); + + if (fd == -1) + { + perror("Error opening tmp_test_path"); + return 1; + } + + struct statfs fstatfs_buf; + if (fstatfs(fd, &fstatfs_buf) == -1) + { + perror("fstatfs failed"); + close(fd); + return EXIT_FAILURE; + } + + close(fd); + return 0; +} diff --git a/mirrord/layer/tests/common/mod.rs b/mirrord/layer/tests/common/mod.rs index 207f1c6c7a6..a4652433768 100644 --- a/mirrord/layer/tests/common/mod.rs +++ b/mirrord/layer/tests/common/mod.rs @@ -17,7 +17,7 @@ use mirrord_intproxy::{agent_conn::AgentConnection, IntProxy}; use mirrord_protocol::{ file::{ AccessFileRequest, AccessFileResponse, OpenFileRequest, OpenOptionsInternal, - ReadFileRequest, SeekFromInternal, XstatRequest, XstatResponse, + ReadFileRequest, SeekFromInternal, XstatFsResponse, XstatRequest, XstatResponse, }, tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, ClientMessage, DaemonCodec, DaemonMessage, FileRequest, FileResponse, @@ -489,6 +489,48 @@ impl TestIntProxy { .unwrap(); } + /// Makes a [`FileRequest::Statefs`] and answers it. + pub async fn expect_statfs(&mut self, expected_path: &str) { + // Expecting `statfs` call with path. + assert_matches!( + self.recv().await, + ClientMessage::FileRequest(FileRequest::StatFs( + mirrord_protocol::file::StatFsRequest { path } + )) if path.to_str().unwrap() == expected_path + ); + + // Answer `statfs`. + self.codec + .send(DaemonMessage::File(FileResponse::XstatFs(Ok( + XstatFsResponse { + metadata: Default::default(), + }, + )))) + .await + .unwrap(); + } + + /// Makes a [`FileRequest::Xstatefs`] and answers it. + pub async fn expect_fstatfs(&mut self, expected_fd: u64) { + // Expecting `fstatfs` call with path. + assert_matches!( + self.recv().await, + ClientMessage::FileRequest(FileRequest::XstatFs( + mirrord_protocol::file::XstatFsRequest { fd } + )) if expected_fd == fd + ); + + // Answer `fstatfs`. + self.codec + .send(DaemonMessage::File(FileResponse::XstatFs(Ok( + XstatFsResponse { + metadata: Default::default(), + }, + )))) + .await + .unwrap(); + } + /// Makes a [`FileRequest::RemoveDir`] and answers it. pub async fn expect_remove_dir(&mut self, expected_dir_name: &str) { // Expecting `rmdir` call with path. @@ -784,6 +826,7 @@ pub enum Application { Fork, ReadLink, MakeDir, + StatfsFstatfs, RemoveDir, OpenFile, CIssue2055, @@ -841,6 +884,7 @@ impl Application { Application::Fork => String::from("tests/apps/fork/out.c_test_app"), Application::ReadLink => String::from("tests/apps/readlink/out.c_test_app"), Application::MakeDir => String::from("tests/apps/mkdir/out.c_test_app"), + Application::StatfsFstatfs => String::from("tests/apps/statfs_fstatfs/out.c_test_app"), Application::RemoveDir => String::from("tests/apps/rmdir/out.c_test_app"), Application::Realpath => String::from("tests/apps/realpath/out.c_test_app"), Application::NodeHTTP | Application::NodeIssue2283 | Application::NodeIssue2807 => { @@ -1080,6 +1124,7 @@ impl Application { | Application::Fork | Application::ReadLink | Application::MakeDir + | Application::StatfsFstatfs | Application::RemoveDir | Application::Realpath | Application::RustFileOps @@ -1159,6 +1204,7 @@ impl Application { | Application::Fork | Application::ReadLink | Application::MakeDir + | Application::StatfsFstatfs | Application::RemoveDir | Application::Realpath | Application::Go21Issue834 diff --git a/mirrord/layer/tests/fileops.rs b/mirrord/layer/tests/fileops.rs index 5daacdadc50..de26b318f40 100644 --- a/mirrord/layer/tests/fileops.rs +++ b/mirrord/layer/tests/fileops.rs @@ -345,6 +345,9 @@ async fn go_stat( )))) .await; + intproxy.expect_statfs("/tmp/test_file.txt").await; + intproxy.expect_fstatfs(fd).await; + test_process.wait_assert_success().await; test_process.assert_no_error_in_stderr().await; } diff --git a/mirrord/layer/tests/statfs_fstatfs.rs b/mirrord/layer/tests/statfs_fstatfs.rs new file mode 100644 index 00000000000..38f48c8495f --- /dev/null +++ b/mirrord/layer/tests/statfs_fstatfs.rs @@ -0,0 +1,45 @@ +#![feature(assert_matches)] +use std::{path::Path, time::Duration}; + +use rstest::rstest; + +mod common; +pub use common::*; + +/// Test for the [`libc::statfs`] and [`libc::fstatfs`] functions. +#[rstest] +#[tokio::test] +#[timeout(Duration::from_secs(60))] +async fn mkdir(dylib_path: &Path) { + let application = Application::StatfsFstatfs; + + let (mut test_process, mut intproxy) = application + .start_process_with_layer(dylib_path, Default::default(), None) + .await; + + println!("waiting for file request (mkdir)."); + intproxy + .expect_make_dir("/statfs_fstatfs_test_path", 0o777) + .await; + + println!("waiting for file request (statfs)."); + intproxy.expect_statfs("/statfs_fstatfs_test_path").await; + + println!("waiting for file request (open)."); + let fd: u64 = 1; + intproxy + .expect_file_open_for_reading("/statfs_fstatfs_test_path", fd) + .await; + + println!("waiting for file request (fstatfs)."); + intproxy.expect_fstatfs(fd).await; + + println!("waiting for file request (close)."); + intproxy.expect_file_close(fd).await; + + assert_eq!(intproxy.try_recv().await, None); + + test_process.wait_assert_success().await; + test_process.assert_no_error_in_stderr().await; + test_process.assert_no_error_in_stdout().await; +} diff --git a/mirrord/operator/src/client.rs b/mirrord/operator/src/client.rs index e6f9230c492..26f3d66e5b1 100644 --- a/mirrord/operator/src/client.rs +++ b/mirrord/operator/src/client.rs @@ -614,6 +614,8 @@ impl OperatorApi { // `targetless` has no `RuntimeData`! if matches!(target, ResolvedTarget::Targetless(_)).not() { + // Extracting runtime data asserts that the user can see at least one pod from the + // workload/service targets. let runtime_data = target .runtime_data(self.client(), target.namespace()) .await?; diff --git a/mirrord/operator/src/crd.rs b/mirrord/operator/src/crd.rs index b6c3ddd334b..efae6aa6751 100644 --- a/mirrord/operator/src/crd.rs +++ b/mirrord/operator/src/crd.rs @@ -63,6 +63,7 @@ impl TargetCrd { Target::Job(target) => ("job", &target.job, &target.container), Target::CronJob(target) => ("cronjob", &target.cron_job, &target.container), Target::StatefulSet(target) => ("statefulset", &target.stateful_set, &target.container), + Target::Service(target) => ("service", &target.service, &target.container), Target::Targetless => return TARGETLESS_TARGET_NAME.to_string(), }; diff --git a/mirrord/operator/src/setup.rs b/mirrord/operator/src/setup.rs index 18edfa63c45..80c09c0b36c 100644 --- a/mirrord/operator/src/setup.rs +++ b/mirrord/operator/src/setup.rs @@ -544,6 +544,7 @@ impl OperatorClusterRole { "cronjobs".to_owned(), "statefulsets".to_owned(), "statefulsets/scale".to_owned(), + "services".to_owned(), ]), verbs: vec!["get".to_owned(), "list".to_owned(), "watch".to_owned()], ..Default::default() diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index 7daa0201505..7b491c83e85 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mirrord-protocol" -version = "1.15.1" +version = "1.16.1" authors.workspace = true description.workspace = true documentation.workspace = true diff --git a/mirrord/protocol/src/codec.rs b/mirrord/protocol/src/codec.rs index 39961b3bce2..8e41d9acab1 100644 --- a/mirrord/protocol/src/codec.rs +++ b/mirrord/protocol/src/codec.rs @@ -94,6 +94,7 @@ pub enum FileRequest { RemoveDir(RemoveDirRequest), Unlink(UnlinkRequest), UnlinkAt(UnlinkAtRequest), + StatFs(StatFsRequest), } /// Minimal mirrord-protocol version that allows `ClientMessage::ReadyForLogs` message. @@ -104,9 +105,27 @@ pub static CLIENT_READY_FOR_LOGS: LazyLock = #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum ClientMessage { Close, + /// TCP sniffer message. + /// + /// These are the messages used by the `mirror` feature, and handled by the + /// `TcpSnifferApi` in the agent. Tcp(LayerTcp), + + /// TCP stealer message. + /// + /// These are the messages used by the `steal` feature, and handled by the `TcpStealerApi` in + /// the agent. TcpSteal(LayerTcpSteal), + /// TCP outgoing message. + /// + /// These are the messages used by the `outgoing` feature (tcp), and handled by the + /// `TcpOutgoingApi` in the agent. TcpOutgoing(LayerTcpOutgoing), + + /// UDP outgoing message. + /// + /// These are the messages used by the `outgoing` feature (udp), and handled by the + /// `UdpOutgoingApi` in the agent. UdpOutgoing(LayerUdpOutgoing), FileRequest(FileRequest), GetEnvVarsRequest(GetEnvVarsRequest), diff --git a/mirrord/protocol/src/error.rs b/mirrord/protocol/src/error.rs index f77ac610d91..1047c9059cd 100644 --- a/mirrord/protocol/src/error.rs +++ b/mirrord/protocol/src/error.rs @@ -44,7 +44,7 @@ pub enum ResponseError { #[error("Remote operation expected fd `{0}` to be a file, but it's a directory!")] NotFile(u64), - #[error("IO failed for remote operation with `{0}!")] + #[error("IO failed for remote operation: `{0}!")] RemoteIO(RemoteIOError), #[error(transparent)] diff --git a/mirrord/protocol/src/file.rs b/mirrord/protocol/src/file.rs index 9a8622731fb..4aa25069bb3 100644 --- a/mirrord/protocol/src/file.rs +++ b/mirrord/protocol/src/file.rs @@ -34,6 +34,9 @@ pub static RMDIR_VERSION: LazyLock = pub static OPEN_LOCAL_VERSION: LazyLock = LazyLock::new(|| ">=1.13.3".parse().expect("Bad Identifier")); +pub static STATFS_VERSION: LazyLock = + LazyLock::new(|| ">=1.16.0".parse().expect("Bad Identifier")); + /// Internal version of Metadata across operating system (macOS, Linux) /// Only mutual attributes #[derive(Encode, Decode, Debug, PartialEq, Clone, Copy, Eq, Default)] @@ -413,6 +416,11 @@ pub struct XstatFsRequest { pub fd: u64, } +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct StatFsRequest { + pub path: PathBuf, +} + #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct XstatResponse { pub metadata: MetadataInternal, diff --git a/mirrord/protocol/src/lib.rs b/mirrord/protocol/src/lib.rs index 983fcd3536b..f1a3cc1e5cc 100644 --- a/mirrord/protocol/src/lib.rs +++ b/mirrord/protocol/src/lib.rs @@ -112,4 +112,6 @@ pub const AGENT_OPERATOR_CERT_ENV: &str = "MIRRORD_AGENT_OPERATOR_CERT"; pub const AGENT_NETWORK_INTERFACE_ENV: &str = "MIRRORD_AGENT_INTERFACE"; +pub const AGENT_METRICS_ENV: &str = "MIRRORD_AGENT_METRICS"; + pub const AGENT_IPV6_ENV: &str = "MIRRORD_AGENT_SUPPORT_IPV6"; diff --git a/mirrord/protocol/src/outgoing/tcp.rs b/mirrord/protocol/src/outgoing/tcp.rs index e38fa0c44d0..877e0d2f6c0 100644 --- a/mirrord/protocol/src/outgoing/tcp.rs +++ b/mirrord/protocol/src/outgoing/tcp.rs @@ -3,14 +3,43 @@ use crate::RemoteResult; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerTcpOutgoing { + /// User is interested in connecting via tcp to some remote address, specified in + /// [`LayerConnect`]. + /// + /// The layer will get a mirrord managed address that it'll `connect` to, meanwhile + /// in the agent we `connect` to the actual remote address. Connect(LayerConnect), + + /// Write data to the remote address the agent is `connect`ed to. + /// + /// There's no `Read` message, as we're calling `read` in the agent, and we send + /// a [`DaemonTcpOutgoing::Read`] message in case we get some data from this connection. Write(LayerWrite), + + /// The layer closed the connection, this message syncs up the agent, closing it + /// over there as well. + /// + /// Connections in the agent may be closed in other ways, such as when an error happens + /// when reading or writing. Which means that this message is not the only way of + /// closing outgoing tcp connections. Close(LayerClose), } #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum DaemonTcpOutgoing { + /// The agent attempted a connection to the remote address specified by + /// [`LayerTcpOutgoing::Connect`], and it might've been successful or not. Connect(RemoteResult), + + /// Read data from the connection. + /// + /// There's no `Write` message, as `write`s come from the user (layer). The agent sending + /// a `write` to the layer like this would make no sense, since it could just `write` it + /// to the remote connection itself. Read(RemoteResult), + + /// Tell the layer that this connection has been `close`d, either by a request from + /// the user with [`LayerTcpOutgoing::Close`], or from some error in the agent when + /// writing or reading from the connection. Close(ConnectionId), } diff --git a/mirrord/protocol/src/outgoing/udp.rs b/mirrord/protocol/src/outgoing/udp.rs index 02b4d97f830..f58378beeea 100644 --- a/mirrord/protocol/src/outgoing/udp.rs +++ b/mirrord/protocol/src/outgoing/udp.rs @@ -3,14 +3,50 @@ use crate::RemoteResult; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerUdpOutgoing { + /// User is interested in connecting via udp to some remote address, specified in + /// [`LayerConnect`]. + /// + /// The layer will get a mirrord managed address that it'll `connect` to, meanwhile + /// in the agent we `connect` to the actual remote address. + /// + /// Saying that we have an _udp connection_ is a bit weird, considering it's a + /// _connectionless_ protocol, but in mirrord we use a _fakeish_ connection mechanism + /// when dealing with outgoing udp traffic. Connect(LayerConnect), + + /// Write data to the remote address the agent is `connect`ed to. + /// + /// There's no `Read` message, as we're calling `read` in the agent, and we send + /// a [`DaemonUdpOutgoing::Read`] message in case we get some data from this connection. Write(LayerWrite), + + /// The layer closed the connection, this message syncs up the agent, closing it + /// over there as well. + /// + /// Connections in the agent may be closed in other ways, such as when an error happens + /// when reading or writing. Which means that this message is not the only way of + /// closing outgoing udp connections. Close(LayerClose), } #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum DaemonUdpOutgoing { + /// The agent attempted a connection to the remote address specified by + /// [`LayerUdpOutgoing::Connect`], and it might've been successful or not. + /// + /// See the docs for [`LayerUdpOutgoing::Connect`] for a bit more information on the + /// weird idea of `connect` and udp in mirrord. Connect(RemoteResult), + + /// Read data from the connection. + /// + /// There's no `Write` message, as `write`s come from the user (layer). The agent sending + /// a `write` to the layer like this would make no sense, since it could just `write` it + /// to the remote connection itself. Read(RemoteResult), + + /// Tell the layer that this connection has been `close`d, either by a request from + /// the user with [`LayerUdpOutgoing::Close`], or from some error in the agent when + /// writing or reading from the connection. Close(ConnectionId), } diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index acf3d734121..e98077a62ec 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -52,14 +52,31 @@ pub struct TcpClose { } /// Messages related to Tcp handler from client. +/// +/// Part of the `mirror` feature. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerTcp { + /// User is interested in mirroring traffic on this `Port`, so add it to the list of + /// ports that the sniffer is filtering. PortSubscribe(Port), + + /// User is not interested in the connection with `ConnectionId` anymore. + /// + /// This means that their app has closed the connection they were `listen`ning on. + /// + /// There is no `ConnectionSubscribe` counter-part of this variant, the subscription + /// happens when the sniffer receives an (agent) internal `SniffedConnection`. ConnectionUnsubscribe(ConnectionId), + + /// Removes this `Port` from the sniffer's filter, the traffic won't be cloned to mirrord + /// anymore. PortUnsubscribe(Port), } /// Messages related to Tcp handler from server. +/// +/// They are the same for both `steal` and `mirror` modes, even though their layer +/// counterparts ([`LayerTcpSteal`] and [`LayerTcp`]) are different. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum DaemonTcp { NewConnection(NewTcpConnection), @@ -214,10 +231,38 @@ impl StealType { } /// Messages related to Steal Tcp handler from client. +/// +/// `PortSubscribe`, `PortUnsubscribe`, and `ConnectionUnsubscribe` variants are similar +/// to what you'll find in the [`LayerTcp`], but they're handled by different tasks in +/// the agent. +/// +/// Stolen traffic might have an additional overhead when compared to mirrored traffic, as +/// we have an intermmediate HTTP server to handle filtering (based on HTTP headers, etc). #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerTcpSteal { + /// User is interested in stealing traffic on this `Port`, so add it to the list of + /// ports that the stealer is filtering. + /// + /// The `TcpConnectionStealer` supports an [`HttpFilter`] granting the ability to steal + /// only traffic that matches the user configured filter. It's also possible to just steal + /// all traffic (which we refer as `Unfiltered`). For more info see [`StealType`]. + /// + /// This variant is somewhat related to [`LayerTcpSteal::ConnectionUnsubscribe`], since + /// we don't have a `ConnectionSubscribe` message anywhere, instead what we do is: when + /// a new connection comes in one of the ports we are subscribed to, we consider it a + /// connection subscription (so this mechanism represents the **non-existing** + /// `ConnectionSubscribe` variant). PortSubscribe(StealType), + + /// User has stopped stealing from this connection with [`ConnectionId`]. + /// + /// We do **not** have a `ConnectionSubscribe` variant/message. What happens instead is that we + /// call a _connection subscription_ the act of `accept`ing a new connection on one of the + /// ports we are subscribed to. See the [`LayerTcpSteal::PortSubscribe`] for more info. ConnectionUnsubscribe(ConnectionId), + + /// Removes this `Port` from the stealers's filter, the traffic won't be stolen by mirrord + /// anymore. PortUnsubscribe(Port), Data(TcpData), HttpResponse(HttpResponse>), diff --git a/tests/python-e2e/ops.py b/tests/python-e2e/ops.py index c107ebb2375..36c7ba5fb8c 100644 --- a/tests/python-e2e/ops.py +++ b/tests/python-e2e/ops.py @@ -88,6 +88,19 @@ def test_mkdir_errors(self): os.close(dir) + def test_statfs_and_fstatvfs_sucess(self): + """ + Test statfs / fstatfs + """ + file_path, _ = self._create_new_tmp_file() + + statvfs_result = os.statvfs(file_path) + self.assertIsNotNone(statvfs_result) + + fd = os.open(file_path, os.O_RDONLY) + fstatvfs_result = os.fstatvfs(fd) + self.assertIsNotNone(fstatvfs_result) + def test_rmdir(self): """ Creates a new directory in "/tmp" and removes it using rmdir. @@ -106,6 +119,5 @@ def _create_new_tmp_file(self): w_file.write(TEXT) return file_path, file_name - if __name__ == "__main__": unittest.main()