diff --git a/CHANGELOG.md b/CHANGELOG.md index c52d3d77a21..bda379d047d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,81 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang +## [3.130.0](https://github.com/metalbear-co/mirrord/tree/3.130.0) - 2025-01-21 + + +### Added + +- Added support for `rmdir`, `unlink` and `unlinkat`. + [#2221](https://github.com/metalbear-co/mirrord/issues/2221) + + +### Changed + +- Updated `configuration.md` and improved `.feature.env.mapping` doc. + + +### Fixed + +- Stopped mirrord entering a crash loop when trying to load into some processes + like VSCode's `watchdog.js` when the user config contained a call to + `get_env()`, which occurred due to missing env - the config is now only + rendered once and set into an env var. + [#2936](https://github.com/metalbear-co/mirrord/issues/2936) +- Fixed an issue where HTTP requests stolen with a filter would hang with a + single-threaded local HTTP server. + Improved handling of incoming connections on the local machine (e.g + introduces reuse of local HTTP connections). + [#3013](https://github.com/metalbear-co/mirrord/issues/3013) + + +### Internal + +- Extended `mirrord-protocol` with info logs from the agent. + +## [3.129.0](https://github.com/metalbear-co/mirrord/tree/3.129.0) - 2025-01-14 + + +### Added + +- Support for stealing incoming connections that are over IPv6. + [#2956](https://github.com/metalbear-co/mirrord/issues/2956) +- mirrord policy to control file ops from the operator. +- mirrord policy to restrict fetching remote environment variables. + + +### Changed + +- Updated how intproxy is outputing logfile when using container mode, now logs + will be written on host machine. + [#2868](https://github.com/metalbear-co/mirrord/issues/2868) +- Changed log level for debugger ports detection. + [#2986](https://github.com/metalbear-co/mirrord/issues/2986) +- Readonly file buffering is not enabled by default to improve performance + [#3004](https://github.com/metalbear-co/mirrord/issues/3004) +- Extended docs for HTTP filter in the mirrord config. + + +### Fixed + +- Fixed panic when Go >=1.23.3 verifies pidfd support on Linux. + [#2988](https://github.com/metalbear-co/mirrord/issues/2988) +- Fix misleading agent IO operation error that always mentioned getaddrinfo. + [#2992](https://github.com/metalbear-co/mirrord/issues/2992) +- Fixed a bug where port mirroring block (due to active mirrord policies) would + terminate the mirrord session. + + +### Internal + +- Added lint for unused crate dependencies. + [#2843](https://github.com/metalbear-co/mirrord/issues/2843) +- Fixed fs policy E2E test. +- Pinned `cargo-chef` version to `0.1.68` in the dockerfiles. +- Added available namespaces to `mirrord ls` output. New output format is + enabled with a flag in an environment variable. + [#2999](https://github.com/metalbear-co/mirrord/issues/2999) + ## [3.128.0](https://github.com/metalbear-co/mirrord/tree/3.128.0) - 2024-12-19 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 818a39e3a89..0583782b4b1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -113,6 +113,40 @@ For example, a test which only tests sanity of the ephemeral container feature s On Linux, running tests may exhaust a large amount of RAM and crash the machine. To prevent this, limit the number of concurrent jobs by running the command with e.g. `-j 4` +### IPv6 + +Some tests create a single-stack IPv6 service. They can only be run on clusters with IPv6 enabled. +In order to test IPv6 on a local cluster on macOS, you can use Kind: + +1. `brew install kind` +2. ```shell + cat >kind-config.yaml <=1.23.3 verifies pidfd support on Linux. \ No newline at end of file diff --git a/changelog.d/statfs.added.md b/changelog.d/statfs.added.md new file mode 100644 index 00000000000..b1cea16a410 --- /dev/null +++ b/changelog.d/statfs.added.md @@ -0,0 +1 @@ +Add statfs support \ No newline at end of file diff --git a/mirrord-schema.json b/mirrord-schema.json index 82a1538b9a7..19577b952af 100644 --- a/mirrord-schema.json +++ b/mirrord-schema.json @@ -1,7 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "LayerFileConfig", - "description": "mirrord allows for a high degree of customization when it comes to which features you want to enable, and how they should function.\n\nAll of the configuration fields have a default value, so a minimal configuration would be no configuration at all.\n\nThe configuration supports templating using the [Tera](https://keats.github.io/tera/docs/) template engine. Currently we don't provide additional values to the context, if you have anything you want us to provide please let us know.\n\nTo use a configuration file in the CLI, use the `-f ` flag. Or if using VSCode Extension or JetBrains plugin, simply create a `.mirrord/mirrord.json` file or use the UI.\n\nTo help you get started, here are examples of a basic configuration file, and a complete configuration file containing all fields.\n\n### Basic `config.json` {#root-basic}\n\n```json { \"target\": \"pod/bear-pod\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Basic `config.json` with templating {#root-basic-templating}\n\n```json { \"target\": \"{{ get_env(name=\"TARGET\", default=\"pod/fallback\") }}\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Complete `config.json` {#root-complete}\n\nDon't use this example as a starting point, it's just here to show you all the available options. ```json { \"accept_invalid_certificates\": false, \"skip_processes\": \"ide-debugger\", \"target\": { \"path\": \"pod/bear-pod\", \"namespace\": \"default\" }, \"connect_tcp\": null, \"agent\": { \"log_level\": \"info\", \"json_log\": false, \"labels\": { \"user\": \"meow\" }, \"annotations\": { \"cats.io/inject\": \"enabled\" }, \"namespace\": \"default\", \"image\": \"ghcr.io/metalbear-co/mirrord:latest\", \"image_pull_policy\": \"IfNotPresent\", \"image_pull_secrets\": [ { \"secret-key\": \"secret\" } ], \"ttl\": 30, \"ephemeral\": false, \"communication_timeout\": 30, \"startup_timeout\": 360, \"network_interface\": \"eth0\", \"flush_connections\": true }, \"feature\": { \"env\": { \"include\": \"DATABASE_USER;PUBLIC_ENV\", \"exclude\": \"DATABASE_PASSWORD;SECRET_ENV\", \"override\": { \"DATABASE_CONNECTION\": \"db://localhost:7777/my-db\", \"LOCAL_BEAR\": \"panda\" }, \"mapping\": { \".+_TIMEOUT\": \"1000\" } }, \"fs\": { \"mode\": \"write\", \"read_write\": \".+\\\\.json\" , \"read_only\": [ \".+\\\\.yaml\", \".+important-file\\\\.txt\" ], \"local\": [ \".+\\\\.js\", \".+\\\\.mjs\" ] }, \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": { \"enabled\": true, \"filter\": { \"local\": [\"1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\"] } } }, \"copy_target\": { \"scale_down\": false } }, \"operator\": true, \"kubeconfig\": \"~/.kube/config\", \"sip_binaries\": \"bash\", \"telemetry\": true, \"kube_context\": \"my-cluster\" } ```\n\n# Options {#root-options}", + "description": "mirrord allows for a high degree of customization when it comes to which features you want to enable, and how they should function.\n\nAll of the configuration fields have a default value, so a minimal configuration would be no configuration at all.\n\nThe configuration supports templating using the [Tera](https://keats.github.io/tera/docs/) template engine. Currently we don't provide additional values to the context, if you have anything you want us to provide please let us know.\n\nTo use a configuration file in the CLI, use the `-f ` flag. Or if using VSCode Extension or JetBrains plugin, simply create a `.mirrord/mirrord.json` file or use the UI.\n\nTo help you get started, here are examples of a basic configuration file, and a complete configuration file containing all fields.\n\n### Basic `config.json` {#root-basic}\n\n```json { \"target\": \"pod/bear-pod\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Basic `config.json` with templating {#root-basic-templating}\n\n```json { \"target\": \"{{ get_env(name=\"TARGET\", default=\"pod/fallback\") }}\", \"feature\": { \"env\": true, \"fs\": \"read\", \"network\": true } } ```\n\n### Complete `config.json` {#root-complete}\n\nDon't use this example as a starting point, it's just here to show you all the available options. ```json { \"accept_invalid_certificates\": false, \"skip_processes\": \"ide-debugger\", \"target\": { \"path\": \"pod/bear-pod\", \"namespace\": \"default\" }, \"connect_tcp\": null, \"agent\": { \"log_level\": \"info\", \"json_log\": false, \"labels\": { \"user\": \"meow\" }, \"annotations\": { \"cats.io/inject\": \"enabled\" }, \"namespace\": \"default\", \"image\": \"ghcr.io/metalbear-co/mirrord:latest\", \"image_pull_policy\": \"IfNotPresent\", \"image_pull_secrets\": [ { \"secret-key\": \"secret\" } ], \"ttl\": 30, \"ephemeral\": false, \"communication_timeout\": 30, \"startup_timeout\": 360, \"network_interface\": \"eth0\", \"flush_connections\": true, \"metrics\": \"0.0.0.0:9000\", }, \"feature\": { \"env\": { \"include\": \"DATABASE_USER;PUBLIC_ENV\", \"exclude\": \"DATABASE_PASSWORD;SECRET_ENV\", \"override\": { \"DATABASE_CONNECTION\": \"db://localhost:7777/my-db\", \"LOCAL_BEAR\": \"panda\" }, \"mapping\": { \".+_TIMEOUT\": \"1000\" } }, \"fs\": { \"mode\": \"write\", \"read_write\": \".+\\\\.json\" , \"read_only\": [ \".+\\\\.yaml\", \".+important-file\\\\.txt\" ], \"local\": [ \".+\\\\.js\", \".+\\\\.mjs\" ] }, \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] }, \"outgoing\": { \"tcp\": true, \"udp\": true, \"filter\": { \"local\": [\"tcp://1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\", \":53\"] }, \"ignore_localhost\": false, \"unix_streams\": \"bear.+\" }, \"dns\": { \"enabled\": true, \"filter\": { \"local\": [\"1.1.1.0/24:1337\", \"1.1.5.0/24\", \"google.com\"] } } }, \"copy_target\": { \"scale_down\": false } }, \"operator\": true, \"kubeconfig\": \"~/.kube/config\", \"sip_binaries\": \"bash\", \"telemetry\": true, \"kube_context\": \"my-cluster\" } ```\n\n# Options {#root-options}", "type": "object", "properties": { "accept_invalid_certificates": { @@ -255,7 +255,7 @@ "properties": { "annotations": { "title": "agent.annotations {#agent-annotations}", - "description": "Allows setting up custom annotations for the agent Job and Pod.\n\n```json { \"annotations\": { \"cats.io/inject\": \"enabled\" } } ```", + "description": "Allows setting up custom annotations for the agent Job and Pod.\n\n```json { \"annotations\": { \"cats.io/inject\": \"enabled\" \"prometheus.io/scrape\": \"true\", \"prometheus.io/port\": \"9000\" } } ```", "type": [ "object", "null" @@ -378,6 +378,14 @@ "null" ] }, + "metrics": { + "title": "agent.metrics {#agent-metrics}", + "description": "Enables prometheus metrics for the agent pod.\n\nYou might need to add annotations to the agent pod depending on how prometheus is configured to scrape for metrics.\n\n```json { \"metrics\": \"0.0.0.0:9000\" } ```", + "type": [ + "string", + "null" + ] + }, "namespace": { "title": "agent.namespace {#agent-namespace}", "description": "Namespace where the agent shall live. Note: Doesn't work with ephemeral containers. Defaults to the current kubernetes namespace.", @@ -752,7 +760,7 @@ }, "mapping": { "title": "feature.env.mapping {#feature-env-mapping}", - "description": "Specify map of patterns that if matched will replace the value according to specification.\n\n*Capture groups are allowed.*\n\nExample: ```json { \".+_TIMEOUT\": \"10000\" \"LOG_.+_VERBOSITY\": \"debug\" \"(\\w+)_(\\d+)\": \"magic-value\" } ```\n\nWill do the next replacements for environment variables that match:\n\n`CONNECTION_TIMEOUT: 500` => `CONNECTION_TIMEOUT: 10000` `LOG_FILE_VERBOSITY: info` => `LOG_FILE_VERBOSITY: debug` `DATA_1234: common-value` => `DATA_1234: magic-value`", + "description": "Specify map of patterns that if matched will replace the value according to specification.\n\n*Capture groups are allowed.*\n\nExample: ```json { \".+_TIMEOUT\": \"10000\" \"LOG_.+_VERBOSITY\": \"debug\" \"(\\w+)_(\\d+)\": \"magic-value\" } ```\n\nWill do the next replacements for environment variables that match:\n\n* `CONNECTION_TIMEOUT: 500` => `CONNECTION_TIMEOUT: 10000`\n\n* `LOG_FILE_VERBOSITY: info` => `LOG_FILE_VERBOSITY: debug`\n\n* `DATA_1234: common-value` => `DATA_1234: magic-value`", "type": [ "object", "null" @@ -1268,7 +1276,7 @@ }, "IncomingFileConfig": { "title": "incoming (network)", - "description": "Controls the incoming TCP traffic feature.\n\nSee the incoming [reference](https://mirrord.dev/docs/reference/traffic/#incoming) for more details.\n\nIncoming traffic supports 2 modes of operation:\n\n1. Mirror (**default**): Sniffs the TCP data from a port, and forwards a copy to the interested listeners;\n\n2. Steal: Captures the TCP data from a port, and forwards it to the local process, see [`steal`](##steal);\n\n### Minimal `incoming` config\n\n```json { \"feature\": { \"network\": { \"incoming\": \"steal\" } } } ```\n\n### Advanced `incoming` config\n\n```json { \"feature\": { \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000] \"listen_ports\": [[80, 8111]] } } } } ```", + "description": "Controls the incoming TCP traffic feature.\n\nSee the incoming [reference](https://mirrord.dev/docs/reference/traffic/#incoming) for more details.\n\nIncoming traffic supports 2 modes of operation:\n\n1. Mirror (**default**): Sniffs the TCP data from a port, and forwards a copy to the interested listeners;\n\n2. Steal: Captures the TCP data from a port, and forwards it to the local process, see [`steal`](##steal);\n\n### Minimal `incoming` config\n\n```json { \"feature\": { \"network\": { \"incoming\": \"steal\" } } } ```\n\n### Advanced `incoming` config\n\n```json { \"feature\": { \"network\": { \"incoming\": { \"mode\": \"steal\", \"http_filter\": { \"header_filter\": \"host: api\\\\..+\" }, \"port_mapping\": [[ 7777, 8888 ]], \"ignore_localhost\": false, \"ignore_ports\": [9999, 10000], \"listen_ports\": [[80, 8111]] } } } } ```", "anyOf": [ { "anyOf": [ @@ -1474,6 +1482,14 @@ } ] }, + "ipv6": { + "title": "feature.network.ipv6 {#feature-network-dns}", + "description": "Enable ipv6 support. Turn on if your application listens to incoming traffic over IPv6.", + "type": [ + "boolean", + "null" + ] + }, "outgoing": { "title": "feature.network.outgoing {#feature-network-outgoing}", "anyOf": [ diff --git a/mirrord/agent/Cargo.toml b/mirrord/agent/Cargo.toml index cdba788acf3..07757b7900e 100644 --- a/mirrord/agent/Cargo.toml +++ b/mirrord/agent/Cargo.toml @@ -69,6 +69,8 @@ x509-parser = "0.16" rustls.workspace = true envy = "0.4" socket2.workspace = true +prometheus = { version = "0.13", features = ["process"] } +axum = { version = "0.7", features = ["macros"] } iptables = { git = "https://github.com/metalbear-co/rust-iptables.git", rev = "e66c7332e361df3c61a194f08eefe3f40763d624" } rawsocket = { git = "https://github.com/metalbear-co/rawsocket.git" } procfs = "0.17.0" @@ -78,3 +80,4 @@ rstest.workspace = true mockall = "0.13" test_bin = "0.4" rcgen.workspace = true +reqwest.workspace = true diff --git a/mirrord/agent/Dockerfile b/mirrord/agent/Dockerfile index 83c97871cae..e9de8df84d6 100644 --- a/mirrord/agent/Dockerfile +++ b/mirrord/agent/Dockerfile @@ -8,7 +8,8 @@ RUN ./platform.sh # this takes around 1 minute since libgit2 is slow https://github.com/rust-lang/cargo/issues/9167 ENV CARGO_NET_GIT_FETCH_WITH_CLI=true -RUN cargo install cargo-chef +# cargo-chef 0.1.69 breaks the build +RUN cargo install cargo-chef@0.1.68 FROM chef AS planner diff --git a/mirrord/agent/README.md b/mirrord/agent/README.md index bf077b5fdcf..d7456ead64c 100644 --- a/mirrord/agent/README.md +++ b/mirrord/agent/README.md @@ -6,3 +6,198 @@ Agent part of [mirrord](https://github.com/metalbear-co/mirrord) responsible for mirrord-agent is written in Rust for safety, low memory consumption and performance. mirrord-agent is distributed as a container image (currently only x86) that is published on [GitHub Packages publicly](https://github.com/metalbear-co/mirrord-agent/pkgs/container/mirrord-agent). + +## Enabling prometheus metrics + +To start the metrics server, you'll need to add this config to your `mirrord.json`: + +```json +{ + "agent": { + "metrics": "0.0.0.0:9000", + "annotations": { + "prometheus.io/scrape": "true", + "prometheus.io/port": "9000" + } +} +``` + +Remember to change the `port` in both `metrics` and `annotations`, they have to match, +otherwise prometheus will try to scrape on `port: 80` or other commonly used ports. + +### Installing prometheus + +Run `kubectl apply -f {file-name}.yaml` on these sequences of `yaml` files and you should +get prometheus running in your cluster. You can access the dashboard from your browser at +`http://{cluster-ip}:30909`, if you're using minikube it might be +`http://192.168.49.2:30909`. + +You'll get prometheus running under the `monitoring` namespace, but it'll be able to look +into resources from all namespaces. The config in `configmap.yaml` sets prometheus to look +at pods only, if you want to use it to scrape other stuff, check +[this example](https://github.com/prometheus/prometheus/blob/main/documentation/examples/prometheus-kubernetes.yml). + +1. `create-namespace.yaml` + +```yaml +apiVersion: v1 +kind: Namespace +metadata: + name: monitoring +``` + +2. `cluster-role.yaml` + +```yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: prometheus +rules: +- apiGroups: [""] + resources: + - nodes + - services + - endpoints + - pods + verbs: ["get", "list", "watch"] +- apiGroups: + - extensions + resources: + - ingresses + verbs: ["get", "list", "watch"] +``` + +3. `service-account.yaml` + +```yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: prometheus + namespace: monitoring +``` + +4. `cluster-role-binding.yaml` + +```yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: prometheus +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: prometheus +subjects: +- kind: ServiceAccount + name: prometheus + namespace: monitoring +``` + +5. `configmap.yaml` + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: prometheus-config + namespace: monitoring +data: + prometheus.yml: | + global: + keep_dropped_targets: 100 + + scrape_configs: + - job_name: "kubernetes-pods" + + kubernetes_sd_configs: + - role: pod + + relabel_configs: + - source_labels: [__address__, __meta_kubernetes_pod_annotation_prometheus_io_port] + action: replace + regex: ([^:]+)(?::\d+)?;(\d+) + replacement: $1:$2 + target_label: __address__ + - action: labelmap + regex: __meta_kubernetes_pod_label_(.+) + - source_labels: [__meta_kubernetes_namespace] + action: replace + target_label: namespace + - source_labels: [__meta_kubernetes_pod_name] + action: replace + target_label: pod +``` + +- If you make any changes to the 5-configmap.yaml file, remember to `kubectl apply` it + **before** restarting the `prometheus` deployment. + +6. `deployment.yaml` + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: prometheus + namespace: monitoring + labels: + app: prometheus +spec: + replicas: 1 + strategy: + rollingUpdate: + maxSurge: 1 + maxUnavailable: 1 + type: RollingUpdate + selector: + matchLabels: + app: prometheus + template: + metadata: + labels: + app: prometheus + annotations: + prometheus.io/scrape: "true" + prometheus.io/port: "9090" + spec: + serviceAccountName: prometheus + containers: + - name: prometheus + image: prom/prometheus + args: + - '--config.file=/etc/prometheus/prometheus.yml' + ports: + - name: web + containerPort: 9090 + volumeMounts: + - name: prometheus-config-volume + mountPath: /etc/prometheus + restartPolicy: Always + volumes: + - name: prometheus-config-volume + configMap: + defaultMode: 420 + name: prometheus-config +``` + +7. `service.yaml` + +```yaml +apiVersion: v1 +kind: Service +metadata: + name: prometheus-service + namespace: monitoring + annotations: + prometheus.io/scrape: 'true' + prometheus.io/port: '9090' +spec: + selector: + app: prometheus + type: NodePort + ports: + - port: 8080 + targetPort: 9090 + nodePort: 30909 +``` diff --git a/mirrord/agent/src/cli.rs b/mirrord/agent/src/cli.rs index 6c5b11e65a2..bbcf23f1816 100644 --- a/mirrord/agent/src/cli.rs +++ b/mirrord/agent/src/cli.rs @@ -1,7 +1,12 @@ #![deny(missing_docs)] +use std::net::SocketAddr; + use clap::{Parser, Subcommand}; -use mirrord_protocol::{MeshVendor, AGENT_NETWORK_INTERFACE_ENV, AGENT_OPERATOR_CERT_ENV}; +use mirrord_protocol::{ + MeshVendor, AGENT_IPV6_ENV, AGENT_METRICS_ENV, AGENT_NETWORK_INTERFACE_ENV, + AGENT_OPERATOR_CERT_ENV, +}; const DEFAULT_RUNTIME: &str = "containerd"; @@ -26,6 +31,10 @@ pub struct Args { #[arg(short = 'i', long, env = AGENT_NETWORK_INTERFACE_ENV)] pub network_interface: Option, + /// Controls whether metrics are enabled, and the address to set up the metrics server. + #[arg(long, env = AGENT_METRICS_ENV)] + pub metrics: Option, + /// Return an error after accepting the first client connection, in order to test agent error /// cleanup. /// @@ -50,6 +59,13 @@ pub struct Args { env = "MIRRORD_AGENT_IN_SERVICE_MESH" )] pub is_mesh: bool, + + /// Enable support for IPv6-only clusters + /// + /// Only when this option is set will take the needed steps to run on an IPv6 single stack + /// cluster. + #[arg(long, default_value_t = false, env = AGENT_IPV6_ENV)] + pub ipv6: bool, } impl Args { diff --git a/mirrord/agent/src/client_connection.rs b/mirrord/agent/src/client_connection.rs index 8181e4baabd..7b484cc25da 100644 --- a/mirrord/agent/src/client_connection.rs +++ b/mirrord/agent/src/client_connection.rs @@ -208,7 +208,7 @@ enum ConnectionFramed { #[cfg(test)] mod test { - use std::sync::Arc; + use std::sync::{Arc, Once}; use futures::StreamExt; use mirrord_protocol::ClientCodec; @@ -220,10 +220,19 @@ mod test { use super::*; + static CRYPTO_PROVIDER: Once = Once::new(); + /// Verifies that [`AgentTlsConnector`] correctly accepts a /// connection from a server using the provided certificate. #[tokio::test] async fn agent_tls_connector_valid_cert() { + CRYPTO_PROVIDER.call_once(|| { + rustls::crypto::CryptoProvider::install_default( + rustls::crypto::aws_lc_rs::default_provider(), + ) + .expect("Failed to install crypto provider") + }); + let cert = rcgen::generate_simple_self_signed(vec!["operator".to_string()]).unwrap(); let cert_bytes = cert.cert.der(); let key_bytes = cert.key_pair.serialize_der(); @@ -269,6 +278,13 @@ mod test { /// connection from a server using some other certificate. #[tokio::test] async fn agent_tls_connector_invalid_cert() { + CRYPTO_PROVIDER.call_once(|| { + rustls::crypto::CryptoProvider::install_default( + rustls::crypto::aws_lc_rs::default_provider(), + ) + .expect("Failed to install crypto provider") + }); + let server_cert = rcgen::generate_simple_self_signed(vec!["operator".to_string()]).unwrap(); let cert_bytes = server_cert.cert.der(); let key_bytes = server_cert.key_pair.serialize_der(); diff --git a/mirrord/agent/src/container_handle.rs b/mirrord/agent/src/container_handle.rs index 6e8ba78173d..dd6755e766d 100644 --- a/mirrord/agent/src/container_handle.rs +++ b/mirrord/agent/src/container_handle.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - error::Result, + error::AgentResult, runtime::{Container, ContainerInfo, ContainerRuntime}, }; @@ -22,7 +22,7 @@ pub(crate) struct ContainerHandle(Arc); impl ContainerHandle { /// Retrieve info about the container and initialize this struct. #[tracing::instrument(level = "trace")] - pub(crate) async fn new(container: Container) -> Result { + pub(crate) async fn new(container: Container) -> AgentResult { let ContainerInfo { pid, env: raw_env } = container.get_info().await?; let inner = Inner { pid, raw_env }; diff --git a/mirrord/agent/src/dns.rs b/mirrord/agent/src/dns.rs index 0ad44c76934..b92487594e0 100644 --- a/mirrord/agent/src/dns.rs +++ b/mirrord/agent/src/dns.rs @@ -3,7 +3,7 @@ use std::{future, path::PathBuf, time::Duration}; use futures::{stream::FuturesOrdered, StreamExt}; use hickory_resolver::{system_conf::parse_resolv_conf, Hosts, Resolver}; use mirrord_protocol::{ - dns::{DnsLookup, GetAddrInfoRequest, GetAddrInfoResponse}, + dns::{DnsLookup, GetAddrInfoRequest, GetAddrInfoRequestV2, GetAddrInfoResponse}, DnsLookupError, RemoteResult, ResolveErrorKindInternal, ResponseError, }; use tokio::{ @@ -16,14 +16,26 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::Level; -use crate::{ - error::{AgentError, Result}, - watched_task::TaskStatus, -}; +use crate::{error::AgentResult, metrics::DNS_REQUEST_COUNT, watched_task::TaskStatus}; + +#[derive(Debug)] +pub(crate) enum ClientGetAddrInfoRequest { + V1(GetAddrInfoRequest), + V2(GetAddrInfoRequestV2), +} + +impl ClientGetAddrInfoRequest { + pub(crate) fn into_v2(self) -> GetAddrInfoRequestV2 { + match self { + ClientGetAddrInfoRequest::V1(old_req) => old_req.into(), + ClientGetAddrInfoRequest::V2(v2_req) => v2_req, + } + } +} #[derive(Debug)] pub(crate) struct DnsCommand { - request: GetAddrInfoRequest, + request: ClientGetAddrInfoRequest, response_tx: oneshot::Sender>, } @@ -34,6 +46,7 @@ pub(crate) struct DnsWorker { request_rx: Receiver, attempts: usize, timeout: Duration, + support_ipv6: bool, } impl DnsWorker { @@ -45,7 +58,11 @@ impl DnsWorker { /// # Note /// /// `pid` is used to find the correct path of `etc` directory. - pub(crate) fn new(pid: Option, request_rx: Receiver) -> Self { + pub(crate) fn new( + pid: Option, + request_rx: Receiver, + support_ipv6: bool, + ) -> Self { let etc_path = pid .map(|pid| { PathBuf::from("/proc") @@ -66,6 +83,7 @@ impl DnsWorker { .ok() .and_then(|attempts| attempts.parse().ok()) .unwrap_or(1), + support_ipv6, } } @@ -79,9 +97,10 @@ impl DnsWorker { #[tracing::instrument(level = Level::TRACE, ret, err(level = Level::TRACE))] async fn do_lookup( etc_path: PathBuf, - host: String, + request: GetAddrInfoRequestV2, attempts: usize, timeout: Duration, + support_ipv6: bool, ) -> RemoteResult { // Prepares the `Resolver` after reading some `/etc` DNS files. // @@ -94,13 +113,32 @@ impl DnsWorker { let hosts_conf = fs::read(hosts_path).await?; let (config, mut options) = parse_resolv_conf(resolv_conf)?; + tracing::debug!(?config, ?options, "parsed config options"); options.server_ordering_strategy = hickory_resolver::config::ServerOrderingStrategy::UserProvidedOrder; options.timeout = timeout; options.attempts = attempts; - options.ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4Only; + options.ip_strategy = if support_ipv6 { + tracing::debug!("IPv6 support enabled. Respecting client IP family."); + request + .family + .try_into() + .inspect_err(|e| { + tracing::error!(%e, + "Unknown address family in addrinfo request. Using IPv4 and IPv6.") + }) + // If the agent gets some new, unknown variant of family address, it's the + // client's fault, so the agent queries both IPv4 and IPv6 and if that's not + // good enough for the client, the client can error out. + .unwrap_or(hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6) + } else { + tracing::debug!("IPv6 support disabled. Resolving IPv4 only."); + hickory_resolver::config::LookupIpStrategy::Ipv4Only + }; + tracing::debug!(?config, ?options, "updated config options"); let mut resolver = Resolver::tokio(config, options); + tracing::debug!(?resolver, "tokio resolver"); let mut hosts = Hosts::default(); hosts.read_hosts_conf(hosts_conf.as_slice())?; @@ -111,9 +149,10 @@ impl DnsWorker { let lookup = resolver .inspect_err(|fail| tracing::error!(?fail, "Failed to build DNS resolver"))? - .lookup_ip(host) + .lookup_ip(request.node) .await - .inspect(|lookup| tracing::trace!(?lookup, "Lookup finished"))? + .inspect(|lookup| tracing::trace!(?lookup, "Lookup finished")) + .inspect_err(|e| tracing::trace!(%e, "lookup failed"))? .into(); Ok(lookup) @@ -125,21 +164,30 @@ impl DnsWorker { let etc_path = self.etc_path.clone(); let timeout = self.timeout; let attempts = self.attempts; + + DNS_REQUEST_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let support_ipv6 = self.support_ipv6; let lookup_future = async move { - let result = Self::do_lookup(etc_path, message.request.node, attempts, timeout).await; + let result = Self::do_lookup( + etc_path, + message.request.into_v2(), + attempts, + timeout, + support_ipv6, + ) + .await; if let Err(result) = message.response_tx.send(result) { tracing::error!(?result, "Failed to send query response"); } + DNS_REQUEST_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); }; tokio::spawn(lookup_future); } - pub(crate) async fn run( - mut self, - cancellation_token: CancellationToken, - ) -> Result<(), AgentError> { + pub(crate) async fn run(mut self, cancellation_token: CancellationToken) -> AgentResult<()> { loop { tokio::select! { _ = cancellation_token.cancelled() => break Ok(()), @@ -174,8 +222,8 @@ impl DnsApi { /// Results of scheduled requests are available via [`Self::recv`] (order is preserved). pub(crate) async fn make_request( &mut self, - request: GetAddrInfoRequest, - ) -> Result<(), AgentError> { + request: ClientGetAddrInfoRequest, + ) -> AgentResult<()> { let (response_tx, response_rx) = oneshot::channel(); let command = DnsCommand { @@ -194,7 +242,7 @@ impl DnsApi { /// Returns the result of the oldest outstanding DNS request issued with this struct (see /// [`Self::make_request`]). #[tracing::instrument(level = Level::TRACE, skip(self), ret, err)] - pub(crate) async fn recv(&mut self) -> Result { + pub(crate) async fn recv(&mut self) -> AgentResult { let Some(response) = self.responses.next().await else { return future::pending().await; }; diff --git a/mirrord/agent/src/entrypoint.rs b/mirrord/agent/src/entrypoint.rs index 28eb5f26634..ac9157897a0 100644 --- a/mirrord/agent/src/entrypoint.rs +++ b/mirrord/agent/src/entrypoint.rs @@ -1,7 +1,7 @@ use std::{ collections::HashMap, mem, - net::{Ipv4Addr, SocketAddrV4}, + net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, path::PathBuf, sync::{ atomic::{AtomicU32, Ordering}, @@ -10,8 +10,9 @@ use std::{ }; use client_connection::AgentTlsConnector; -use dns::{DnsCommand, DnsWorker}; +use dns::{ClientGetAddrInfoRequest, DnsCommand, DnsWorker}; use futures::TryFutureExt; +use metrics::{start_metrics, CLIENT_COUNT}; use mirrord_protocol::{ClientMessage, DaemonMessage, GetEnvVarsRequest, LogMessage}; use sniffer::tcp_capture::RawSocketTcpCapture; use tokio::{ @@ -24,7 +25,7 @@ use tokio::{ time::{timeout, Duration}, }; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, trace, warn}; +use tracing::{debug, error, info, trace, warn, Level}; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; use crate::{ @@ -32,7 +33,7 @@ use crate::{ client_connection::ClientConnection, container_handle::ContainerHandle, dns::DnsApi, - error::{AgentError, Result}, + error::{AgentError, AgentResult}, file::FileManager, outgoing::{TcpOutgoingApi, UdpOutgoingApi}, runtime::get_container, @@ -72,7 +73,7 @@ struct State { impl State { /// Return [`Err`] if container runtime operations failed. - pub async fn new(args: &Args) -> Result { + pub async fn new(args: &Args) -> AgentResult { let tls_connector = args .operator_tls_cert_pem .clone() @@ -205,6 +206,12 @@ struct ClientConnectionHandler { ready_for_logs: bool, } +impl Drop for ClientConnectionHandler { + fn drop(&mut self) { + CLIENT_COUNT.fetch_sub(1, Ordering::Relaxed); + } +} + impl ClientConnectionHandler { /// Initializes [`ClientConnectionHandler`]. pub async fn new( @@ -212,7 +219,7 @@ impl ClientConnectionHandler { mut connection: ClientConnection, bg_tasks: BackgroundTasks, state: State, - ) -> Result { + ) -> AgentResult { let pid = state.container_pid(); let file_manager = FileManager::new(pid.or_else(|| state.ephemeral.then_some(1))); @@ -238,6 +245,8 @@ impl ClientConnectionHandler { ready_for_logs: false, }; + CLIENT_COUNT.fetch_add(1, Ordering::Relaxed); + Ok(client_handler) } @@ -273,7 +282,7 @@ impl ClientConnectionHandler { id: ClientId, task: BackgroundTask, connection: &mut ClientConnection, - ) -> Result> { + ) -> AgentResult> { if let BackgroundTask::Running(stealer_status, stealer_sender) = task { match TcpStealerApi::new( id, @@ -313,7 +322,7 @@ impl ClientConnectionHandler { /// /// Breaks upon receiver/sender drop. #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, cancellation_token: CancellationToken) -> Result<()> { + async fn start(mut self, cancellation_token: CancellationToken) -> AgentResult<()> { let error = loop { select! { message = self.connection.receive() => { @@ -364,7 +373,7 @@ impl ClientConnectionHandler { Ok(message) => self.respond(DaemonMessage::TcpOutgoing(message)).await?, Err(e) => break e, }, - message = self.udp_outgoing_api.daemon_message() => match message { + message = self.udp_outgoing_api.recv_from_task() => match message { Ok(message) => self.respond(DaemonMessage::UdpOutgoing(message)).await?, Err(e) => break e, }, @@ -389,15 +398,15 @@ impl ClientConnectionHandler { /// Sends a [`DaemonMessage`] response to the connected client (`mirrord-layer`). #[tracing::instrument(level = "trace", skip(self))] - async fn respond(&mut self, response: DaemonMessage) -> Result<()> { + async fn respond(&mut self, response: DaemonMessage) -> AgentResult<()> { self.connection.send(response).await.map_err(Into::into) } /// Handles incoming messages from the connected client (`mirrord-layer`). /// /// Returns `false` if the client disconnected. - #[tracing::instrument(level = "trace", skip(self))] - async fn handle_client_message(&mut self, message: ClientMessage) -> Result { + #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::DEBUG))] + async fn handle_client_message(&mut self, message: ClientMessage) -> AgentResult { match message { ClientMessage::FileRequest(req) => { if let Some(response) = self.file_manager.handle_message(req)? { @@ -415,7 +424,7 @@ impl ClientConnectionHandler { self.tcp_outgoing_api.send_to_task(layer_message).await? } ClientMessage::UdpOutgoing(layer_message) => { - self.udp_outgoing_api.layer_message(layer_message).await? + self.udp_outgoing_api.send_to_task(layer_message).await? } ClientMessage::GetEnvVarsRequest(GetEnvVarsRequest { env_vars_filter, @@ -433,7 +442,14 @@ impl ClientConnectionHandler { .await? } ClientMessage::GetAddrInfoRequest(request) => { - self.dns_api.make_request(request).await?; + self.dns_api + .make_request(ClientGetAddrInfoRequest::V1(request)) + .await?; + } + ClientMessage::GetAddrInfoRequestV2(request) => { + self.dns_api + .make_request(ClientGetAddrInfoRequest::V2(request)) + .await?; } ClientMessage::Ping => self.respond(DaemonMessage::Pong).await?, ClientMessage::Tcp(message) => { @@ -488,15 +504,37 @@ impl ClientConnectionHandler { } /// Initializes the agent's [`State`], channels, threads, and runs [`ClientConnectionHandler`]s. -#[tracing::instrument(level = "trace", ret)] -async fn start_agent(args: Args) -> Result<()> { +#[tracing::instrument(level = Level::TRACE, ret, err)] +async fn start_agent(args: Args) -> AgentResult<()> { trace!("start_agent -> Starting agent with args: {args:?}"); - let listener = TcpListener::bind(SocketAddrV4::new( + // listen for client connections + let ipv4_listener_result = TcpListener::bind(SocketAddrV4::new( Ipv4Addr::UNSPECIFIED, args.communicate_port, )) - .await?; + .await; + + let listener = if args.ipv6 && ipv4_listener_result.is_err() { + debug!("IPv6 Support enabled, and IPv4 bind failed, binding IPv6 listener"); + TcpListener::bind(SocketAddrV6::new( + Ipv6Addr::UNSPECIFIED, + args.communicate_port, + 0, + 0, + )) + .await + } else { + ipv4_listener_result + }?; + + match listener.local_addr() { + Ok(addr) => debug!( + client_listener_address = addr.to_string(), + "Created listener." + ), + Err(err) => error!(%err, "listener local address error"), + } let state = State::new(&args).await?; @@ -505,6 +543,18 @@ async fn start_agent(args: Args) -> Result<()> { // To make sure that background tasks are cancelled when we exit early from this function. let cancel_guard = cancellation_token.clone().drop_guard(); + if let Some(metrics_address) = args.metrics { + let cancellation_token = cancellation_token.clone(); + tokio::spawn(async move { + start_metrics(metrics_address, cancellation_token.clone()) + .await + .inspect_err(|fail| { + tracing::error!(?fail, "Failed starting metrics server!"); + cancellation_token.cancel(); + }) + }); + } + let (sniffer_command_tx, sniffer_command_rx) = mpsc::channel::(1000); let (stealer_command_tx, stealer_command_rx) = mpsc::channel::(1000); let (dns_command_tx, dns_command_rx) = mpsc::channel::(1000); @@ -566,13 +616,15 @@ async fn start_agent(args: Args) -> Result<()> { let cancellation_token = cancellation_token.clone(); let watched_task = WatchedTask::new( TcpConnectionStealer::TASK_NAME, - TcpConnectionStealer::new(stealer_command_rx).and_then(|stealer| async move { - let res = stealer.start(cancellation_token).await; - if let Err(err) = res.as_ref() { - error!("Stealer failed: {err}"); - } - res - }), + TcpConnectionStealer::new(stealer_command_rx, args.ipv6).and_then( + |stealer| async move { + let res = stealer.start(cancellation_token).await; + if let Err(err) = res.as_ref() { + error!("Stealer failed: {err}"); + } + res + }, + ), ); let status = watched_task.status(); let task = run_thread_in_namespace( @@ -589,7 +641,8 @@ async fn start_agent(args: Args) -> Result<()> { let cancellation_token = cancellation_token.clone(); let watched_task = WatchedTask::new( DnsWorker::TASK_NAME, - DnsWorker::new(state.container_pid(), dns_command_rx).run(cancellation_token), + DnsWorker::new(state.container_pid(), dns_command_rx, args.ipv6) + .run(cancellation_token), ); let status = watched_task.status(); let task = run_thread_in_namespace( @@ -723,7 +776,7 @@ async fn start_agent(args: Args) -> Result<()> { Ok(()) } -async fn clear_iptable_chain() -> Result<()> { +async fn clear_iptable_chain() -> AgentResult<()> { let ipt = new_iptables(); SafeIpTables::load(IPTablesWrapper::from(ipt), false) @@ -734,7 +787,7 @@ async fn clear_iptable_chain() -> Result<()> { Ok(()) } -async fn run_child_agent() -> Result<()> { +async fn run_child_agent() -> AgentResult<()> { let command_args = std::env::args().collect::>(); let (command, args) = command_args .split_first() @@ -758,7 +811,7 @@ async fn run_child_agent() -> Result<()> { /// /// Captures SIGTERM signals sent by Kubernetes when the pod is gracefully deleted. /// When a signal is captured, the child process is killed and the iptables are cleaned. -async fn start_iptable_guard(args: Args) -> Result<()> { +async fn start_iptable_guard(args: Args) -> AgentResult<()> { debug!("start_iptable_guard -> Initializing iptable-guard."); let state = State::new(&args).await?; @@ -795,7 +848,18 @@ async fn start_iptable_guard(args: Args) -> Result<()> { result } -pub async fn main() -> Result<()> { +/// The agent is somewhat started twice, first with [`start_iptable_guard`], and then the +/// proper agent with [`start_agent`]. +/// +/// ## Things to keep in mind due to the double initialization +/// +/// Since the _second_ agent gets spawned as a child of the _first_, they share resources, +/// like the `namespace`, which means: +/// +/// 1. If you try to `bind` a socket to some address before [`start_agent`], it'll actually be bound +/// **twice**, which incurs an error (address already in use). You could get around this by +/// `bind`ing on `0.0.0.0:0`, but this is most likely **not** what you want. +pub async fn main() -> AgentResult<()> { rustls::crypto::CryptoProvider::install_default(rustls::crypto::aws_lc_rs::default_provider()) .expect("Failed to install crypto provider"); diff --git a/mirrord/agent/src/env.rs b/mirrord/agent/src/env.rs index 26fa4681431..5a349709f2d 100644 --- a/mirrord/agent/src/env.rs +++ b/mirrord/agent/src/env.rs @@ -7,7 +7,7 @@ use mirrord_protocol::RemoteResult; use tokio::io::AsyncReadExt; use wildmatch::WildMatch; -use crate::error::Result; +use crate::error::AgentResult; struct EnvFilter { include: Vec, @@ -97,7 +97,7 @@ pub(crate) fn parse_raw_env<'a, S: AsRef + 'a + ?Sized, T: IntoIterator>() } -pub(crate) async fn get_proc_environ(path: PathBuf) -> Result> { +pub(crate) async fn get_proc_environ(path: PathBuf) -> AgentResult> { let mut environ_file = tokio::fs::File::open(path).await?; let mut raw_env_vars = String::with_capacity(8192); diff --git a/mirrord/agent/src/error.rs b/mirrord/agent/src/error.rs index ad04e49c8c5..88b811e590b 100644 --- a/mirrord/agent/src/error.rs +++ b/mirrord/agent/src/error.rs @@ -84,6 +84,10 @@ pub(crate) enum AgentError { /// Temporary error for vpn feature #[error("Generic error in vpn: {0}")] VpnError(String), + + /// When we neither create a redirector for IPv4, nor for IPv6 + #[error("Could not create a listener for stolen connections")] + CannotListenForStolenConnections, } impl From> for AgentError { @@ -92,4 +96,4 @@ impl From> for AgentError { } } -pub(crate) type Result = std::result::Result; +pub(crate) type AgentResult = std::result::Result; diff --git a/mirrord/agent/src/file.rs b/mirrord/agent/src/file.rs index 9261b0bcb69..571b2ad9d3c 100644 --- a/mirrord/agent/src/file.rs +++ b/mirrord/agent/src/file.rs @@ -5,16 +5,20 @@ use std::{ io::{self, prelude::*, BufReader, SeekFrom}, iter::{Enumerate, Peekable}, ops::RangeInclusive, - os::unix::{fs::MetadataExt, prelude::FileExt}, + os::{ + fd::RawFd, + unix::{fs::MetadataExt, prelude::FileExt}, + }, path::{Path, PathBuf}, }; use faccess::{AccessMode, PathExt}; use libc::DT_DIR; use mirrord_protocol::{file::*, FileRequest, FileResponse, RemoteResult, ResponseError}; +use nix::unistd::UnlinkatFlags; use tracing::{error, trace, Level}; -use crate::error::Result; +use crate::{error::AgentResult, metrics::OPEN_FD_COUNT}; #[derive(Debug)] pub enum RemoteFile { @@ -72,15 +76,11 @@ pub(crate) struct FileManager { fds_iter: RangeInclusive, } -impl Default for FileManager { - fn default() -> Self { - Self { - root_path: Default::default(), - open_files: Default::default(), - dir_streams: Default::default(), - getdents_streams: Default::default(), - fds_iter: (0..=u64::MAX), - } +impl Drop for FileManager { + fn drop(&mut self) { + let descriptors = + self.open_files.len() + self.dir_streams.len() + self.getdents_streams.len(); + OPEN_FD_COUNT.fetch_sub(descriptors as i64, std::sync::atomic::Ordering::Relaxed); } } @@ -147,8 +147,11 @@ pub fn resolve_path + std::fmt::Debug, R: AsRef + std::fmt: impl FileManager { /// Executes the request and returns the response. - #[tracing::instrument(level = "trace", skip(self))] - pub fn handle_message(&mut self, request: FileRequest) -> Result> { + #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::DEBUG))] + pub(crate) fn handle_message( + &mut self, + request: FileRequest, + ) -> AgentResult> { Ok(match request { FileRequest::Open(OpenFileRequest { path, open_options }) => { // TODO: maybe not agent error on this? @@ -202,10 +205,7 @@ impl FileManager { let write_result = self.write_limited(remote_fd, start_from, write_bytes); Some(FileResponse::WriteLimited(write_result)) } - FileRequest::Close(CloseFileRequest { fd }) => { - self.close(fd); - None - } + FileRequest::Close(CloseFileRequest { fd }) => self.close(fd), FileRequest::Access(AccessFileRequest { pathname, mode }) => { let pathname = pathname .strip_prefix("/") @@ -223,8 +223,12 @@ impl FileManager { Some(FileResponse::Xstat(xstat_result)) } FileRequest::XstatFs(XstatFsRequest { fd }) => { - let xstat_result = self.xstatfs(fd); - Some(FileResponse::XstatFs(xstat_result)) + let xstatfs_result = self.xstatfs(fd); + Some(FileResponse::XstatFs(xstatfs_result)) + } + FileRequest::StatFs(StatFsRequest { path }) => { + let statfs_result = self.statfs(path); + Some(FileResponse::XstatFs(statfs_result)) } // dir operations @@ -240,10 +244,7 @@ impl FileManager { let read_dir_result = self.read_dir_batch(remote_fd, amount); Some(FileResponse::ReadDirBatch(read_dir_result)) } - FileRequest::CloseDir(CloseDirRequest { remote_fd }) => { - self.close_dir(remote_fd); - None - } + FileRequest::CloseDir(CloseDirRequest { remote_fd }) => self.close_dir(remote_fd), FileRequest::GetDEnts64(GetDEnts64Request { remote_fd, buffer_size, @@ -258,21 +259,35 @@ impl FileManager { pathname, mode, }) => Some(FileResponse::MakeDir(self.mkdirat(dirfd, &pathname, mode))), + FileRequest::RemoveDir(RemoveDirRequest { pathname }) => { + Some(FileResponse::RemoveDir(self.rmdir(&pathname))) + } + FileRequest::Unlink(UnlinkRequest { pathname }) => { + Some(FileResponse::Unlink(self.unlink(&pathname))) + } + FileRequest::UnlinkAt(UnlinkAtRequest { + dirfd, + pathname, + flags, + }) => Some(FileResponse::Unlink(self.unlinkat(dirfd, &pathname, flags))), }) } - #[tracing::instrument(level = "trace")] + #[tracing::instrument(level = Level::TRACE, ret)] pub fn new(pid: Option) -> Self { let root_path = get_root_path_from_optional_pid(pid); trace!("Agent root path >> {root_path:?}"); + Self { - open_files: HashMap::new(), root_path, - ..Default::default() + open_files: Default::default(), + dir_streams: Default::default(), + getdents_streams: Default::default(), + fds_iter: (0..=u64::MAX), } } - #[tracing::instrument(level = "trace", skip(self))] + #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::DEBUG))] fn open( &mut self, path: PathBuf, @@ -294,12 +309,14 @@ impl FileManager { RemoteFile::File(file) }; - self.open_files.insert(fd, remote_file); + if self.open_files.insert(fd, remote_file).is_none() { + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } Ok(OpenFileResponse { fd }) } - #[tracing::instrument(level = "trace", skip(self))] + #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::DEBUG))] fn open_relative( &mut self, relative_fd: u64, @@ -328,7 +345,9 @@ impl FileManager { RemoteFile::File(file) }; - self.open_files.insert(fd, remote_file); + if self.open_files.insert(fd, remote_file).is_none() { + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } Ok(OpenFileResponse { fd }) } else { @@ -520,6 +539,55 @@ impl FileManager { } } + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn rmdir(&mut self, path: &Path) -> RemoteResult<()> { + let path = resolve_path(path, &self.root_path)?; + + std::fs::remove_dir(path.as_path()).map_err(ResponseError::from) + } + + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn unlink(&mut self, path: &Path) -> RemoteResult<()> { + let path = resolve_path(path, &self.root_path)?; + + nix::unistd::unlink(path.as_path()) + .map_err(|error| ResponseError::from(std::io::Error::from_raw_os_error(error as i32))) + } + + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn unlinkat( + &mut self, + dirfd: Option, + path: &Path, + flags: u32, + ) -> RemoteResult<()> { + let path = match dirfd { + Some(dirfd) => { + let relative_dir = self + .open_files + .get(&dirfd) + .ok_or(ResponseError::NotFound(dirfd))?; + + if let RemoteFile::Directory(relative_dir) = relative_dir { + relative_dir.join(path) + } else { + return Err(ResponseError::NotDirectory(dirfd)); + } + } + None => resolve_path(path, &self.root_path)?, + }; + + let flags = match flags { + 0 => UnlinkatFlags::RemoveDir, + _ => UnlinkatFlags::NoRemoveDir, + }; + + let fd: Option = dirfd.map(|fd| fd as RawFd); + + nix::unistd::unlinkat(fd, path.as_path(), flags) + .map_err(|error| ResponseError::from(std::io::Error::from_raw_os_error(error as i32))) + } + pub(crate) fn seek(&mut self, fd: u64, seek_from: SeekFrom) -> RemoteResult { trace!( "FileManager::seek -> fd {:#?} | seek_from {:#?}", @@ -572,20 +640,36 @@ impl FileManager { }) } - pub(crate) fn close(&mut self, fd: u64) { - trace!("FileManager::close -> fd {:#?}", fd,); - + /// Always returns `None`, since we don't return any [`FileResponse`] back to mirrord + /// on `close` of an fd. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn close(&mut self, fd: u64) -> Option { if self.open_files.remove(&fd).is_none() { - error!("FileManager::close -> fd {:#?} not found", fd); + error!(fd, "fd not found!"); + } else { + OPEN_FD_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); } - } - pub(crate) fn close_dir(&mut self, fd: u64) { - trace!("FileManager::close_dir -> fd {:#?}", fd,); + None + } - if self.dir_streams.remove(&fd).is_none() && self.getdents_streams.remove(&fd).is_none() { + /// Always returns `None`, since we don't return any [`FileResponse`] back to mirrord + /// on `close_dir` of an fd. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn close_dir(&mut self, fd: u64) -> Option { + let closed_dir_stream = self.dir_streams.remove(&fd); + let closed_getdents_stream = self.getdents_streams.remove(&fd); + + if closed_dir_stream.is_some() && closed_getdents_stream.is_some() { + // Closed `dirstream` and `dentsstream` + OPEN_FD_COUNT.fetch_sub(2, std::sync::atomic::Ordering::Relaxed); + } else if closed_dir_stream.is_some() || closed_getdents_stream.is_some() { + OPEN_FD_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } else { error!("FileManager::close_dir -> fd {:#?} not found", fd); } + + None } pub(crate) fn access( @@ -690,6 +774,18 @@ impl FileManager { } #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn statfs(&mut self, path: PathBuf) -> RemoteResult { + let path = resolve_path(path, &self.root_path)?; + + let statfs = nix::sys::statfs::statfs(&path) + .map_err(|err| std::io::Error::from_raw_os_error(err as i32))?; + + Ok(XstatFsResponse { + metadata: statfs.into(), + }) + } + + #[tracing::instrument(level = Level::TRACE, skip(self), err(level = Level::DEBUG))] pub(crate) fn fdopen_dir(&mut self, fd: u64) -> RemoteResult { let path = match self .open_files @@ -706,7 +802,10 @@ impl FileManager { .ok_or_else(|| ResponseError::IdsExhausted("fdopen_dir".to_string()))?; let dir_stream = path.read_dir()?.enumerate(); - self.dir_streams.insert(fd, dir_stream); + + if self.dir_streams.insert(fd, dir_stream).is_none() { + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } Ok(OpenDirResponse { fd }) } @@ -755,7 +854,7 @@ impl FileManager { /// The possible remote errors are: /// [`ResponseError::NotFound`] if there is not such fd here. /// [`ResponseError::NotDirectory`] if the fd points to a file with a non-directory file type. - #[tracing::instrument(level = "trace", skip(self))] + #[tracing::instrument(level = Level::TRACE, skip(self))] pub(crate) fn get_or_create_getdents64_stream( &mut self, fd: u64, @@ -768,6 +867,7 @@ impl FileManager { let current_and_parent = Self::get_current_and_parent_entries(dir); let stream = GetDEnts64Stream::new(dir.read_dir()?, current_and_parent).peekable(); + OPEN_FD_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); Ok(e.insert(stream)) } }, diff --git a/mirrord/agent/src/main.rs b/mirrord/agent/src/main.rs index 305ec50e0ed..e9f3e107907 100644 --- a/mirrord/agent/src/main.rs +++ b/mirrord/agent/src/main.rs @@ -22,6 +22,7 @@ mod env; mod error; mod file; mod http; +mod metrics; mod namespace; mod outgoing; mod runtime; @@ -31,7 +32,8 @@ mod util; mod vpn; mod watched_task; +#[cfg(target_os = "linux")] #[tokio::main(flavor = "current_thread")] -async fn main() -> crate::error::Result<()> { +async fn main() -> crate::error::AgentResult<()> { crate::entrypoint::main().await } diff --git a/mirrord/agent/src/metrics.rs b/mirrord/agent/src/metrics.rs new file mode 100644 index 00000000000..59fefbc2305 --- /dev/null +++ b/mirrord/agent/src/metrics.rs @@ -0,0 +1,366 @@ +use std::{ + net::SocketAddr, + sync::{atomic::AtomicI64, Arc}, +}; + +use axum::{extract::State, routing::get, Router}; +use http::StatusCode; +use prometheus::{proto::MetricFamily, IntGauge, Registry}; +use tokio::net::TcpListener; +use tokio_util::sync::CancellationToken; +use tracing::Level; + +use crate::error::AgentError; + +/// Incremented whenever we get a new client in `ClientConnectionHandler`, and decremented +/// when this client is dropped. +pub(crate) static CLIENT_COUNT: AtomicI64 = AtomicI64::new(0); + +/// Incremented whenever we handle a new `DnsCommand`, and decremented after the result of +/// `do_lookup` has been sent back through the response channel. +pub(crate) static DNS_REQUEST_COUNT: AtomicI64 = AtomicI64::new(0); + +/// Incremented and decremented in _open-ish_/_close-ish_ file operations in `FileManager`, +/// Also gets decremented when `FileManager` is dropped. +pub(crate) static OPEN_FD_COUNT: AtomicI64 = AtomicI64::new(0); + +/// Follows the amount of subscribed ports in `update_packet_filter`. We don't really +/// increment/decrement this one, and mostly `set` it to the latest amount of ports, zeroing it when +/// the `TcpConnectionSniffer` gets dropped. +pub(crate) static MIRROR_PORT_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static MIRROR_CONNECTION_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_FILTERED_PORT_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_UNFILTERED_PORT_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_FILTERED_CONNECTION_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static HTTP_REQUEST_IN_PROGRESS_COUNT: AtomicI64 = AtomicI64::new(0); + +pub(crate) static TCP_OUTGOING_CONNECTION: AtomicI64 = AtomicI64::new(0); + +pub(crate) static UDP_OUTGOING_CONNECTION: AtomicI64 = AtomicI64::new(0); + +/// The state with all the metrics [`IntGauge`]s and the prometheus [`Registry`] where we keep them. +/// +/// **Do not** modify the gauges directly! +/// +/// Instead rely on [`Metrics::gather_metrics`], as we actually use a bunch of [`AtomicI64`]s to +/// keep track of the values, they are the ones being (de|in)cremented. These gauges are just set +/// when it's time to send them via [`get_metrics`]. +#[derive(Debug)] +struct Metrics { + registry: Registry, + client_count: IntGauge, + dns_request_count: IntGauge, + open_fd_count: IntGauge, + mirror_port_subscription: IntGauge, + mirror_connection_subscription: IntGauge, + steal_filtered_port_subscription: IntGauge, + steal_unfiltered_port_subscription: IntGauge, + steal_filtered_connection_subscription: IntGauge, + steal_unfiltered_connection_subscription: IntGauge, + http_request_in_progress_count: IntGauge, + tcp_outgoing_connection: IntGauge, + udp_outgoing_connection: IntGauge, +} + +impl Metrics { + /// Creates a [`Registry`] to ... register our [`IntGauge`]s. + fn new() -> Self { + use prometheus::Opts; + + let registry = Registry::new(); + + let client_count = { + let opts = Opts::new( + "mirrord_agent_client_count", + "amount of connected clients to this mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let dns_request_count = { + let opts = Opts::new( + "mirrord_agent_dns_request_count", + "amount of in-progress dns requests in the mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let open_fd_count = { + let opts = Opts::new( + "mirrord_agent_open_fd_count", + "amount of open file descriptors in mirrord-agent file manager", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let mirror_port_subscription = { + let opts = Opts::new( + "mirrord_agent_mirror_port_subscription_count", + "amount of mirror port subscriptions in mirror-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let mirror_connection_subscription = { + let opts = Opts::new( + "mirrord_agent_mirror_connection_subscription_count", + "amount of connections in mirror mode in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_filtered_port_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_filtered_port_subscription_count", + "amount of filtered steal port subscriptions in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_unfiltered_port_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_unfiltered_port_subscription_count", + "amount of unfiltered steal port subscriptions in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_filtered_connection_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_connection_subscription_count", + "amount of filtered connections in steal mode in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let steal_unfiltered_connection_subscription = { + let opts = Opts::new( + "mirrord_agent_steal_unfiltered_connection_subscription_count", + "amount of unfiltered connections in steal mode in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let http_request_in_progress_count = { + let opts = Opts::new( + "mirrord_agent_http_request_in_progress_count", + "amount of in-progress http requests in the mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let tcp_outgoing_connection = { + let opts = Opts::new( + "mirrord_agent_tcp_outgoing_connection_count", + "amount of tcp outgoing connections in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + let udp_outgoing_connection = { + let opts = Opts::new( + "mirrord_agent_udp_outgoing_connection_count", + "amount of udp outgoing connections in mirrord-agent", + ); + IntGauge::with_opts(opts).expect("Valid at initialization!") + }; + + registry + .register(Box::new(client_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(dns_request_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(open_fd_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(mirror_port_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(mirror_connection_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_filtered_port_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_unfiltered_port_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_filtered_connection_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(steal_unfiltered_connection_subscription.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(http_request_in_progress_count.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(tcp_outgoing_connection.clone())) + .expect("Register must be valid at initialization!"); + registry + .register(Box::new(udp_outgoing_connection.clone())) + .expect("Register must be valid at initialization!"); + + Self { + registry, + client_count, + dns_request_count, + open_fd_count, + mirror_port_subscription, + mirror_connection_subscription, + steal_filtered_port_subscription, + steal_unfiltered_port_subscription, + steal_filtered_connection_subscription, + steal_unfiltered_connection_subscription, + http_request_in_progress_count, + tcp_outgoing_connection, + udp_outgoing_connection, + } + } + + /// Calls [`IntGauge::set`] on every [`IntGauge`] of `Self`, setting it to the value of + /// the corresponding [`AtomicI64`] global (the uppercase named version of the gauge). + /// + /// Returns the list of [`MetricFamily`] registered in our [`Metrics::registry`], ready to be + /// encoded and sent to prometheus. + fn gather_metrics(&self) -> Vec { + use std::sync::atomic::Ordering; + + let Self { + registry, + client_count, + dns_request_count, + open_fd_count, + mirror_port_subscription, + mirror_connection_subscription, + steal_filtered_port_subscription, + steal_unfiltered_port_subscription, + steal_filtered_connection_subscription, + steal_unfiltered_connection_subscription, + http_request_in_progress_count, + tcp_outgoing_connection, + udp_outgoing_connection, + } = self; + + client_count.set(CLIENT_COUNT.load(Ordering::Relaxed)); + dns_request_count.set(DNS_REQUEST_COUNT.load(Ordering::Relaxed)); + open_fd_count.set(OPEN_FD_COUNT.load(Ordering::Relaxed)); + mirror_port_subscription.set(MIRROR_PORT_SUBSCRIPTION.load(Ordering::Relaxed)); + mirror_connection_subscription.set(MIRROR_CONNECTION_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_filtered_port_subscription + .set(STEAL_FILTERED_PORT_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_unfiltered_port_subscription + .set(STEAL_UNFILTERED_PORT_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_filtered_connection_subscription + .set(STEAL_FILTERED_CONNECTION_SUBSCRIPTION.load(Ordering::Relaxed)); + steal_unfiltered_connection_subscription + .set(STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.load(Ordering::Relaxed)); + http_request_in_progress_count.set(HTTP_REQUEST_IN_PROGRESS_COUNT.load(Ordering::Relaxed)); + tcp_outgoing_connection.set(TCP_OUTGOING_CONNECTION.load(Ordering::Relaxed)); + udp_outgoing_connection.set(UDP_OUTGOING_CONNECTION.load(Ordering::Relaxed)); + + registry.gather() + } +} + +/// `GET /metrics` +/// +/// Prepares all the metrics with [`Metrics::gather_metrics`], and responds to the prometheus +/// request. +#[tracing::instrument(level = Level::TRACE, ret)] +async fn get_metrics(State(state): State>) -> (StatusCode, String) { + use prometheus::TextEncoder; + + let metric_families = state.gather_metrics(); + match TextEncoder.encode_to_string(&metric_families) { + Ok(response) => (StatusCode::OK, response), + Err(fail) => { + tracing::error!(?fail, "Failed GET /metrics"); + (StatusCode::INTERNAL_SERVER_ERROR, fail.to_string()) + } + } +} + +/// Starts the mirrord-agent prometheus metrics service. +/// +/// You can get the metrics from `GET address/metrics`. +/// +/// - `address`: comes from a mirrord-agent config. +#[tracing::instrument(level = Level::TRACE, skip_all, ret ,err)] +pub(crate) async fn start_metrics( + address: SocketAddr, + cancellation_token: CancellationToken, +) -> Result<(), axum::BoxError> { + let metrics_state = Arc::new(Metrics::new()); + + let app = Router::new() + .route("/metrics", get(get_metrics)) + .with_state(metrics_state); + + let listener = TcpListener::bind(address) + .await + .map_err(AgentError::from) + .inspect_err(|fail| { + tracing::error!(?fail, "Failed to bind TCP socket for metrics server") + })?; + + let cancel_on_error = cancellation_token.clone(); + axum::serve(listener, app) + .with_graceful_shutdown(async move { cancellation_token.cancelled().await }) + .await + .inspect_err(|fail| { + tracing::error!(%fail, "Could not start agent metrics server!"); + cancel_on_error.cancel(); + })?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::{sync::atomic::Ordering, time::Duration}; + + use tokio_util::sync::CancellationToken; + + use super::OPEN_FD_COUNT; + use crate::metrics::start_metrics; + + #[tokio::test] + async fn test_metrics() { + let metrics_address = "127.0.0.1:9000".parse().unwrap(); + let cancellation_token = CancellationToken::new(); + + let metrics_cancellation = cancellation_token.child_token(); + tokio::spawn(async move { + start_metrics(metrics_address, metrics_cancellation) + .await + .unwrap() + }); + + OPEN_FD_COUNT.fetch_add(1, Ordering::Relaxed); + + // Give the server some time to start. + tokio::time::sleep(Duration::from_secs(1)).await; + + let get_all_metrics = reqwest::get("http://127.0.0.1:9000/metrics") + .await + .unwrap() + .error_for_status() + .unwrap() + .text() + .await + .unwrap(); + + assert!(get_all_metrics.contains("mirrord_agent_open_fd_count 1")); + + cancellation_token.drop_guard(); + } +} diff --git a/mirrord/agent/src/outgoing.rs b/mirrord/agent/src/outgoing.rs index 13e3a9e1e06..96a063d7a05 100644 --- a/mirrord/agent/src/outgoing.rs +++ b/mirrord/agent/src/outgoing.rs @@ -18,7 +18,8 @@ use tokio_util::io::ReaderStream; use tracing::Level; use crate::{ - error::Result, + error::AgentResult, + metrics::TCP_OUTGOING_CONNECTION, util::run_thread_in_namespace, watched_task::{TaskStatus, WatchedTask}, }; @@ -81,7 +82,7 @@ impl TcpOutgoingApi { /// Sends the [`LayerTcpOutgoing`] message to the background task. #[tracing::instrument(level = Level::TRACE, skip(self), err)] - pub(crate) async fn send_to_task(&mut self, message: LayerTcpOutgoing) -> Result<()> { + pub(crate) async fn send_to_task(&mut self, message: LayerTcpOutgoing) -> AgentResult<()> { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { @@ -91,7 +92,7 @@ impl TcpOutgoingApi { /// Receives a [`DaemonTcpOutgoing`] message from the background task. #[tracing::instrument(level = Level::TRACE, skip(self), err)] - pub(crate) async fn recv_from_task(&mut self) -> Result { + pub(crate) async fn recv_from_task(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), None => Err(self.task_status.unwrap_err().await), @@ -112,6 +113,13 @@ struct TcpOutgoingTask { daemon_tx: Sender, } +impl Drop for TcpOutgoingTask { + fn drop(&mut self) { + let connections = self.readers.keys().chain(self.writers.keys()).count(); + TCP_OUTGOING_CONNECTION.fetch_sub(connections as i64, std::sync::atomic::Ordering::Relaxed); + } +} + impl fmt::Debug for TcpOutgoingTask { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TcpOutgoingTask") @@ -152,7 +160,7 @@ impl TcpOutgoingTask { /// Runs this task as long as the channels connecting it with [`TcpOutgoingApi`] are open. /// This routine never fails and returns [`Result`] only due to [`WatchedTask`] constraints. #[tracing::instrument(level = Level::TRACE, skip(self))] - async fn run(mut self) -> Result<()> { + async fn run(mut self) -> AgentResult<()> { loop { let channel_closed = select! { biased; @@ -216,6 +224,7 @@ impl TcpOutgoingTask { self.readers.remove(&connection_id); self.writers.remove(&connection_id); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); let daemon_message = DaemonTcpOutgoing::Close(connection_id); self.daemon_tx.send(daemon_message).await?; @@ -246,6 +255,8 @@ impl TcpOutgoingTask { "Layer connection is shut down as well, sending close message.", ); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.daemon_tx .send(DaemonTcpOutgoing::Close(connection_id)) .await?; @@ -287,6 +298,7 @@ impl TcpOutgoingTask { connection_id, ReaderStream::with_capacity(read_half, Self::READ_BUFFER_SIZE), ); + TCP_OUTGOING_CONNECTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); Ok(DaemonConnect { connection_id, @@ -299,9 +311,12 @@ impl TcpOutgoingTask { result = ?daemon_connect, "Connection attempt finished.", ); + self.daemon_tx .send(DaemonTcpOutgoing::Connect(daemon_connect)) - .await + .await?; + + Ok(()) } // This message handles two cases: @@ -341,9 +356,14 @@ impl TcpOutgoingTask { connection_id, "Peer connection is shut down as well, sending close message to the client.", ); + TCP_OUTGOING_CONNECTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.daemon_tx .send(DaemonTcpOutgoing::Close(connection_id)) - .await + .await?; + + Ok(()) } } @@ -352,6 +372,7 @@ impl TcpOutgoingTask { Err(error) => { self.writers.remove(&connection_id); self.readers.remove(&connection_id); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); tracing::trace!( connection_id, @@ -360,7 +381,9 @@ impl TcpOutgoingTask { ); self.daemon_tx .send(DaemonTcpOutgoing::Close(connection_id)) - .await + .await?; + + Ok(()) } } } @@ -370,6 +393,7 @@ impl TcpOutgoingTask { LayerTcpOutgoing::Close(LayerClose { connection_id }) => { self.writers.remove(&connection_id); self.readers.remove(&connection_id); + TCP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); Ok(()) } diff --git a/mirrord/agent/src/outgoing/udp.rs b/mirrord/agent/src/outgoing/udp.rs index b6baa5e537e..0dab137a92b 100644 --- a/mirrord/agent/src/outgoing/udp.rs +++ b/mirrord/agent/src/outgoing/udp.rs @@ -1,10 +1,11 @@ +use core::fmt; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, thread, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::{ prelude::*, stream::{SplitSink, SplitStream}, @@ -15,21 +16,262 @@ use mirrord_protocol::{ }; use streammap_ext::StreamMap; use tokio::{ + io, net::UdpSocket, select, - sync::mpsc::{self, Receiver, Sender}, + sync::mpsc::{self, error::SendError, Receiver, Sender}, }; use tokio_util::{codec::BytesCodec, udp::UdpFramed}; -use tracing::{debug, trace, warn}; +use tracing::Level; use crate::{ - error::Result, + error::AgentResult, + metrics::UDP_OUTGOING_CONNECTION, util::run_thread_in_namespace, watched_task::{TaskStatus, WatchedTask}, }; -type Layer = LayerUdpOutgoing; -type Daemon = DaemonUdpOutgoing; +/// Task that handles [`LayerUdpOutgoing`] and [`DaemonUdpOutgoing`] messages. +/// +/// We start these tasks from the [`UdpOutgoingApi`] as a [`WatchedTask`]. +struct UdpOutgoingTask { + next_connection_id: ConnectionId, + /// Writing halves of peer connections made on layer's requests. + #[allow(clippy::type_complexity)] + writers: HashMap< + ConnectionId, + ( + SplitSink, (BytesMut, SocketAddr)>, + SocketAddr, + ), + >, + /// Reading halves of peer connections made on layer's requests. + readers: StreamMap>>, + /// Optional pid of agent's target. Used in `SocketStream::connect`. + pid: Option, + layer_rx: Receiver, + daemon_tx: Sender, +} + +impl Drop for UdpOutgoingTask { + fn drop(&mut self) { + let connections = self.readers.keys().chain(self.writers.keys()).count(); + UDP_OUTGOING_CONNECTION.fetch_sub(connections as i64, std::sync::atomic::Ordering::Relaxed); + } +} + +impl fmt::Debug for UdpOutgoingTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UdpOutgoingTask") + .field("next_connection_id", &self.next_connection_id) + .field("writers", &self.writers.len()) + .field("readers", &self.readers.len()) + .field("pid", &self.pid) + .finish() + } +} + +impl UdpOutgoingTask { + fn new( + pid: Option, + layer_rx: Receiver, + daemon_tx: Sender, + ) -> Self { + Self { + next_connection_id: 0, + writers: Default::default(), + readers: Default::default(), + pid, + layer_rx, + daemon_tx, + } + } + + /// Runs this task as long as the channels connecting it with [`UdpOutgoingApi`] are open. + /// This routine never fails and returns [`AgentResult`] only due to [`WatchedTask`] + /// constraints. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(super) async fn run(mut self) -> AgentResult<()> { + loop { + let channel_closed = select! { + biased; + + message = self.layer_rx.recv() => match message { + // We have a message from the layer to be handled. + Some(message) => { + self.handle_layer_msg(message).await.is_err() + }, + // Our channel with the layer is closed, this task is no longer needed. + None => true, + }, + + // We have data coming from one of our peers. + Some((connection_id, remote_read)) = self.readers.next() => { + self.handle_connection_read(connection_id, remote_read.transpose().map(|remote| remote.map(|(read, _)| read.into()))).await.is_err() + }, + }; + + if channel_closed { + tracing::trace!("Client channel closed, exiting"); + break Ok(()); + } + } + } + + /// Returns [`Err`] only when the client has disconnected. + #[tracing::instrument( + level = Level::TRACE, + skip(read), + fields(read = ?read.as_ref().map(|data| data.as_ref().map(Bytes::len).unwrap_or_default())) + err(level = Level::TRACE) + )] + async fn handle_connection_read( + &mut self, + connection_id: ConnectionId, + read: io::Result>, + ) -> Result<(), SendError> { + match read { + Ok(Some(read)) => { + let message = DaemonUdpOutgoing::Read(Ok(DaemonRead { + connection_id, + bytes: read.to_vec(), + })); + + self.daemon_tx.send(message).await? + } + // An error occurred when reading from a peer connection. + // We remove both io halves and inform the layer that the connection is closed. + // We remove the reader, because otherwise the `StreamMap` will produce an extra `None` + // item from the related stream. + Err(error) => { + tracing::trace!( + ?error, + connection_id, + "Reading from peer connection failed, sending close message.", + ); + + self.readers.remove(&connection_id); + self.writers.remove(&connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + let daemon_message = DaemonUdpOutgoing::Close(connection_id); + self.daemon_tx.send(daemon_message).await?; + } + Ok(None) => { + self.writers.remove(&connection_id); + self.readers.remove(&connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + let daemon_message = DaemonUdpOutgoing::Close(connection_id); + self.daemon_tx.send(daemon_message).await?; + } + } + + Ok(()) + } + + /// Returns [`Err`] only when the client has disconnected. + #[allow(clippy::type_complexity)] + #[tracing::instrument(level = Level::TRACE, ret)] + async fn handle_layer_msg( + &mut self, + message: LayerUdpOutgoing, + ) -> Result<(), SendError> { + match message { + // [user] -> [layer] -> [agent] -> [layer] + // `user` is asking us to connect to some remote host. + LayerUdpOutgoing::Connect(LayerConnect { remote_address }) => { + let daemon_connect = + connect(remote_address.clone()) + .await + .and_then(|mirror_socket| { + let connection_id = self.next_connection_id; + self.next_connection_id += 1; + + let peer_address = mirror_socket.peer_addr()?; + let local_address = mirror_socket.local_addr()?; + let local_address = SocketAddress::Ip(local_address); + + let framed = UdpFramed::new(mirror_socket, BytesCodec::new()); + + let (sink, stream): ( + SplitSink, (BytesMut, SocketAddr)>, + SplitStream>, + ) = framed.split(); + + self.writers.insert(connection_id, (sink, peer_address)); + self.readers.insert(connection_id, stream); + UDP_OUTGOING_CONNECTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + Ok(DaemonConnect { + connection_id, + remote_address, + local_address, + }) + }); + + tracing::trace!( + result = ?daemon_connect, + "Connection attempt finished.", + ); + + self.daemon_tx + .send(DaemonUdpOutgoing::Connect(daemon_connect)) + .await?; + + Ok(()) + } + // [user] -> [layer] -> [agent] -> [remote] + // `user` wrote some message to the remote host. + LayerUdpOutgoing::Write(LayerWrite { + connection_id, + bytes, + }) => { + let write_result = match self + .writers + .get_mut(&connection_id) + .ok_or(ResponseError::NotFound(connection_id)) + { + Ok((mirror, remote_address)) => mirror + .send((BytesMut::from(bytes.as_slice()), *remote_address)) + .await + .map_err(ResponseError::from), + Err(fail) => Err(fail), + }; + + match write_result { + Ok(()) => Ok(()), + Err(error) => { + self.writers.remove(&connection_id); + self.readers.remove(&connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + tracing::trace!( + connection_id, + ?error, + "Failed to handle layer write, sending close message to the client.", + ); + + let daemon_message = DaemonUdpOutgoing::Close(connection_id); + self.daemon_tx.send(daemon_message).await?; + + Ok(()) + } + } + } + // [layer] -> [agent] + // `layer` closed their interceptor stream. + LayerUdpOutgoing::Close(LayerClose { ref connection_id }) => { + self.writers.remove(connection_id); + self.readers.remove(connection_id); + UDP_OUTGOING_CONNECTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + Ok(()) + } + } + } +} /// Handles (briefly) the `UdpOutgoingRequest` and `UdpOutgoingResponse` messages, mostly the /// passing of these messages to the `interceptor_task` thread. @@ -41,10 +283,10 @@ pub(crate) struct UdpOutgoingApi { task_status: TaskStatus, /// Sends the `Layer` message to the `interceptor_task`. - layer_tx: Sender, + layer_tx: Sender, /// Reads the `Daemon` message from the `interceptor_task`. - daemon_rx: Receiver, + daemon_rx: Receiver, } /// Performs an [`UdpSocket::connect`] that handles 3 situations: @@ -55,8 +297,9 @@ pub(crate) struct UdpOutgoingApi { /// read access to `/etc/resolv.conf`, otherwise they'll be getting a mismatched connection; /// 3. User is trying to use `sendto` and `recvfrom`, we use the same hack as in DNS to fake a /// connection. -#[tracing::instrument(level = "trace", ret)] -async fn connect(remote_address: SocketAddr) -> Result { +#[tracing::instrument(level = Level::TRACE, ret, err(level = Level::DEBUG))] +async fn connect(remote_address: SocketAddress) -> Result { + let remote_address = remote_address.try_into()?; let mirror_address = match remote_address { std::net::SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), std::net::SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), @@ -75,8 +318,10 @@ impl UdpOutgoingApi { let (layer_tx, layer_rx) = mpsc::channel(1000); let (daemon_tx, daemon_rx) = mpsc::channel(1000); - let watched_task = - WatchedTask::new(Self::TASK_NAME, Self::interceptor_task(layer_rx, daemon_tx)); + let watched_task = WatchedTask::new( + Self::TASK_NAME, + UdpOutgoingTask::new(pid, layer_rx, daemon_tx).run(), + ); let task_status = watched_task.status(); let task = run_thread_in_namespace( @@ -94,150 +339,9 @@ impl UdpOutgoingApi { } } - /// The [`UdpOutgoingApi`] task. - /// - /// Receives [`LayerUdpOutgoing`] messages and replies with [`DaemonUdpOutgoing`]. - #[allow(clippy::type_complexity)] - async fn interceptor_task( - mut layer_rx: Receiver, - daemon_tx: Sender, - ) -> Result<()> { - let mut connection_ids = 0..=ConnectionId::MAX; - - // TODO: Right now we're manually keeping these 2 maps in sync (aviram suggested using - // `Weak` for `writers`). - let mut writers: HashMap< - ConnectionId, - ( - SplitSink, (BytesMut, SocketAddr)>, - SocketAddr, - ), - > = HashMap::default(); - - let mut readers: StreamMap>> = - StreamMap::default(); - - loop { - select! { - biased; - - // [layer] -> [agent] - Some(layer_message) = layer_rx.recv() => { - trace!("udp: interceptor_task -> layer_message {:?}", layer_message); - match layer_message { - // [user] -> [layer] -> [agent] -> [layer] - // `user` is asking us to connect to some remote host. - LayerUdpOutgoing::Connect(LayerConnect { remote_address }) => { - let daemon_connect = connect(remote_address.clone().try_into()?) - .await - .and_then(|mirror_socket| { - let connection_id = connection_ids - .next() - .ok_or_else(|| ResponseError::IdsExhausted("connect".into()))?; - - debug!("interceptor_task -> mirror_socket {:#?}", mirror_socket); - let peer_address = mirror_socket.peer_addr()?; - let local_address = mirror_socket.local_addr()?; - let local_address = SocketAddress::Ip(local_address); - let framed = UdpFramed::new(mirror_socket, BytesCodec::new()); - debug!("interceptor_task -> framed {:#?}", framed); - let (sink, stream): ( - SplitSink, (BytesMut, SocketAddr)>, - SplitStream>, - ) = framed.split(); - - writers.insert(connection_id, (sink, peer_address)); - readers.insert(connection_id, stream); - - Ok(DaemonConnect { - connection_id, - remote_address, - local_address - }) - }); - - let daemon_message = DaemonUdpOutgoing::Connect(daemon_connect); - debug!("interceptor_task -> daemon_message {:#?}", daemon_message); - daemon_tx.send(daemon_message).await? - } - // [user] -> [layer] -> [agent] -> [remote] - // `user` wrote some message to the remote host. - LayerUdpOutgoing::Write(LayerWrite { - connection_id, - bytes, - }) => { - let daemon_write = match writers - .get_mut(&connection_id) - .ok_or(ResponseError::NotFound(connection_id)) - { - Ok((mirror, remote_address)) => mirror - .send((BytesMut::from(bytes.as_slice()), *remote_address)) - .await - .map_err(ResponseError::from), - Err(fail) => Err(fail), - }; - - if let Err(fail) = daemon_write { - warn!("LayerUdpOutgoing::Write -> Failed with {:#?}", fail); - writers.remove(&connection_id); - readers.remove(&connection_id); - - let daemon_message = DaemonUdpOutgoing::Close(connection_id); - daemon_tx.send(daemon_message).await? - } - } - // [layer] -> [agent] - // `layer` closed their interceptor stream. - LayerUdpOutgoing::Close(LayerClose { ref connection_id }) => { - writers.remove(connection_id); - readers.remove(connection_id); - } - } - } - - // [remote] -> [agent] -> [layer] -> [user] - // Read the data from one of the connected remote hosts, and forward the result back - // to the `user`. - Some((connection_id, remote_read)) = readers.next() => { - trace!("interceptor_task -> read connection_id {:#?}", connection_id); - - match remote_read { - Some(read) => { - let daemon_read = read - .map_err(ResponseError::from) - .map(|(bytes, _)| DaemonRead { connection_id, bytes: bytes.to_vec() }); - - let daemon_message = DaemonUdpOutgoing::Read(daemon_read); - daemon_tx.send(daemon_message).await? - } - None => { - trace!("interceptor_task -> close connection {:#?}", connection_id); - writers.remove(&connection_id); - readers.remove(&connection_id); - - let daemon_message = DaemonUdpOutgoing::Close(connection_id); - daemon_tx.send(daemon_message).await? - } - } - } - else => { - // We have no more data coming from any of the remote hosts. - warn!("interceptor_task -> no messages left"); - break; - } - } - } - - Ok(()) - } - /// Sends a `UdpOutgoingRequest` to the `interceptor_task`. - pub(crate) async fn layer_message(&mut self, message: LayerUdpOutgoing) -> Result<()> { - trace!( - "UdpOutgoingApi::layer_message -> layer_message {:#?}", - message - ); - + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + pub(crate) async fn send_to_task(&mut self, message: LayerUdpOutgoing) -> AgentResult<()> { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { @@ -246,7 +350,7 @@ impl UdpOutgoingApi { } /// Receives a `UdpOutgoingResponse` from the `interceptor_task`. - pub(crate) async fn daemon_message(&mut self) -> Result { + pub(crate) async fn recv_from_task(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), None => Err(self.task_status.unwrap_err().await), diff --git a/mirrord/agent/src/sniffer.rs b/mirrord/agent/src/sniffer.rs index b1c232eae6d..94c40e0ce67 100644 --- a/mirrord/agent/src/sniffer.rs +++ b/mirrord/agent/src/sniffer.rs @@ -24,8 +24,9 @@ use self::{ tcp_capture::RawSocketTcpCapture, }; use crate::{ - error::AgentError, + error::AgentResult, http::HttpVersion, + metrics::{MIRROR_CONNECTION_SUBSCRIPTION, MIRROR_PORT_SUBSCRIPTION}, util::{ChannelClosedFuture, ClientId, Subscriptions}, }; @@ -138,7 +139,14 @@ pub(crate) struct TcpConnectionSniffer { sessions: TCPSessionMap, client_txs: HashMap>, - clients_closed: FuturesUnordered>, + clients_closed: FuturesUnordered, +} + +impl Drop for TcpConnectionSniffer { + fn drop(&mut self) { + MIRROR_PORT_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + MIRROR_CONNECTION_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + } } impl fmt::Debug for TcpConnectionSniffer { @@ -163,7 +171,7 @@ impl TcpConnectionSniffer { command_rx: Receiver, network_interface: Option, is_mesh: bool, - ) -> Result { + ) -> AgentResult { let tcp_capture = RawSocketTcpCapture::new(network_interface, is_mesh).await?; Ok(Self { @@ -190,7 +198,7 @@ where /// Runs the sniffer loop, capturing packets. #[tracing::instrument(level = Level::DEBUG, skip(cancel_token), err)] - pub async fn start(mut self, cancel_token: CancellationToken) -> Result<(), AgentError> { + pub async fn start(mut self, cancel_token: CancellationToken) -> AgentResult<()> { loop { select! { command = self.command_rx.recv() => { @@ -232,7 +240,7 @@ where /// Removes the client with `client_id`, and also unsubscribes its port. /// Adjusts BPF filter if needed. #[tracing::instrument(level = Level::TRACE, err)] - fn handle_client_closed(&mut self, client_id: ClientId) -> Result<(), AgentError> { + fn handle_client_closed(&mut self, client_id: ClientId) -> AgentResult<()> { self.client_txs.remove(&client_id); if self.port_subscriptions.remove_client(client_id) { @@ -245,8 +253,9 @@ where /// Updates BPF filter used by [`Self::tcp_capture`] to match state of /// [`Self::port_subscriptions`]. #[tracing::instrument(level = Level::TRACE, err)] - fn update_packet_filter(&mut self) -> Result<(), AgentError> { + fn update_packet_filter(&mut self) -> AgentResult<()> { let ports = self.port_subscriptions.get_subscribed_topics(); + MIRROR_PORT_SUBSCRIPTION.store(ports.len() as i64, std::sync::atomic::Ordering::Relaxed); let filter = if ports.is_empty() { tracing::trace!("No ports subscribed, setting dummy bpf"); @@ -261,7 +270,7 @@ where } #[tracing::instrument(level = Level::TRACE, err)] - fn handle_command(&mut self, command: SnifferCommand) -> Result<(), AgentError> { + fn handle_command(&mut self, command: SnifferCommand) -> AgentResult<()> { match command { SnifferCommand { client_id, @@ -325,7 +334,7 @@ where &mut self, identifier: TcpSessionIdentifier, tcp_packet: TcpPacketData, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { let data_tx = match self.sessions.entry(identifier) { Entry::Occupied(e) => e, Entry::Vacant(e) => { @@ -394,6 +403,7 @@ where } } + MIRROR_CONNECTION_SUBSCRIPTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); e.insert_entry(data_tx) } }; @@ -422,7 +432,7 @@ mod test { atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, + time::{Duration, Instant}, }; use api::TcpSnifferApi; @@ -430,6 +440,7 @@ mod test { tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, ConnectionId, LogLevel, }; + use rstest::rstest; use tcp_capture::test::TcpPacketsChannel; use tokio::sync::mpsc; @@ -448,6 +459,7 @@ mod test { async fn get_api(&mut self) -> TcpSnifferApi { let client_id = self.next_client_id; self.next_client_id += 1; + TcpSnifferApi::new(client_id, self.command_tx.clone(), self.task_status.clone()) .await .unwrap() @@ -845,4 +857,45 @@ mod test { }), ); } + + /// Verifies that [`TcpConnectionSniffer`] reacts to [`TcpSnifferApi`] being dropped + /// and clears the packet filter. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn cleanup_on_client_closed() { + let mut setup = TestSnifferSetup::new(); + + let mut api = setup.get_api().await; + + api.handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + assert_eq!( + api.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + assert_eq!(setup.times_filter_changed(), 1); + + std::mem::drop(api); + let dropped_at = Instant::now(); + + loop { + match setup.times_filter_changed() { + 1 => { + println!( + "filter still not changed {}ms after client closed", + dropped_at.elapsed().as_millis() + ); + tokio::time::sleep(Duration::from_millis(20)).await; + } + + 2 => { + break; + } + + other => panic!("unexpected times filter changed {other}"), + } + } + } } diff --git a/mirrord/agent/src/sniffer/api.rs b/mirrord/agent/src/sniffer/api.rs index 31ec4107f97..08874e93124 100644 --- a/mirrord/agent/src/sniffer/api.rs +++ b/mirrord/agent/src/sniffer/api.rs @@ -14,12 +14,17 @@ use tokio_stream::{ StreamMap, StreamNotifyClose, }; -use super::messages::{SniffedConnection, SnifferCommand, SnifferCommandInner}; +use super::{ + messages::{SniffedConnection, SnifferCommand, SnifferCommandInner}, + AgentResult, +}; use crate::{error::AgentError, util::ClientId, watched_task::TaskStatus}; /// Interface used by clients to interact with the /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). Multiple instances of this struct operate /// on a single sniffer instance. +/// +/// Enabled by the `mirror` feature for incoming traffic. pub(crate) struct TcpSnifferApi { /// Id of the client using this struct. client_id: ClientId, @@ -55,7 +60,7 @@ impl TcpSnifferApi { client_id: ClientId, sniffer_sender: Sender, mut task_status: TaskStatus, - ) -> Result { + ) -> AgentResult { let (sender, receiver) = mpsc::channel(Self::CONNECTION_CHANNEL_SIZE); let command = SnifferCommand { @@ -79,7 +84,7 @@ impl TcpSnifferApi { /// Send the given command to the connected /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). - async fn send_command(&mut self, command: SnifferCommandInner) -> Result<(), AgentError> { + async fn send_command(&mut self, command: SnifferCommandInner) -> AgentResult<()> { let command = SnifferCommand { client_id: self.client_id, command, @@ -94,7 +99,7 @@ impl TcpSnifferApi { /// Return the next message from the connected /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). - pub async fn recv(&mut self) -> Result<(DaemonTcp, Option), AgentError> { + pub async fn recv(&mut self) -> AgentResult<(DaemonTcp, Option)> { tokio::select! { conn = self.receiver.recv() => match conn { Some(conn) => { @@ -158,27 +163,26 @@ impl TcpSnifferApi { } } - /// Tansform the given message into a [`SnifferCommand`] and pass it to the connected + /// Tansforms a [`LayerTcp`] message into a [`SnifferCommand`] and passes it to the connected /// [`TcpConnectionSniffer`](super::TcpConnectionSniffer). - pub async fn handle_client_message(&mut self, message: LayerTcp) -> Result<(), AgentError> { + pub async fn handle_client_message(&mut self, message: LayerTcp) -> AgentResult<()> { match message { LayerTcp::PortSubscribe(port) => { let (tx, rx) = oneshot::channel(); self.send_command(SnifferCommandInner::Subscribe(port, tx)) .await?; self.subscriptions_in_progress.push(rx); - Ok(()) } LayerTcp::PortUnsubscribe(port) => { self.send_command(SnifferCommandInner::UnsubscribePort(port)) - .await + .await?; + Ok(()) } LayerTcp::ConnectionUnsubscribe(connection_id) => { self.connections.remove(&connection_id); - Ok(()) } } diff --git a/mirrord/agent/src/sniffer/tcp_capture.rs b/mirrord/agent/src/sniffer/tcp_capture.rs index 1d8031d08b3..dc8fb2bba04 100644 --- a/mirrord/agent/src/sniffer/tcp_capture.rs +++ b/mirrord/agent/src/sniffer/tcp_capture.rs @@ -12,7 +12,7 @@ use rawsocket::{filter::SocketFilterProgram, RawCapture}; use tokio::net::UdpSocket; use tracing::Level; -use super::{TcpPacketData, TcpSessionIdentifier}; +use super::{AgentResult, TcpPacketData, TcpSessionIdentifier}; use crate::error::AgentError; /// Trait for structs that are able to sniff incoming Ethernet packets and filter TCP packets. @@ -36,7 +36,7 @@ impl RawSocketTcpCapture { /// /// Returned instance initially uses a BPF filter that drops every packet. #[tracing::instrument(level = Level::DEBUG, err)] - pub async fn new(network_interface: Option, is_mesh: bool) -> Result { + pub async fn new(network_interface: Option, is_mesh: bool) -> AgentResult { // Priority is whatever the user set as an option to mirrord, then we check if we're in a // mesh to use `lo` interface, otherwise we try to get the appropriate interface. let interface = match network_interface.or_else(|| is_mesh.then(|| "lo".to_string())) { diff --git a/mirrord/agent/src/steal.rs b/mirrord/agent/src/steal.rs index 399c0597d4e..a425748a0d8 100644 --- a/mirrord/agent/src/steal.rs +++ b/mirrord/agent/src/steal.rs @@ -1,5 +1,5 @@ use mirrord_protocol::{ - tcp::{DaemonTcp, HttpResponseFallback, StealType, TcpData}, + tcp::{DaemonTcp, StealType, TcpData}, ConnectionId, Port, }; use tokio::sync::mpsc::Sender; @@ -17,6 +17,8 @@ mod subscriptions; pub(crate) use api::TcpStealerApi; pub(crate) use connection::TcpConnectionStealer; +use self::http::HttpResponseFallback; + /// Commands from the agent that are passed down to the stealer worker, through [`TcpStealerApi`]. /// /// These are the operations that the agent receives from the layer to make the _steal_ feature diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index a6ec1d8d1f7..2fc5733f8fa 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -1,24 +1,23 @@ -use std::collections::HashMap; +use std::{collections::HashMap, convert::Infallible}; use bytes::Bytes; use hyper::body::Frame; use mirrord_protocol::{ - tcp::{ - ChunkedResponse, DaemonTcp, HttpResponse, HttpResponseFallback, InternalHttpResponse, - LayerTcpSteal, ReceiverStreamBody, TcpData, - }, + tcp::{ChunkedResponse, DaemonTcp, HttpResponse, InternalHttpResponse, LayerTcpSteal, TcpData}, RequestId, }; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio_stream::wrappers::ReceiverStream; +use tracing::Level; -use super::*; +use super::{http::ReceiverStreamBody, *}; use crate::{ - error::{AgentError, Result}, - util::ClientId, + error::AgentResult, metrics::HTTP_REQUEST_IN_PROGRESS_COUNT, util::ClientId, watched_task::TaskStatus, }; +type ResponseBodyTx = Sender, Infallible>>; + /// Bridges the communication between the agent and the [`TcpConnectionStealer`] task. /// There is an API instance for each connected layer ("client"). All API instances send commands /// On the same stealer command channel, where the layer-independent stealer listens to them. @@ -40,20 +39,34 @@ pub(crate) struct TcpStealerApi { /// View on the stealer task's status. task_status: TaskStatus, - response_body_txs: HashMap<(ConnectionId, RequestId), Sender>>>, + /// [`Sender`]s that allow us to provide body [`Frame`]s of responses to filtered HTTP + /// requests. + /// + /// With [`LayerTcpSteal::HttpResponseChunked`], response bodies come from the client + /// in a series of [`ChunkedResponse::Body`] messages. + /// + /// Thus, we use [`ReceiverStreamBody`] for [`Response`](hyper::Response)'s body type and + /// pipe the [`Frame`]s through an [`mpsc::channel`]. + response_body_txs: HashMap<(ConnectionId, RequestId), ResponseBodyTx>, +} + +impl Drop for TcpStealerApi { + fn drop(&mut self) { + HTTP_REQUEST_IN_PROGRESS_COUNT.store(0, std::sync::atomic::Ordering::Relaxed); + } } impl TcpStealerApi { /// Initializes a [`TcpStealerApi`] and sends a message to [`TcpConnectionStealer`] signaling /// that we have a new client. - #[tracing::instrument(level = "trace")] + #[tracing::instrument(level = Level::TRACE, err)] pub(crate) async fn new( client_id: ClientId, command_tx: Sender, task_status: TaskStatus, channel_size: usize, protocol_version: semver::Version, - ) -> Result { + ) -> AgentResult { let (daemon_tx, daemon_rx) = mpsc::channel(channel_size); command_tx @@ -73,7 +86,7 @@ impl TcpStealerApi { } /// Send `command` to stealer, with the client id of the client that is using this API instance. - async fn send_command(&mut self, command: Command) -> Result<()> { + async fn send_command(&mut self, command: Command) -> AgentResult<()> { let command = StealerCommand { client_id: self.client_id, command, @@ -91,12 +104,16 @@ impl TcpStealerApi { /// /// Called in the `ClientConnectionHandler`. #[tracing::instrument(level = "trace", skip(self))] - pub(crate) async fn recv(&mut self) -> Result { + pub(crate) async fn recv(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => { if let DaemonTcp::Close(close) = &msg { self.response_body_txs .retain(|(key_id, _), _| *key_id != close.connection_id); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); } Ok(msg) } @@ -108,7 +125,7 @@ impl TcpStealerApi { /// agent, to an internal stealer command [`Command::PortSubscribe`]. /// /// The actual handling of this message is done in [`TcpConnectionStealer`]. - pub(crate) async fn port_subscribe(&mut self, port_steal: StealType) -> Result<(), AgentError> { + pub(crate) async fn port_subscribe(&mut self, port_steal: StealType) -> AgentResult<()> { self.send_command(Command::PortSubscribe(port_steal)).await } @@ -116,7 +133,7 @@ impl TcpStealerApi { /// agent, to an internal stealer command [`Command::PortUnsubscribe`]. /// /// The actual handling of this message is done in [`TcpConnectionStealer`]. - pub(crate) async fn port_unsubscribe(&mut self, port: Port) -> Result<(), AgentError> { + pub(crate) async fn port_unsubscribe(&mut self, port: Port) -> AgentResult<()> { self.send_command(Command::PortUnsubscribe(port)).await } @@ -127,7 +144,7 @@ impl TcpStealerApi { pub(crate) async fn connection_unsubscribe( &mut self, connection_id: ConnectionId, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { self.send_command(Command::ConnectionUnsubscribe(connection_id)) .await } @@ -136,7 +153,7 @@ impl TcpStealerApi { /// agent, to an internal stealer command [`Command::ResponseData`]. /// /// The actual handling of this message is done in [`TcpConnectionStealer`]. - pub(crate) async fn client_data(&mut self, tcp_data: TcpData) -> Result<(), AgentError> { + pub(crate) async fn client_data(&mut self, tcp_data: TcpData) -> AgentResult<()> { self.send_command(Command::ResponseData(tcp_data)).await } @@ -147,24 +164,32 @@ impl TcpStealerApi { pub(crate) async fn http_response( &mut self, response: HttpResponseFallback, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { self.send_command(Command::HttpResponse(response)).await } pub(crate) async fn switch_protocol_version( &mut self, version: semver::Version, - ) -> Result<(), AgentError> { + ) -> AgentResult<()> { self.send_command(Command::SwitchProtocolVersion(version)) .await } - pub(crate) async fn handle_client_message(&mut self, message: LayerTcpSteal) -> Result<()> { + pub(crate) async fn handle_client_message( + &mut self, + message: LayerTcpSteal, + ) -> AgentResult<()> { match message { LayerTcpSteal::PortSubscribe(port_steal) => self.port_subscribe(port_steal).await, LayerTcpSteal::ConnectionUnsubscribe(connection_id) => { self.response_body_txs .retain(|(key_id, _), _| *key_id != connection_id); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); + self.connection_unsubscribe(connection_id).await } LayerTcpSteal::PortUnsubscribe(port) => self.port_unsubscribe(port).await, @@ -195,13 +220,21 @@ impl TcpStealerApi { let key = (response.connection_id, response.request_id); self.response_body_txs.insert(key, tx.clone()); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); - self.http_response(HttpResponseFallback::Streamed(http_response, None)) + self.http_response(HttpResponseFallback::Streamed(http_response)) .await?; for frame in response.internal_response.body { if let Err(err) = tx.send(Ok(frame.into())).await { self.response_body_txs.remove(&key); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); tracing::trace!(?err, "error while sending streaming response frame"); } } @@ -224,12 +257,20 @@ impl TcpStealerApi { } if send_err || body.is_last { self.response_body_txs.remove(key); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); }; Ok(()) } ChunkedResponse::Error(err) => { self.response_body_txs .remove(&(err.connection_id, err.request_id)); + HTTP_REQUEST_IN_PROGRESS_COUNT.store( + self.response_body_txs.len() as i64, + std::sync::atomic::Ordering::Relaxed, + ); tracing::trace!(?err, "ChunkedResponse error received"); Ok(()) } diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 463c61f88d0..f6b4a9f2b7b 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -1,6 +1,6 @@ use std::{ collections::{HashMap, HashSet}, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, }; use fancy_regex::Regex; @@ -12,12 +12,11 @@ use hyper::{ http::{header::UPGRADE, request::Parts}, }; use mirrord_protocol::{ - body_chunks::{BodyExt as _, Frames}, + batched_body::{BatchedBody, Frames}, tcp::{ ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest, - HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, - StealType, TcpClose, TcpData, HTTP_CHUNKED_REQUEST_VERSION, HTTP_FILTERED_UPGRADE_VERSION, - HTTP_FRAMED_VERSION, + InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, StealType, TcpClose, TcpData, + HTTP_CHUNKED_REQUEST_VERSION, HTTP_FILTERED_UPGRADE_VERSION, HTTP_FRAMED_VERSION, }, ConnectionId, Port, RemoteError::{BadHttpFilterExRegex, BadHttpFilterRegex}, @@ -29,10 +28,12 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, }; use tokio_util::sync::CancellationToken; -use tracing::warn; +use tracing::{warn, Level}; +use super::{http::HttpResponseFallback, subscriptions::PortRedirector}; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, + metrics::HTTP_REQUEST_IN_PROGRESS_COUNT, steal::{ connections::{ ConnectionMessageIn, ConnectionMessageOut, StolenConnection, StolenConnections, @@ -55,6 +56,22 @@ struct MatchedHttpRequest { } impl MatchedHttpRequest { + fn new( + connection_id: ConnectionId, + port: Port, + request_id: RequestId, + request: Request, + ) -> Self { + HTTP_REQUEST_IN_PROGRESS_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + Self { + connection_id, + port, + request_id, + request, + } + } + async fn into_serializable(self) -> Result, hyper::Error> { let ( Parts { @@ -173,7 +190,7 @@ impl Client { }, mut body, ) = request.request.into_parts(); - match body.next_frames(true).await { + match body.ready_frames() { Err(..) => return, // We don't check is_last here since loop will finish when body.next_frames() // returns None @@ -181,7 +198,7 @@ impl Client { let frames = frames .into_iter() .map(InternalHttpBodyFrame::try_from) - .filter_map(Result::ok) + .filter_map(AgentResult::ok) .collect(); let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Start(HttpRequest { @@ -205,12 +222,12 @@ impl Client { } loop { - match body.next_frames(false).await { + match body.next_frames().await { Ok(Frames { frames, is_last }) => { let frames = frames .into_iter() .map(InternalHttpBodyFrame::try_from) - .filter_map(Result::ok) + .filter_map(AgentResult::ok) .collect(); let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Body( ChunkedHttpBody { @@ -273,9 +290,11 @@ struct TcpStealerConfig { /// Meant to be run (see [`TcpConnectionStealer::start`]) in a separate thread while the agent /// lives. When handling port subscription requests, this struct manipulates iptables, so it should /// run in the same network namespace as the agent's target. -pub(crate) struct TcpConnectionStealer { +/// +/// Enabled by the `steal` feature for incoming traffic. +pub(crate) struct TcpConnectionStealer { /// For managing active subscriptions and port redirections. - port_subscriptions: PortSubscriptions, + port_subscriptions: PortSubscriptions, /// For receiving commands. /// The other end of this channel belongs to [`TcpStealerApi`](super::api::TcpStealerApi). @@ -285,37 +304,62 @@ pub(crate) struct TcpConnectionStealer { clients: HashMap, /// [`Future`](std::future::Future)s that resolve when stealer clients close. - clients_closed: FuturesUnordered>, + clients_closed: FuturesUnordered, /// Set of active connections stolen by [`Self::port_subscriptions`]. connections: StolenConnections, + + /// Shen set, the stealer will use IPv6 if needed. + support_ipv6: bool, } -impl TcpConnectionStealer { +impl TcpConnectionStealer { pub const TASK_NAME: &'static str = "Stealer"; /// Initializes a new [`TcpConnectionStealer`], but doesn't start the actual work. /// You need to call [`TcpConnectionStealer::start`] to do so. - #[tracing::instrument(level = "trace")] - pub(crate) async fn new(command_rx: Receiver) -> Result { + #[tracing::instrument(level = Level::TRACE, err)] + pub(crate) async fn new( + command_rx: Receiver, + support_ipv6: bool, + ) -> AgentResult { let config = envy::prefixed("MIRRORD_AGENT_") .from_env::() .unwrap_or_default(); - let port_subscriptions = { - let redirector = - IpTablesRedirector::new(config.stealer_flush_connections, config.pod_ips).await?; + let redirector = IpTablesRedirector::new( + config.stealer_flush_connections, + config.pod_ips, + support_ipv6, + ) + .await?; - PortSubscriptions::new(redirector, 4) - }; + Ok(Self::with_redirector(command_rx, support_ipv6, redirector)) + } +} - Ok(Self { - port_subscriptions, +impl TcpConnectionStealer +where + Redirector: PortRedirector, + Redirector::Error: std::error::Error + Into, + AgentError: From, +{ + /// Creates a new stealer. + /// + /// Given [`PortRedirector`] will be used to capture incoming connections. + pub(crate) fn with_redirector( + command_rx: Receiver, + support_ipv6: bool, + redirector: Redirector, + ) -> Self { + Self { + port_subscriptions: PortSubscriptions::new(redirector, 4), command_rx, clients: HashMap::with_capacity(8), clients_closed: Default::default(), connections: StolenConnections::with_capacity(8), - }) + support_ipv6, + } } /// Runs the tcp traffic stealer loop. @@ -330,10 +374,7 @@ impl TcpConnectionStealer { /// /// 4. Handling the cancellation of the whole stealer thread (given `cancellation_token`). #[tracing::instrument(level = "trace", skip(self))] - pub(crate) async fn start( - mut self, - cancellation_token: CancellationToken, - ) -> Result<(), AgentError> { + pub(crate) async fn start(mut self, cancellation_token: CancellationToken) -> AgentResult<()> { loop { tokio::select! { command = self.command_rx.recv() => { @@ -351,10 +392,12 @@ impl TcpConnectionStealer { }, accept = self.port_subscriptions.next_connection() => match accept { - Ok((stream, peer)) => self.incoming_connection(stream, peer).await?, + Ok((stream, peer)) => { + self.incoming_connection(stream, peer).await?; + } Err(error) => { tracing::error!(?error, "Failed to accept a stolen connection"); - break Err(error); + break Err(error.into()); } }, @@ -369,11 +412,20 @@ impl TcpConnectionStealer { /// Handles a new remote connection that was stolen by [`Self::port_subscriptions`]. #[tracing::instrument(level = "trace", skip(self))] - async fn incoming_connection(&mut self, stream: TcpStream, peer: SocketAddr) -> Result<()> { + async fn incoming_connection( + &mut self, + stream: TcpStream, + peer: SocketAddr, + ) -> AgentResult<()> { let mut real_address = orig_dst::orig_dst_addr(&stream)?; + let localhost = if self.support_ipv6 && real_address.is_ipv6() { + IpAddr::V6(Ipv6Addr::LOCALHOST) + } else { + IpAddr::V4(Ipv4Addr::LOCALHOST) + }; // If we use the original IP we would go through prerouting and hit a loop. // localhost should always work. - real_address.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); + real_address.set_ip(localhost); let Some(port_subscription) = self.port_subscriptions.get(real_address.port()).cloned() else { @@ -400,10 +452,7 @@ impl TcpConnectionStealer { /// Handles an update from one of the connections in [`Self::connections`]. #[tracing::instrument(level = "trace", skip(self))] - async fn handle_connection_update( - &mut self, - update: ConnectionMessageOut, - ) -> Result<(), AgentError> { + async fn handle_connection_update(&mut self, update: ConnectionMessageOut) -> AgentResult<()> { match update { ConnectionMessageOut::Closed { connection_id, @@ -510,12 +559,7 @@ impl TcpConnectionStealer { return Ok(()); } - let matched_request = MatchedHttpRequest { - connection_id, - request, - request_id: id, - port, - }; + let matched_request = MatchedHttpRequest::new(connection_id, port, id, request); if !client.send_request_async(matched_request) { self.connections @@ -534,11 +578,18 @@ impl TcpConnectionStealer { Ok(()) } - /// Helper function to handle [`Command::PortSubscribe`] messages. + /// Helper function to handle [`Command::PortSubscribe`] messages for the `TcpStealer`. /// - /// Inserts a subscription into [`Self::port_subscriptions`]. - #[tracing::instrument(level = "trace", skip(self))] - async fn port_subscribe(&mut self, client_id: ClientId, port_steal: StealType) -> Result<()> { + /// Checks if [`StealType`] is a valid [`HttpFilter`], then inserts a subscription into + /// [`Self::port_subscriptions`]. + /// + /// - Returns: `true` if this is an HTTP filtered subscription. + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + async fn port_subscribe( + &mut self, + client_id: ClientId, + port_steal: StealType, + ) -> AgentResult { let spec = match port_steal { StealType::All(port) => Ok((port, None)), StealType::FilteredHttp(port, filter) => Regex::new(&format!("(?i){filter}")) @@ -549,6 +600,11 @@ impl TcpConnectionStealer { .map_err(|err| BadHttpFilterExRegex(filter, err.to_string())), }; + let filtered = spec + .as_ref() + .map(|(_, filter)| filter.is_some()) + .unwrap_or_default(); + let res = match spec { Ok((port, filter)) => self.port_subscriptions.add(client_id, port, filter).await?, Err(e) => Err(e.into()), @@ -557,18 +613,18 @@ impl TcpConnectionStealer { let client = self.clients.get(&client_id).expect("client not found"); let _ = client.tx.send(DaemonTcp::SubscribeResult(res)).await; - Ok(()) + Ok(filtered) } /// Removes the client with `client_id` from our list of clients (layers), and also removes /// their subscriptions from [`Self::port_subscriptions`] and all their open /// connections. #[tracing::instrument(level = "trace", skip(self))] - async fn close_client(&mut self, client_id: ClientId) -> Result<(), AgentError> { + async fn close_client(&mut self, client_id: ClientId) -> AgentResult<()> { self.port_subscriptions.remove_all(client_id).await?; let client = self.clients.remove(&client_id).expect("client not found"); - for connection in client.subscribed_connections.into_iter() { + for connection in client.subscribed_connections { self.connections .send(connection, ConnectionMessageIn::Unsubscribed { client_id }) .await; @@ -583,49 +639,27 @@ impl TcpConnectionStealer { async fn send_http_response(&mut self, client_id: ClientId, response: HttpResponseFallback) { let connection_id = response.connection_id(); let request_id = response.request_id(); - - match response.into_hyper::() { - Ok(response) => { - self.connections - .send( - connection_id, - ConnectionMessageIn::Response { - client_id, - request_id, - response, - }, - ) - .await; - } - Err(error) => { - tracing::warn!( - ?error, - connection_id, - request_id, + self.connections + .send( + connection_id, + ConnectionMessageIn::Response { client_id, - "Failed to transform client message into a hyper response", - ); - - self.connections - .send( - connection_id, - ConnectionMessageIn::ResponseFailed { - client_id, - request_id, - }, - ) - .await; - } - } + request_id, + response: response.into_hyper::(), + }, + ) + .await; } /// Handles [`Command`]s that were received by [`TcpConnectionStealer::command_rx`]. - #[tracing::instrument(level = "trace", skip(self))] - async fn handle_command(&mut self, command: StealerCommand) -> Result<(), AgentError> { + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + async fn handle_command(&mut self, command: StealerCommand) -> AgentResult<()> { let StealerCommand { client_id, command } = command; match command { Command::NewClient(daemon_tx, protocol_version) => { + self.clients_closed + .push(ChannelClosedFuture::new(daemon_tx.clone(), client_id)); self.clients.insert( client_id, Client { @@ -652,7 +686,7 @@ impl TcpConnectionStealer { } Command::PortSubscribe(port_steal) => { - self.port_subscribe(client_id, port_steal).await? + self.port_subscribe(client_id, port_steal).await?; } Command::PortUnsubscribe(port) => { @@ -690,7 +724,7 @@ impl TcpConnectionStealer { #[cfg(test)] mod test { - use std::net::SocketAddr; + use std::{net::SocketAddr, time::Duration}; use bytes::Bytes; use futures::{future::BoxFuture, FutureExt}; @@ -701,18 +735,75 @@ mod test { service::Service, }; use hyper_util::rt::TokioIo; - use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame}; + use mirrord_protocol::{ + tcp::{ChunkedRequest, DaemonTcp, Filter, HttpFilter, InternalHttpBodyFrame, StealType}, + Port, + }; use rstest::rstest; use tokio::{ net::{TcpListener, TcpStream}, sync::{ mpsc::{self, Receiver, Sender}, - oneshot, + oneshot, watch, }, }; use tokio_stream::wrappers::ReceiverStream; + use tokio_util::sync::CancellationToken; + + use super::AgentError; + use crate::{ + steal::{ + connection::{Client, MatchedHttpRequest}, + subscriptions::PortRedirector, + TcpConnectionStealer, TcpStealerApi, + }, + watched_task::TaskStatus, + }; + + /// Notification about a requested redirection operation. + /// + /// Produced by [`NotifyingRedirector`]. + #[derive(Debug, PartialEq, Eq)] + enum RedirectNotification { + Added(Port), + Removed(Port), + Cleanup, + } + + /// Test [`PortRedirector`] that never fails and notifies about requested operations using an + /// [`mpsc::channel`]. + struct NotifyingRedirector(Sender); + + #[async_trait::async_trait] + impl PortRedirector for NotifyingRedirector { + type Error = AgentError; + + async fn add_redirection(&mut self, port: Port) -> Result<(), Self::Error> { + self.0 + .send(RedirectNotification::Added(port)) + .await + .unwrap(); + Ok(()) + } + + async fn remove_redirection(&mut self, port: Port) -> Result<(), Self::Error> { + self.0 + .send(RedirectNotification::Removed(port)) + .await + .unwrap(); + Ok(()) + } + + async fn cleanup(&mut self) -> Result<(), Self::Error> { + self.0.send(RedirectNotification::Cleanup).await.unwrap(); + Ok(()) + } + + async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error> { + std::future::pending().await + } + } - use crate::steal::connection::{Client, MatchedHttpRequest}; async fn prepare_dummy_service() -> ( SocketAddr, Receiver<(Request, oneshot::Sender>>)>, @@ -889,4 +980,51 @@ mod test { let _ = response_tx.send(Response::new(Empty::default())); } + + /// Verifies that [`TcpConnectionStealer`] removes client's port subscriptions + /// when client's [`TcpStealerApi`] is dropped. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn cleanup_on_client_closed() { + let (command_tx, command_rx) = mpsc::channel(8); + let (redirect_tx, mut redirect_rx) = mpsc::channel(2); + let stealer = TcpConnectionStealer::with_redirector( + command_rx, + false, + NotifyingRedirector(redirect_tx), + ); + + tokio::spawn(stealer.start(CancellationToken::new())); + + let (_dummy_tx, dummy_rx) = watch::channel(None); + let task_status = TaskStatus::dummy(TcpConnectionStealer::TASK_NAME, dummy_rx); + let mut api = TcpStealerApi::new( + 0, + command_tx.clone(), + task_status, + 8, + mirrord_protocol::VERSION.clone(), + ) + .await + .unwrap(); + + api.port_subscribe(StealType::FilteredHttpEx( + 80, + HttpFilter::Header(Filter::new("user: test".into()).unwrap()), + )) + .await + .unwrap(); + + let response = api.recv().await.unwrap(); + assert_eq!(response, DaemonTcp::SubscribeResult(Ok(80))); + + let notification = redirect_rx.recv().await.unwrap(); + assert_eq!(notification, RedirectNotification::Added(80)); + + std::mem::drop(api); + + let notification = redirect_rx.recv().await.unwrap(); + assert_eq!(notification, RedirectNotification::Removed(80)); + } } diff --git a/mirrord/agent/src/steal/connections.rs b/mirrord/agent/src/steal/connections.rs index bc47eb80e4b..8e969b0a868 100644 --- a/mirrord/agent/src/steal/connections.rs +++ b/mirrord/agent/src/steal/connections.rs @@ -11,10 +11,14 @@ use tokio::{ sync::mpsc::{self, error::SendError, Receiver, Sender}, task::JoinSet, }; +use tracing::Level; use self::{filtered::DynamicBody, unfiltered::UnfilteredStealTask}; use super::{http::DefaultReversibleStream, subscriptions::PortSubscription}; -use crate::{http::HttpVersion, steal::connections::filtered::FilteredStealTask, util::ClientId}; +use crate::{ + http::HttpVersion, metrics::STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION, + steal::connections::filtered::FilteredStealTask, util::ClientId, +}; mod filtered; mod unfiltered; @@ -287,7 +291,7 @@ impl StolenConnections { /// Adds the given [`StolenConnection`] to this set. Spawns a new [`tokio::task`] that will /// manage it. - #[tracing::instrument(level = "trace", name = "manage_stolen_connection", skip(self))] + #[tracing::instrument(level = Level::TRACE, name = "manage_stolen_connection", skip(self))] pub fn manage(&mut self, connection: StolenConnection) { let connection_id = self.next_connection_id; self.next_connection_id += 1; @@ -458,13 +462,9 @@ impl ConnectionTask { }) .await?; - let task = UnfilteredStealTask { - connection_id: self.connection_id, - client_id, - stream: self.connection.stream, - }; - - task.run(self.tx, &mut self.rx).await + UnfilteredStealTask::new(self.connection_id, client_id, self.connection.stream) + .run(self.tx, &mut self.rx) + .await } PortSubscription::Filtered(filters) => { diff --git a/mirrord/agent/src/steal/connections/filtered.rs b/mirrord/agent/src/steal/connections/filtered.rs index b30e48a5757..ecc9f0064ad 100644 --- a/mirrord/agent/src/steal/connections/filtered.rs +++ b/mirrord/agent/src/steal/connections/filtered.rs @@ -1,5 +1,6 @@ use std::{ - collections::HashMap, future::Future, marker::PhantomData, net::SocketAddr, pin::Pin, sync::Arc, + collections::HashMap, future::Future, marker::PhantomData, net::SocketAddr, ops::Not, pin::Pin, + sync::Arc, }; use bytes::Bytes; @@ -28,9 +29,13 @@ use tokio::{ use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::Level; -use super::{ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError}; +use super::{ + ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError, + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION, +}; use crate::{ http::HttpVersion, + metrics::STEAL_FILTERED_CONNECTION_SUBSCRIPTION, steal::{connections::unfiltered::UnfilteredStealTask, http::HttpFilter}, util::ClientId, }; @@ -368,6 +373,18 @@ pub struct FilteredStealTask { /// For safely downcasting the IO stream after an HTTP upgrade. See [`Upgraded::downcast`]. _io_type: PhantomData T>, + + /// Helps us figuring out if we should update some metrics in the `Drop` implementation. + metrics_updated: bool, +} + +impl Drop for FilteredStealTask { + fn drop(&mut self) { + if self.metrics_updated.not() { + STEAL_FILTERED_CONNECTION_SUBSCRIPTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + } } impl FilteredStealTask @@ -443,6 +460,8 @@ where } }; + STEAL_FILTERED_CONNECTION_SUBSCRIPTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Self { connection_id, original_destination, @@ -453,6 +472,7 @@ where blocked_requests: Default::default(), next_request_id: Default::default(), _io_type: Default::default(), + metrics_updated: false, } } @@ -638,6 +658,8 @@ where queued_raw_data.remove(&client_id); self.subscribed.insert(client_id, false); self.blocked_requests.retain(|key, _| key.0 != client_id); + + STEAL_FILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); }, }, @@ -646,7 +668,10 @@ where // No more requests from the `FilteringService`. // HTTP connection is closed and possibly upgraded. - None => break, + None => { + STEAL_FILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + break + } } } } @@ -788,15 +813,18 @@ where ) -> Result<(), ConnectionTaskError> { let res = self.run_until_http_ends(tx.clone(), rx).await; + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.metrics_updated = true; + let res = match res { Ok(data) => self.run_after_http_ends(data, tx.clone(), rx).await, Err(e) => Err(e), }; - for (client_id, subscribed) in self.subscribed { - if subscribed { + for (client_id, subscribed) in self.subscribed.iter() { + if *subscribed { tx.send(ConnectionMessageOut::Closed { - client_id, + client_id: *client_id, connection_id: self.connection_id, }) .await?; diff --git a/mirrord/agent/src/steal/connections/unfiltered.rs b/mirrord/agent/src/steal/connections/unfiltered.rs index 5b6676094c3..ec54691315e 100644 --- a/mirrord/agent/src/steal/connections/unfiltered.rs +++ b/mirrord/agent/src/steal/connections/unfiltered.rs @@ -7,7 +7,10 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, }; -use super::{ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError}; +use super::{ + ConnectionMessageIn, ConnectionMessageOut, ConnectionTaskError, + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION, +}; use crate::util::ClientId; /// Manages an unfiltered stolen connection. @@ -19,7 +22,23 @@ pub struct UnfilteredStealTask { pub stream: T, } +impl Drop for UnfilteredStealTask { + fn drop(&mut self) { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } +} + impl UnfilteredStealTask { + pub(crate) fn new(connection_id: ConnectionId, client_id: ClientId, stream: T) -> Self { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + Self { + connection_id, + client_id, + stream, + } + } + /// Runs this task until the managed connection is closed. /// /// # Note @@ -40,6 +59,8 @@ impl UnfilteredStealTask { read = self.stream.read_buf(&mut buf), if !reading_closed => match read { Ok(..) => { if buf.is_empty() { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tracing::trace!( client_id = self.client_id, connection_id = self.connection_id, @@ -63,6 +84,8 @@ impl UnfilteredStealTask { Err(e) if e.kind() == ErrorKind::WouldBlock => {} Err(e) => { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tx.send(ConnectionMessageOut::Closed { client_id: self.client_id, connection_id: self.connection_id @@ -85,6 +108,8 @@ impl UnfilteredStealTask { ConnectionMessageIn::Raw { data, .. } => { let res = if data.is_empty() { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tracing::trace!( client_id = self.client_id, connection_id = self.connection_id, @@ -97,6 +122,8 @@ impl UnfilteredStealTask { }; if let Err(e) = res { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + tx.send(ConnectionMessageOut::Closed { client_id: self.client_id, connection_id: self.connection_id @@ -115,6 +142,8 @@ impl UnfilteredStealTask { }, ConnectionMessageIn::Unsubscribed { .. } => { + STEAL_UNFILTERED_CONNECTION_SUBSCRIPTION.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + return Ok(()); } } diff --git a/mirrord/agent/src/steal/http.rs b/mirrord/agent/src/steal/http.rs index 159d9c9aac8..cad0308bc96 100644 --- a/mirrord/agent/src/steal/http.rs +++ b/mirrord/agent/src/steal/http.rs @@ -3,11 +3,12 @@ use crate::http::HttpVersion; mod filter; +mod response_fallback; mod reversible_stream; -pub use filter::HttpFilter; - -pub(crate) use self::reversible_stream::ReversibleStream; +pub(crate) use filter::HttpFilter; +pub(crate) use response_fallback::{HttpResponseFallback, ReceiverStreamBody}; +pub(crate) use reversible_stream::ReversibleStream; /// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches. pub(crate) type DefaultReversibleStream = ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }>; diff --git a/mirrord/agent/src/steal/http/response_fallback.rs b/mirrord/agent/src/steal/http/response_fallback.rs new file mode 100644 index 00000000000..2124ec41a57 --- /dev/null +++ b/mirrord/agent/src/steal/http/response_fallback.rs @@ -0,0 +1,58 @@ +use std::convert::Infallible; + +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; +use hyper::{body::Frame, Response}; +use mirrord_protocol::{ + tcp::{HttpResponse, InternalHttpBody}, + ConnectionId, RequestId, +}; +use tokio_stream::wrappers::ReceiverStream; + +pub type ReceiverStreamBody = StreamBody, Infallible>>>; + +#[derive(Debug)] +pub enum HttpResponseFallback { + Framed(HttpResponse), + Fallback(HttpResponse>), + Streamed(HttpResponse), +} + +impl HttpResponseFallback { + pub fn connection_id(&self) -> ConnectionId { + match self { + HttpResponseFallback::Framed(req) => req.connection_id, + HttpResponseFallback::Fallback(req) => req.connection_id, + HttpResponseFallback::Streamed(req) => req.connection_id, + } + } + + pub fn request_id(&self) -> RequestId { + match self { + HttpResponseFallback::Framed(req) => req.request_id, + HttpResponseFallback::Fallback(req) => req.request_id, + HttpResponseFallback::Streamed(req) => req.request_id, + } + } + + pub fn into_hyper(self) -> Response> { + match self { + HttpResponseFallback::Framed(req) => req + .internal_response + .map_body(|body| body.map_err(|_| unreachable!()).boxed()) + .into(), + HttpResponseFallback::Fallback(req) => req + .internal_response + .map_body(|body| { + Full::new(Bytes::from_owner(body)) + .map_err(|_| unreachable!()) + .boxed() + }) + .into(), + HttpResponseFallback::Streamed(req) => req + .internal_response + .map_body(|body| body.map_err(|_| unreachable!()).boxed()) + .into(), + } + } +} diff --git a/mirrord/agent/src/steal/ip_tables.rs b/mirrord/agent/src/steal/ip_tables.rs index 5583b485000..9c175220767 100644 --- a/mirrord/agent/src/steal/ip_tables.rs +++ b/mirrord/agent/src/steal/ip_tables.rs @@ -9,7 +9,7 @@ use rand::distributions::{Alphanumeric, DistString}; use tracing::warn; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, steal::ip_tables::{ flush_connections::FlushConnections, mesh::{istio::AmbientRedirect, MeshRedirect, MeshVendorExt}, @@ -84,13 +84,13 @@ pub(crate) trait IPTables { where Self: Sized; - fn create_chain(&self, name: &str) -> Result<()>; - fn remove_chain(&self, name: &str) -> Result<()>; + fn create_chain(&self, name: &str) -> AgentResult<()>; + fn remove_chain(&self, name: &str) -> AgentResult<()>; - fn add_rule(&self, chain: &str, rule: &str) -> Result<()>; - fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> Result<()>; - fn list_rules(&self, chain: &str) -> Result>; - fn remove_rule(&self, chain: &str, rule: &str) -> Result<()>; + fn add_rule(&self, chain: &str, rule: &str) -> AgentResult<()>; + fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> AgentResult<()>; + fn list_rules(&self, chain: &str) -> AgentResult>; + fn remove_rule(&self, chain: &str, rule: &str) -> AgentResult<()>; } #[derive(Clone)] @@ -111,6 +111,18 @@ pub fn new_iptables() -> iptables::IPTables { .expect("IPTables initialization may not fail!") } +/// wrapper around iptables::new that uses nft or legacy based on env +pub fn new_ip6tables() -> iptables::IPTables { + if let Ok(val) = std::env::var("MIRRORD_AGENT_NFTABLES") + && val.to_lowercase() == "true" + { + iptables::new_with_cmd("/usr/sbin/ip6tables-nft") + } else { + iptables::new_with_cmd("/usr/sbin/ip6tables-legacy") + } + .expect("IPTables initialization may not fail!") +} + impl Debug for IPTablesWrapper { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("IPTablesWrapper") @@ -140,8 +152,13 @@ impl IPTables for IPTablesWrapper { } } - #[tracing::instrument(level = "trace")] - fn create_chain(&self, name: &str) -> Result<()> { + #[tracing::instrument( + level = tracing::Level::TRACE, + skip(self), + ret, + fields(table_name=%self.table_name + ))] + fn create_chain(&self, name: &str) -> AgentResult<()> { self.tables .new_chain(self.table_name, name) .map_err(|e| AgentError::IPTablesError(e.to_string()))?; @@ -153,7 +170,7 @@ impl IPTables for IPTablesWrapper { } #[tracing::instrument(level = "trace")] - fn remove_chain(&self, name: &str) -> Result<()> { + fn remove_chain(&self, name: &str) -> AgentResult<()> { self.tables .flush_chain(self.table_name, name) .map_err(|e| AgentError::IPTablesError(e.to_string()))?; @@ -165,28 +182,28 @@ impl IPTables for IPTablesWrapper { } #[tracing::instrument(level = "trace", ret)] - fn add_rule(&self, chain: &str, rule: &str) -> Result<()> { + fn add_rule(&self, chain: &str, rule: &str) -> AgentResult<()> { self.tables .append(self.table_name, chain, rule) .map_err(|e| AgentError::IPTablesError(e.to_string())) } #[tracing::instrument(level = "trace", ret)] - fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> Result<()> { + fn insert_rule(&self, chain: &str, rule: &str, index: i32) -> AgentResult<()> { self.tables .insert(self.table_name, chain, rule, index) .map_err(|e| AgentError::IPTablesError(e.to_string())) } #[tracing::instrument(level = "trace")] - fn list_rules(&self, chain: &str) -> Result> { + fn list_rules(&self, chain: &str) -> AgentResult> { self.tables .list(self.table_name, chain) .map_err(|e| AgentError::IPTablesError(e.to_string())) } #[tracing::instrument(level = "trace")] - fn remove_rule(&self, chain: &str, rule: &str) -> Result<()> { + fn remove_rule(&self, chain: &str, rule: &str) -> AgentResult<()> { self.tables .delete(self.table_name, chain, rule) .map_err(|e| AgentError::IPTablesError(e.to_string())) @@ -220,7 +237,8 @@ where ipt: IPT, flush_connections: bool, pod_ips: Option<&str>, - ) -> Result { + ipv6: bool, + ) -> AgentResult { let ipt = Arc::new(ipt); let mut redirect = if let Some(vendor) = MeshVendor::detect(ipt.as_ref())? { @@ -231,6 +249,7 @@ where _ => Redirects::Mesh(MeshRedirect::create(ipt.clone(), vendor, pod_ips)?), } } else { + tracing::trace!(ipv6 = ipv6, "creating standard redirect"); match StandardRedirect::create(ipt.clone(), pod_ips) { Err(err) => { warn!("Unable to create StandardRedirect chain: {err}"); @@ -251,7 +270,7 @@ where Ok(Self { redirect }) } - pub(crate) async fn load(ipt: IPT, flush_connections: bool) -> Result { + pub(crate) async fn load(ipt: IPT, flush_connections: bool) -> AgentResult { let ipt = Arc::new(ipt); let mut redirect = if let Some(vendor) = MeshVendor::detect(ipt.as_ref())? { @@ -280,12 +299,12 @@ where /// Adds the redirect rule to iptables. /// /// Used to redirect packets when mirrord incoming feature is set to `steal`. - #[tracing::instrument(level = "trace", skip(self))] + #[tracing::instrument(level = tracing::Level::DEBUG, skip(self))] pub(super) async fn add_redirect( &self, redirected_port: Port, target_port: Port, - ) -> Result<()> { + ) -> AgentResult<()> { self.redirect .add_redirect(redirected_port, target_port) .await @@ -300,13 +319,13 @@ where &self, redirected_port: Port, target_port: Port, - ) -> Result<()> { + ) -> AgentResult<()> { self.redirect .remove_redirect(redirected_port, target_port) .await } - pub(crate) async fn cleanup(&self) -> Result<()> { + pub(crate) async fn cleanup(&self) -> AgentResult<()> { self.redirect.unmount_entrypoint().await } } @@ -408,7 +427,7 @@ mod tests { .times(1) .returning(|_| Ok(())); - let ipt = SafeIpTables::create(mock, false, None) + let ipt = SafeIpTables::create(mock, false, None, false) .await .expect("Create Failed"); @@ -541,7 +560,7 @@ mod tests { .times(1) .returning(|_| Ok(())); - let ipt = SafeIpTables::create(mock, false, None) + let ipt = SafeIpTables::create(mock, false, None, false) .await .expect("Create Failed"); diff --git a/mirrord/agent/src/steal/ip_tables/chain.rs b/mirrord/agent/src/steal/ip_tables/chain.rs index c5bc6d65404..c1c34715c85 100644 --- a/mirrord/agent/src/steal/ip_tables/chain.rs +++ b/mirrord/agent/src/steal/ip_tables/chain.rs @@ -4,7 +4,7 @@ use std::sync::{ }; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, steal::ip_tables::IPTables, }; @@ -19,7 +19,7 @@ impl IPTableChain where IPT: IPTables, { - pub fn create(inner: Arc, chain_name: String) -> Result { + pub fn create(inner: Arc, chain_name: String) -> AgentResult { inner.create_chain(&chain_name)?; // Start with 1 because the chain will allways have atleast `-A ` as a rule @@ -32,7 +32,7 @@ where }) } - pub fn load(inner: Arc, chain_name: String) -> Result { + pub fn load(inner: Arc, chain_name: String) -> AgentResult { let existing_rules = inner.list_rules(&chain_name)?.len(); if existing_rules == 0 { @@ -59,7 +59,7 @@ where &self.inner } - pub fn add_rule(&self, rule: &str) -> Result { + pub fn add_rule(&self, rule: &str) -> AgentResult { self.inner .insert_rule( &self.chain_name, @@ -72,7 +72,7 @@ where }) } - pub fn remove_rule(&self, rule: &str) -> Result<()> { + pub fn remove_rule(&self, rule: &str) -> AgentResult<()> { self.inner.remove_rule(&self.chain_name, rule)?; self.chain_size.fetch_sub(1, Ordering::Relaxed); diff --git a/mirrord/agent/src/steal/ip_tables/flush_connections.rs b/mirrord/agent/src/steal/ip_tables/flush_connections.rs index 6675a40651f..c0f19c20b8d 100644 --- a/mirrord/agent/src/steal/ip_tables/flush_connections.rs +++ b/mirrord/agent/src/steal/ip_tables/flush_connections.rs @@ -13,7 +13,7 @@ use tokio::process::Command; use tracing::warn; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{chain::IPTableChain, redirect::Redirect, IPTables, IPTABLE_INPUT}, }; @@ -33,7 +33,7 @@ where const ENTRYPOINT: &'static str = "INPUT"; #[tracing::instrument(level = "trace", skip(ipt, inner))] - pub fn create(ipt: Arc, inner: Box) -> Result { + pub fn create(ipt: Arc, inner: Box) -> AgentResult { let managed = IPTableChain::create(ipt.with_table("filter").into(), IPTABLE_INPUT.to_string())?; @@ -48,7 +48,7 @@ where } #[tracing::instrument(level = "trace", skip(ipt, inner))] - pub fn load(ipt: Arc, inner: Box) -> Result { + pub fn load(ipt: Arc, inner: Box) -> AgentResult { let managed = IPTableChain::load(ipt.with_table("filter").into(), IPTABLE_INPUT.to_string())?; @@ -63,7 +63,7 @@ where T: Redirect + Send + Sync, { #[tracing::instrument(level = "trace", skip(self), ret)] - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.inner.mount_entrypoint().await?; self.managed.inner().add_rule( @@ -75,7 +75,7 @@ where } #[tracing::instrument(level = "trace", skip(self), ret)] - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.inner.unmount_entrypoint().await?; self.managed.inner().remove_rule( @@ -87,7 +87,7 @@ where } #[tracing::instrument(level = "trace", skip(self), ret)] - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.inner .add_redirect(redirected_port, target_port) .await?; @@ -115,7 +115,7 @@ where } #[tracing::instrument(level = "trace", skip(self), ret)] - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.inner .remove_redirect(redirected_port, target_port) .await?; diff --git a/mirrord/agent/src/steal/ip_tables/mesh.rs b/mirrord/agent/src/steal/ip_tables/mesh.rs index 88fdff5d0b1..1a3e5acbe62 100644 --- a/mirrord/agent/src/steal/ip_tables/mesh.rs +++ b/mirrord/agent/src/steal/ip_tables/mesh.rs @@ -5,7 +5,7 @@ use fancy_regex::Regex; use mirrord_protocol::{MeshVendor, Port}; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{ output::OutputRedirect, prerouting::PreroutingRedirect, redirect::Redirect, IPTables, IPTABLE_MESH, @@ -29,7 +29,7 @@ impl MeshRedirect where IPT: IPTables, { - pub fn create(ipt: Arc, vendor: MeshVendor, pod_ips: Option<&str>) -> Result { + pub fn create(ipt: Arc, vendor: MeshVendor, pod_ips: Option<&str>) -> AgentResult { let prerouting = PreroutingRedirect::create(ipt.clone())?; for port in Self::get_skip_ports(&ipt, &vendor)? { @@ -45,7 +45,7 @@ where }) } - pub fn load(ipt: Arc, vendor: MeshVendor) -> Result { + pub fn load(ipt: Arc, vendor: MeshVendor) -> AgentResult { let prerouting = PreroutingRedirect::load(ipt.clone())?; let output = OutputRedirect::load(ipt, IPTABLE_MESH.to_string())?; @@ -56,7 +56,7 @@ where }) } - fn get_skip_ports(ipt: &IPT, vendor: &MeshVendor) -> Result> { + fn get_skip_ports(ipt: &IPT, vendor: &MeshVendor) -> AgentResult> { let chain_name = vendor.input_chain(); let lookup_regex = if let Some(regex) = vendor.skip_ports_regex() { regex @@ -86,21 +86,21 @@ impl Redirect for MeshRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.prerouting.mount_entrypoint().await?; self.output.mount_entrypoint().await?; Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.prerouting.unmount_entrypoint().await?; self.output.unmount_entrypoint().await?; Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { if self.vendor != MeshVendor::IstioCni { self.prerouting .add_redirect(redirected_port, target_port) @@ -113,7 +113,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { if self.vendor != MeshVendor::IstioCni { self.prerouting .remove_redirect(redirected_port, target_port) @@ -129,13 +129,13 @@ where /// Extends the [`MeshVendor`] type with methods that are only relevant for the agent. pub(super) trait MeshVendorExt: Sized { - fn detect(ipt: &IPT) -> Result>; + fn detect(ipt: &IPT) -> AgentResult>; fn input_chain(&self) -> &str; fn skip_ports_regex(&self) -> Option<&Regex>; } impl MeshVendorExt for MeshVendor { - fn detect(ipt: &IPT) -> Result> { + fn detect(ipt: &IPT) -> AgentResult> { if let Ok(val) = std::env::var("MIRRORD_AGENT_ISTIO_CNI") && val.to_lowercase() == "true" { diff --git a/mirrord/agent/src/steal/ip_tables/mesh/istio.rs b/mirrord/agent/src/steal/ip_tables/mesh/istio.rs index cd3d4b06fa9..01e513a6bf9 100644 --- a/mirrord/agent/src/steal/ip_tables/mesh/istio.rs +++ b/mirrord/agent/src/steal/ip_tables/mesh/istio.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use mirrord_protocol::Port; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{ output::OutputRedirect, prerouting::PreroutingRedirect, redirect::Redirect, IPTables, IPTABLE_IPV4_ROUTE_LOCALNET_ORIGINAL, IPTABLE_MESH, @@ -20,14 +20,14 @@ impl AmbientRedirect where IPT: IPTables, { - pub fn create(ipt: Arc, pod_ips: Option<&str>) -> Result { + pub fn create(ipt: Arc, pod_ips: Option<&str>) -> AgentResult { let prerouting = PreroutingRedirect::create(ipt.clone())?; let output = OutputRedirect::create(ipt, IPTABLE_MESH.to_string(), pod_ips)?; Ok(AmbientRedirect { prerouting, output }) } - pub fn load(ipt: Arc) -> Result { + pub fn load(ipt: Arc) -> AgentResult { let prerouting = PreroutingRedirect::load(ipt.clone())?; let output = OutputRedirect::load(ipt, IPTABLE_MESH.to_string())?; @@ -40,7 +40,7 @@ impl Redirect for AmbientRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { tokio::fs::write("/proc/sys/net/ipv4/conf/all/route_localnet", "1".as_bytes()).await?; self.prerouting.mount_entrypoint().await?; @@ -49,7 +49,7 @@ where Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.prerouting.unmount_entrypoint().await?; self.output.unmount_entrypoint().await?; @@ -62,7 +62,7 @@ where Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .add_redirect(redirected_port, target_port) .await?; @@ -73,7 +73,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .remove_redirect(redirected_port, target_port) .await?; diff --git a/mirrord/agent/src/steal/ip_tables/output.rs b/mirrord/agent/src/steal/ip_tables/output.rs index 944bc26f95b..9eebad0c9ae 100644 --- a/mirrord/agent/src/steal/ip_tables/output.rs +++ b/mirrord/agent/src/steal/ip_tables/output.rs @@ -6,7 +6,7 @@ use nix::unistd::getgid; use tracing::warn; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{chain::IPTableChain, IPTables, Redirect}, }; @@ -20,8 +20,11 @@ where { const ENTRYPOINT: &'static str = "OUTPUT"; - pub fn create(ipt: Arc, chain_name: String, pod_ips: Option<&str>) -> Result { - let managed = IPTableChain::create(ipt, chain_name)?; + #[tracing::instrument(level = tracing::Level::TRACE, skip(ipt), err)] + pub fn create(ipt: Arc, chain_name: String, pod_ips: Option<&str>) -> AgentResult { + let managed = IPTableChain::create(ipt, chain_name.clone()).inspect_err( + |e| tracing::error!(%e, "Could not create iptables chain \"{chain_name}\"."), + )?; let exclude_source_ips = pod_ips .map(|pod_ips| format!("! -s {pod_ips}")) @@ -39,7 +42,7 @@ where Ok(OutputRedirect { managed }) } - pub fn load(ipt: Arc, chain_name: String) -> Result { + pub fn load(ipt: Arc, chain_name: String) -> AgentResult { let managed = IPTableChain::load(ipt, chain_name)?; Ok(OutputRedirect { managed }) @@ -53,7 +56,7 @@ impl Redirect for OutputRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { if USE_INSERT { self.managed.inner().insert_rule( Self::ENTRYPOINT, @@ -70,7 +73,7 @@ where Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.managed.inner().remove_rule( Self::ENTRYPOINT, &format!("-j {}", self.managed.chain_name()), @@ -79,7 +82,7 @@ where Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!( "-o lo -m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}" ); @@ -89,7 +92,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!( "-o lo -m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}" ); diff --git a/mirrord/agent/src/steal/ip_tables/prerouting.rs b/mirrord/agent/src/steal/ip_tables/prerouting.rs index 486b0ca1b51..29d5de06103 100644 --- a/mirrord/agent/src/steal/ip_tables/prerouting.rs +++ b/mirrord/agent/src/steal/ip_tables/prerouting.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use mirrord_protocol::Port; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{chain::IPTableChain, IPTables, Redirect, IPTABLE_PREROUTING}, }; @@ -18,13 +18,13 @@ where { const ENTRYPOINT: &'static str = "PREROUTING"; - pub fn create(ipt: Arc) -> Result { + pub fn create(ipt: Arc) -> AgentResult { let managed = IPTableChain::create(ipt, IPTABLE_PREROUTING.to_string())?; Ok(PreroutingRedirect { managed }) } - pub fn load(ipt: Arc) -> Result { + pub fn load(ipt: Arc) -> AgentResult { let managed = IPTableChain::load(ipt, IPTABLE_PREROUTING.to_string())?; Ok(PreroutingRedirect { managed }) @@ -36,7 +36,7 @@ impl Redirect for PreroutingRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.managed.inner().add_rule( Self::ENTRYPOINT, &format!("-j {}", self.managed.chain_name()), @@ -45,7 +45,7 @@ where Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.managed.inner().remove_rule( Self::ENTRYPOINT, &format!("-j {}", self.managed.chain_name()), @@ -54,7 +54,7 @@ where Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!("-m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}"); @@ -63,7 +63,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { let redirect_rule = format!("-m tcp -p tcp --dport {redirected_port} -j REDIRECT --to-ports {target_port}"); diff --git a/mirrord/agent/src/steal/ip_tables/redirect.rs b/mirrord/agent/src/steal/ip_tables/redirect.rs index d18aeb1d7ea..fe52d90fc1e 100644 --- a/mirrord/agent/src/steal/ip_tables/redirect.rs +++ b/mirrord/agent/src/steal/ip_tables/redirect.rs @@ -2,17 +2,17 @@ use async_trait::async_trait; use enum_dispatch::enum_dispatch; use mirrord_protocol::Port; -use crate::error::Result; +use crate::error::AgentResult; #[async_trait] #[enum_dispatch] pub(crate) trait Redirect { - async fn mount_entrypoint(&self) -> Result<()>; + async fn mount_entrypoint(&self) -> AgentResult<()>; - async fn unmount_entrypoint(&self) -> Result<()>; + async fn unmount_entrypoint(&self) -> AgentResult<()>; /// Create port redirection - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()>; + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()>; /// Remove port redirection - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()>; + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()>; } diff --git a/mirrord/agent/src/steal/ip_tables/standard.rs b/mirrord/agent/src/steal/ip_tables/standard.rs index 3302b05c02e..47b9bf0c0af 100644 --- a/mirrord/agent/src/steal/ip_tables/standard.rs +++ b/mirrord/agent/src/steal/ip_tables/standard.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use mirrord_protocol::Port; use crate::{ - error::Result, + error::AgentResult, steal::ip_tables::{ output::OutputRedirect, prerouting::PreroutingRedirect, IPTables, Redirect, IPTABLE_STANDARD, @@ -20,14 +20,14 @@ impl StandardRedirect where IPT: IPTables, { - pub fn create(ipt: Arc, pod_ips: Option<&str>) -> Result { + pub fn create(ipt: Arc, pod_ips: Option<&str>) -> AgentResult { let prerouting = PreroutingRedirect::create(ipt.clone())?; let output = OutputRedirect::create(ipt, IPTABLE_STANDARD.to_string(), pod_ips)?; Ok(StandardRedirect { prerouting, output }) } - pub fn load(ipt: Arc) -> Result { + pub fn load(ipt: Arc) -> AgentResult { let prerouting = PreroutingRedirect::load(ipt.clone())?; let output = OutputRedirect::load(ipt, IPTABLE_STANDARD.to_string())?; @@ -42,21 +42,21 @@ impl Redirect for StandardRedirect where IPT: IPTables + Send + Sync, { - async fn mount_entrypoint(&self) -> Result<()> { + async fn mount_entrypoint(&self) -> AgentResult<()> { self.prerouting.mount_entrypoint().await?; self.output.mount_entrypoint().await?; Ok(()) } - async fn unmount_entrypoint(&self) -> Result<()> { + async fn unmount_entrypoint(&self) -> AgentResult<()> { self.prerouting.unmount_entrypoint().await?; self.output.unmount_entrypoint().await?; Ok(()) } - async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn add_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .add_redirect(redirected_port, target_port) .await?; @@ -67,7 +67,7 @@ where Ok(()) } - async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> Result<()> { + async fn remove_redirect(&self, redirected_port: Port, target_port: Port) -> AgentResult<()> { self.prerouting .remove_redirect(redirected_port, target_port) .await?; diff --git a/mirrord/agent/src/steal/subscriptions.rs b/mirrord/agent/src/steal/subscriptions.rs index 0ff0e1fa8ea..901ecd725ef 100644 --- a/mirrord/agent/src/steal/subscriptions.rs +++ b/mirrord/agent/src/steal/subscriptions.rs @@ -1,18 +1,26 @@ use std::{ collections::{hash_map::Entry, HashMap}, - net::{Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + ops::Not, sync::Arc, }; use dashmap::{mapref::entry::Entry as DashMapEntry, DashMap}; use mirrord_protocol::{Port, RemoteResult, ResponseError}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::{ + net::{TcpListener, TcpStream}, + select, +}; use super::{ http::HttpFilter, - ip_tables::{new_iptables, IPTablesWrapper, SafeIpTables}, + ip_tables::{new_ip6tables, new_iptables, IPTablesWrapper, SafeIpTables}, +}; +use crate::{ + error::{AgentError, AgentResult}, + metrics::{STEAL_FILTERED_PORT_SUBSCRIPTION, STEAL_UNFILTERED_PORT_SUBSCRIPTION}, + util::ClientId, }; -use crate::{error::AgentError, util::ClientId}; /// For stealing incoming TCP connections. #[async_trait::async_trait] @@ -47,19 +55,82 @@ pub trait PortRedirector { async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error>; } -/// Implementation of [`PortRedirector`] that manipulates iptables to steal connections by -/// redirecting TCP packets to inner [`TcpListener`]. -pub(crate) struct IpTablesRedirector { +/// A TCP listener, together with an iptables wrapper to set rules that send traffic to the +/// listener. +pub(crate) struct IptablesListener { /// For altering iptables rules. iptables: Option>, - /// Whether exisiting connections should be flushed when adding new redirects. - flush_connections: bool, - /// Port of [`IpTablesRedirector::listener`]. + /// Port of [`listener`](Self::listener). redirect_to: Port, /// Listener to which redirect all connections. listener: TcpListener, - + /// Optional comma-seperated list of IPs of the pod, originating in the pod's `Status.PodIps` pod_ips: Option, + /// Whether existing connections should be flushed when adding new redirects. + flush_connections: bool, + /// Is this for connections incoming over IPv6 + ipv6: bool, +} + +#[async_trait::async_trait] +impl PortRedirector for IptablesListener { + type Error = AgentError; + + #[tracing::instrument(skip(self), err, ret, level=tracing::Level::DEBUG, fields(self.ipv6 = %self.ipv6))] + async fn add_redirection(&mut self, from: Port) -> Result<(), Self::Error> { + let iptables = if let Some(iptables) = self.iptables.as_ref() { + iptables + } else { + let safe = crate::steal::ip_tables::SafeIpTables::create( + if self.ipv6 { + new_ip6tables() + } else { + new_iptables() + } + .into(), + self.flush_connections, + self.pod_ips.as_deref(), + self.ipv6, + ) + .await?; + self.iptables.insert(safe) + }; + iptables.add_redirect(from, self.redirect_to).await + } + + async fn remove_redirection(&mut self, from: Port) -> Result<(), Self::Error> { + if let Some(iptables) = self.iptables.as_ref() { + iptables.remove_redirect(from, self.redirect_to).await?; + } + + Ok(()) + } + + async fn cleanup(&mut self) -> Result<(), Self::Error> { + if let Some(iptables) = self.iptables.take() { + iptables.cleanup().await?; + } + + Ok(()) + } + + async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error> { + self.listener.accept().await.map_err(Into::into) + } +} + +/// Implementation of [`PortRedirector`] that manipulates iptables to steal connections by +/// redirecting TCP packets to inner [`TcpListener`]. +/// +/// Holds TCP listeners + iptables, for redirecting IPv4 and/or IPv6 connections. +pub(crate) enum IpTablesRedirector { + Ipv4Only(IptablesListener), + /// Could be used if IPv6 support is enabled, and we cannot bind an IPv4 address. + Ipv6Only(IptablesListener), + Dual { + ipv4_listener: IptablesListener, + ipv6_listener: IptablesListener, + }, } impl IpTablesRedirector { @@ -67,28 +138,116 @@ impl IpTablesRedirector { /// [`Ipv4Addr::UNSPECIFIED`] address and a random port. This listener will be used to accept /// redirected connections. /// + /// If `support_ipv6` is set, will also listen on IPv6, and a fail to listen over IPv4 will be + /// accepted. + /// /// # Note /// /// Does not yet alter iptables. /// /// # Params /// - /// * `flush_connections` - whether exisitng connections should be flushed when adding new + /// * `flush_connections` - whether existing connections should be flushed when adding new /// redirects pub(crate) async fn new( flush_connections: bool, pod_ips: Option, - ) -> Result { - let listener = TcpListener::bind((Ipv4Addr::UNSPECIFIED, 0)).await?; - let redirect_to = listener.local_addr()?.port(); - - Ok(Self { - iptables: None, - flush_connections, - redirect_to, - listener, - pod_ips, - }) + support_ipv6: bool, + ) -> AgentResult { + let (pod_ips4, pod_ips6) = pod_ips.map_or_else( + || (None, None), + |ips| { + // TODO: probably nicer to split at the client and avoid the conversion to and back + // from a string. + let (ip4s, ip6s): (Vec<_>, Vec<_>) = ips.split(',').partition(|ip_str| { + ip_str + .parse::() + .inspect_err(|e| tracing::warn!(%e, "failed to parse pod IP {ip_str}")) + .as_ref() + .map(IpAddr::is_ipv4) + .unwrap_or_default() + }); + // Convert to options, `None` if vector is empty. + ( + ip4s.is_empty().not().then(|| ip4s.join(",")), + ip6s.is_empty().not().then(|| ip6s.join(",")), + ) + }, + ); + tracing::debug!("pod IPv4 addresses: {pod_ips4:?}, pod IPv6 addresses: {pod_ips6:?}"); + + tracing::debug!("Creating IPv4 iptables redirection listener"); + let listener4 = TcpListener::bind((Ipv4Addr::UNSPECIFIED, 0)).await + .inspect_err( + |err| tracing::debug!(%err, "Could not bind IPv4, continuing with IPv6 only."), + ) + .ok() + .and_then(|listener| { + let redirect_to = listener + .local_addr() + .inspect_err( + |err| tracing::debug!(%err, "Get IPv4 listener address, continuing with IPv6 only."), + ) + .ok()? + .port(); + Some(IptablesListener { + iptables: None, + redirect_to, + listener, + pod_ips: pod_ips4, + flush_connections, + ipv6: false, + }) + }); + tracing::debug!("Creating IPv6 iptables redirection listener"); + let listener6 = if support_ipv6 { + TcpListener::bind((Ipv6Addr::UNSPECIFIED, 0)).await + .inspect_err( + |err| tracing::debug!(%err, "Could not bind IPv6, continuing with IPv4 only."), + ) + .ok() + .and_then(|listener| { + let redirect_to = listener + .local_addr() + .inspect_err( + |err| tracing::debug!(%err, "Get IPv6 listener address, continuing with IPv4 only."), + ) + .ok()? + .port(); + Some(IptablesListener { + iptables: None, + redirect_to, + listener, + pod_ips: pod_ips6, + flush_connections, + ipv6: true, + }) + }) + } else { + None + }; + match (listener4, listener6) { + (None, None) => Err(AgentError::CannotListenForStolenConnections), + (Some(ipv4_listener), None) => Ok(Self::Ipv4Only(ipv4_listener)), + (None, Some(ipv6_listener)) => Ok(Self::Ipv6Only(ipv6_listener)), + (Some(ipv4_listener), Some(ipv6_listener)) => Ok(Self::Dual { + ipv4_listener, + ipv6_listener, + }), + } + } + + pub(crate) fn get_listeners_mut( + &mut self, + ) -> (Option<&mut IptablesListener>, Option<&mut IptablesListener>) { + match self { + IpTablesRedirector::Ipv4Only(ipv4_listener) => (Some(ipv4_listener), None), + IpTablesRedirector::Ipv6Only(ipv6_listener) => (None, Some(ipv6_listener)), + IpTablesRedirector::Dual { + ipv4_listener, + ipv6_listener, + } => (Some(ipv4_listener), Some(ipv6_listener)), + } } } @@ -97,41 +256,53 @@ impl PortRedirector for IpTablesRedirector { type Error = AgentError; async fn add_redirection(&mut self, from: Port) -> Result<(), Self::Error> { - let iptables = match self.iptables.as_ref() { - Some(iptables) => iptables, - None => { - let iptables = new_iptables(); - let safe = SafeIpTables::create( - iptables.into(), - self.flush_connections, - self.pod_ips.as_deref(), - ) - .await?; - self.iptables.insert(safe) - } - }; - - iptables.add_redirect(from, self.redirect_to).await + let (ipv4_listener, ipv6_listener) = self.get_listeners_mut(); + if let Some(ip4_listener) = ipv4_listener { + tracing::debug!("Adding IPv4 redirection from port {from}"); + ip4_listener.add_redirection(from).await?; + } + if let Some(ip6_listener) = ipv6_listener { + tracing::debug!("Adding IPv6 redirection from port {from}"); + ip6_listener.add_redirection(from).await?; + } + Ok(()) } async fn remove_redirection(&mut self, from: Port) -> Result<(), Self::Error> { - if let Some(iptables) = self.iptables.as_ref() { - iptables.remove_redirect(from, self.redirect_to).await?; + let (ipv4_listener, ipv6_listener) = self.get_listeners_mut(); + if let Some(ip4_listener) = ipv4_listener { + ip4_listener.remove_redirection(from).await?; + } + if let Some(ip6_listener) = ipv6_listener { + ip6_listener.remove_redirection(from).await?; } - Ok(()) } async fn cleanup(&mut self) -> Result<(), Self::Error> { - if let Some(iptables) = self.iptables.take() { - iptables.cleanup().await?; + let (ipv4_listener, ipv6_listener) = self.get_listeners_mut(); + if let Some(ip4_listener) = ipv4_listener { + ip4_listener.cleanup().await?; + } + if let Some(ip6_listener) = ipv6_listener { + ip6_listener.cleanup().await?; } - Ok(()) } async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error> { - self.listener.accept().await.map_err(Into::into) + match self { + Self::Dual { + ipv4_listener, + ipv6_listener, + } => { + select! { + con = ipv4_listener.next_connection() => con, + con = ipv6_listener.next_connection() => con, + } + } + Self::Ipv4Only(listener) | Self::Ipv6Only(listener) => listener.next_connection().await, + } } } @@ -143,6 +314,13 @@ pub struct PortSubscriptions { subscriptions: HashMap, } +impl Drop for PortSubscriptions { + fn drop(&mut self) { + STEAL_FILTERED_PORT_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + STEAL_UNFILTERED_PORT_SUBSCRIPTION.store(0, std::sync::atomic::Ordering::Relaxed); + } +} + impl PortSubscriptions { /// Create an empty instance of this struct. /// @@ -184,7 +362,16 @@ impl PortSubscriptions { ) -> Result, R::Error> { let add_redirect = match self.subscriptions.entry(port) { Entry::Occupied(mut e) => { + let filtered = filter.is_some(); if e.get_mut().try_extend(client_id, filter) { + if filtered { + STEAL_FILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } else { + STEAL_UNFILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + Ok(false) } else { Err(ResponseError::PortAlreadyStolen(port)) @@ -192,6 +379,14 @@ impl PortSubscriptions { } Entry::Vacant(e) => { + if filter.is_some() { + STEAL_FILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } else { + STEAL_UNFILTERED_PORT_SUBSCRIPTION + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + e.insert(PortSubscription::new(client_id, filter)); Ok(true) } @@ -228,11 +423,17 @@ impl PortSubscriptions { let remove_redirect = match e.get_mut() { PortSubscription::Unfiltered(subscribed_client) if *subscribed_client == client_id => { e.remove(); + STEAL_UNFILTERED_PORT_SUBSCRIPTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + true } PortSubscription::Unfiltered(..) => false, PortSubscription::Filtered(filters) => { - filters.remove(&client_id); + if filters.remove(&client_id).is_some() { + STEAL_FILTERED_PORT_SUBSCRIPTION + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } if filters.is_empty() { e.remove(); diff --git a/mirrord/agent/src/util.rs b/mirrord/agent/src/util.rs index 9dcbc6cd892..c5a002979e9 100644 --- a/mirrord/agent/src/util.rs +++ b/mirrord/agent/src/util.rs @@ -8,11 +8,12 @@ use std::{ thread::JoinHandle, }; +use futures::{future::BoxFuture, FutureExt}; use tokio::sync::mpsc; use tracing::error; use crate::{ - error::AgentError, + error::AgentResult, namespace::{set_namespace, NamespaceType}, }; @@ -151,7 +152,7 @@ where /// Many of the agent's TCP/UDP connections require that they're made from the `pid`'s namespace to /// work. #[tracing::instrument(level = "trace")] -pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> Result<(), AgentError> { +pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> AgentResult<()> { if let Some(pid) = pid { Ok(set_namespace(pid, NamespaceType::Net).inspect_err(|fail| { error!("Failed setting pid {pid:#?} namespace {namespace:#?} with {fail:#?}") @@ -162,27 +163,25 @@ pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> Result<(), A } /// [`Future`] that resolves to [`ClientId`] when the client drops their [`mpsc::Receiver`]. -pub(crate) struct ChannelClosedFuture { - tx: mpsc::Sender, - client_id: ClientId, -} +pub(crate) struct ChannelClosedFuture(BoxFuture<'static, ClientId>); + +impl ChannelClosedFuture { + pub(crate) fn new(tx: mpsc::Sender, client_id: ClientId) -> Self { + let future = async move { + tx.closed().await; + client_id + } + .boxed(); -impl ChannelClosedFuture { - pub(crate) fn new(tx: mpsc::Sender, client_id: ClientId) -> Self { - Self { tx, client_id } + Self(future) } } -impl Future for ChannelClosedFuture { +impl Future for ChannelClosedFuture { type Output = ClientId; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let client_id = self.client_id; - - let future = std::pin::pin!(self.get_mut().tx.closed()); - std::task::ready!(future.poll(cx)); - - Poll::Ready(client_id) + self.get_mut().0.as_mut().poll(cx) } } @@ -264,3 +263,52 @@ mod subscription_tests { assert_eq!(subscriptions.get_subscribed_topics(), Vec::::new()); } } + +#[cfg(test)] +mod channel_closed_tests { + use std::time::Duration; + + use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; + use rstest::rstest; + + use super::*; + + /// Verifies that [`ChannelClosedFuture`] resolves when the related [`mpsc::Receiver`] is + /// dropped. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn channel_closed_resolves() { + let (tx, rx) = mpsc::channel::<()>(1); + let future = ChannelClosedFuture::new(tx, 0); + std::mem::drop(rx); + assert_eq!(future.await, 0); + } + + /// Verifies that [`ChannelClosedFuture`] works fine when used in [`FuturesUnordered`]. + /// + /// The future used to hold the [`mpsc::Sender`] and call poll [`mpsc::Sender::closed`] in it's + /// [`Future::poll`] implementation. This worked fine when the future was used in a simple way + /// ([`channel_closed_resolves`] test was passing). + /// + /// However, [`FuturesUnordered::next`] was hanging forever due to [`mpsc::Sender::closed`] + /// implementation details. + /// + /// New implementation of [`ChannelClosedFuture`] uses a [`BoxFuture`] internally, which works + /// fine. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn channel_closed_works_in_futures_unordered() { + let mut unordered: FuturesUnordered = FuturesUnordered::new(); + + let (tx, rx) = mpsc::channel::<()>(1); + let future = ChannelClosedFuture::new(tx, 0); + + unordered.push(future); + + assert!(unordered.next().now_or_never().is_none()); + std::mem::drop(rx); + assert_eq!(unordered.next().await.unwrap(), 0); + } +} diff --git a/mirrord/agent/src/vpn.rs b/mirrord/agent/src/vpn.rs index dd8c3a5133f..d7d30d5ca6f 100644 --- a/mirrord/agent/src/vpn.rs +++ b/mirrord/agent/src/vpn.rs @@ -17,7 +17,7 @@ use tokio::{ }; use crate::{ - error::{AgentError, Result}, + error::{AgentError, AgentResult}, util::run_thread_in_namespace, watched_task::{TaskStatus, WatchedTask}, }; @@ -75,7 +75,7 @@ impl VpnApi { /// Sends the [`ClientVpn`] message to the background task. #[tracing::instrument(level = "trace", skip(self))] - pub(crate) async fn layer_message(&mut self, message: ClientVpn) -> Result<()> { + pub(crate) async fn layer_message(&mut self, message: ClientVpn) -> AgentResult<()> { if self.layer_tx.send(message).await.is_ok() { Ok(()) } else { @@ -84,7 +84,7 @@ impl VpnApi { } /// Receives a [`ServerVpn`] message from the background task. - pub(crate) async fn daemon_message(&mut self) -> Result { + pub(crate) async fn daemon_message(&mut self) -> AgentResult { match self.daemon_rx.recv().await { Some(msg) => Ok(msg), None => Err(self.task_status.unwrap_err().await), @@ -121,7 +121,7 @@ impl AsyncRawSocket { } } -async fn create_raw_socket() -> Result { +async fn create_raw_socket() -> AgentResult { let index = nix::net::if_::if_nametoindex("eth0") .map_err(|err| AgentError::VpnError(err.to_string()))?; @@ -139,7 +139,7 @@ async fn create_raw_socket() -> Result { } #[tracing::instrument(level = "debug", ret)] -async fn resolve_interface() -> Result<(IpAddr, IpAddr, IpAddr)> { +async fn resolve_interface() -> AgentResult<(IpAddr, IpAddr, IpAddr)> { // Connect to a remote address so we can later get the default network interface. let temporary_socket = UdpSocket::bind("0.0.0.0:0").await?; temporary_socket.connect("8.8.8.8:53").await?; @@ -209,7 +209,7 @@ impl fmt::Debug for VpnTask { } } -fn interface_index_to_sock_addr(index: i32) -> Result { +fn interface_index_to_sock_addr(index: i32) -> AgentResult { let mut addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; let len = std::mem::size_of::() as libc::socklen_t; let macs = procfs::net::arp().map_err(|err| AgentError::VpnError(err.to_string()))?; @@ -245,7 +245,7 @@ impl VpnTask { } #[allow(clippy::indexing_slicing)] - async fn run(mut self) -> Result<()> { + async fn run(mut self) -> AgentResult<()> { // so host won't respond with RST to our packets. // TODO: need to do it for UDP as well to avoid ICMP unreachable. let output = std::process::Command::new("iptables") @@ -318,7 +318,7 @@ impl VpnTask { &mut self, message: ClientVpn, network_configuration: &NetworkConfiguration, - ) -> Result<()> { + ) -> AgentResult<()> { match message { // We make connection to the requested address, split the stream into halves with // `io::split`, and put them into respective maps. diff --git a/mirrord/agent/src/watched_task.rs b/mirrord/agent/src/watched_task.rs index 0212f279163..2e7370b262c 100644 --- a/mirrord/agent/src/watched_task.rs +++ b/mirrord/agent/src/watched_task.rs @@ -2,7 +2,7 @@ use std::future::Future; use tokio::sync::watch::{self, Receiver, Sender}; -use crate::error::AgentError; +use crate::error::{AgentError, AgentResult}; /// A shared clonable view on a background task's status. #[derive(Debug, Clone)] @@ -83,7 +83,7 @@ impl WatchedTask { impl WatchedTask where - T: Future>, + T: Future>, { /// Execute the wrapped task. /// Store its result in the inner [`TaskStatus`]. @@ -94,9 +94,21 @@ where } #[cfg(test)] -mod test { +pub(crate) mod test { use super::*; + impl TaskStatus { + pub fn dummy( + task_name: &'static str, + result_rx: Receiver>>, + ) -> Self { + Self { + task_name, + result_rx, + } + } + } + #[tokio::test] async fn simple_successful() { let task = WatchedTask::new("task", async move { Ok(()) }); diff --git a/mirrord/cli/Cargo.toml b/mirrord/cli/Cargo.toml index 420ec6e97f1..aea0b77d7cc 100644 --- a/mirrord/cli/Cargo.toml +++ b/mirrord/cli/Cargo.toml @@ -63,11 +63,12 @@ tempfile.workspace = true rcgen.workspace = true rustls-pemfile.workspace = true tokio-rustls.workspace = true -tokio-stream = { workspace = true, features = ["net"] } +tokio-stream = { workspace = true, features = ["io-util", "net"] } regex.workspace = true mid = "3.0.0" rand.workspace = true + [target.'cfg(target_os = "macos")'.dependencies] mirrord-sip = { path = "../sip" } diff --git a/mirrord/cli/Dockerfile b/mirrord/cli/Dockerfile index 81e37142055..144f4a0e162 100644 --- a/mirrord/cli/Dockerfile +++ b/mirrord/cli/Dockerfile @@ -8,7 +8,8 @@ RUN ./platform.sh # this takes around 1 minute since libgit2 is slow https://github.com/rust-lang/cargo/issues/9167 ENV CARGO_NET_GIT_FETCH_WITH_CLI=true -RUN cargo install cargo-chef +# cargo-chef 0.1.69 breaks the build +RUN cargo install cargo-chef@0.1.68 FROM chef AS planner diff --git a/mirrord/cli/src/config.rs b/mirrord/cli/src/config.rs index 1570cfbca35..96548235ad4 100644 --- a/mirrord/cli/src/config.rs +++ b/mirrord/cli/src/config.rs @@ -712,11 +712,18 @@ pub(super) struct ListTargetArgs { #[arg(short = 'n', long = "namespace")] pub namespace: Option, - /// Specify config file to use + /// Specify config file to use. #[arg(short = 'f', long, value_hint = ValueHint::FilePath)] pub config_file: Option, } +impl ListTargetArgs { + /// Controls the output of `mirrord ls`. + /// If set to `true`, the command outputs a JSON object that contains more data. + /// Otherwise, it outputs a plain array of target paths. + pub(super) const RICH_OUTPUT_ENV: &str = "MIRRORD_LS_RICH_OUTPUT"; +} + #[derive(Args, Debug)] pub(super) struct ExtensionExecArgs { /// Specify config file to use @@ -834,6 +841,13 @@ pub struct RuntimeArgs { /// Supported command for using mirrord with container runtimes. #[derive(Subcommand, Debug, Clone)] pub(super) enum ContainerRuntimeCommand { + /// Execute a ` create` command with mirrord loaded. (not supported with ) + #[command(hide = true)] + Create { + /// Arguments that will be propogated to underlying ` create` command. + #[arg(allow_hyphen_values = true, trailing_var_arg = true)] + runtime_args: Vec, + }, /// Execute a ` run` command with mirrord loaded. Run { /// Arguments that will be propogated to underlying ` run` command. @@ -843,14 +857,17 @@ pub(super) enum ContainerRuntimeCommand { } impl ContainerRuntimeCommand { - pub fn run>(runtime_args: impl IntoIterator) -> Self { - ContainerRuntimeCommand::Run { + pub fn create>(runtime_args: impl IntoIterator) -> Self { + ContainerRuntimeCommand::Create { runtime_args: runtime_args.into_iter().map(T::into).collect(), } } pub fn has_publish(&self) -> bool { - let ContainerRuntimeCommand::Run { runtime_args } = self; + let runtime_args = match self { + ContainerRuntimeCommand::Run { runtime_args } => runtime_args, + _ => return false, + }; let mut hit_trailing_token = false; @@ -860,6 +877,15 @@ impl ContainerRuntimeCommand { !hit_trailing_token && matches!(runtime_arg.as_str(), "-p" | "--publish") }) } + + pub fn into_parts(self) -> (Vec, Vec) { + match self { + ContainerRuntimeCommand::Create { runtime_args } => { + (vec!["create".to_owned()], runtime_args) + } + ContainerRuntimeCommand::Run { runtime_args } => (vec!["run".to_owned()], runtime_args), + } + } } #[derive(Args, Debug)] @@ -947,7 +973,9 @@ mod tests { assert_eq!(runtime_args.runtime, ContainerRuntime::Podman); - let ContainerRuntimeCommand::Run { runtime_args } = runtime_args.command; + let ContainerRuntimeCommand::Run { runtime_args } = runtime_args.command else { + panic!("expected run command"); + }; assert_eq!(runtime_args, vec!["-it", "--rm", "debian"]); } @@ -965,7 +993,9 @@ mod tests { assert_eq!(runtime_args.runtime, ContainerRuntime::Podman); - let ContainerRuntimeCommand::Run { runtime_args } = runtime_args.command; + let ContainerRuntimeCommand::Run { runtime_args } = runtime_args.command else { + panic!("expected run command"); + }; assert_eq!(runtime_args, vec!["-it", "--rm", "debian"]); } diff --git a/mirrord/cli/src/container.rs b/mirrord/cli/src/container.rs index afffad9831d..a2659190e0d 100644 --- a/mirrord/cli/src/container.rs +++ b/mirrord/cli/src/container.rs @@ -2,7 +2,6 @@ use std::{ collections::HashMap, io::Write, net::SocketAddr, - ops::Not, path::{Path, PathBuf}, process::Stdio, time::Duration, @@ -15,7 +14,6 @@ use mirrord_config::{ external_proxy::{MIRRORD_EXTERNAL_TLS_CERTIFICATE_ENV, MIRRORD_EXTERNAL_TLS_KEY_ENV}, internal_proxy::{ MIRRORD_INTPROXY_CLIENT_TLS_CERTIFICATE_ENV, MIRRORD_INTPROXY_CLIENT_TLS_KEY_ENV, - MIRRORD_INTPROXY_CONTAINER_MODE_ENV, }, LayerConfig, MIRRORD_CONFIG_FILE_ENV, }; @@ -28,20 +26,22 @@ use tokio::{ use tracing::Level; use crate::{ - config::{ContainerRuntime, ContainerRuntimeCommand, ExecParams, RuntimeArgs}, + config::{ContainerRuntime, ExecParams, RuntimeArgs}, connection::AGENT_CONNECT_INFO_ENV_KEY, - container::command_builder::RuntimeCommandBuilder, + container::{command_builder::RuntimeCommandBuilder, sidecar::Sidecar}, error::{CliError, CliResult, ContainerError}, execution::{ MirrordExecution, LINUX_INJECTION_ENV_VAR, MIRRORD_CONNECT_TCP_ENV, MIRRORD_EXECUTION_KIND_ENV, }, + logging::pipe_intproxy_sidecar_logs, util::MIRRORD_CONSOLE_ADDR_ENV, }; static CONTAINER_EXECUTION_KIND: ExecutionKind = ExecutionKind::Container; mod command_builder; +mod sidecar; /// Format [`Command`] to look like the executated command (currently without env because we don't /// use it in these scenarios) @@ -65,10 +65,8 @@ async fn exec_and_get_first_line(command: &mut Command) -> Result .spawn() .map_err(ContainerError::UnableToExecuteCommand)?; - tracing::warn!(?child, "spawned watch for child"); - let stdout = child.stdout.take().expect("stdout should be piped"); - let stderr = child.stderr.take().expect("stdout should be piped"); + let stderr = child.stderr.take().expect("stderr should be piped"); let result = tokio::time::timeout(Duration::from_secs(30), async { BufReader::new(stdout) @@ -151,115 +149,6 @@ fn create_self_signed_certificate( Ok((certificate, private_key)) } -/// Create a "sidecar" container that is running `mirrord intproxy` that connects to `mirrord -/// extproxy` running on user machine to be used by execution container (via mounting on same -/// network) -#[tracing::instrument(level = Level::TRACE, ret)] -async fn create_sidecar_intproxy( - config: &LayerConfig, - base_command: &RuntimeCommandBuilder, - connection_info: Vec<(&str, &str)>, -) -> Result<(String, SocketAddr), ContainerError> { - let mut sidecar_command = base_command.clone(); - - sidecar_command.add_env(MIRRORD_INTPROXY_CONTAINER_MODE_ENV, "true"); - sidecar_command.add_envs(connection_info); - - let cleanup = config.container.cli_prevent_cleanup.not().then_some("--rm"); - - let sidecar_container_command = ContainerRuntimeCommand::run( - config - .container - .cli_extra_args - .iter() - .map(String::as_str) - .chain(cleanup) - .chain(["-d", &config.container.cli_image, "mirrord", "intproxy"]), - ); - - let (runtime_binary, sidecar_args) = sidecar_command - .with_command(sidecar_container_command) - .into_command_args(); - - let mut sidecar_container_spawn = Command::new(&runtime_binary); - - sidecar_container_spawn.args(sidecar_args); - - let sidecar_container_id = exec_and_get_first_line(&mut sidecar_container_spawn) - .await? - .ok_or_else(|| { - ContainerError::UnsuccesfulCommandOutput( - format_command(&sidecar_container_spawn), - "stdout and stderr were empty".to_owned(), - ) - })?; - - // For Docker runtime sometimes the sidecar doesn't start so we double check. - // See [#2927](https://github.com/metalbear-co/mirrord/issues/2927) - if matches!(base_command.runtime(), ContainerRuntime::Docker) { - let mut container_inspect_command = Command::new(&runtime_binary); - container_inspect_command - .args(["inspect", &sidecar_container_id]) - .stdout(Stdio::piped()); - - let container_inspect_output = container_inspect_command.output().await.map_err(|err| { - ContainerError::UnsuccesfulCommandOutput( - format_command(&container_inspect_command), - err.to_string(), - ) - })?; - - let (container_inspection,) = - serde_json::from_slice::<(serde_json::Value,)>(&container_inspect_output.stdout) - .unwrap_or_default(); - - let container_status = container_inspection - .get("State") - .and_then(|inspect| inspect.get("Status")); - - if container_status - .map(|status| status == "created") - .unwrap_or(false) - { - let mut container_start_command = Command::new(&runtime_binary); - - container_start_command - .args(["start", &sidecar_container_id]) - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()); - - let _ = container_start_command.status().await.map_err(|err| { - ContainerError::UnsuccesfulCommandOutput( - format_command(&container_start_command), - err.to_string(), - ) - })?; - } - } - - // After spawning sidecar with -d flag it prints container_id, now we need the address of - // intproxy running in sidecar to be used by mirrord-layer in execution container - let intproxy_address: SocketAddr = { - let mut attach_command = Command::new(&runtime_binary); - attach_command.args(["logs", "-f", &sidecar_container_id]); - - match exec_and_get_first_line(&mut attach_command).await? { - Some(line) => line - .parse() - .map_err(ContainerError::UnableParseProxySocketAddr)?, - None => { - return Err(ContainerError::UnsuccesfulCommandOutput( - format_command(&attach_command), - "stdout and stderr were empty".into(), - )) - } - } - }; - - Ok((sidecar_container_id, intproxy_address)) -} - type TlsGuard = (NamedTempFile, NamedTempFile); fn prepare_tls_certs_for_container( @@ -315,34 +204,15 @@ fn prepare_tls_certs_for_container( Ok((internal_proxy_tls_guards, external_proxy_tls_guards)) } -/// Main entry point for the `mirrord container` command. -/// This spawns: "agent" - "external proxy" - "intproxy sidecar" - "execution container" -pub(crate) async fn container_command( - runtime_args: RuntimeArgs, - exec_params: ExecParams, +/// Load [`LayerConfig`] from env and create [`AnalyticsReporter`] whilst reporting any warnings. +fn create_config_and_analytics( + progress: &mut P, watch: drain::Watch, -) -> CliResult { - let mut progress = ProgressTracker::from_env("mirrord container"); - - if runtime_args.command.has_publish() { - progress.warning("mirrord container may have problems with \"-p\" directly container in command, please add to \"contanier.cli_extra_args\" in config if you are planning to publish ports"); - } - - progress.warning("mirrord container is currently an unstable feature"); - - for (name, value) in exec_params.as_env_vars()? { - std::env::set_var(name, value); - } - - std::env::set_var( - MIRRORD_EXECUTION_KIND_ENV, - (CONTAINER_EXECUTION_KIND as u32).to_string(), - ); - - let (mut config, mut context) = LayerConfig::from_env_with_warnings()?; +) -> CliResult<(LayerConfig, AnalyticsReporter)> { + let (config, mut context) = LayerConfig::from_env_with_warnings()?; // Initialize only error analytics, extproxy will be the full AnalyticsReporter. - let mut analytics = + let analytics = AnalyticsReporter::only_error(config.telemetry, CONTAINER_EXECUTION_KIND, watch); config.verify(&mut context)?; @@ -350,16 +220,22 @@ pub(crate) async fn container_command( progress.warning(warning); } - let (_internal_proxy_tls_guards, _external_proxy_tls_guards) = - prepare_tls_certs_for_container(&mut config)?; - - let composed_config_file = create_composed_config(&config)?; - std::env::set_var(MIRRORD_CONFIG_FILE_ENV, composed_config_file.path()); + Ok((config, analytics)) +} +/// Create [`RuntimeCommandBuilder`] with the corresponding [`Sidecar`] connected to +/// [`MirrordExecution`] as extproxy. +async fn create_runtime_command_with_sidecar( + analytics: &mut AnalyticsReporter, + progress: &mut P, + config: &LayerConfig, + composed_config_path: &Path, + runtime: ContainerRuntime, +) -> CliResult<(RuntimeCommandBuilder, Sidecar, MirrordExecution)> { let mut sub_progress = progress.subtask("preparing to launch process"); let execution_info = - MirrordExecution::start_external(&config, &mut sub_progress, &mut analytics).await?; + MirrordExecution::start_external(config, &mut sub_progress, analytics).await?; let mut connection_info = Vec::new(); let mut execution_info_env_without_connection_info = Vec::new(); @@ -374,7 +250,7 @@ pub(crate) async fn container_command( sub_progress.success(None); - let mut runtime_command = RuntimeCommandBuilder::new(runtime_args.runtime); + let mut runtime_command = RuntimeCommandBuilder::new(runtime); if let Ok(console_addr) = std::env::var(MIRRORD_CONSOLE_ADDR_ENV) { if console_addr @@ -398,8 +274,7 @@ pub(crate) async fn container_command( ); runtime_command.add_env(MIRRORD_CONFIG_FILE_ENV, "/tmp/mirrord-config.json"); - runtime_command - .add_volume::(composed_config_file.path(), "/tmp/mirrord-config.json"); + runtime_command.add_volume::(composed_config_path, "/tmp/mirrord-config.json"); let mut load_env_and_mount_pem = |env: &str, path: &Path| { let container_path = format!("/tmp/{}.pem", env.to_lowercase()); @@ -426,11 +301,58 @@ pub(crate) async fn container_command( runtime_command.add_envs(execution_info_env_without_connection_info); - let (sidecar_container_id, sidecar_intproxy_address) = - create_sidecar_intproxy(&config, &runtime_command, connection_info).await?; + let sidecar = Sidecar::create_intproxy(config, &runtime_command, connection_info).await?; + + runtime_command.add_network(sidecar.as_network()); + runtime_command.add_volumes_from(&sidecar.container_id); + + Ok((runtime_command, sidecar, execution_info)) +} + +/// Main entry point for the `mirrord container` command. +/// This spawns: "agent" - "external proxy" - "intproxy sidecar" - "execution container" +pub(crate) async fn container_command( + runtime_args: RuntimeArgs, + exec_params: ExecParams, + watch: drain::Watch, +) -> CliResult { + let mut progress = ProgressTracker::from_env("mirrord container"); + + if runtime_args.command.has_publish() { + progress.warning("mirrord container may have problems with \"-p\" when used as part of container run command, please add the publish arguments to \"contanier.cli_extra_args\" in config if you are planning to publish ports"); + } + + progress.warning("mirrord container is currently an unstable feature"); + + for (name, value) in exec_params.as_env_vars()? { + std::env::set_var(name, value); + } + + std::env::set_var( + MIRRORD_EXECUTION_KIND_ENV, + (CONTAINER_EXECUTION_KIND as u32).to_string(), + ); - runtime_command.add_network(format!("container:{sidecar_container_id}")); - runtime_command.add_volumes_from(sidecar_container_id); + // LayerConfig must be created after setting relevant env vars + let (mut config, mut analytics) = create_config_and_analytics(&mut progress, watch)?; + + let (_internal_proxy_tls_guards, _external_proxy_tls_guards) = + prepare_tls_certs_for_container(&mut config)?; + + let composed_config_file = create_composed_config(&config)?; + std::env::set_var(MIRRORD_CONFIG_FILE_ENV, composed_config_file.path()); + + let (mut runtime_command, sidecar, _execution_info) = create_runtime_command_with_sidecar( + &mut analytics, + &mut progress, + &config, + composed_config_file.path(), + runtime_args.runtime, + ) + .await?; + + let (sidecar_intproxy_address, sidecar_intproxy_logs) = sidecar.start().await?; + tokio::spawn(pipe_intproxy_sidecar_logs(&config, sidecar_intproxy_logs).await?); runtime_command.add_env(LINUX_INJECTION_ENV_VAR, config.container.cli_image_lib_path); runtime_command.add_env( @@ -505,15 +427,9 @@ pub(crate) async fn container_ext_command( std::env::set_var("MIRRORD_IMPERSONATED_TARGET", target.clone()); env.insert("MIRRORD_IMPERSONATED_TARGET".into(), target.to_string()); } - let (mut config, mut context) = LayerConfig::from_env_with_warnings()?; - - // Initialize only error analytics, extproxy will be the full AnalyticsReporter. - let mut analytics = AnalyticsReporter::only_error(config.telemetry, Default::default(), watch); - config.verify(&mut context)?; - for warning in context.get_warnings() { - progress.warning(warning); - } + // LayerConfig must be created after setting relevant env vars + let (mut config, mut analytics) = create_config_and_analytics(&mut progress, watch)?; let (_internal_proxy_tls_guards, _external_proxy_tls_guards) = prepare_tls_certs_for_container(&mut config)?; @@ -521,86 +437,22 @@ pub(crate) async fn container_ext_command( let composed_config_file = create_composed_config(&config)?; std::env::set_var(MIRRORD_CONFIG_FILE_ENV, composed_config_file.path()); - let mut sub_progress = progress.subtask("preparing to launch process"); - - let execution_info = - MirrordExecution::start_external(&config, &mut sub_progress, &mut analytics).await?; - - let mut connection_info = Vec::new(); - let mut execution_info_env_without_connection_info = Vec::new(); - - for (key, value) in &execution_info.environment { - if key == MIRRORD_CONNECT_TCP_ENV || key == AGENT_CONNECT_INFO_ENV_KEY { - connection_info.push((key.as_str(), value.as_str())); - } else { - execution_info_env_without_connection_info.push((key.as_str(), value.as_str())) - } - } - - sub_progress.success(None); - let container_runtime = std::env::var("MIRRORD_CONTAINER_USE_RUNTIME") .ok() .and_then(|value| ContainerRuntime::from_str(&value, true).ok()) .unwrap_or(ContainerRuntime::Docker); - let mut runtime_command = RuntimeCommandBuilder::new(container_runtime); - - if let Ok(console_addr) = std::env::var(MIRRORD_CONSOLE_ADDR_ENV) { - if console_addr - .parse() - .map(|addr: SocketAddr| !addr.ip().is_loopback()) - .unwrap_or_default() - { - runtime_command.add_env(MIRRORD_CONSOLE_ADDR_ENV, console_addr); - } else { - tracing::warn!( - ?console_addr, - "{MIRRORD_CONSOLE_ADDR_ENV} needs to be a non loopback address when used with containers" - ); - } - } - - runtime_command.add_env(MIRRORD_PROGRESS_ENV, "off"); - runtime_command.add_env( - MIRRORD_EXECUTION_KIND_ENV, - (CONTAINER_EXECUTION_KIND as u32).to_string(), - ); - - runtime_command.add_env(MIRRORD_CONFIG_FILE_ENV, "/tmp/mirrord-config.json"); - runtime_command - .add_volume::(composed_config_file.path(), "/tmp/mirrord-config.json"); - - let mut load_env_and_mount_pem = |env: &str, path: &Path| { - let container_path = format!("/tmp/{}.pem", env.to_lowercase()); - - runtime_command.add_env(env, &container_path); - runtime_command.add_volume::(path, container_path); - }; - - if let Some(path) = config.internal_proxy.client_tls_certificate.as_ref() { - load_env_and_mount_pem(MIRRORD_INTPROXY_CLIENT_TLS_CERTIFICATE_ENV, path) - } - - if let Some(path) = config.internal_proxy.client_tls_key.as_ref() { - load_env_and_mount_pem(MIRRORD_INTPROXY_CLIENT_TLS_KEY_ENV, path) - } - - if let Some(path) = config.external_proxy.tls_certificate.as_ref() { - load_env_and_mount_pem(MIRRORD_EXTERNAL_TLS_CERTIFICATE_ENV, path) - } - - if let Some(path) = config.external_proxy.tls_key.as_ref() { - load_env_and_mount_pem(MIRRORD_EXTERNAL_TLS_KEY_ENV, path) - } - - runtime_command.add_envs(execution_info_env_without_connection_info); - - let (sidecar_container_id, sidecar_intproxy_address) = - create_sidecar_intproxy(&config, &runtime_command, connection_info).await?; - - runtime_command.add_network(format!("container:{sidecar_container_id}")); - runtime_command.add_volumes_from(sidecar_container_id); + let (mut runtime_command, sidecar, execution_info) = create_runtime_command_with_sidecar( + &mut analytics, + &mut progress, + &config, + composed_config_file.path(), + container_runtime, + ) + .await?; + + let (sidecar_intproxy_address, sidecar_intproxy_logs) = sidecar.start().await?; + tokio::spawn(pipe_intproxy_sidecar_logs(&config, sidecar_intproxy_logs).await?); runtime_command.add_env(LINUX_INJECTION_ENV_VAR, config.container.cli_image_lib_path); runtime_command.add_env( diff --git a/mirrord/cli/src/container/command_builder.rs b/mirrord/cli/src/container/command_builder.rs index 20e0fac883b..c5bf3a6ae39 100644 --- a/mirrord/cli/src/container/command_builder.rs +++ b/mirrord/cli/src/container/command_builder.rs @@ -17,10 +17,6 @@ pub struct RuntimeCommandBuilder { } impl RuntimeCommandBuilder { - pub(super) fn runtime(&self) -> &ContainerRuntime { - &self.runtime - } - fn push_arg(&mut self, value: V) where V: Into, @@ -152,13 +148,12 @@ impl RuntimeCommandBuilder { step, } = self; - let (runtime_command, runtime_args) = match step.command { - ContainerRuntimeCommand::Run { runtime_args } => ("run".to_owned(), runtime_args), - }; + let (runtime_command, runtime_args) = step.command.into_parts(); ( runtime.to_string(), - std::iter::once(runtime_command) + runtime_command + .into_iter() .chain(extra_args) .chain(runtime_args), ) diff --git a/mirrord/cli/src/container/sidecar.rs b/mirrord/cli/src/container/sidecar.rs new file mode 100644 index 00000000000..28641070058 --- /dev/null +++ b/mirrord/cli/src/container/sidecar.rs @@ -0,0 +1,126 @@ +use std::{net::SocketAddr, ops::Not, process::Stdio, time::Duration}; + +use mirrord_config::{internal_proxy::MIRRORD_INTPROXY_CONTAINER_MODE_ENV, LayerConfig}; +use tokio::{ + io::{AsyncBufReadExt, BufReader}, + process::{ChildStderr, ChildStdout, Command}, +}; +use tokio_stream::{wrappers::LinesStream, StreamExt}; +use tracing::Level; + +use crate::{ + config::ContainerRuntimeCommand, + container::{command_builder::RuntimeCommandBuilder, exec_and_get_first_line, format_command}, + error::ContainerError, +}; + +#[derive(Debug)] +pub(crate) struct Sidecar { + pub container_id: String, + pub runtime_binary: String, +} + +impl Sidecar { + /// Create a "sidecar" container that is running `mirrord intproxy` that connects to `mirrord + /// extproxy` running on user machine to be used by execution container (via mounting on same + /// network) + #[tracing::instrument(level = Level::TRACE)] + pub async fn create_intproxy( + config: &LayerConfig, + base_command: &RuntimeCommandBuilder, + connection_info: Vec<(&str, &str)>, + ) -> Result { + let mut sidecar_command = base_command.clone(); + + sidecar_command.add_env(MIRRORD_INTPROXY_CONTAINER_MODE_ENV, "true"); + sidecar_command.add_envs(connection_info); + + let cleanup = config.container.cli_prevent_cleanup.not().then_some("--rm"); + + let sidecar_container_command = ContainerRuntimeCommand::create( + config + .container + .cli_extra_args + .iter() + .map(String::as_str) + .chain(cleanup) + .chain([&config.container.cli_image, "mirrord", "intproxy"]), + ); + + let (runtime_binary, sidecar_args) = sidecar_command + .with_command(sidecar_container_command) + .into_command_args(); + + let mut sidecar_container_spawn = Command::new(&runtime_binary); + sidecar_container_spawn.args(sidecar_args); + + let container_id = exec_and_get_first_line(&mut sidecar_container_spawn) + .await? + .ok_or_else(|| { + ContainerError::UnsuccesfulCommandOutput( + format_command(&sidecar_container_spawn), + "stdout and stderr were empty".to_owned(), + ) + })?; + + Ok(Sidecar { + container_id, + runtime_binary, + }) + } + + pub fn as_network(&self) -> String { + let Sidecar { container_id, .. } = self; + format!("container:{container_id}") + } + + #[tracing::instrument(level = Level::TRACE)] + pub async fn start(&self) -> Result<(SocketAddr, SidecarLogs), ContainerError> { + let mut command = Command::new(&self.runtime_binary); + command.args(["start", "--attach", &self.container_id]); + + let mut child = command + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(ContainerError::UnableToExecuteCommand)?; + + let mut stdout = + BufReader::new(child.stdout.take().expect("stdout should be piped")).lines(); + let stderr = BufReader::new(child.stderr.take().expect("stderr should be piped")).lines(); + + let first_line = tokio::time::timeout(Duration::from_secs(30), async { + stdout.next_line().await.map_err(|error| { + ContainerError::UnableReadCommandStdout(format_command(&command), error) + }) + }) + .await + .map_err(|_| { + ContainerError::UnsuccesfulCommandOutput( + format_command(&command), + "timeout reached for reading first line".into(), + ) + })?? + .ok_or_else(|| { + ContainerError::UnsuccesfulCommandOutput( + format_command(&command), + "unexpected EOF".into(), + ) + })?; + + let internal_proxy_addr: SocketAddr = first_line + .parse() + .map_err(ContainerError::UnableParseProxySocketAddr)?; + + Ok(( + internal_proxy_addr, + LinesStream::new(stdout).merge(LinesStream::new(stderr)), + )) + } +} + +type SidecarLogs = tokio_stream::adapters::Merge< + LinesStream>, + LinesStream>, +>; diff --git a/mirrord/cli/src/execution.rs b/mirrord/cli/src/execution.rs index 7629b466b20..611c45172ce 100644 --- a/mirrord/cli/src/execution.rs +++ b/mirrord/cli/src/execution.rs @@ -7,7 +7,7 @@ use std::{ use mirrord_analytics::{AnalyticsError, AnalyticsReporter, Reporter}; use mirrord_config::{ config::ConfigError, feature::env::mapper::EnvVarsRemapper, - internal_proxy::MIRRORD_INTPROXY_CONNECT_TCP_ENV, LayerConfig, + internal_proxy::MIRRORD_INTPROXY_CONNECT_TCP_ENV, LayerConfig, MIRRORD_RESOLVED_CONFIG_ENV, }; use mirrord_intproxy::agent_conn::AgentConnectInfo; use mirrord_operator::client::OperatorSession; @@ -27,7 +27,7 @@ use tokio::{ sync::mpsc::{self, UnboundedReceiver}, }; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, trace, warn, Level}; +use tracing::{debug, error, info, trace, warn, Level}; #[cfg(all(target_os = "macos", target_arch = "aarch64"))] use crate::extract::extract_arm64; @@ -177,7 +177,7 @@ impl MirrordExecution { /// [`tokio::time::sleep`] or [`tokio::task::yield_now`] after calling this function. #[tracing::instrument(level = Level::TRACE, skip_all)] pub(crate) async fn start

( - config: &LayerConfig, + config: &mut LayerConfig, // We only need the executable on macos, for SIP handling. #[cfg(target_os = "macos")] executable: Option<&str>, progress: &mut P, @@ -269,7 +269,7 @@ impl MirrordExecution { let stdout = proxy_process.stdout.take().expect("stdout was piped"); - let address: SocketAddr = BufReader::new(stdout) + let intproxy_address: SocketAddr = BufReader::new(stdout) .lines() .next_line() .await @@ -288,11 +288,10 @@ impl MirrordExecution { )) })?; - // Provide details for layer to connect to agent via internal proxy - env_vars.insert( - MIRRORD_CONNECT_TCP_ENV.to_string(), - format!("127.0.0.1:{}", address.port()), - ); + config.connect_tcp.replace(intproxy_address.to_string()); + config.update_env_var()?; + let config_as_env = config.to_env_var()?; + env_vars.insert(MIRRORD_RESOLVED_CONFIG_ENV.to_string(), config_as_env); // Fixes // by disabling the fork safety check in the Objective-C runtime. @@ -552,6 +551,7 @@ impl MirrordExecution { match msg.level { LogLevel::Error => error!("Agent log: {}", msg.message), LogLevel::Warn => warn!("Agent log: {}", msg.message), + LogLevel::Info => info!("Agent log: {}", msg.message), } continue; diff --git a/mirrord/cli/src/extension.rs b/mirrord/cli/src/extension.rs index f8d4a0da482..f46e91e5092 100644 --- a/mirrord/cli/src/extension.rs +++ b/mirrord/cli/src/extension.rs @@ -4,13 +4,13 @@ use mirrord_analytics::{AnalyticsError, AnalyticsReporter, Reporter}; use mirrord_config::{LayerConfig, MIRRORD_CONFIG_FILE_ENV}; use mirrord_progress::{JsonProgress, Progress, ProgressTracker}; -use crate::{config::ExtensionExecArgs, error::CliError, execution::MirrordExecution, CliResult}; +use crate::{config::ExtensionExecArgs, execution::MirrordExecution, CliResult}; /// Actually facilitate execution after all preparations were complete async fn mirrord_exec

( #[cfg(target_os = "macos")] executable: Option<&str>, env: HashMap, - config: LayerConfig, + mut config: LayerConfig, mut progress: P, analytics: &mut AnalyticsReporter, ) -> CliResult<()> @@ -21,9 +21,9 @@ where // or run tasks before actually launching. #[cfg(target_os = "macos")] let mut execution_info = - MirrordExecution::start(&config, executable, &mut progress, analytics).await?; + MirrordExecution::start(&mut config, executable, &mut progress, analytics).await?; #[cfg(not(target_os = "macos"))] - let mut execution_info = MirrordExecution::start(&config, &mut progress, analytics).await?; + let mut execution_info = MirrordExecution::start(&mut config, &mut progress, analytics).await?; // We don't execute so set envs aren't passed, so we need to add config file and target to // env. @@ -40,23 +40,15 @@ where pub(crate) async fn extension_exec(args: ExtensionExecArgs, watch: drain::Watch) -> CliResult<()> { let progress = ProgressTracker::try_from_env("mirrord preparing to launch") .unwrap_or_else(|| JsonProgress::new("mirrord preparing to launch").into()); - let mut env: HashMap = HashMap::new(); + // Set environment required for `LayerConfig::from_env_with_warnings`. if let Some(config_file) = args.config_file.as_ref() { - // Set canoncialized path to config file, in case forks/children are in different - // working directories. - let full_path = std::fs::canonicalize(config_file) - .map_err(|e| CliError::CanonicalizeConfigPathFailed(config_file.into(), e))?; - std::env::set_var(MIRRORD_CONFIG_FILE_ENV, full_path.clone()); - env.insert( - MIRRORD_CONFIG_FILE_ENV.into(), - full_path.to_string_lossy().into(), - ); + std::env::set_var(MIRRORD_CONFIG_FILE_ENV, config_file); } if let Some(target) = args.target.as_ref() { std::env::set_var("MIRRORD_IMPERSONATED_TARGET", target.clone()); - env.insert("MIRRORD_IMPERSONATED_TARGET".into(), target.to_string()); } + let (config, mut context) = LayerConfig::from_env_with_warnings()?; let mut analytics = AnalyticsReporter::only_error(config.telemetry, Default::default(), watch); @@ -69,14 +61,14 @@ pub(crate) async fn extension_exec(args: ExtensionExecArgs, watch: drain::Watch) #[cfg(target_os = "macos")] let execution_result = mirrord_exec( args.executable.as_deref(), - env, + Default::default(), config, progress, &mut analytics, ) .await; #[cfg(not(target_os = "macos"))] - let execution_result = mirrord_exec(env, config, progress, &mut analytics).await; + let execution_result = mirrord_exec(Default::default(), config, progress, &mut analytics).await; if execution_result.is_err() && !analytics.has_error() { analytics.set_error(AnalyticsError::Unknown); diff --git a/mirrord/cli/src/external_proxy.rs b/mirrord/cli/src/external_proxy.rs index b728416b83a..34bf46b82b8 100644 --- a/mirrord/cli/src/external_proxy.rs +++ b/mirrord/cli/src/external_proxy.rs @@ -20,7 +20,7 @@ //! ``` use std::{ - fs::{File, OpenOptions}, + fs::File, io, io::BufReader, net::{Ipv4Addr, SocketAddr}, @@ -41,13 +41,13 @@ use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::server::TlsStream; use tokio_util::{either::Either, sync::CancellationToken}; use tracing::Level; -use tracing_subscriber::EnvFilter; use crate::{ connection::AGENT_CONNECT_INFO_ENV_KEY, error::{CliResult, ExternalProxyError}, execution::MIRRORD_EXECUTION_KIND_ENV, internal_proxy::connect_and_ping, + logging::init_extproxy_tracing_registry, util::{create_listen_socket, detach_io}, }; @@ -60,30 +60,11 @@ fn print_addr(listener: &TcpListener) -> io::Result<()> { } pub async fn proxy(listen_port: u16, watch: drain::Watch) -> CliResult<()> { - let config = LayerConfig::from_env()?; + let config = LayerConfig::recalculate_from_env()?; + init_extproxy_tracing_registry(&config)?; tracing::info!(?config, "external_proxy starting"); - if let Some(log_destination) = config.external_proxy.log_destination.as_ref() { - let output_file = OpenOptions::new() - .create(true) - .append(true) - .open(log_destination) - .map_err(|e| ExternalProxyError::OpenLogFile(log_destination.clone(), e))?; - - let tracing_registry = tracing_subscriber::fmt() - .with_writer(output_file) - .with_ansi(false); - - if let Some(log_level) = config.external_proxy.log_level.as_ref() { - tracing_registry - .with_env_filter(EnvFilter::builder().parse_lossy(log_level)) - .init(); - } else { - tracing_registry.init(); - } - } - let agent_connect_info = std::env::var(AGENT_CONNECT_INFO_ENV_KEY) .ok() .map(|var| { diff --git a/mirrord/cli/src/internal_proxy.rs b/mirrord/cli/src/internal_proxy.rs index 1c8763a0d92..e8210962e84 100644 --- a/mirrord/cli/src/internal_proxy.rs +++ b/mirrord/cli/src/internal_proxy.rs @@ -11,12 +11,9 @@ //! or let the [`OperatorApi`](mirrord_operator::client::OperatorApi) handle the connection. use std::{ - env, - fs::OpenOptions, - io, + env, io, net::{Ipv4Addr, SocketAddr}, - path::PathBuf, - time::{Duration, SystemTime}, + time::Duration, }; use mirrord_analytics::{AnalyticsReporter, CollectAnalytics, Reporter}; @@ -28,15 +25,14 @@ use mirrord_intproxy::{ }; use mirrord_protocol::{ClientMessage, DaemonMessage, LogLevel, LogMessage}; use nix::sys::resource::{setrlimit, Resource}; -use rand::{distributions::Alphanumeric, Rng}; use tokio::net::TcpListener; use tracing::{warn, Level}; -use tracing_subscriber::EnvFilter; use crate::{ connection::AGENT_CONNECT_INFO_ENV_KEY, error::{CliResult, InternalProxyError}, execution::MIRRORD_EXECUTION_KIND_ENV, + logging::init_intproxy_tracing_registry, util::{create_listen_socket, detach_io}, }; @@ -54,46 +50,11 @@ pub(crate) async fn proxy( listen_port: u16, watch: drain::Watch, ) -> CliResult<(), InternalProxyError> { - let config = LayerConfig::from_env()?; + let config = LayerConfig::recalculate_from_env()?; + init_intproxy_tracing_registry(&config)?; tracing::info!(?config, "internal_proxy starting"); - // Setting up default logging for intproxy. - let log_destination = config - .internal_proxy - .log_destination - .as_ref() - .map(PathBuf::from) - .unwrap_or_else(|| { - let random_name: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let timestamp = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs(); - - PathBuf::from(format!( - "/tmp/mirrord-intproxy-{timestamp}-{random_name}.log" - )) - }); - - let output_file = OpenOptions::new() - .create(true) - .append(true) - .open(&log_destination) - .map_err(|fail| { - InternalProxyError::OpenLogFile(log_destination.to_string_lossy().to_string(), fail) - })?; - - let log_level = config.internal_proxy.log_level.as_deref().unwrap_or("info"); - - tracing_subscriber::fmt() - .with_writer(output_file) - .with_ansi(false) - .with_env_filter(EnvFilter::builder().parse_lossy(log_level)) - .pretty() - .init(); - // According to https://wilsonmar.github.io/maximum-limits/ this is the limit on macOS // so we assume Linux can be higher and set to that. if let Err(error) = setrlimit(Resource::RLIMIT_NOFILE, 12288, 12288) { diff --git a/mirrord/cli/src/list.rs b/mirrord/cli/src/list.rs new file mode 100644 index 00000000000..2b8059bad62 --- /dev/null +++ b/mirrord/cli/src/list.rs @@ -0,0 +1,215 @@ +use std::sync::LazyLock; + +use futures::TryStreamExt; +use k8s_openapi::api::core::v1::Namespace; +use kube::Client; +use mirrord_analytics::NullReporter; +use mirrord_config::{ + config::{ConfigContext, MirrordConfig}, + LayerConfig, LayerFileConfig, +}; +use mirrord_kube::{ + api::kubernetes::{create_kube_config, seeker::KubeResourceSeeker}, + error::KubeApiError, +}; +use mirrord_operator::client::OperatorApi; +use semver::VersionReq; +use serde::{ser::SerializeSeq, Serialize, Serializer}; + +use crate::{util, CliError, CliResult, Format, ListTargetArgs}; + +/// A mirrord target found in the cluster. +#[derive(Serialize)] +struct FoundTarget { + /// E.g `pod/my-pod-1234/container/my-container`. + path: String, + + /// Whether this target is currently available. + /// + /// # Note + /// + /// Right now this is always true. Some preliminary checks are done in the + /// [`KubeResourceSeeker`] and results come filtered. + /// + /// This field is here for forward compatibility, because in the future we might want to return + /// unavailable targets as well (along with some validation error message) to improve UX. + available: bool, +} + +/// Result of mirrord targets lookup in the cluster. +#[derive(Serialize)] +struct FoundTargets { + /// In order: + /// 1. deployments + /// 2. rollouts + /// 3. statefulsets + /// 4. cronjobs + /// 5. jobs + /// 6. pods + targets: Vec, + + /// Current lookup namespace. + /// + /// Taken from [`LayerConfig::target`], defaults to [`Client`]'s default namespace. + current_namespace: String, + + /// Available lookup namespaces. + namespaces: Vec, +} + +impl FoundTargets { + /// Performs a lookup of mirrord targets in the cluster. + /// + /// Unless the operator is explicitly disabled, attempts to connect with it. + /// Operator lookup affects returned results (e.g some targets are only available via the + /// operator). + /// + /// If `fetch_namespaces` is set, returned [`FoundTargets`] will contain info about namespaces + /// available in the cluster. + async fn resolve(config: LayerConfig, fetch_namespaces: bool) -> CliResult { + let client = create_kube_config( + config.accept_invalid_certificates, + config.kubeconfig.clone(), + config.kube_context.clone(), + ) + .await + .and_then(|config| Client::try_from(config).map_err(From::from)) + .map_err(|error| { + CliError::friendlier_error_or_else(error, CliError::CreateKubeApiFailed) + })?; + + let mut reporter = NullReporter::default(); + let operator_api = if config.operator != Some(false) + && let Some(api) = OperatorApi::try_new(&config, &mut reporter).await? + { + let api = api.prepare_client_cert(&mut reporter).await; + + api.inspect_cert_error( + |error| tracing::error!(%error, "failed to prepare client certificate"), + ); + + Some(api) + } else { + None + }; + + let seeker = KubeResourceSeeker { + client: &client, + namespace: config.target.namespace.as_deref(), + }; + let paths = match operator_api { + None if config.operator == Some(true) => Err(CliError::OperatorNotInstalled), + + Some(api) + if ALL_TARGETS_SUPPORTED_OPERATOR_VERSION + .matches(&api.operator().spec.operator_version) => + { + seeker.all().await.map_err(|error| { + CliError::friendlier_error_or_else(error, CliError::ListTargetsFailed) + }) + } + + _ => seeker.all_open_source().await.map_err(|error| { + CliError::friendlier_error_or_else(error, CliError::ListTargetsFailed) + }), + }?; + + let targets = paths + .into_iter() + .map(|path| FoundTarget { + path, + available: true, + }) + .collect(); + let current_namespace = config + .target + .namespace + .as_deref() + .unwrap_or(client.default_namespace()) + .to_owned(); + + let namespaces = if fetch_namespaces { + seeker + .list_all_clusterwide::(None) + .try_filter_map(|namespace| std::future::ready(Ok(namespace.metadata.name))) + .try_collect::>() + .await + .map_err(KubeApiError::KubeError) + .map_err(|error| { + CliError::friendlier_error_or_else(error, CliError::ListTargetsFailed) + })? + } else { + Default::default() + }; + + Ok(Self { + targets, + current_namespace, + namespaces, + }) + } +} + +/// Thin wrapper over [`FoundTargets`] that implements [`Serialize`]. +/// Its serialized format is a sequence of available target paths. +/// +/// Used to print available targets when the plugin/extension does not support the full format +/// (backward compatibility). +struct FoundTargetsList<'a>(&'a FoundTargets); + +impl Serialize for FoundTargetsList<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let count = self.0.targets.iter().filter(|t| t.available).count(); + let mut list = serializer.serialize_seq(Some(count))?; + + for target in self.0.targets.iter().filter(|t| t.available) { + list.serialize_element(&target.path)?; + } + + list.end() + } +} + +/// Controls whether we support listing all targets or just the open source ones. +static ALL_TARGETS_SUPPORTED_OPERATOR_VERSION: LazyLock = + LazyLock::new(|| ">=3.84.0".parse().expect("version should be valid")); + +/// Fetches mirrord targets from the cluster and prints output to stdout. +/// +/// When `rich_output` is set, targets info is printed as a JSON object containing extra data. +/// Otherwise, targets are printed as a plain JSON array of strings (backward compatibility). +pub(super) async fn print_targets(args: ListTargetArgs, rich_output: bool) -> CliResult<()> { + let mut layer_config = if let Some(config) = &args.config_file { + let mut cfg_context = ConfigContext::default(); + LayerFileConfig::from_path(config)?.generate_config(&mut cfg_context)? + } else { + LayerConfig::from_env()? + }; + + if let Some(namespace) = args.namespace { + layer_config.target.namespace.replace(namespace); + }; + + if !layer_config.use_proxy { + util::remove_proxy_env(); + } + + let targets = FoundTargets::resolve(layer_config, rich_output).await?; + + match args.output { + Format::Json => { + let serialized = if rich_output { + serde_json::to_string(&targets).unwrap() + } else { + serde_json::to_string(&FoundTargetsList(&targets)).unwrap() + }; + + println!("{serialized}"); + } + } + + Ok(()) +} diff --git a/mirrord/cli/src/logging.rs b/mirrord/cli/src/logging.rs new file mode 100644 index 00000000000..8b9fba787ba --- /dev/null +++ b/mirrord/cli/src/logging.rs @@ -0,0 +1,192 @@ +use std::{ + fs::OpenOptions, + future::Future, + path::{Path, PathBuf}, + time::SystemTime, +}; + +use futures::StreamExt; +use mirrord_config::LayerConfig; +use rand::{distributions::Alphanumeric, Rng}; +use tokio::io::AsyncWriteExt; +use tokio_stream::Stream; +use tracing_subscriber::{prelude::*, EnvFilter}; + +use crate::{ + config::Commands, + error::{CliError, ExternalProxyError, InternalProxyError}, +}; + +// only ls and ext commands need the errors in json format +// error logs are disabled for extensions +fn init_ext_error_handler(commands: &Commands) -> bool { + match commands { + Commands::ListTargets(_) | Commands::ExtensionExec(_) => { + let _ = miette::set_hook(Box::new(|_| Box::new(miette::JSONReportHandler::new()))); + + true + } + _ => false, + } +} + +pub async fn init_tracing_registry( + command: &Commands, + watch: drain::Watch, +) -> Result<(), CliError> { + if let Ok(console_addr) = std::env::var("MIRRORD_CONSOLE_ADDR") { + mirrord_console::init_async_logger(&console_addr, watch.clone(), 124).await?; + + return Ok(()); + } + + if matches!( + command, + Commands::InternalProxy { .. } | Commands::ExternalProxy { .. } + ) { + return Ok(()); + } + + // There are situations where even if running "ext" commands that shouldn't log, we want those + // to log to be able to debug issues. + let force_log = std::env::var("MIRRORD_FORCE_LOG") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(false); + + if force_log || init_ext_error_handler(command) { + tracing_subscriber::registry() + .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) + .with(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + } + + Ok(()) +} + +fn default_logfile_path(prefix: &str) -> PathBuf { + let random_name: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(); + let timestamp = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs(); + + PathBuf::from(format!("/tmp/{prefix}-{timestamp}-{random_name}.log")) +} + +fn init_proxy_tracing_registry( + log_destination: &Path, + log_level: Option<&str>, +) -> std::io::Result<()> { + if std::env::var("MIRRORD_CONSOLE_ADDR").is_ok() { + return Ok(()); + } + + let output_file = OpenOptions::new() + .create(true) + .append(true) + .open(log_destination)?; + + let env_filter = log_level + .map(|log_level| EnvFilter::builder().parse_lossy(log_level)) + .unwrap_or_else(EnvFilter::from_default_env); + + tracing_subscriber::fmt() + .with_writer(output_file) + .with_ansi(false) + .with_env_filter(env_filter) + .pretty() + .init(); + + Ok(()) +} + +pub fn init_intproxy_tracing_registry(config: &LayerConfig) -> Result<(), InternalProxyError> { + if !config.internal_proxy.container_mode { + // Setting up default logging for intproxy. + let log_destination = config + .internal_proxy + .log_destination + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| default_logfile_path("mirrord-intproxy")); + + init_proxy_tracing_registry(&log_destination, config.internal_proxy.log_level.as_deref()) + .map_err(|fail| { + InternalProxyError::OpenLogFile(log_destination.to_string_lossy().to_string(), fail) + }) + } else { + let env_filter = config + .internal_proxy + .log_level + .as_ref() + .map(|log_level| EnvFilter::builder().parse_lossy(log_level)) + .unwrap_or_else(EnvFilter::from_default_env); + + tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_ansi(false) + .with_env_filter(env_filter) + .pretty() + .init(); + + Ok(()) + } +} + +pub fn init_extproxy_tracing_registry(config: &LayerConfig) -> Result<(), ExternalProxyError> { + // Setting up default logging for extproxy. + let log_destination = config + .external_proxy + .log_destination + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| default_logfile_path("mirrord-extproxy")); + + init_proxy_tracing_registry(&log_destination, config.external_proxy.log_level.as_deref()) + .map_err(|fail| { + ExternalProxyError::OpenLogFile(log_destination.to_string_lossy().to_string(), fail) + }) +} + +pub async fn pipe_intproxy_sidecar_logs<'s, S>( + config: &LayerConfig, + stream: S, +) -> Result + 's, InternalProxyError> +where + S: Stream> + 's, +{ + let log_destination = config + .internal_proxy + .log_destination + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| default_logfile_path("mirrord-intproxy")); + + let mut output_file = tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&log_destination) + .await + .map_err(|fail| { + InternalProxyError::OpenLogFile(log_destination.to_string_lossy().to_string(), fail) + })?; + + Ok(async move { + let mut stream = std::pin::pin!(stream); + + while let Some(line) = stream.next().await { + let result: std::io::Result<_> = try { + output_file.write_all(line?.as_bytes()).await?; + output_file.write_u8(b'\n').await?; + + output_file.flush().await?; + }; + + if let Err(error) = result { + tracing::error!(?error, "unable to pipe logs from intproxy"); + } + } + }) +} diff --git a/mirrord/cli/src/main.rs b/mirrord/cli/src/main.rs index 27923cab4fd..2ca498c83fc 100644 --- a/mirrord/cli/src/main.rs +++ b/mirrord/cli/src/main.rs @@ -5,7 +5,7 @@ use std::{ collections::HashMap, env::vars, ffi::CString, net::SocketAddr, os::unix::ffi::OsStrExt, - sync::LazyLock, time::Duration, + time::Duration, }; use clap::{CommandFactory, Parser}; @@ -17,13 +17,10 @@ use diagnose::diagnose_command; use execution::MirrordExecution; use extension::extension_exec; use extract::extract_library; -use kube::Client; -use miette::JSONReportHandler; use mirrord_analytics::{ - AnalyticsError, AnalyticsReporter, CollectAnalytics, ExecutionKind, NullReporter, Reporter, + AnalyticsError, AnalyticsReporter, CollectAnalytics, ExecutionKind, Reporter, }; use mirrord_config::{ - config::{ConfigContext, MirrordConfig}, feature::{ fs::FsModeConfig, network::{ @@ -34,18 +31,14 @@ use mirrord_config::{ LayerConfig, LayerFileConfig, MIRRORD_CONFIG_FILE_ENV, }; use mirrord_intproxy::agent_conn::{AgentConnection, AgentConnectionError}; -use mirrord_kube::api::kubernetes::{create_kube_config, seeker::KubeResourceSeeker}; -use mirrord_operator::client::OperatorApi; use mirrord_progress::{messages::EXEC_CONTAINER_BINARY, Progress, ProgressTracker}; #[cfg(all(target_os = "macos", target_arch = "aarch64"))] use nix::errno::Errno; use operator::operator_command; use port_forward::{PortForwardError, PortForwarder, ReversePortForwarder}; use regex::Regex; -use semver::{Version, VersionReq}; -use serde_json::json; +use semver::Version; use tracing::{error, info, warn}; -use tracing_subscriber::{fmt, prelude::*, registry, EnvFilter}; use which::which; mod config; @@ -58,8 +51,10 @@ mod extension; mod external_proxy; mod extract; mod internal_proxy; +mod list; +mod logging; mod operator; -pub mod port_forward; +mod port_forward; mod teams; mod util; mod verify_config; @@ -68,14 +63,8 @@ mod vpn; pub(crate) use error::{CliError, CliResult}; use verify_config::verify_config; -use crate::util::remove_proxy_env; - -/// Controls whether we support listing all targets or just the open source ones. -static ALL_TARGETS_SUPPORTED_OPERATOR_VERSION: LazyLock = - LazyLock::new(|| ">=3.84.0".parse().expect("verion should be valid")); - async fn exec_process

( - config: LayerConfig, + mut config: LayerConfig, args: &ExecArgs, progress: &P, analytics: &mut AnalyticsReporter, @@ -86,10 +75,15 @@ where let mut sub_progress = progress.subtask("preparing to launch process"); #[cfg(target_os = "macos")] - let execution_info = - MirrordExecution::start(&config, Some(&args.binary), &mut sub_progress, analytics).await?; + let execution_info = MirrordExecution::start( + &mut config, + Some(&args.binary), + &mut sub_progress, + analytics, + ) + .await?; #[cfg(not(target_os = "macos"))] - let execution_info = MirrordExecution::start(&config, &mut sub_progress, analytics).await?; + let execution_info = MirrordExecution::start(&mut config, &mut sub_progress, analytics).await?; // This is not being yielded, as this is not proper async, something along those lines. // We need an `await` somewhere in this function to drive our socket IO that happens @@ -363,6 +357,7 @@ async fn exec(args: &ExecArgs, watch: drain::Watch) -> CliResult<()> { std::env::set_var(name, value); } + // LayerConfig must be created after setting relevant env vars let (config, mut context) = LayerConfig::from_env_with_warnings()?; let mut analytics = AnalyticsReporter::only_error(config.telemetry, Default::default(), watch); @@ -382,100 +377,6 @@ async fn exec(args: &ExecArgs, watch: drain::Watch) -> CliResult<()> { execution_result } -/// Lists targets based on whether or not the operator has been enabled in `layer_config`. -/// If the operator is enabled (and we can reach it), then we list [`KubeResourceSeeker::all`] -/// targets, otherwise we list [`KubeResourceSeeker::all_open_source`] only. -async fn list_targets(layer_config: &LayerConfig, args: &ListTargetArgs) -> CliResult> { - let client = create_kube_config( - layer_config.accept_invalid_certificates, - layer_config.kubeconfig.clone(), - layer_config.kube_context.clone(), - ) - .await - .and_then(|config| Client::try_from(config).map_err(From::from)) - .map_err(|error| CliError::friendlier_error_or_else(error, CliError::CreateKubeApiFailed))?; - - let namespace = args - .namespace - .as_deref() - .or(layer_config.target.namespace.as_deref()); - - let seeker = KubeResourceSeeker { - client: &client, - namespace, - }; - - let mut reporter = NullReporter::default(); - - let operator_api = if layer_config.operator != Some(false) - && let Some(api) = OperatorApi::try_new(layer_config, &mut reporter).await? - { - let api = api.prepare_client_cert(&mut reporter).await; - - api.inspect_cert_error( - |error| tracing::error!(%error, "failed to prepare client certificate"), - ); - - Some(api) - } else { - None - }; - - match operator_api { - None if layer_config.operator == Some(true) => Err(CliError::OperatorNotInstalled), - Some(api) - if ALL_TARGETS_SUPPORTED_OPERATOR_VERSION - .matches(&api.operator().spec.operator_version) => - { - seeker.all().await.map_err(|error| { - CliError::friendlier_error_or_else(error, CliError::ListTargetsFailed) - }) - } - _ => seeker.all_open_source().await.map_err(|error| { - CliError::friendlier_error_or_else(error, CliError::ListTargetsFailed) - }), - } -} - -/// Lists all possible target paths. -/// Tries to use operator if available, otherwise falls back to k8s API (if operator isn't -/// explicitly true). Example: -/// ``` -/// [ -/// "pod/metalbear-deployment-85c754c75f-982p5", -/// "pod/nginx-deployment-66b6c48dd5-dc9wk", -/// "pod/py-serv-deployment-5c57fbdc98-pdbn4/container/py-serv", -/// "deployment/nginx-deployment" -/// "deployment/nginx-deployment/container/nginx" -/// "rollout/nginx-rollout" -/// "statefulset/nginx-statefulset" -/// "statefulset/nginx-statefulset/container/nginx" -/// ] -/// ``` -async fn print_targets(args: &ListTargetArgs) -> CliResult<()> { - let mut layer_config = if let Some(config) = &args.config_file { - let mut cfg_context = ConfigContext::default(); - LayerFileConfig::from_path(config)?.generate_config(&mut cfg_context)? - } else { - LayerConfig::from_env()? - }; - - if let Some(namespace) = &args.namespace { - layer_config.target.namespace = Some(namespace.clone()); - }; - - if !layer_config.use_proxy { - remove_proxy_env(); - } - - // The targets come sorted in the following order: - // `deployments - rollouts - statefulsets - cronjobs - jobs - pods` - let targets = list_targets(&layer_config, args).await?; - let json_obj = json!(targets); - println!("{json_obj}"); - Ok(()) -} - async fn port_forward(args: &PortForwardArgs, watch: drain::Watch) -> CliResult<()> { fn hash_port_mappings( args: &PortForwardArgs, @@ -577,6 +478,7 @@ async fn port_forward(args: &PortForwardArgs, watch: drain::Watch) -> CliResult< std::env::set_var("MIRRORD_CONFIG_FILE", config_file); } + // LayerConfig must be created after setting relevant env vars let (config, mut context) = LayerConfig::from_env_with_warnings()?; let mut analytics = AnalyticsReporter::new(config.telemetry, ExecutionKind::PortForward, watch); @@ -650,21 +552,8 @@ fn main() -> miette::Result<()> { let (signal, watch) = drain::channel(); - // There are situations where even if running "ext" commands that shouldn't log, we want those - // to log to be able to debug issues. - let force_log = std::env::var("MIRRORD_FORCE_LOG") - .map(|s| s.parse().unwrap_or(false)) - .unwrap_or(false); - let res: CliResult<(), CliError> = rt.block_on(async move { - if let Ok(console_addr) = std::env::var("MIRRORD_CONSOLE_ADDR") { - mirrord_console::init_async_logger(&console_addr, watch.clone(), 124).await?; - } else if force_log || !init_ext_error_handler(&cli.commands) { - registry() - .with(fmt::layer().with_writer(std::io::stderr)) - .with(EnvFilter::from_default_env()) - .init(); - } + logging::init_tracing_registry(&cli.commands, watch.clone()).await?; match cli.commands { Commands::Exec(args) => exec(&args, watch).await?, @@ -675,7 +564,14 @@ fn main() -> miette::Result<()> { false, )?; } - Commands::ListTargets(args) => print_targets(&args).await?, + Commands::ListTargets(args) => { + let rich_output = std::env::var(ListTargetArgs::RICH_OUTPUT_ENV) + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or_default(); + + list::print_targets(*args, rich_output).await? + } Commands::Operator(args) => operator_command(*args).await?, Commands::ExtensionExec(args) => { extension_exec(*args, watch).await?; @@ -720,19 +616,6 @@ fn main() -> miette::Result<()> { res.map_err(Into::into) } -// only ls and ext commands need the errors in json format -// error logs are disabled for extensions -fn init_ext_error_handler(commands: &Commands) -> bool { - match commands { - Commands::ListTargets(_) | Commands::ExtensionExec(_) => { - let _ = miette::set_hook(Box::new(|_| Box::new(JSONReportHandler::new()))); - true - } - Commands::InternalProxy { .. } | Commands::ExternalProxy { .. } => true, - _ => false, - } -} - async fn prompt_outdated_version(progress: &ProgressTracker) { let mut progress = progress.subtask("version check"); let check_version: bool = std::env::var("MIRRORD_CHECK_VERSION") diff --git a/mirrord/cli/src/port_forward.rs b/mirrord/cli/src/port_forward.rs index ba59679d945..220dc1c337e 100644 --- a/mirrord/cli/src/port_forward.rs +++ b/mirrord/cli/src/port_forward.rs @@ -1,6 +1,6 @@ use std::{ collections::{HashMap, HashSet, VecDeque}, - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, time::{Duration, Instant}, }; @@ -11,12 +11,8 @@ use mirrord_config::feature::network::incoming::{ }; use mirrord_intproxy::{ background_tasks::{BackgroundTasks, TaskError, TaskSender, TaskUpdate}, - error::IntProxyError, - main_tasks::{MainTaskId, ProxyMessage, ToLayer}, - proxies::incoming::{ - port_subscription_ext::PortSubscriptionExt, IncomingProxy, IncomingProxyError, - IncomingProxyMessage, - }, + main_tasks::{ProxyMessage, ToLayer}, + proxies::incoming::{IncomingProxy, IncomingProxyError, IncomingProxyMessage}, }; use mirrord_intproxy_protocol::{ IncomingRequest, IncomingResponse, LayerId, PortSubscribe, PortSubscription, @@ -28,9 +24,8 @@ use mirrord_protocol::{ tcp::{DaemonTcpOutgoing, LayerTcpOutgoing}, LayerClose, LayerConnect, LayerWrite, SocketAddress, }, - tcp::{Filter, HttpFilter, LayerTcp, LayerTcpSteal, StealType}, - ClientMessage, ConnectionId, DaemonMessage, LogLevel, Port, ResponseError, - CLIENT_READY_FOR_LOGS, + tcp::{Filter, HttpFilter, StealType}, + ClientMessage, ConnectionId, DaemonMessage, LogLevel, Port, CLIENT_READY_FOR_LOGS, }; use thiserror::Error; use tokio::{ @@ -316,6 +311,7 @@ impl PortForwarder { DaemonMessage::LogMessage(log_message) => match log_message.level { LogLevel::Warn => tracing::warn!("agent log: {}", log_message.message), LogLevel::Error => tracing::error!("agent log: {}", log_message.message), + LogLevel::Info => tracing::info!("agent log: {}", log_message.message), }, DaemonMessage::Close(error) => { return Err(PortForwardError::AgentError(error)); @@ -427,44 +423,57 @@ impl PortForwarder { } pub struct ReversePortForwarder { - /// details for traffic mirroring or stealing - incoming_mode: IncomingMode, /// communicates with the agent (only TCP supported). agent_connection: AgentConnection, - /// associates destination ports with local ports. - mappings: HashMap, - /// background task (uses IncomingProxy to communicate with layer) - background_tasks: BackgroundTasks, + /// background task (uses [`IncomingProxy`] to communicate with layer) + background_tasks: BackgroundTasks<(), ProxyMessage, IncomingProxyError>, /// incoming proxy background task tx incoming_proxy: TaskSender, - - /// true if Ping has been sent to agent. + /// `true` if [`ClientMessage::Ping`] has been sent to agent and we're waiting for the the + /// [`DaemonMessage::Pong`] waiting_for_pong: bool, ping_pong_timeout: Instant, } impl ReversePortForwarder { pub(crate) async fn new( - agent_connection: AgentConnection, + mut agent_connection: AgentConnection, mappings: HashMap, network_config: IncomingConfig, ) -> Result { - // setup IncomingProxy - let mut background_tasks: BackgroundTasks = + let mut background_tasks: BackgroundTasks<(), ProxyMessage, IncomingProxyError> = Default::default(); - let incoming = - background_tasks.register(IncomingProxy::default(), MainTaskId::IncomingProxy, 512); - // construct IncomingMode from config file + let incoming = background_tasks.register(IncomingProxy::default(), (), 512); + + agent_connection + .sender + .send(ClientMessage::SwitchProtocolVersion( + mirrord_protocol::VERSION.clone(), + )) + .await?; + let protocol_version = match agent_connection.receiver.recv().await { + Some(DaemonMessage::SwitchProtocolVersionResponse(version)) => version, + _ => return Err(PortForwardError::AgentConnectionFailed), + }; + + if CLIENT_READY_FOR_LOGS.matches(&protocol_version) { + agent_connection + .sender + .send(ClientMessage::ReadyForLogs) + .await?; + } + + incoming + .send(IncomingProxyMessage::AgentProtocolVersion(protocol_version)) + .await; + let incoming_mode = IncomingMode::new(&network_config); for (i, (&remote, &local)) in mappings.iter().enumerate() { - // send subscription to incoming proxy let subscription = incoming_mode.subscription(remote); let message_id = i as u64; let layer_id = LayerId(1); let req = IncomingRequest::PortSubscribe(PortSubscribe { - listening_on: format!("127.0.0.1:{local}") - .parse() - .expect("Error parsing socket address"), + listening_on: SocketAddr::new(Ipv4Addr::LOCALHOST.into(), local), subscription, }); incoming @@ -475,9 +484,7 @@ impl ReversePortForwarder { } Ok(Self { - incoming_mode, agent_connection, - mappings, background_tasks, incoming_proxy: incoming, waiting_for_pong: false, @@ -486,31 +493,6 @@ impl ReversePortForwarder { } pub(crate) async fn run(&mut self) -> Result<(), PortForwardError> { - // setup agent connection - self.agent_connection - .sender - .send(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )) - .await?; - match self.agent_connection.receiver.recv().await { - Some(DaemonMessage::SwitchProtocolVersionResponse(version)) - if CLIENT_READY_FOR_LOGS.matches(&version) => - { - self.agent_connection - .sender - .send(ClientMessage::ReadyForLogs) - .await?; - } - _ => return Err(PortForwardError::AgentConnectionFailed), - } - - for remote_port in self.mappings.keys() { - let subscription = self.incoming_mode.subscription(*remote_port); - let msg = subscription.agent_subscribe(); - self.agent_connection.sender.send(msg).await? - } - loop { select! { _ = tokio::time::sleep_until(self.ping_pong_timeout.into()) => { @@ -531,8 +513,8 @@ impl ReversePortForwarder { }, }, - Some((task_id, update)) = self.background_tasks.next() => { - self.handle_msg_from_local(task_id, update).await? + Some((_, update)) = self.background_tasks.next() => { + self.handle_msg_from_local(update).await? }, } } @@ -557,6 +539,7 @@ impl ReversePortForwarder { DaemonMessage::LogMessage(log_message) => match log_message.level { LogLevel::Warn => tracing::warn!("agent log: {}", log_message.message), LogLevel::Error => tracing::error!("agent log: {}", log_message.message), + LogLevel::Info => tracing::info!("agent log: {}", log_message.message), }, DaemonMessage::Close(error) => { return Err(PortForwardError::AgentError(error)); @@ -564,8 +547,8 @@ impl ReversePortForwarder { DaemonMessage::Pong if self.waiting_for_pong => { self.waiting_for_pong = false; } + // Includes unexpected DaemonMessage::Pong other => { - // includes unexepcted DaemonMessage::Pong return Err(PortForwardError::AgentError(format!( "unexpected message from agent: {other:?}" ))); @@ -578,20 +561,11 @@ impl ReversePortForwarder { #[tracing::instrument(level = Level::TRACE, skip(self), err)] async fn handle_msg_from_local( &mut self, - task_id: MainTaskId, - update: TaskUpdate, + update: TaskUpdate, ) -> Result<(), PortForwardError> { - match (task_id, update) { - (MainTaskId::IncomingProxy, TaskUpdate::Message(message)) => match message { + match update { + TaskUpdate::Message(message) => match message { ProxyMessage::ToAgent(message) => { - if matches!( - message, - ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(_)) - | ClientMessage::Tcp(LayerTcp::PortSubscribe(_)) - ) { - // suppress additional subscription requests - return Ok(()); - } self.agent_connection.sender.send(message).await?; } ProxyMessage::ToLayer(ToLayer { @@ -599,9 +573,7 @@ impl ReversePortForwarder { .. }) => { if let Err(error) = res { - return Err(PortForwardError::from(IntProxyError::from( - IncomingProxyError::SubscriptionFailed(error), - ))); + return Err(IncomingProxyError::SubscriptionFailed(error).into()); } } other => { @@ -610,21 +582,20 @@ impl ReversePortForwarder { ) } }, - (MainTaskId::IncomingProxy, TaskUpdate::Finished(result)) => match result { + + TaskUpdate::Finished(result) => match result { Ok(()) => { - tracing::error!("incoming proxy task finished unexpectedly"); - return Err(IntProxyError::TaskExit(task_id).into()); + unreachable!( + "IncomingProxy should not finish, task sender is alive in this struct" + ); } Err(TaskError::Error(e)) => { - tracing::error!("incoming proxy task failed: {e}"); return Err(e.into()); } Err(TaskError::Panic) => { - tracing::error!("incoming proxy task panicked"); - return Err(IntProxyError::TaskPanic(task_id).into()); + return Err(PortForwardError::IncomingProxyPanicked); } }, - _ => unreachable!("other task types are never used in port forwarding"), } Ok(()) @@ -969,41 +940,27 @@ impl IncomingMode { #[derive(Debug, Error)] pub enum PortForwardError { - // setup errors - #[error("wrong combination of arguments used: {0}")] - ArgsError(String), - #[error("multiple port forwarding mappings found for local address `{0}`")] PortMapSetupError(SocketAddr), #[error("multiple port forwarding mappings found for desination port `{0:?}`")] ReversePortMapSetupError(RemotePort), - #[error("no port forwarding mappings were provided")] - NoMappingsError(), - - // running errors #[error("agent closed connection with error: `{0}`")] AgentError(String), #[error("connection with the agent failed")] AgentConnectionFailed, - #[error("error from Incoming Proxy task")] - IncomingProxyError(IntProxyError), + #[error("error from the IncomingProxy task: {0}")] + IncomingProxyError(#[from] IncomingProxyError), - #[error("failed to send Ping to agent: `{0}`")] - PingError(String), + #[error("IncomingProxy task panicked")] + IncomingProxyPanicked, #[error("TcpListener operation failed with error: `{0}`")] TcpListenerError(std::io::Error), - #[error("TcpStream operation failed with error: `{0}`")] - TcpStreamError(std::io::Error), - - #[error("no destination address found for local address `{0}`")] - SocketMappingNotFound(SocketAddr), - #[error("no task for socket {0} ready to receive connection ID: `{1}`")] ReadyTaskNotFound(SocketAddr, ConnectionId), @@ -1012,9 +969,6 @@ pub enum PortForwardError { #[error("failed to establish connection with remote process: `{0}`")] ConnectionError(String), - - #[error("failed to subscribe to remote port: `{0}`")] - SubscriptionError(ResponseError), } impl From> for PortForwardError { @@ -1023,12 +977,6 @@ impl From> for PortForwardError { } } -impl From for PortForwardError { - fn from(value: IntProxyError) -> Self { - Self::IncomingProxyError(value) - } -} - #[cfg(test)] mod test { use std::{ @@ -1044,9 +992,9 @@ mod test { DaemonConnect, DaemonRead, LayerConnect, LayerWrite, SocketAddress, }, tcp::{ - DaemonTcp, Filter, HttpRequest, HttpResponse, InternalHttpRequest, - InternalHttpResponse, LayerTcp, LayerTcpSteal, NewTcpConnection, StealType, TcpClose, - TcpData, + DaemonTcp, Filter, HttpRequest, HttpResponse, InternalHttpBody, InternalHttpBodyFrame, + InternalHttpRequest, InternalHttpResponse, LayerTcp, LayerTcpSteal, NewTcpConnection, + StealType, TcpClose, TcpData, }, ClientMessage, DaemonMessage, }; @@ -1064,90 +1012,142 @@ mod test { RemoteAddr, }; + /// Connects [`ReversePortForwarder`] with test code with [`ClientMessage`] and + /// [`DaemonMessage`] channels. Runs a background [`tokio::task`] that auto responds to + /// standard [`mirrord_protocol`] messages (e.g [`ClientMessage::Ping`]). + struct TestAgentConnection { + daemon_msg_tx: mpsc::Sender, + client_msg_rx: mpsc::Receiver, + } + + impl TestAgentConnection { + fn new() -> (Self, AgentConnection) { + let (daemon_to_forwarder, daemon_from_forwarder) = mpsc::channel::(8); + let (client_task_to_test, client_task_from_test) = mpsc::channel::(8); + let (client_forwarder_to_task, client_task_from_forwarder) = + mpsc::channel::(8); + + tokio::spawn(Self::auto_responder( + client_task_from_forwarder, + client_task_to_test, + daemon_to_forwarder.clone(), + )); + + ( + Self { + daemon_msg_tx: daemon_to_forwarder, + client_msg_rx: client_task_from_test, + }, + AgentConnection { + sender: client_forwarder_to_task, + receiver: daemon_from_forwarder, + }, + ) + } + + /// Sends the [`DaemonMessage`] to the [`ReversePortForwarder`]. + async fn send(&self, message: DaemonMessage) { + self.daemon_msg_tx.send(message).await.unwrap(); + } + + /// Receives a [`ClientMessage`] from the [`ReversePortForwarder`]. + /// + /// Some standard messages are handled internally and are never returned: + /// 1. [`ClientMessage::Ping`] + /// 2. [`ClientMessage::SwitchProtocolVersion`] + /// 3. [`ClientMessage::ReadyForLogs`] + async fn recv(&mut self) -> ClientMessage { + self.client_msg_rx.recv().await.unwrap() + } + + async fn auto_responder( + mut rx: mpsc::Receiver, + tx_to_test_code: mpsc::Sender, + tx_to_port_forwarder: mpsc::Sender, + ) { + loop { + let Some(message) = rx.recv().await else { + break; + }; + + match message { + ClientMessage::Ping => { + tx_to_port_forwarder + .send(DaemonMessage::Pong) + .await + .unwrap(); + } + ClientMessage::ReadyForLogs => {} + ClientMessage::SwitchProtocolVersion(version) => { + tx_to_port_forwarder + .send(DaemonMessage::SwitchProtocolVersionResponse( + std::cmp::min(&version, &*mirrord_protocol::VERSION).clone(), + )) + .await + .unwrap(); + } + other => tx_to_test_code.send(other).await.unwrap(), + } + } + } + } + #[tokio::test] async fn single_port_forwarding() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); drop(listener); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; - let remote_destination = (RemoteAddr::Ip("152.37.40.40".parse().unwrap()), 3038); + let remote_ip = "152.37.40.40".parse::().unwrap(); + let remote_destination = (RemoteAddr::Ip(remote_ip), 3038); let mappings = HashMap::from([(local_destination, remote_destination.clone())]); - tokio::spawn(async move { - let mut port_forwarder = PortForwarder::new(agent_connection, mappings) - .await - .unwrap(); - port_forwarder.run().await.unwrap() - }); - - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) + // Prepare listeners before sending work to the background task. + let mut port_forwarder = PortForwarder::new(agent_connection, mappings) .await .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); + tokio::spawn(async move { port_forwarder.run().await.unwrap() }); - // send data to socket + // Connect to PortForwarders listener and send some data to trigger remote connection + // request. let mut stream = TcpStream::connect(local_destination).await.unwrap(); stream.write_all(b"data-my-beloved").await.unwrap(); - // expect Connect on client_msg_rx - let remote_address = SocketAddress::Ip("152.37.40.40:3038".parse().unwrap()); + // Expect a connection request + let remote_address = SocketAddress::Ip(SocketAddr::new(remote_ip.into(), 3038)); let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { remote_address: remote_address.clone(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected,); // reply with successful on daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( DaemonConnect { connection_id: 1, - remote_address: remote_address.clone(), - local_address: remote_address, + remote_address, + local_address: "1.2.3.4:2137".parse::().unwrap().into(), }, )))) - .await - .unwrap(); + .await; let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { connection_id: 1, bytes: b"data-my-beloved".to_vec(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // send response data from agent on daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( DaemonRead { connection_id: 1, bytes: b"reply-my-beloved".to_vec(), }, )))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 16]; @@ -1167,38 +1167,17 @@ mod test { let local_destination_2 = listener.local_addr().unwrap(); drop(listener); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; + let (mut test_connection, agent_connection) = TestAgentConnection::new(); let mappings = HashMap::from([ (local_destination_1, remote_destination_1.clone()), (local_destination_2, remote_destination_2.clone()), ]); - tokio::spawn(async move { - let mut port_forwarder = PortForwarder::new(agent_connection, mappings) - .await - .unwrap(); - port_forwarder.run().await.unwrap() - }); - - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) + // Prepare listeners before sending work to the background task. + let mut port_forwarder = PortForwarder::new(agent_connection, mappings) .await .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); + tokio::spawn(async move { port_forwarder.run().await.unwrap() }); // send data to first socket let mut stream_1 = TcpStream::connect(local_destination_1).await.unwrap(); @@ -1214,11 +1193,7 @@ mod test { let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { remote_address: remote_address_1.clone(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // send data to second socket let mut stream_2 = TcpStream::connect(local_destination_2).await.unwrap(); @@ -1232,14 +1207,10 @@ mod test { let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { remote_address: remote_address_2.clone(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // reply with successful on each daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( DaemonConnect { connection_id: 1, @@ -1247,9 +1218,8 @@ mod test { local_address: remote_address_1, }, )))) - .await - .unwrap(); - daemon_msg_tx + .await; + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( DaemonConnect { connection_id: 2, @@ -1257,49 +1227,38 @@ mod test { local_address: remote_address_2, }, )))) - .await - .unwrap(); + .await; // expect data to be received let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { connection_id: 1, bytes: b"data-from-1".to_vec(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { connection_id: 2, bytes: b"data-from-2".to_vec(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // send each data response from agent on daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( DaemonRead { connection_id: 1, bytes: b"reply-to-1".to_vec(), }, )))) - .await - .unwrap(); - daemon_msg_tx + .await; + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( DaemonRead { connection_id: 2, bytes: b"reply-to-2".to_vec(), }, )))) - .await - .unwrap(); + .await; // check data arrives at each local addr let mut buf = [0; 10]; @@ -1317,54 +1276,33 @@ mod test { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port = 3038; let mappings = HashMap::from([(destination_port, local_destination.port())]); let network_config = IncomingConfig::default(); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); + tokio::spawn(async move { - let mut port_forwarder = - ReversePortForwarder::new(agent_connection, mappings, network_config) - .await - .unwrap(); - port_forwarder.run().await.unwrap() + ReversePortForwarder::new(agent_connection, mappings, network_config) + .await + .unwrap() + .run() + .await + .unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - // expect port subscription for remote port and send subscribe result - let expected = Some(ClientMessage::Tcp(LayerTcp::PortSubscribe( - destination_port, - ))); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx + let expected = ClientMessage::Tcp(LayerTcp::PortSubscribe(destination_port)); + assert_eq!(test_connection.recv().await, expected); + test_connection .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( destination_port, )))) - .await - .unwrap(); + .await; // send new connection from agent and some data - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 1, @@ -1374,17 +1312,15 @@ mod test { local_address: local_destination.ip(), }, ))) - .await - .unwrap(); + .await; let mut stream = listener.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Data(TcpData { connection_id: 1, bytes: b"data-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 15]; @@ -1392,12 +1328,11 @@ mod test { assert_eq!(buf, b"data-my-beloved".as_ref()); // ensure graceful behaviour on close - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 1, }))) - .await - .unwrap(); + .await; } #[rstest] @@ -1407,13 +1342,6 @@ mod test { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port = 3038; let mappings = HashMap::from([(destination_port, local_destination.port())]); @@ -1422,62 +1350,47 @@ mod test { ..Default::default() }; + let (mut test_connection, agent_connection) = TestAgentConnection::new(); tokio::spawn(async move { - let mut port_forwarder = - ReversePortForwarder::new(agent_connection, mappings, network_config) - .await - .unwrap(); - port_forwarder.run().await.unwrap() + ReversePortForwarder::new(agent_connection, mappings, network_config) + .await + .unwrap() + .run() + .await + .unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - // expect port subscription for remote port and send subscribe result - let expected = Some(ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe( - StealType::All(destination_port), + let expected = ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(StealType::All( + destination_port, ))); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( + assert_eq!(test_connection.recv().await, expected); + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::SubscribeResult(Ok( destination_port, )))) - .await - .unwrap(); + .await; // send new connection from agent and some data - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 1, remote_address, destination_port, - source_port: local_destination.port(), - local_address: local_destination.ip(), + source_port: 2137, + local_address: "1.2.3.4".parse().unwrap(), }, ))) - .await - .unwrap(); + .await; let mut stream = listener.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::TcpSteal(DaemonTcp::Data(TcpData { connection_id: 1, bytes: b"data-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 15]; @@ -1486,12 +1399,8 @@ mod test { // check for response from local stream.write_all(b"reply-my-beloved").await.unwrap(); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; assert_eq!( - message, + test_connection.recv().await, ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { connection_id: 1, bytes: b"reply-my-beloved".to_vec() @@ -1499,12 +1408,11 @@ mod test { ); // ensure graceful behaviour on close - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 1, }))) - .await - .unwrap(); + .await; } #[rstest] @@ -1517,13 +1425,6 @@ mod test { let local_destination_1 = listener_1.local_addr().unwrap(); let local_destination_2 = listener_2.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port_1 = 3038; let destination_port_2 = 4048; @@ -1533,6 +1434,7 @@ mod test { ]); let network_config = IncomingConfig::default(); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); tokio::spawn(async move { let mut port_forwarder = ReversePortForwarder::new(agent_connection, mappings, network_config) @@ -1541,48 +1443,29 @@ mod test { port_forwarder.run().await.unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - // expect port subscription for each remote port and send subscribe result // matches! used because order may be random for _ in 0..2 { - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; + let message = test_connection.recv().await; assert!( matches!(message, ClientMessage::Tcp(LayerTcp::PortSubscribe(_))), "expected ClientMessage::Tcp(LayerTcp::PortSubscribe(_), received {message:?}" ); } - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( destination_port_1, )))) - .await - .unwrap(); - daemon_msg_tx + .await; + test_connection .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( destination_port_2, )))) - .await - .unwrap(); + .await; // send new connections from agent and some data - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 1, @@ -1592,11 +1475,10 @@ mod test { local_address: local_destination_1.ip(), }, ))) - .await - .unwrap(); + .await; let mut stream_1 = listener_1.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 2, @@ -1606,25 +1488,22 @@ mod test { local_address: local_destination_2.ip(), }, ))) - .await - .unwrap(); + .await; let mut stream_2 = listener_2.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Data(TcpData { connection_id: 1, bytes: b"connection-1-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Data(TcpData { connection_id: 2, bytes: b"connection-2-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 23]; @@ -1636,19 +1515,17 @@ mod test { assert_eq!(buf, b"connection-2-my-beloved".as_ref()); // ensure graceful behaviour on close - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 1, }))) - .await - .unwrap(); + .await; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 2, }))) - .await - .unwrap(); + .await; } #[rstest] @@ -1660,14 +1537,6 @@ mod test { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; - let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port = 8080; let mappings = HashMap::from([(destination_port, local_destination.port())]); let mut network_config = IncomingConfig { @@ -1676,6 +1545,8 @@ mod test { }; network_config.http_filter.header_filter = Some("header: value".to_string()); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); + tokio::spawn(async move { let mut port_forwarder = ReversePortForwarder::new(agent_connection, mappings, network_config) @@ -1684,27 +1555,8 @@ mod test { port_forwarder.run().await.unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - - // expect port subscription for remote port and send subscribe result - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; assert_eq!( - message, + test_connection.recv().await, ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(StealType::FilteredHttpEx( destination_port, mirrord_protocol::tcp::HttpFilter::Header( @@ -1712,27 +1564,11 @@ mod test { ) ),)) ); - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::SubscribeResult(Ok( destination_port, )))) - .await - .unwrap(); - - // send new connection from agent and some data - daemon_msg_tx - .send(DaemonMessage::TcpSteal(DaemonTcp::NewConnection( - NewTcpConnection { - connection_id: 1, - remote_address, - destination_port, - source_port: local_destination.port(), - local_address: local_destination.ip(), - }, - ))) - .await - .unwrap(); - let mut stream = listener.accept().await.unwrap().0; + .await; // send data from agent with correct header let mut headers = HeaderMap::new(); @@ -1744,23 +1580,22 @@ mod test { version: Version::HTTP_11, body: vec![], }; - daemon_msg_tx + test_connection .send(DaemonMessage::TcpSteal(DaemonTcp::HttpRequest( HttpRequest { internal_request, - connection_id: 1, - request_id: 1, - port: local_destination.port(), + connection_id: 0, + request_id: 0, + port: destination_port, }, ))) - .await - .unwrap(); + .await; + let mut stream = listener.accept().await.unwrap().0; // check data is read from stream let mut buf = [0; 15]; assert_eq!(buf, [0; 15]); stream.read_exact(&mut buf).await.unwrap(); - assert_ne!(buf, [0; 15]); // check for response from local stream @@ -1771,31 +1606,30 @@ mod test { let mut headers = HeaderMap::new(); headers.insert("content-length", "3".parse().unwrap()); let internal_response = InternalHttpResponse { - status: StatusCode::from_u16(200).unwrap(), + status: StatusCode::OK, version: Version::HTTP_11, headers, - body: b"yay".to_vec(), + body: InternalHttpBody( + [InternalHttpBodyFrame::Data(b"yay".into())] + .into_iter() + .collect(), + ), }; let expected_response = - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(HttpResponse { - connection_id: 1, - request_id: 1, - port: local_destination.port(), + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(HttpResponse { + connection_id: 0, + request_id: 0, + port: destination_port, internal_response, })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected_response); + assert_eq!(test_connection.recv().await, expected_response); // ensure graceful behaviour on close - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { - connection_id: 1, + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::Close(TcpClose { + connection_id: 0, }))) - .await - .unwrap(); + .await; } } diff --git a/mirrord/config/Cargo.toml b/mirrord/config/Cargo.toml index e17985aa0cd..a2fd265c274 100644 --- a/mirrord/config/Cargo.toml +++ b/mirrord/config/Cargo.toml @@ -17,8 +17,8 @@ edition.workspace = true workspace = true [dependencies] -mirrord-config-derive = { path = "./derive"} -mirrord-analytics = { path = "../analytics"} +mirrord-config-derive = { path = "./derive" } +mirrord-analytics = { path = "../analytics" } serde.workspace = true serde_json.workspace = true @@ -27,13 +27,14 @@ tracing.workspace = true serde_yaml.workspace = true toml = "0.8" schemars.workspace = true -bimap = "0.6" +bimap = { version = "0.6" } nom = "7.1" ipnet.workspace = true bitflags = "2" k8s-openapi = { workspace = true, features = ["schemars", "earliest"] } tera = "1" fancy-regex.workspace = true +base64.workspace = true [dev-dependencies] rstest.workspace = true diff --git a/mirrord/config/configuration.md b/mirrord/config/configuration.md index 93dab79c3c0..0d1af19401d 100644 --- a/mirrord/config/configuration.md +++ b/mirrord/config/configuration.md @@ -68,7 +68,8 @@ configuration file containing all fields. "communication_timeout": 30, "startup_timeout": 360, "network_interface": "eth0", - "flush_connections": true + "flush_connections": true, + "metrics": "0.0.0.0:9000", }, "feature": { "env": { @@ -166,7 +167,11 @@ Allows setting up custom annotations for the agent Job and Pod. ```json { - "annotations": { "cats.io/inject": "enabled" } + "annotations": { + "cats.io/inject": "enabled" + "prometheus.io/scrape": "true", + "prometheus.io/port": "9000" + } } ``` @@ -299,6 +304,19 @@ with `RUST_LOG`. } ``` +### agent.metrics {#agent-metrics} + +Enables prometheus metrics for the agent pod. + +You might need to add annotations to the agent pod depending on how prometheus is +configured to scrape for metrics. + +```json +{ + "metrics": "0.0.0.0:9000" +} +``` + ### agent.namespace {#agent-namespace} Namespace where the agent shall live. @@ -729,9 +747,11 @@ Example: Will do the next replacements for environment variables that match: -`CONNECTION_TIMEOUT: 500` => `CONNECTION_TIMEOUT: 10000` -`LOG_FILE_VERBOSITY: info` => `LOG_FILE_VERBOSITY: debug` -`DATA_1234: common-value` => `DATA_1234: magic-value` +* `CONNECTION_TIMEOUT: 500` => `CONNECTION_TIMEOUT: 10000` + +* `LOG_FILE_VERBOSITY: info` => `LOG_FILE_VERBOSITY: debug` + +* `DATA_1234: common-value` => `DATA_1234: magic-value` ### feature.env.override {#feature-env-override} @@ -1266,6 +1286,10 @@ List of ports to mirror/steal traffic from. Other ports will remain local. Mutually exclusive with [`feature.network.incoming.ignore_ports`](#feature-network-ignore_ports). +### feature.network.ipv6 {#feature-network-dns} + +Enable ipv6 support. Turn on if your application listens to incoming traffic over IPv6. + ### feature.network.outgoing {#feature-network-outgoing} Tunnel outgoing network operations through mirrord. diff --git a/mirrord/config/src/agent.rs b/mirrord/config/src/agent.rs index b82c45a7177..9600edfcd4d 100644 --- a/mirrord/config/src/agent.rs +++ b/mirrord/config/src/agent.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt, path::Path}; +use std::{collections::HashMap, fmt, net::SocketAddr, path::Path}; use k8s_openapi::api::core::v1::{ResourceRequirements, Toleration}; use mirrord_analytics::CollectAnalytics; @@ -67,7 +67,7 @@ impl fmt::Display for LinuxCapability { /// } /// } /// ``` -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "AgentFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq"))] pub struct AgentConfig { @@ -322,7 +322,11 @@ pub struct AgentConfig { /// /// ```json /// { - /// "annotations": { "cats.io/inject": "enabled" } + /// "annotations": { + /// "cats.io/inject": "enabled" + /// "prometheus.io/scrape": "true", + /// "prometheus.io/port": "9000" + /// } /// } /// ``` pub annotations: Option>, @@ -350,11 +354,25 @@ pub struct AgentConfig { /// ``` pub service_account: Option, + /// ### agent.metrics {#agent-metrics} + /// + /// Enables prometheus metrics for the agent pod. + /// + /// You might need to add annotations to the agent pod depending on how prometheus is + /// configured to scrape for metrics. + /// + /// ```json + /// { + /// "metrics": "0.0.0.0:9000" + /// } + /// ``` + pub metrics: Option, + /// /// Create an agent that returns an error after accepting the first client. For testing /// purposes. Only supported with job agents (not with ephemeral agents). #[cfg(all(debug_assertions, not(test)))] // not(test) so that it's not included in the schema json. - #[serde(skip_serializing)] + #[serde(skip)] #[config(env = "MIRRORD_AGENT_TEST_ERROR", default = false, unstable)] pub test_error: bool, } @@ -477,7 +495,7 @@ impl AgentFileConfig { } } -#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize, Deserialize)] #[config(derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct AgentDnsConfig { diff --git a/mirrord/config/src/config.rs b/mirrord/config/src/config.rs index c0176c6f66b..9fa4bdd168a 100644 --- a/mirrord/config/src/config.rs +++ b/mirrord/config/src/config.rs @@ -86,6 +86,12 @@ pub enum ConfigError { value: String, fail: Box, }, + + #[error("mirrord-config: decoding resolved config from env var failed with `{0}`")] + EnvVarDecodeError(String), + + #[error("mirrord-config: encoding resolved config failed with `{0}`")] + EnvVarEncodeError(String), } impl From for ConfigError { diff --git a/mirrord/config/src/config/from_env.rs b/mirrord/config/src/config/from_env.rs index 9770456721a..0f87ef59034 100644 --- a/mirrord/config/src/config/from_env.rs +++ b/mirrord/config/src/config/from_env.rs @@ -20,6 +20,10 @@ where { type Value = T; + /// Returns: + /// - `None` if there is no env var with that name. + /// - `Some(Err(ConfigError::InvalidValue{...}))` if the value of the env var cannot be parsed. + /// - `Some(Ok(...))` if the env var exists and was parsed successfully. fn source_value(self, _context: &mut ConfigContext) -> Option> { std::env::var(self.0).ok().map(|var| { var.parse::() diff --git a/mirrord/config/src/container.rs b/mirrord/config/src/container.rs index 5f9e16f7447..c48fec4a34d 100644 --- a/mirrord/config/src/container.rs +++ b/mirrord/config/src/container.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::config::source::MirrordConfigSource; @@ -12,7 +12,7 @@ static DEFAULT_CLI_IMAGE: &str = concat!( ); /// Unstable: `mirrord container` command specific config. -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "ContainerFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq"))] pub struct ContainerConfig { diff --git a/mirrord/config/src/experimental.rs b/mirrord/config/src/experimental.rs index 1fd9213a762..29af8e2f145 100644 --- a/mirrord/config/src/experimental.rs +++ b/mirrord/config/src/experimental.rs @@ -1,13 +1,13 @@ use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::config::source::MirrordConfigSource; /// mirrord Experimental features. /// This shouldn't be used unless someone from MetalBear/mirrord tells you to. -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "ExperimentalFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct ExperimentalConfig { @@ -68,7 +68,7 @@ pub struct ExperimentalConfig { /// Setting to 0 disables file buffering. /// /// - #[config(default = 0)] + #[config(default = 128000)] pub readonly_file_buffer: u64, } diff --git a/mirrord/config/src/external_proxy.rs b/mirrord/config/src/external_proxy.rs index 2ce284b6b7b..0f1c80f5e0a 100644 --- a/mirrord/config/src/external_proxy.rs +++ b/mirrord/config/src/external_proxy.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::config::source::MirrordConfigSource; @@ -23,7 +23,7 @@ pub static MIRRORD_EXTERNAL_TLS_KEY_ENV: &str = "MIRRORD_EXTERNAL_TLS_KEY"; /// } /// } /// ``` -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "ExternalProxyFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq"))] pub struct ExternalProxyConfig { diff --git a/mirrord/config/src/feature.rs b/mirrord/config/src/feature.rs index 317d590e0c8..41fecca7002 100644 --- a/mirrord/config/src/feature.rs +++ b/mirrord/config/src/feature.rs @@ -1,7 +1,7 @@ use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use self::{copy_target::CopyTargetConfig, env::EnvConfig, fs::FsConfig, network::NetworkConfig}; use crate::{config::source::MirrordConfigSource, feature::split_queues::SplitQueuesConfig}; @@ -64,7 +64,7 @@ pub mod split_queues; /// } /// } /// ``` -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "FeatureFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct FeatureConfig { diff --git a/mirrord/config/src/feature/copy_target.rs b/mirrord/config/src/feature/copy_target.rs index c5ed004ebb4..8a72326d52a 100644 --- a/mirrord/config/src/feature/copy_target.rs +++ b/mirrord/config/src/feature/copy_target.rs @@ -75,7 +75,7 @@ impl FromMirrordConfig for CopyTargetConfig { /// } /// } /// ``` -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct CopyTargetConfig { pub enabled: bool, diff --git a/mirrord/config/src/feature/env.rs b/mirrord/config/src/feature/env.rs index 339b0c16e78..1763a960b22 100644 --- a/mirrord/config/src/feature/env.rs +++ b/mirrord/config/src/feature/env.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, path::PathBuf}; use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::{ config::{from_env::FromEnv, source::MirrordConfigSource, ConfigContext, Result}, @@ -47,7 +47,7 @@ pub const MIRRORD_OVERRIDE_ENV_FILE_ENV: &str = "MIRRORD_OVERRIDE_ENV_VARS_FILE" /// } /// } /// ``` -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "EnvFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct EnvConfig { @@ -134,9 +134,11 @@ pub struct EnvConfig { /// /// Will do the next replacements for environment variables that match: /// - /// `CONNECTION_TIMEOUT: 500` => `CONNECTION_TIMEOUT: 10000` - /// `LOG_FILE_VERBOSITY: info` => `LOG_FILE_VERBOSITY: debug` - /// `DATA_1234: common-value` => `DATA_1234: magic-value` + /// * `CONNECTION_TIMEOUT: 500` => `CONNECTION_TIMEOUT: 10000` + /// + /// * `LOG_FILE_VERBOSITY: info` => `LOG_FILE_VERBOSITY: debug` + /// + /// * `DATA_1234: common-value` => `DATA_1234: magic-value` pub mapping: Option>, } diff --git a/mirrord/config/src/feature/fs/advanced.rs b/mirrord/config/src/feature/fs/advanced.rs index da1df66c00f..40c0e21c926 100644 --- a/mirrord/config/src/feature/fs/advanced.rs +++ b/mirrord/config/src/feature/fs/advanced.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use mirrord_analytics::{AnalyticValue, CollectAnalytics}; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use super::{FsModeConfig, FsUserConfig}; use crate::{ @@ -80,7 +80,7 @@ use crate::{ /// } /// } /// ``` -#[derive(MirrordConfig, Default, Clone, PartialEq, Eq, Debug, Serialize)] +#[derive(MirrordConfig, Default, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] #[config( map_to = "AdvancedFsUserConfig", derive = "PartialEq,Eq,JsonSchema", diff --git a/mirrord/config/src/feature/network.rs b/mirrord/config/src/feature/network.rs index 2f2b7901aee..976adf2b814 100644 --- a/mirrord/config/src/feature/network.rs +++ b/mirrord/config/src/feature/network.rs @@ -2,14 +2,16 @@ use dns::{DnsConfig, DnsFileConfig}; use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use self::{incoming::*, outgoing::*}; use crate::{ - config::{ConfigContext, ConfigError}, + config::{from_env::FromEnv, source::MirrordConfigSource, ConfigContext, ConfigError}, util::MirrordToggleableConfig, }; +const IPV6_ENV_VAR: &str = "MIRRORD_ENABLE_IPV6"; + pub mod dns; pub mod filter; pub mod incoming; @@ -52,7 +54,7 @@ pub mod outgoing; /// } /// } /// ``` -#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize, Deserialize)] #[config(map_to = "NetworkFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct NetworkConfig { @@ -67,14 +69,26 @@ pub struct NetworkConfig { /// ### feature.network.dns {#feature-network-dns} #[config(toggleable, nested)] pub dns: DnsConfig, + + /// ### feature.network.ipv6 {#feature-network-dns} + /// + /// Enable ipv6 support. Turn on if your application listens to incoming traffic over IPv6. + #[config(env = IPV6_ENV_VAR, default = false)] + pub ipv6: bool, } impl MirrordToggleableConfig for NetworkFileConfig { fn disabled_config(context: &mut ConfigContext) -> Result { + let ipv6 = FromEnv::new(IPV6_ENV_VAR) + .source_value(context) + .transpose()? + .unwrap_or_default(); + Ok(NetworkConfig { incoming: IncomingFileConfig::disabled_config(context)?, dns: DnsFileConfig::disabled_config(context)?, outgoing: OutgoingFileConfig::disabled_config(context)?, + ipv6, }) } } @@ -84,6 +98,7 @@ impl CollectAnalytics for &NetworkConfig { analytics.add("incoming", &self.incoming); analytics.add("outgoing", &self.outgoing); analytics.add("dns", &self.dns); + analytics.add("ipv6", self.ipv6); } } diff --git a/mirrord/config/src/feature/network/dns.rs b/mirrord/config/src/feature/network/dns.rs index fd0df6411b6..9371a2ee1e8 100644 --- a/mirrord/config/src/feature/network/dns.rs +++ b/mirrord/config/src/feature/network/dns.rs @@ -87,7 +87,7 @@ pub enum DnsFilterConfig { /// `read_only: ["/etc/resolv.conf"]`. /// - DNS filter currently works only with frameworks that use `getaddrinfo`/`gethostbyname` /// functions. -#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize, Deserialize)] #[config(map_to = "DnsFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct DnsConfig { diff --git a/mirrord/config/src/feature/network/incoming.rs b/mirrord/config/src/feature/network/incoming.rs index d56199e003e..857fb0179d3 100644 --- a/mirrord/config/src/feature/network/incoming.rs +++ b/mirrord/config/src/feature/network/incoming.rs @@ -58,7 +58,7 @@ use http_filter::*; /// }, /// "port_mapping": [[ 7777, 8888 ]], /// "ignore_localhost": false, -/// "ignore_ports": [9999, 10000] +/// "ignore_ports": [9999, 10000], /// "listen_ports": [[80, 8111]] /// } /// } @@ -96,9 +96,7 @@ impl MirrordConfig for IncomingFileConfig { .unwrap_or_default(), http_filter: HttpFilterFileConfig::default().generate_config(context)?, on_concurrent_steal: FromEnv::new("MIRRORD_OPERATOR_ON_CONCURRENT_STEAL") - .layer(|layer| { - Unstable::new("IncomingFileConfig", "on_concurrent_steal", layer) - }) + .layer(|layer| Unstable::new("incoming", "on_concurrent_steal", layer)) .source_value(context) .transpose()? .unwrap_or_default(), @@ -129,9 +127,7 @@ impl MirrordConfig for IncomingFileConfig { .unwrap_or_default(), on_concurrent_steal: FromEnv::new("MIRRORD_OPERATOR_ON_CONCURRENT_STEAL") .or(advanced.on_concurrent_steal) - .layer(|layer| { - Unstable::new("IncomingFileConfig", "on_concurrent_steal", layer) - }) + .layer(|layer| Unstable::new("incoming", "on_concurrent_steal", layer)) .source_value(context) .transpose()? .unwrap_or_default(), @@ -149,7 +145,7 @@ impl MirrordToggleableConfig for IncomingFileConfig { .unwrap_or_else(|| Ok(IncomingMode::Off))?; let on_concurrent_steal = FromEnv::new("MIRRORD_OPERATOR_ON_CONCURRENT_STEAL") - .layer(|layer| Unstable::new("IncomingFileConfig", "on_concurrent_steal", layer)) + .layer(|layer| Unstable::new("incoming", "on_concurrent_steal", layer)) .source_value(context) .transpose()? .unwrap_or_default(); @@ -314,6 +310,7 @@ fn serialize_bi_map(map: &BiMap, serializer: S) -> Result(deserializer: D) -> Result, D::Error> +where + D: de::Deserializer<'de>, +{ + // NB: this deserialises the BiMap from a vec + let vec: Vec<(u16, u16)> = Vec::deserialize(deserializer)?; + + let mut elements = BiMap::new(); + vec.iter().for_each(|(key, value)| { + elements.insert(*key, *value); + }); + Ok(elements) +} + /// Controls the incoming TCP traffic feature. /// /// See the incoming [reference](https://mirrord.dev/docs/reference/traffic/#incoming) for more @@ -391,7 +402,7 @@ where /// } /// } /// ``` -#[derive(Default, PartialEq, Eq, Clone, Debug, Serialize)] +#[derive(Default, PartialEq, Eq, Clone, Debug, Serialize, Deserialize)] pub struct IncomingConfig { /// #### feature.network.incoming.port_mapping {#feature-network-incoming-port_mapping} /// @@ -400,7 +411,10 @@ pub struct IncomingConfig { /// This is useful when you want to mirror/steal a port to a different port on the remote /// machine. For example, your local process listens on port `9333` and the container listens /// on port `80`. You'd use `[[9333, 80]]` - #[serde(serialize_with = "serialize_bi_map")] + #[serde( + serialize_with = "serialize_bi_map", + deserialize_with = "deserialize_bi_map" + )] pub port_mapping: BiMap, /// #### feature.network.incoming.ignore_localhost {#feature-network-incoming-ignore_localhost} @@ -438,7 +452,10 @@ pub struct IncomingConfig { /// you probably can't listen on `80` without sudo, so you can use `[[80, 4480]]` /// then access it on `4480` while getting traffic from remote `80`. /// The value of `port_mapping` doesn't affect this. - #[serde(serialize_with = "serialize_bi_map")] + #[serde( + serialize_with = "serialize_bi_map", + deserialize_with = "deserialize_bi_map" + )] pub listen_ports: BiMap, /// #### feature.network.incoming.on_concurrent_steal {#feature-network-incoming-on_concurrent_steal} diff --git a/mirrord/config/src/feature/network/incoming/http_filter.rs b/mirrord/config/src/feature/network/incoming/http_filter.rs index 343e850b75d..c68e288fa58 100644 --- a/mirrord/config/src/feature/network/incoming/http_filter.rs +++ b/mirrord/config/src/feature/network/incoming/http_filter.rs @@ -79,7 +79,7 @@ use crate::{ /// { "header": "^x-debug-session: 121212$" } /// ] ///} -#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize, Deserialize)] #[config(map_to = "HttpFilterFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct HttpFilterConfig { diff --git a/mirrord/config/src/feature/network/outgoing.rs b/mirrord/config/src/feature/network/outgoing.rs index 8ff2a84ce21..c0878a412a9 100644 --- a/mirrord/config/src/feature/network/outgoing.rs +++ b/mirrord/config/src/feature/network/outgoing.rs @@ -89,7 +89,7 @@ pub enum OutgoingFilterConfig { /// } /// } /// ``` -#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Default, PartialEq, Eq, Clone, Debug, Serialize, Deserialize)] #[config(map_to = "OutgoingFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq, Eq"))] pub struct OutgoingConfig { diff --git a/mirrord/config/src/internal_proxy.rs b/mirrord/config/src/internal_proxy.rs index dd28f2f4d9e..ba5b9483479 100644 --- a/mirrord/config/src/internal_proxy.rs +++ b/mirrord/config/src/internal_proxy.rs @@ -2,7 +2,7 @@ use std::{net::SocketAddr, path::PathBuf}; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::config::source::MirrordConfigSource; @@ -26,7 +26,7 @@ pub static MIRRORD_INTPROXY_CLIENT_TLS_KEY_ENV: &str = "MIRRORD_INTPROXY_CLIENT_ /// } /// } /// ``` -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "InternalProxyFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq"))] pub struct InternalProxyConfig { diff --git a/mirrord/config/src/lib.rs b/mirrord/config/src/lib.rs index 8ffd1d48c29..7ca3e00dc39 100644 --- a/mirrord/config/src/lib.rs +++ b/mirrord/config/src/lib.rs @@ -25,13 +25,14 @@ use std::{ path::Path, }; +use base64::prelude::*; use config::{ConfigContext, ConfigError, MirrordConfig}; use experimental::ExperimentalConfig; use feature::{env::mapper::EnvVarsRemapper, network::outgoing::OutgoingFilterConfig}; use mirrord_analytics::CollectAnalytics; use mirrord_config_derive::MirrordConfig; use schemars::JsonSchema; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use target::Target; use tera::Tera; use tracing::warn; @@ -45,6 +46,9 @@ use crate::{ /// Env variable to load config from file (json, yaml and toml supported). pub static MIRRORD_CONFIG_FILE_ENV: &str = "MIRRORD_CONFIG_FILE"; +/// Env variable to load config from an already resolved base64 encoding. +pub static MIRRORD_RESOLVED_CONFIG_ENV: &str = "MIRRORD_RESOLVED_CONFIG"; + /// mirrord allows for a high degree of customization when it comes to which features you want to /// enable, and how they should function. /// @@ -115,7 +119,8 @@ pub static MIRRORD_CONFIG_FILE_ENV: &str = "MIRRORD_CONFIG_FILE"; /// "communication_timeout": 30, /// "startup_timeout": 360, /// "network_interface": "eth0", -/// "flush_connections": true +/// "flush_connections": true, +/// "metrics": "0.0.0.0:9000", /// }, /// "feature": { /// "env": { @@ -174,7 +179,7 @@ pub static MIRRORD_CONFIG_FILE_ENV: &str = "MIRRORD_CONFIG_FILE"; /// ``` /// /// # Options {#root-options} -#[derive(MirrordConfig, Clone, Debug, Serialize)] +#[derive(MirrordConfig, Clone, Debug, Serialize, Deserialize, PartialEq)] #[config(map_to = "LayerFileConfig", derive = "JsonSchema")] #[cfg_attr(test, config(derive = "PartialEq"))] pub struct LayerConfig { @@ -327,15 +332,55 @@ pub struct LayerConfig { } impl LayerConfig { + /// Given an encoded complete config from the [`MIRRORD_RESOLVED_CONFIG_ENV`] + /// env var, attempt to decode it into [`LayerConfig`]. + /// Intended to avoid re-resolving the config in every process mirrord is loaded into. + fn from_env_var(encoded_value: String) -> Result { + let decoded = BASE64_STANDARD + .decode(encoded_value) + .map_err(|error| ConfigError::EnvVarDecodeError(error.to_string()))?; + let serialized = std::str::from_utf8(&decoded) + .map_err(|error| ConfigError::EnvVarDecodeError(error.to_string()))?; + Ok(serde_json::from_str::(serialized)?) + } + + /// Given a [`LayerConfig`], serialise it and convert to base 64 so it can be + /// set into [`MIRRORD_RESOLVED_CONFIG_ENV`]. + pub fn to_env_var(&self) -> Result { + let serialized = serde_json::to_string(self) + .map_err(|error| ConfigError::EnvVarEncodeError(error.to_string()))?; + Ok(BASE64_STANDARD.encode(serialized)) + } + + /// Encode this config with [`Self::to_env_var`] and set it into + /// [`MIRRORD_RESOLVED_CONFIG_ENV`]. Must be used when updating [`LayerConfig`] after + /// creation in order for the config in env to reflect the change. + pub fn update_env_var(&self) -> Result<(), ConfigError> { + std::env::set_var(MIRRORD_RESOLVED_CONFIG_ENV, self.to_env_var()?); + Ok(()) + } + /// Generate a config from the environment variables and/or a config file. /// On success, returns the config and a vec of warnings. /// To be used from CLI to verify config and print warnings pub fn from_env_with_warnings() -> Result<(Self, ConfigContext), ConfigError> { let mut cfg_context = ConfigContext::default(); - if let Ok(path) = std::env::var(MIRRORD_CONFIG_FILE_ENV) { - LayerFileConfig::from_path(path)?.generate_config(&mut cfg_context) - } else { - LayerFileConfig::default().generate_config(&mut cfg_context) + + match std::env::var(MIRRORD_RESOLVED_CONFIG_ENV) { + Ok(value) if !value.is_empty() => LayerConfig::from_env_var(value), + _ => { + // the resolved config is not present in env, so resolve it and then set into env + // var + let config = if let Ok(path) = std::env::var(MIRRORD_CONFIG_FILE_ENV) { + LayerFileConfig::from_path(path)?.generate_config(&mut cfg_context) + } else { + LayerFileConfig::default().generate_config(&mut cfg_context) + }?; + + // serialise the config and encode as base64 + config.update_env_var()?; + Ok(config) + } } .map(|config| (config, cfg_context)) } @@ -347,6 +392,17 @@ impl LayerConfig { Self::from_env_with_warnings().map(|(config, _)| config) } + /// forcefully recalculate the config using [`Self::from_env_with_warnings()`] + pub fn recalculate_from_env_with_warnings() -> Result<(Self, ConfigContext), ConfigError> { + std::env::remove_var(MIRRORD_RESOLVED_CONFIG_ENV); + Self::from_env_with_warnings() + } + + /// forcefully recalculate the config using [`Self::from_env_with_warnings()`] without warnings + pub fn recalculate_from_env() -> Result { + Self::recalculate_from_env_with_warnings().map(|(config, _)| config) + } + /// Verify that there are no conflicting settings. /// /// We don't call it from `from_env` since we want to verify it only once (from cli) @@ -878,6 +934,7 @@ mod tests { udp: Some(false), ..Default::default() })), + ipv6: None, })), copy_target: None, hostname: None, @@ -985,4 +1042,73 @@ mod tests { assert_eq!(existing_content.replace("\r\n", "\n"), compare_content); } + + /// related to issue #2936: https://github.com/metalbear-co/mirrord/issues/2936 + /// checks that resolved config written to [`MIRRORD_RESOLVED_CONFIG_ENV`] can be + /// transformed back into a [`LayerConfig`] + #[test] + fn encode_and_decode_default_config() { + let mut cfg_context = ConfigContext::default(); + let resolved_config = LayerFileConfig::default() + .generate_config(&mut cfg_context) + .expect("Default config should be generated from default 'LayerFileConfig'"); + + let encoded = resolved_config.to_env_var().unwrap(); + let decoded = LayerConfig::from_env_var(encoded).unwrap(); + + assert_eq!(decoded, resolved_config); + } + + #[test] + fn encode_and_decode_advanced_config() { + let mut cfg_context = ConfigContext::default(); + + // this config includes template variables, so it needs to be rendered first + let mut template_engine = Tera::default(); + template_engine + .add_raw_template("main", get_advanced_config().as_str()) + .unwrap(); + let rendered = template_engine + .render("main", &tera::Context::new()) + .expect("Tera should render JSON config file contents"); + let resolved_config = ConfigType::Json + .parse(rendered.as_str()) + .generate_config(&mut cfg_context) + .expect("Layer config should be generated from JSON config file contents"); + + let encoded = resolved_config.to_env_var().unwrap(); + let decoded = LayerConfig::from_env_var(encoded).unwrap(); + + assert_eq!(decoded, resolved_config); + } + + fn get_advanced_config() -> String { + r#" + { + "accept_invalid_certificates": false, + "target": { + "path": "pod/test-service-abcdefg-abcd", + "namespace": "default" + }, + "feature": { + "env": true, + "fs": "write", + "network": { + "dns": false, + "incoming": { + "mode": "steal", + "http_filter": { + "header_filter": "x-intercept: {{ get_env(name="USER") }}" + } + }, + "outgoing": { + "tcp": true, + "udp": false + } + } + } + } + "# + .to_string() + } } diff --git a/mirrord/intproxy/Cargo.toml b/mirrord/intproxy/Cargo.toml index 57f7fdedab0..0086e9a9f3d 100644 --- a/mirrord/intproxy/Cargo.toml +++ b/mirrord/intproxy/Cargo.toml @@ -33,12 +33,9 @@ tokio.workspace = true tracing.workspace = true tokio-stream.workspace = true hyper = { workspace = true, features = ["client", "http1", "http2"] } -# For checking the `RST_STREAM` error from HTTP2 stealer + filter. -h2 = "0.4" hyper-util.workspace = true http-body-util.workspace = true bytes.workspace = true -futures.workspace = true rand.workspace = true tokio-retry = "0.3" tokio-rustls.workspace = true @@ -47,5 +44,4 @@ rustls-pemfile.workspace = true exponential-backoff = "2" [dev-dependencies] -reqwest.workspace = true rstest.workspace = true diff --git a/mirrord/intproxy/protocol/src/lib.rs b/mirrord/intproxy/protocol/src/lib.rs index 623020718b6..7648f3d6cf6 100644 --- a/mirrord/intproxy/protocol/src/lib.rs +++ b/mirrord/intproxy/protocol/src/lib.rs @@ -10,7 +10,7 @@ use std::{ use bincode::{Decode, Encode}; use mirrord_protocol::{ - dns::{GetAddrInfoRequest, GetAddrInfoResponse}, + dns::{GetAddrInfoRequestV2, GetAddrInfoResponse}, file::*, outgoing::SocketAddress, tcp::StealType, @@ -44,7 +44,7 @@ pub enum LayerToProxyMessage { /// A file operation request. File(FileRequest), /// A DNS request. - GetAddrInfo(GetAddrInfoRequest), + GetAddrInfo(GetAddrInfoRequestV2), /// A request to initiate a new outgoing connection. OutgoingConnect(OutgoingConnectRequest), /// Requests related to incoming connections. @@ -210,7 +210,7 @@ pub enum ProxyToLayerMessage { NewSession(LayerId), /// A response to layer's [`FileRequest`]. File(FileResponse), - /// A response to layer's [`GetAddrInfoRequest`]. + /// A response to layer's [`GetAddrInfoRequestV2`]. GetAddrInfo(GetAddrInfoResponse), /// A response to layer's [`OutgoingConnectRequest`]. OutgoingConnect(RemoteResult), @@ -324,6 +324,27 @@ impl_request!( res_path = ProxyToLayerMessage::File => FileResponse::MakeDir, ); +impl_request!( + req = RemoveDirRequest, + res = RemoteResult<()>, + req_path = LayerToProxyMessage::File => FileRequest::RemoveDir, + res_path = ProxyToLayerMessage::File => FileResponse::RemoveDir, +); + +impl_request!( + req = UnlinkRequest, + res = RemoteResult<()>, + req_path = LayerToProxyMessage::File => FileRequest::Unlink, + res_path = ProxyToLayerMessage::File => FileResponse::Unlink, +); + +impl_request!( + req = UnlinkAtRequest, + res = RemoteResult<()>, + req_path = LayerToProxyMessage::File => FileRequest::UnlinkAt, + res_path = ProxyToLayerMessage::File => FileResponse::Unlink, +); + impl_request!( req = SeekFileRequest, res = RemoteResult, @@ -366,6 +387,13 @@ impl_request!( res_path = ProxyToLayerMessage::File => FileResponse::XstatFs, ); +impl_request!( + req = StatFsRequest, + res = RemoteResult, + req_path = LayerToProxyMessage::File => FileRequest::StatFs, + res_path = ProxyToLayerMessage::File => FileResponse::XstatFs, +); + impl_request!( req = FdOpenDirRequest, res = RemoteResult, @@ -407,7 +435,7 @@ impl_request!( ); impl_request!( - req = GetAddrInfoRequest, + req = GetAddrInfoRequestV2, res = GetAddrInfoResponse, req_path = LayerToProxyMessage::GetAddrInfo, res_path = ProxyToLayerMessage::GetAddrInfo, diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index 1e974f8d236..315bbc7af15 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -3,11 +3,12 @@ //! The proxy utilizes multiple background tasks to split the code into more self-contained parts. //! Structs in this module aim to ease managing their state. //! -//! Each background task implement the [`BackgroundTask`] trait, which specifies its properties and +//! Each background task implements the [`BackgroundTask`] trait, which specifies its properties and //! allows for managing groups of related tasks with one [`BackgroundTasks`] instance. use std::{collections::HashMap, fmt, future::Future, hash::Hash, ops::ControlFlow}; +use thiserror::Error; use tokio::{ sync::mpsc::{self, Receiver, Sender}, task::JoinHandle, @@ -42,6 +43,67 @@ impl MessageBus { { unsafe { &mut *(self as *mut MessageBus as *mut MessageBus) } } + + /// Returns a [`Closed`] instance for this [`MessageBus`]. + pub(crate) fn closed(&self) -> Closed { + Closed(self.tx.clone()) + } +} + +/// A helper struct bound to some [`MessageBus`] instance. +/// +/// Used in [`BackgroundTask`]s to `.await` on [`Future`]s without lingering after their +/// [`MessageBus`] is closed. +/// +/// Its lifetime does not depend on the origin [`MessageBus`] and it does not hold any references +/// to it, so that you can use it **and** the [`MessageBus`] at the same time. +/// +/// # Usage example +/// +/// ```ignore +/// use std::convert::Infallible; +/// +/// use mirrord_intproxy::background_tasks::{BackgroundTask, Closed, MessageBus}; +/// +/// struct ExampleTask; +/// +/// impl ExampleTask { +/// /// Thanks to the usage of [`Closed`] in [`Self::run`], +/// /// this function can freely resolve [`Future`]s and use the [`MessageBus`]. +/// /// When the [`MessageBus`] is closed, the whole task will exit. +/// /// +/// /// To achieve the same without [`Closed`], you'd need to wrap each +/// /// [`Future`] resolution with [`tokio::select`]. +/// async fn do_work(&self, message_bus: &mut MessageBus) {} +/// } +/// +/// impl BackgroundTask for ExampleTask { +/// type MessageIn = Infallible; +/// type MessageOut = Infallible; +/// type Error = Infallible; +/// +/// async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { +/// let closed: Closed = message_bus.closed(); +/// closed.cancel_on_close(self.do_work(message_bus)).await; +/// Ok(()) +/// } +/// } +/// ``` +pub(crate) struct Closed(Sender); + +impl Closed { + /// Resolves the given [`Future`], unless the origin [`MessageBus`] closes first. + /// + /// # Returns + /// + /// * [`Some`] holding the future output - if the future resolved first + /// * [`None`] - if the [`MessageBus`] closed first + pub(crate) async fn cancel_on_close(&self, future: F) -> Option { + tokio::select! { + _ = self.0.closed() => None, + output = future => Some(output) + } + } } /// Common trait for all background tasks in the internal proxy. @@ -197,6 +259,10 @@ where self.register(RestartableBackgroundTaskWrapper { task }, id, channel_size) } + pub fn tasks_ids(&self) -> impl Iterator { + self.handles.keys() + } + pub async fn kill_task(&mut self, id: Id) { self.streams.remove(&id); let Some(task) = self.handles.remove(&id) else { @@ -254,12 +320,14 @@ where } /// An error that can occur when executing a [`BackgroundTask`]. -#[derive(Debug)] +#[derive(Debug, Error)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum TaskError { /// An internal task error. + #[error(transparent)] Error(Err), /// A panic. + #[error("task panicked")] Panic, } diff --git a/mirrord/intproxy/src/lib.rs b/mirrord/intproxy/src/lib.rs index c1b7ab1ee7f..e75583e2ab8 100644 --- a/mirrord/intproxy/src/lib.rs +++ b/mirrord/intproxy/src/lib.rs @@ -327,6 +327,13 @@ impl IntProxy { .send(FilesProxyMessage::ProtocolVersion(protocol_version.clone())) .await; + self.task_txs + .simple + .send(SimpleProxyMessage::ProtocolVersion( + protocol_version.clone(), + )) + .await; + self.task_txs .incoming .send(IncomingProxyMessage::AgentProtocolVersion(protocol_version)) @@ -335,6 +342,7 @@ impl IntProxy { DaemonMessage::LogMessage(log) => match log.level { LogLevel::Error => tracing::error!("agent log: {}", log.message), LogLevel::Warn => tracing::warn!("agent log: {}", log.message), + LogLevel::Info => tracing::info!("agent log: {}", log.message), }, DaemonMessage::GetEnvVarsResponse(res) => { self.task_txs diff --git a/mirrord/intproxy/src/proxies/files.rs b/mirrord/intproxy/src/proxies/files.rs index e83d2f2b764..0d24d9d41c5 100644 --- a/mirrord/intproxy/src/proxies/files.rs +++ b/mirrord/intproxy/src/proxies/files.rs @@ -6,7 +6,7 @@ use mirrord_protocol::{ file::{ CloseDirRequest, CloseFileRequest, DirEntryInternal, ReadDirBatchRequest, ReadDirResponse, ReadFileResponse, ReadLimitedFileRequest, SeekFromInternal, MKDIR_VERSION, - READDIR_BATCH_VERSION, READLINK_VERSION, + READDIR_BATCH_VERSION, READLINK_VERSION, RMDIR_VERSION, STATFS_VERSION, }, ClientMessage, DaemonMessage, ErrorKindInternal, FileRequest, FileResponse, RemoteIOError, ResponseError, @@ -253,6 +253,37 @@ impl FilesProxy { self.protocol_version.replace(version); } + /// Checks if the mirrord protocol version supports this [`FileRequest`]. + fn is_request_supported(&self, request: &FileRequest) -> Result<(), FileResponse> { + let protocol_version = self.protocol_version.as_ref(); + + match request { + FileRequest::ReadLink(..) + if protocol_version.is_none_or(|version| !READLINK_VERSION.matches(version)) => + { + Err(FileResponse::ReadLink(Err(ResponseError::NotImplemented))) + } + FileRequest::MakeDir(..) | FileRequest::MakeDirAt(..) + if protocol_version.is_none_or(|version| !MKDIR_VERSION.matches(version)) => + { + Err(FileResponse::MakeDir(Err(ResponseError::NotImplemented))) + } + FileRequest::RemoveDir(..) | FileRequest::Unlink(..) | FileRequest::UnlinkAt(..) + if protocol_version + .is_none_or(|version: &Version| !RMDIR_VERSION.matches(version)) => + { + Err(FileResponse::RemoveDir(Err(ResponseError::NotImplemented))) + } + FileRequest::StatFs(..) + if protocol_version + .is_none_or(|version: &Version| !STATFS_VERSION.matches(version)) => + { + Err(FileResponse::XstatFs(Err(ResponseError::NotImplemented))) + } + _ => Ok(()), + } + } + // #[tracing::instrument(level = Level::TRACE, skip(message_bus))] async fn file_request( &mut self, @@ -261,6 +292,18 @@ impl FilesProxy { message_id: MessageId, message_bus: &mut MessageBus, ) { + // Not supported in old `mirrord-protocol` versions. + if let Err(response) = self.is_request_supported(&request) { + message_bus + .send(ToLayer { + message_id, + layer_id, + message: ProxyToLayerMessage::File(response), + }) + .await; + return; + } + match request { // Should trigger remote close only when the fd is closed in all layer instances. FileRequest::Close(close) => { @@ -454,31 +497,6 @@ impl FilesProxy { } }, - // Not supported in old `mirrord-protocol` versions. - req @ FileRequest::ReadLink(..) => { - let supported = self - .protocol_version - .as_ref() - .is_some_and(|version| READLINK_VERSION.matches(version)); - - if supported { - self.request_queue.push_back(message_id, layer_id); - message_bus - .send(ProxyMessage::ToAgent(ClientMessage::FileRequest(req))) - .await; - } else { - message_bus - .send(ToLayer { - message_id, - message: ProxyToLayerMessage::File(FileResponse::ReadLink(Err( - ResponseError::NotImplemented, - ))), - layer_id, - }) - .await; - } - } - // Should only be sent from intproxy, not from the layer. FileRequest::ReadDirBatch(..) => { unreachable!("ReadDirBatch request is never sent from the layer"); @@ -522,30 +540,6 @@ impl FilesProxy { .await; } - FileRequest::MakeDir(_) | FileRequest::MakeDirAt(_) => { - let supported = self - .protocol_version - .as_ref() - .is_some_and(|version| MKDIR_VERSION.matches(version)); - - if supported { - self.request_queue.push_back(message_id, layer_id); - message_bus - .send(ProxyMessage::ToAgent(ClientMessage::FileRequest(request))) - .await; - } else { - let file_response = FileResponse::MakeDir(Err(ResponseError::NotImplemented)); - - message_bus - .send(ToLayer { - message_id, - message: ProxyToLayerMessage::File(file_response), - layer_id, - }) - .await; - } - } - // Doesn't require any special logic. other => { self.request_queue.push_back(message_id, layer_id); diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 2ef25baf665..3c57b630b19 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -1,107 +1,63 @@ //! Handles the logic of the `incoming` feature. - -use std::{ - collections::{hash_map::Entry, HashMap}, - fmt, io, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, -}; - -use bytes::Bytes; -use futures::StreamExt; -use http::RETRY_ON_RESET_ATTEMPTS; -use http_body_util::StreamBody; -use hyper::body::Frame; +//! +//! +//! Background tasks: +//! 1. TcpProxy - always handles remote connection first. Attempts to connect a couple times. Waits +//! until connection becomes readable (is TCP) or receives an http request. +//! 2. HttpSender - + +use std::{collections::HashMap, io, net::SocketAddr}; + +use bound_socket::BoundTcpSocket; +use http::{ClientStore, ResponseMode, StreamingBody}; +use http_gateway::HttpGatewayTask; +use metadata_store::MetadataStore; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, - MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, + MessageId, PortSubscription, ProxyToLayerMessage, }; use mirrord_protocol::{ - body_chunks::BodyExt, tcp::{ - ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, HttpRequest, - HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBodyFrame, - InternalHttpRequest, InternalHttpResponse, LayerTcpSteal, NewTcpConnection, - ReceiverStreamBody, StreamingBody, TcpData, + ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest, + InternalHttpBodyFrame, LayerTcp, LayerTcpSteal, NewTcpConnection, StealType, TcpData, }, ClientMessage, ConnectionId, RequestId, ResponseError, }; +use tasks::{HttpGatewayId, HttpOut, InProxyTask, InProxyTaskError, InProxyTaskMessage}; +use tcp_proxy::{LocalTcpConnection, TcpProxyTask}; use thiserror::Error; -use tokio::{ - net::TcpSocket, - sync::mpsc::{self, Sender}, -}; -use tokio_stream::{wrappers::ReceiverStream, StreamMap, StreamNotifyClose}; -use tracing::{debug, Level}; +use tokio::sync::mpsc; +use tracing::Level; -use self::{ - interceptor::{Interceptor, InterceptorError, MessageOut}, - port_subscription_ext::PortSubscriptionExt, - subscriptions::SubscriptionsManager, -}; +use self::subscriptions::SubscriptionsManager; use crate::{ - background_tasks::{BackgroundTask, BackgroundTasks, MessageBus, TaskSender, TaskUpdate}, + background_tasks::{ + BackgroundTask, BackgroundTasks, MessageBus, TaskError, TaskSender, TaskUpdate, + }, main_tasks::{LayerClosed, LayerForked, ToLayer}, ProxyMessage, }; +mod bound_socket; mod http; -mod interceptor; -pub mod port_subscription_ext; +mod http_gateway; +mod metadata_store; +mod port_subscription_ext; mod subscriptions; - -/// Creates and binds a new [`TcpSocket`]. -/// The socket has the same IP version and address as the given `addr`. -/// -/// # Exception -/// -/// If the given `addr` is unspecified, this function binds to localhost. -#[tracing::instrument(level = Level::TRACE, ret, err)] -fn bind_similar(addr: SocketAddr) -> io::Result { - match addr.ip() { - IpAddr::V4(Ipv4Addr::UNSPECIFIED) => { - let socket = TcpSocket::new_v4()?; - socket.bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))?; - Ok(socket) - } - IpAddr::V6(Ipv6Addr::UNSPECIFIED) => { - let socket = TcpSocket::new_v6()?; - socket.bind(SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 0))?; - Ok(socket) - } - addr @ IpAddr::V4(..) => { - let socket = TcpSocket::new_v4()?; - socket.bind(SocketAddr::new(addr, 0))?; - Ok(socket) - } - addr @ IpAddr::V6(..) => { - let socket = TcpSocket::new_v6()?; - socket.bind(SocketAddr::new(addr, 0))?; - Ok(socket) - } - } -} - -/// Id of a single [`Interceptor`] task. Used to manage interceptor tasks with the -/// [`BackgroundTasks`] struct. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct InterceptorId(pub ConnectionId); - -impl fmt::Display for InterceptorId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "incoming interceptor {}", self.0,) - } -} +mod tasks; +mod tcp_proxy; /// Errors that can occur when handling the `incoming` feature. #[derive(Error, Debug)] pub enum IncomingProxyError { - #[error(transparent)] - Io(#[from] io::Error), + #[error("failed to prepare a TCP socket: {0}")] + SocketSetupFailed(#[source] io::Error), #[error("subscribing port failed: {0}")] - SubscriptionFailed(ResponseError), + SubscriptionFailed(#[source] ResponseError), } /// Messages consumed by [`IncomingProxy`] running as a [`BackgroundTask`]. +#[derive(Debug)] pub enum IncomingProxyMessage { LayerRequest(MessageId, LayerId, IncomingRequest), LayerForked(LayerForked), @@ -113,314 +69,485 @@ pub enum IncomingProxyMessage { ConnectionRefresh, } -/// Handle for an [`Interceptor`]. -struct InterceptorHandle { - /// A channel for sending messages to the [`Interceptor`] task. - tx: TaskSender, - /// Port subscription that the intercepted connection belongs to. - subscription: PortSubscription, -} - -/// Store for mapping [`Interceptor`] socket addresses to addresses of the original peers. -#[derive(Default)] -struct MetadataStore { - prepared_responses: HashMap, - expected_requests: HashMap, -} - -impl MetadataStore { - fn get(&mut self, req: ConnMetadataRequest) -> ConnMetadataResponse { - self.prepared_responses - .remove(&req) - .unwrap_or_else(|| ConnMetadataResponse { - remote_source: req.peer_address, - local_address: req.listener_address.ip(), - }) - } - - fn expect(&mut self, req: ConnMetadataRequest, from: InterceptorId, res: ConnMetadataResponse) { - self.expected_requests.insert(from, req.clone()); - self.prepared_responses.insert(req, res); - } - - fn no_longer_expect(&mut self, from: InterceptorId) { - let Some(req) = self.expected_requests.remove(&from) else { - return; - }; - self.prepared_responses.remove(&req); - } +/// Handle to a running [`HttpGatewayTask`]. +struct HttpGatewayHandle { + /// Only keeps the [`HttpGatewayTask`] alive. + _tx: TaskSender, + /// For sending request body [`Frame`](hyper::body::Frame)s. + /// + /// [`None`] if all frames were already sent. + body_tx: Option>, } /// Handles logic and state of the `incoming` feature. /// Run as a [`BackgroundTask`]. /// -/// Handles port subscriptions state of the connected layers. Utilizes multiple background tasks -/// ([`Interceptor`]s) to handle incoming connections. Each connection is managed by a single -/// [`Interceptor`], that establishes a TCP connection with the user application's port and proxies -/// data. +/// Handles port subscriptions state of the connected layers. +/// Utilizes multiple background tasks ([`TcpProxyTask`]s and [`HttpGatewayTask`]s) to handle +/// incoming connections and requests. +/// +/// # Connections mirrored or stolen without a filter +/// +/// Each such connection exists in two places: +/// +/// 1. Here, between the intproxy and the user application. Managed by a single [`TcpProxyTask`]. +/// 2. In the cluster, between the agent and the original TCP client. /// -/// Incoming connections are created by the agent either explicitly ([`NewTcpConnection`] message) -/// or implicitly ([`HttpRequest`]). +/// We are notified about such connections with the [`NewTcpConnection`] message. +/// +/// The local connection lives until the agent or the user application closes it, or a local IO +/// error occurs. When we want to close this connection, we simply drop the [`TcpProxyTask`]'s +/// [`TaskSender`]. When a local IO error occurs, the [`TcpProxyTask`] finishes with an +/// [`InProxyTaskError`]. +/// +/// # Requests stolen with a filter +/// +/// In the cluster, we have a real persistent connection between the agent and the original HTTP +/// client. From this connection, intproxy receives a subset of requests. +/// +/// Locally, we don't have a concept of a filered connection. +/// Each request is handled independently by a single [`HttpGatewayTask`]. +/// Also: +/// 1. Local HTTP connections are reused when possible. +/// 2. Unless the error is fatal, each request is retried a couple of times. +/// 3. We never send [`LayerTcpSteal::ConnectionUnsubscribe`] (due to requests being handled +/// independently). If a request fails locally, we send a +/// [`StatusCode::BAD_GATEWAY`](hyper::http::StatusCode::BAD_GATEWAY) response. +/// +/// We are notified about stolen requests with the [`HttpRequest`] messages. +/// +/// The request can be cancelled only when one of the following happen: +/// 1. The agent closes the remote connection to which this request belongs +/// 2. The agent informs us that it failed to read request body ([`ChunkedRequest::Error`]) +/// +/// When we want to cancel the request, we drop the [`HttpGatewayTask`]'s [`TaskSender`]. +/// +/// # HTTP upgrades +/// +/// An HTTP request stolen with a filter can result in an HTTP upgrade. +/// When this happens, the TCP connection is recovered and passed to a new [`TcpProxyTask`]. +/// The TCP connection is then treated as stolen without a filter. #[derive(Default)] pub struct IncomingProxy { /// Active port subscriptions for all layers. subscriptions: SubscriptionsManager, - /// [`TaskSender`]s for active [`Interceptor`]s. - interceptors: HashMap, - /// For receiving updates from [`Interceptor`]s. - background_tasks: BackgroundTasks, /// For managing intercepted connections metadata. metadata_store: MetadataStore, - /// For managing streamed [`DaemonTcp::HttpRequestChunked`] request channels. - request_body_txs: HashMap<(ConnectionId, RequestId), Sender>, - /// For managing streamed [`LayerTcpSteal::HttpResponseChunked`] response streams. - response_body_rxs: StreamMap<(ConnectionId, RequestId), StreamNotifyClose>, - /// Version of [`mirrord_protocol`] negotiated with the agent. - agent_protocol_version: Option, + /// What HTTP response flavor we produce. + response_mode: ResponseMode, + /// Cache for [`LocalHttpClient`](http::LocalHttpClient)s. + client_store: ClientStore, + /// Each mirrored remote connection is mapped to a [`TcpProxyTask`] in mirror mode. + /// + /// Each entry here maps to a connection that is in progress both locally and remotely. + mirror_tcp_proxies: HashMap>, + /// Each remote connection stolen without a filter is mapped to a [`TcpProxyTask`] in steal + /// mode. + /// + /// Each entry here maps to a connection that is in progress both locally and remotely. + steal_tcp_proxies: HashMap>, + /// Each remote HTTP request stolen with a filter is mapped to a [`HttpGatewayTask`]. + /// + /// Each entry here maps to a request that is in progress both locally and remotely. + http_gateways: HashMap>, + /// Running [`BackgroundTask`]s utilized by this proxy. + tasks: BackgroundTasks, } impl IncomingProxy { - /// Used when registering new `RawInterceptor` and `HttpInterceptor` tasks in the - /// [`BackgroundTasks`] struct. - // TODO: Update outdated documentation. RawInterceptor, HttpInterceptor do not exist + /// Used when registering new tasks in the internal [`BackgroundTasks`] instance. const CHANNEL_SIZE: usize = 512; - /// Tries to register the new subscription in the [`SubscriptionsManager`]. + /// Starts a new [`HttpGatewayTask`] to handle the given request. + /// + /// If we don't have a [`PortSubscription`] for the port, the task is not started. + /// Instead, we respond immediately to the agent. #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_port_subscribe( + async fn start_http_gateway( &mut self, - message_id: MessageId, - layer_id: LayerId, - subscribe: PortSubscribe, - message_bus: &mut MessageBus, + request: HttpRequest, + body_tx: Option>, + message_bus: &MessageBus, ) { - let msg = self - .subscriptions - .layer_subscribed(layer_id, message_id, subscribe); + let subscription = self.subscriptions.get(request.port).filter(|subscription| { + matches!( + subscription.subscription, + PortSubscription::Steal( + StealType::FilteredHttp(..) | StealType::FilteredHttpEx(..) + ) + ) + }); + let Some(subscription) = subscription else { + tracing::debug!( + ?request, + "Received a new HTTP request within a stale port subscription, \ + sending an unsubscribe request or an error response." + ); + + let no_other_requests = self + .http_gateways + .get(&request.connection_id) + .map(|gateways| gateways.is_empty()) + .unwrap_or(true); + if no_other_requests { + message_bus + .send(ClientMessage::TcpSteal( + LayerTcpSteal::ConnectionUnsubscribe(request.connection_id), + )) + .await; + } else { + let response = http::mirrord_error_response( + "port no longer subscribed with an HTTP filter", + request.version(), + request.connection_id, + request.request_id, + request.port, + ); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( + response, + ))) + .await; + } - if let Some(msg) = msg { - message_bus.send(msg).await; - } + return; + }; + + let connection_id = request.connection_id; + let request_id = request.request_id; + let id = HttpGatewayId { + connection_id, + request_id, + port: request.port, + version: request.version(), + }; + let tx = self.tasks.register( + HttpGatewayTask::new( + request, + self.client_store.clone(), + self.response_mode, + subscription.listening_on, + ), + InProxyTask::HttpGateway(id), + Self::CHANNEL_SIZE, + ); + self.http_gateways + .entry(connection_id) + .or_default() + .insert(request_id, HttpGatewayHandle { _tx: tx, body_tx }); } - /// Tries to unregister the subscription from the [`SubscriptionsManager`]. + /// Handles [`NewTcpConnection`] message from the agent, starting a new [`TcpProxyTask`]. + /// + /// If we don't have a [`PortSubscription`] for the port, the task is not started. + /// Instead, we respond immediately to the agent. #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_port_unsubscribe( + async fn handle_new_connection( &mut self, - layer_id: LayerId, - request: PortUnsubscribe, + connection: NewTcpConnection, + is_steal: bool, message_bus: &mut MessageBus, - ) { - let msg = self.subscriptions.layer_unsubscribed(layer_id, request); + ) -> Result<(), IncomingProxyError> { + let NewTcpConnection { + connection_id, + remote_address, + destination_port, + source_port, + local_address, + } = connection; + + let subscription = self + .subscriptions + .get(destination_port) + .filter(|subscription| match &subscription.subscription { + PortSubscription::Mirror(..) if !is_steal => true, + PortSubscription::Steal(StealType::All(..)) if is_steal => true, + _ => false, + }); + let Some(subscription) = subscription else { + tracing::debug!( + port = destination_port, + connection_id, + "Received a new connection within a stale port subscription, sending an unsubscribe request.", + ); + + let message = if is_steal { + ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe(connection_id)) + } else { + ClientMessage::TcpSteal(LayerTcpSteal::ConnectionUnsubscribe(connection_id)) + }; + message_bus.send(message).await; + + return Ok(()); + }; - if let Some(msg) = msg { - message_bus.send(msg).await; + let socket = BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip()) + .map_err(IncomingProxyError::SocketSetupFailed)?; + + self.metadata_store.expect( + ConnMetadataRequest { + listener_address: subscription.listening_on, + peer_address: socket + .local_addr() + .map_err(IncomingProxyError::SocketSetupFailed)?, + }, + connection_id, + ConnMetadataResponse { + remote_source: SocketAddr::new(remote_address, source_port), + local_address, + }, + ); + + let id = if is_steal { + InProxyTask::StealTcpProxy(connection_id) + } else { + InProxyTask::MirrorTcpProxy(connection_id) + }; + let tx = self.tasks.register( + TcpProxyTask::new( + LocalTcpConnection::FromTheStart { + socket, + peer: subscription.listening_on, + }, + !is_steal, + ), + id, + Self::CHANNEL_SIZE, + ); + + if is_steal { + self.steal_tcp_proxies.insert(connection_id, tx); + } else { + self.mirror_tcp_proxies.insert(connection_id, tx); } + + Ok(()) } - /// Retrieves or creates an [`Interceptor`] for the given [`HttpRequestFallback`]. - /// The request may or may not belong to an existing connection (when stealing with an http - /// filter, connections are created implicitly). - #[tracing::instrument(level = Level::TRACE, skip(self))] - fn get_interceptor_for_http_request( + /// Handles [`ChunkedRequest`] message from the agent. + async fn handle_chunked_request( &mut self, - request: &HttpRequestFallback, - ) -> Result>, IncomingProxyError> { - let id: InterceptorId = InterceptorId(request.connection_id()); - - let interceptor = match self.interceptors.entry(id) { - Entry::Occupied(e) => e.into_mut(), - - Entry::Vacant(e) => { - let Some(subscription) = self.subscriptions.get(request.port()) else { - tracing::trace!( - "received a new connection for port {} that is no longer mirrored", - request.port(), + request: ChunkedRequest, + message_bus: &mut MessageBus, + ) { + match request { + ChunkedRequest::Start(request) => { + let (body_tx, body_rx) = mpsc::channel(128); + let request = request.map_body(|frames| StreamingBody::new(body_rx, frames)); + self.start_http_gateway(request, Some(body_tx), message_bus) + .await; + } + + ChunkedRequest::Body(ChunkedHttpBody { + frames, + is_last, + connection_id, + request_id, + }) => { + let gateway = self + .http_gateways + .get_mut(&connection_id) + .and_then(|gateways| gateways.get_mut(&request_id)); + let Some(gateway) = gateway else { + tracing::debug!( + connection_id, + request_id, + frames = ?frames, + last_body_chunk = is_last, + "Received a body chunk for a request that is no longer alive locally" ); - return Ok(None); + return; }; - let interceptor_socket = bind_similar(subscription.listening_on)?; + let Some(tx) = gateway.body_tx.as_ref() else { + tracing::debug!( + connection_id, + request_id, + frames = ?frames, + last_body_chunk = is_last, + "Received a body chunk for a request with a closed body" + ); + + return; + }; - let interceptor = self.background_tasks.register( - Interceptor::new( - interceptor_socket, - subscription.listening_on, - self.agent_protocol_version.clone(), - ), - id, - Self::CHANNEL_SIZE, - ); + for frame in frames { + if let Err(err) = tx.send(frame).await { + tracing::debug!( + frame = ?err.0, + connection_id, + request_id, + "Failed to send an HTTP request body frame to the HttpGatewayTask, channel is closed" + ); + break; + } + } - e.insert(InterceptorHandle { - tx: interceptor, - subscription: subscription.subscription.clone(), - }) + if is_last { + gateway.body_tx = None; + } } - }; - Ok(Some(&interceptor.tx)) + ChunkedRequest::Error(ChunkedHttpError { + connection_id, + request_id, + }) => { + tracing::debug!( + connection_id, + request_id, + "Received an error in an HTTP request body", + ); + + if let Some(gateways) = self.http_gateways.get_mut(&connection_id) { + gateways.remove(&request_id); + }; + } + } } /// Handles all agent messages. - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] async fn handle_agent_message( &mut self, message: DaemonTcp, + is_steal: bool, message_bus: &mut MessageBus, ) -> Result<(), IncomingProxyError> { match message { DaemonTcp::Close(close) => { - self.interceptors - .remove(&InterceptorId(close.connection_id)); - self.request_body_txs - .retain(|(connection_id, _), _| *connection_id != close.connection_id); - let keys: Vec<(ConnectionId, RequestId)> = self - .response_body_rxs - .keys() - .filter(|key| key.0 == close.connection_id) - .cloned() - .collect(); - for key in keys.iter() { - self.response_body_rxs.remove(key); + if is_steal { + self.steal_tcp_proxies.remove(&close.connection_id); + self.http_gateways.remove(&close.connection_id); + } else { + self.mirror_tcp_proxies.remove(&close.connection_id); } } + DaemonTcp::Data(data) => { - if let Some(interceptor) = self.interceptors.get(&InterceptorId(data.connection_id)) - { - interceptor.tx.send(data.bytes).await; + let tx = if is_steal { + self.steal_tcp_proxies.get(&data.connection_id) + } else { + self.mirror_tcp_proxies.get(&data.connection_id) + }; + + if let Some(tx) = tx { + tx.send(data.bytes).await; } else { - tracing::trace!( - "received new data for connection {} that is already closed", - data.connection_id + tracing::debug!( + connection_id = data.connection_id, + bytes = data.bytes.len(), + "Received new data for a connection that does not belong to any TcpProxy task", ); } } - DaemonTcp::HttpRequest(req) => { - let req = HttpRequestFallback::Fallback(req); - let interceptor = self.get_interceptor_for_http_request(&req)?; - if let Some(interceptor) = interceptor { - interceptor.send(req).await; - } + + DaemonTcp::HttpRequest(request) => { + self.start_http_gateway(request.map_body(From::from), None, message_bus) + .await; + } + + DaemonTcp::HttpRequestFramed(request) => { + self.start_http_gateway(request.map_body(From::from), None, message_bus) + .await; } - DaemonTcp::HttpRequestFramed(req) => { - let req = HttpRequestFallback::Framed(req); - let interceptor = self.get_interceptor_for_http_request(&req)?; - if let Some(interceptor) = interceptor { - interceptor.send(req).await; + + DaemonTcp::HttpRequestChunked(request) => { + self.handle_chunked_request(request, message_bus).await; + } + + DaemonTcp::NewConnection(connection) => { + self.handle_new_connection(connection, is_steal, message_bus) + .await?; + } + + DaemonTcp::SubscribeResult(result) => { + let msgs = self.subscriptions.agent_responded(result)?; + + for msg in msgs { + message_bus.send(msg).await; } } - DaemonTcp::HttpRequestChunked(req) => { - match req { - ChunkedRequest::Start(req) => { - let (tx, rx) = mpsc::channel::(128); - let http_stream = StreamingBody::new(rx); - let http_req = HttpRequest { - internal_request: InternalHttpRequest { - method: req.internal_request.method, - uri: req.internal_request.uri, - headers: req.internal_request.headers, - version: req.internal_request.version, - body: http_stream, - }, - connection_id: req.connection_id, - request_id: req.request_id, - port: req.port, - }; - let key = (http_req.connection_id, http_req.request_id); - - self.request_body_txs.insert(key, tx.clone()); - - let http_req = HttpRequestFallback::Streamed { - request: http_req, - retries: 0, - }; - let interceptor = self.get_interceptor_for_http_request(&http_req)?; - if let Some(interceptor) = interceptor { - interceptor.send(http_req).await; - } + } - for frame in req.internal_request.body { - if let Err(err) = tx.send(frame).await { - self.request_body_txs.remove(&key); - tracing::trace!(?err, "error while sending"); - } - } - } - ChunkedRequest::Body(body) => { - let key = &(body.connection_id, body.request_id); - let mut send_err = false; - if let Some(tx) = self.request_body_txs.get(key) { - for frame in body.frames { - if let Err(err) = tx.send(frame).await { - send_err = true; - tracing::trace!(?err, "error while sending"); - } - } - } - if send_err || body.is_last { - self.request_body_txs.remove(key); - } + Ok(()) + } + + /// Handles all messages from this task's [`MessageBus`]. + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] + async fn handle_message( + &mut self, + message: IncomingProxyMessage, + message_bus: &mut MessageBus, + ) -> Result<(), IncomingProxyError> { + match message { + IncomingProxyMessage::LayerRequest(message_id, layer_id, req) => match req { + IncomingRequest::PortSubscribe(subscribe) => { + let msg = self + .subscriptions + .layer_subscribed(layer_id, message_id, subscribe); + + if let Some(msg) = msg { + message_bus.send(msg).await; } - ChunkedRequest::Error(err) => { - self.request_body_txs - .remove(&(err.connection_id, err.request_id)); - tracing::trace!(?err, "ChunkedRequest error received"); + } + IncomingRequest::PortUnsubscribe(unsubscribe) => { + let msg = self.subscriptions.layer_unsubscribed(layer_id, unsubscribe); + + if let Some(msg) = msg { + message_bus.send(msg).await; } - }; + } + IncomingRequest::ConnMetadata(req) => { + let res = self.metadata_store.get(req); + message_bus + .send(ToLayer { + message_id, + layer_id, + message: ProxyToLayerMessage::Incoming(IncomingResponse::ConnMetadata( + res, + )), + }) + .await; + } + }, + + IncomingProxyMessage::AgentMirror(msg) => { + self.handle_agent_message(msg, false, message_bus).await?; } - DaemonTcp::NewConnection(NewTcpConnection { - connection_id, - remote_address, - destination_port, - source_port, - local_address, - }) => { - let Some(subscription) = self.subscriptions.get(destination_port) else { - tracing::trace!("received a new connection for port {destination_port} that is no longer mirrored"); - return Ok(()); - }; - let interceptor_socket = bind_similar(subscription.listening_on)?; + IncomingProxyMessage::AgentSteal(msg) => { + self.handle_agent_message(msg, true, message_bus).await?; + } - let id = InterceptorId(connection_id); + IncomingProxyMessage::LayerClosed(msg) => { + let msgs = self.subscriptions.layer_closed(msg.id); - self.metadata_store.expect( - ConnMetadataRequest { - listener_address: subscription.listening_on, - peer_address: interceptor_socket.local_addr()?, - }, - id, - ConnMetadataResponse { - remote_source: SocketAddr::new(remote_address, source_port), - local_address, - }, - ); + for msg in msgs { + message_bus.send(msg).await; + } + } - let interceptor = self.background_tasks.register( - Interceptor::new( - interceptor_socket, - subscription.listening_on, - self.agent_protocol_version.clone(), - ), - id, - Self::CHANNEL_SIZE, - ); + IncomingProxyMessage::LayerForked(msg) => { + self.subscriptions.layer_forked(msg.parent, msg.child); + } - self.interceptors.insert( - id, - InterceptorHandle { - tx: interceptor, - subscription: subscription.subscription.clone(), - }, - ); + IncomingProxyMessage::AgentProtocolVersion(version) => { + self.response_mode = ResponseMode::from(&version); } - DaemonTcp::SubscribeResult(result) => { - let msgs = self.subscriptions.agent_responded(result)?; - for msg in msgs { - message_bus.send(msg).await; + IncomingProxyMessage::ConnectionRefresh => { + let running_task_ids = self.tasks.tasks_ids().cloned().collect::>(); + + for task in running_task_ids { + self.tasks.kill_task(task).await; + } + + for subscription in self.subscriptions.iter_mut() { + tracing::debug!(?subscription, "resubscribing"); + + for message in subscription.resubscribe() { + message_bus.send(ProxyMessage::ToAgent(message)).await + } } } } @@ -428,246 +555,192 @@ impl IncomingProxy { Ok(()) } - fn handle_layer_fork(&mut self, msg: LayerForked) { - let LayerForked { child, parent } = msg; - self.subscriptions.layer_forked(parent, child); - } + /// Handles all updates from [`TcpProxyTask`]s. + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn handle_tcp_proxy_update( + &mut self, + connection_id: ConnectionId, + is_steal: bool, + update: TaskUpdate, + message_bus: &mut MessageBus, + ) { + match update { + TaskUpdate::Finished(result) => { + match result { + Err(TaskError::Error(error)) => { + tracing::warn!(connection_id, %error, is_steal, "TcpProxyTask failed"); + } + Err(TaskError::Panic) => { + tracing::error!(connection_id, is_steal, "TcpProxyTask task panicked"); + } + Ok(()) => {} + }; - async fn handle_layer_close(&mut self, msg: LayerClosed, message_bus: &MessageBus) { - let msgs = self.subscriptions.layer_closed(msg.id); + self.metadata_store.no_longer_expect(connection_id); - for msg in msgs { - message_bus.send(msg).await; + if is_steal { + if self.steal_tcp_proxies.remove(&connection_id).is_some() { + message_bus + .send(ClientMessage::TcpSteal( + LayerTcpSteal::ConnectionUnsubscribe(connection_id), + )) + .await; + } + } else if self.mirror_tcp_proxies.remove(&connection_id).is_some() { + message_bus + .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( + connection_id, + ))) + .await; + } + } + + TaskUpdate::Message(..) if !is_steal => { + unreachable!("TcpProxyTask does not produce messages in mirror mode") + } + + TaskUpdate::Message(InProxyTaskMessage::Tcp(bytes)) => { + if self.steal_tcp_proxies.contains_key(&connection_id) { + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { + connection_id, + bytes, + }))) + .await; + } + } + + TaskUpdate::Message(InProxyTaskMessage::Http(..)) => { + unreachable!("TcpProxyTask does not produce HTTP messages") + } } } - fn get_subscription(&self, interceptor_id: InterceptorId) -> Option<&PortSubscription> { - self.interceptors - .get(&interceptor_id) - .map(|handle| &handle.subscription) - } -} + /// Handles all updates from [`HttpGatewayTask`]s. + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn handle_http_gateway_update( + &mut self, + id: HttpGatewayId, + update: TaskUpdate, + message_bus: &mut MessageBus, + ) { + match update { + TaskUpdate::Finished(result) => { + let respond_on_panic = self + .http_gateways + .get_mut(&id.connection_id) + .and_then(|gateways| gateways.remove(&id.request_id)) + .is_some(); + + match result { + Ok(()) => {} + Err(TaskError::Error( + InProxyTaskError::IoError(..) | InProxyTaskError::UpgradeError(..), + )) => unreachable!("HttpGatewayTask does not return any errors"), + Err(TaskError::Panic) => { + tracing::error!( + connection_id = id.connection_id, + request_id = id.request_id, + "HttpGatewayTask panicked", + ); + + if respond_on_panic { + let response = http::mirrord_error_response( + "HTTP gateway task panicked", + id.version, + id.connection_id, + id.request_id, + id.port, + ); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( + response, + ))) + .await; + } + } + } + } -impl BackgroundTask for IncomingProxy { - type Error = IncomingProxyError; - type MessageIn = IncomingProxyMessage; - type MessageOut = ProxyMessage; + TaskUpdate::Message(InProxyTaskMessage::Http(message)) => { + let exists = self + .http_gateways + .get(&id.connection_id) + .and_then(|gateways| gateways.get(&id.request_id)) + .is_some(); + if !exists { + return; + } - #[tracing::instrument(level = Level::TRACE, skip_all, err)] - async fn run(&mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { - loop { - tokio::select! { - Some(((connection_id, request_id), stream_item)) = self.response_body_rxs.next() => match stream_item { - Some(Ok(frame)) => { - let int_frame = InternalHttpBodyFrame::from(frame); - let res = ChunkedResponse::Body(ChunkedHttpBody { - frames: vec![int_frame], - is_last: false, - connection_id, - request_id, - }); + match message { + HttpOut::ResponseBasic(response) => { message_bus - .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - res, + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( + response, ))) - .await; - }, - Some(Err(error)) => { - debug!(%error, "Error while reading streamed response body"); - let res = ChunkedResponse::Error(ChunkedHttpError {connection_id, request_id}); + .await + } + HttpOut::ResponseFramed(response) => { message_bus - .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - res, + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed( + response, ))) - .await; - self.response_body_rxs.remove(&(connection_id, request_id)); - }, - None => { - let res = ChunkedResponse::Body(ChunkedHttpBody { - frames: vec![], - is_last: true, - connection_id, - request_id, - }); + .await + } + HttpOut::ResponseChunked(response) => { message_bus .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - res, + response, ))) .await; - self.response_body_rxs.remove(&(connection_id, request_id)); } - }, + HttpOut::Upgraded(on_upgrade) => { + let proxy = self.tasks.register( + TcpProxyTask::new(LocalTcpConnection::AfterUpgrade(on_upgrade), false), + InProxyTask::StealTcpProxy(id.connection_id), + Self::CHANNEL_SIZE, + ); + self.steal_tcp_proxies.insert(id.connection_id, proxy); + } + } + } + TaskUpdate::Message(InProxyTaskMessage::Tcp(..)) => { + unreachable!("HttpGatewayTask does not produce TCP messages") + } + } + } +} + +impl BackgroundTask for IncomingProxy { + type Error = IncomingProxyError; + type MessageIn = IncomingProxyMessage; + type MessageOut = ProxyMessage; + + #[tracing::instrument(level = Level::TRACE, name = "incoming_proxy_main_loop", skip_all, err)] + async fn run(&mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + loop { + tokio::select! { msg = message_bus.recv() => match msg { None => { tracing::trace!("message bus closed, exiting"); break Ok(()); }, - Some(IncomingProxyMessage::LayerRequest(message_id, layer_id, req)) => match req { - IncomingRequest::PortSubscribe(subscribe) => self.handle_port_subscribe(message_id, layer_id, subscribe, message_bus).await, - IncomingRequest::PortUnsubscribe(unsubscribe) => self.handle_port_unsubscribe(layer_id, unsubscribe, message_bus).await, - IncomingRequest::ConnMetadata(req) => { - let res = self.metadata_store.get(req); - message_bus.send(ToLayer { message_id, layer_id, message: ProxyToLayerMessage::Incoming(IncomingResponse::ConnMetadata(res)) }).await; - } - }, - Some(IncomingProxyMessage::AgentMirror(msg)) => { - self.handle_agent_message(msg, message_bus).await?; + Some(message) => self.handle_message(message, message_bus).await?, + }, + + Some((id, update)) = self.tasks.next() => match id { + InProxyTask::MirrorTcpProxy(connection_id) => { + self.handle_tcp_proxy_update(connection_id, false, update, message_bus).await; } - Some(IncomingProxyMessage::AgentSteal(msg)) => { - self.handle_agent_message(msg, message_bus).await?; + InProxyTask::StealTcpProxy(connection_id) => { + self.handle_tcp_proxy_update(connection_id, true, update, message_bus).await; } - Some(IncomingProxyMessage::LayerClosed(msg)) => self.handle_layer_close(msg, message_bus).await, - Some(IncomingProxyMessage::LayerForked(msg)) => self.handle_layer_fork(msg), - Some(IncomingProxyMessage::AgentProtocolVersion(version)) => { - self.agent_protocol_version.replace(version); + InProxyTask::HttpGateway(id) => { + self.handle_http_gateway_update(id, update, message_bus).await; } - Some(IncomingProxyMessage::ConnectionRefresh) => { - self.request_body_txs.clear(); - self.response_body_rxs.clear(); - - for (interceptor_id, _) in self.interceptors.drain() { - self.background_tasks.kill_task(interceptor_id).await; - } - - for subscription in self.subscriptions.iter_mut() { - tracing::debug!(?subscription, "resubscribing"); - - for message in subscription.resubscribe() { - message_bus.send(ProxyMessage::ToAgent(message)).await - } - } - } - }, - - Some(task_update) = self.background_tasks.next() => match task_update { - (id, TaskUpdate::Finished(res)) => { - tracing::trace!("{id} finished: {res:?}"); - - self.metadata_store.no_longer_expect(id); - - let msg = self.get_subscription(id).map(|s| s.wrap_agent_unsubscribe_connection(id.0)); - if let Some(msg) = msg { - message_bus.send(msg).await; - } - - self.request_body_txs.retain(|(connection_id, _), _| *connection_id != id.0); - }, - - (id, TaskUpdate::Message(msg)) => { - let Some(PortSubscription::Steal(_)) = self.get_subscription(id) else { - continue; - }; - let msg = match msg { - MessageOut::Raw(bytes) => { - ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { - connection_id: id.0, - bytes, - })) - }, - MessageOut::Http(HttpResponseFallback::Fallback(res)) => { - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(res)) - }, - MessageOut::Http(HttpResponseFallback::Framed(res)) => { - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(res)) - }, - MessageOut::Http(HttpResponseFallback::Streamed(response, request)) => { - match self.streamed_http_response(response, request).await { - Some(response) => response, - None => continue, - } - } - }; - message_bus.send(msg).await; - }, }, } } } } - -impl IncomingProxy { - /// Sends back the streamed http response to the agent. - /// - /// If we cannot get the next frame of the streamed body, then we retry the whole - /// process, by sending the original `request` again through the http `interceptor` to - /// our hyper handler. - #[allow(clippy::type_complexity)] - #[tracing::instrument(level = Level::TRACE, skip(self), ret)] - async fn streamed_http_response( - &mut self, - mut response: HttpResponse, hyper::Error>>>>, - request: Option, - ) -> Option { - let mut body = vec![]; - let key = (response.connection_id, response.request_id); - - match response - .internal_response - .body - .next_frames(true) - .await - .map_err(InterceptorError::from) - { - Ok(frames) => { - frames - .frames - .into_iter() - .map(From::from) - .for_each(|frame| body.push(frame)); - - self.response_body_rxs - .insert(key, StreamNotifyClose::new(response.internal_response.body)); - - let internal_response = InternalHttpResponse { - status: response.internal_response.status, - version: response.internal_response.version, - headers: response.internal_response.headers, - body, - }; - let response = ChunkedResponse::Start(HttpResponse { - port: response.port, - connection_id: response.connection_id, - request_id: response.request_id, - internal_response, - }); - Some(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - response, - ))) - } - // Retry on known errors. - Err(error @ InterceptorError::Reset) - | Err(error @ InterceptorError::ConnectionClosedTooSoon(..)) - | Err(error @ InterceptorError::IncompleteMessage(..)) => { - tracing::warn!(%error, ?request, "Failed to read first frames of streaming HTTP response"); - - let interceptor = self - .interceptors - .get(&InterceptorId(response.connection_id))?; - - if let Some(HttpRequestFallback::Streamed { request, retries }) = request - && retries < RETRY_ON_RESET_ATTEMPTS - { - tracing::trace!( - ?request, - ?retries, - "`RST_STREAM` from hyper, retrying the request." - ); - interceptor - .tx - .send(HttpRequestFallback::Streamed { - request, - retries: retries + 1, - }) - .await; - } - - None - } - Err(fail) => { - tracing::warn!(?fail, "Something went wrong, skipping this response!"); - None - } - } - } -} diff --git a/mirrord/intproxy/src/proxies/incoming/bound_socket.rs b/mirrord/intproxy/src/proxies/incoming/bound_socket.rs new file mode 100644 index 00000000000..1c6cbef385a --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/bound_socket.rs @@ -0,0 +1,46 @@ +use std::{ + fmt, io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, +}; + +use tokio::net::{TcpSocket, TcpStream}; +use tracing::Level; + +/// A TCP socket that is already bound. +/// +/// Provides a nicer [`fmt::Debug`] implementation than [`TcpSocket`]. +pub struct BoundTcpSocket(TcpSocket); + +impl BoundTcpSocket { + /// Opens a new TCP socket and binds it to the given IP address and a random port. + /// If the given IP address is not specified, binds the socket to localhost instead. + #[tracing::instrument(level = Level::TRACE, ret, err)] + pub fn bind_specified_or_localhost(ip: IpAddr) -> io::Result { + let (socket, ip) = match ip { + IpAddr::V4(Ipv4Addr::UNSPECIFIED) => (TcpSocket::new_v4()?, Ipv4Addr::LOCALHOST.into()), + IpAddr::V6(Ipv6Addr::UNSPECIFIED) => (TcpSocket::new_v6()?, Ipv6Addr::LOCALHOST.into()), + addr @ IpAddr::V4(..) => (TcpSocket::new_v4()?, addr), + addr @ IpAddr::V6(..) => (TcpSocket::new_v6()?, addr), + }; + + socket.bind(SocketAddr::new(ip, 0))?; + + Ok(Self(socket)) + } + + /// Returns the address to which this socket is bound. + pub fn local_addr(&self) -> io::Result { + self.0.local_addr() + } + + /// Makes a connection to the given peer. + pub async fn connect(self, peer: SocketAddr) -> io::Result { + self.0.connect(peer).await + } +} + +impl fmt::Debug for BoundTcpSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.local_addr().fmt(f) + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index d122aab53c5..a871cebc2c5 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -1,86 +1,257 @@ -use std::convert::Infallible; +use std::{fmt, io, net::SocketAddr, ops::Not}; -use bytes::Bytes; -use http_body_util::combinators::BoxBody; use hyper::{ body::Incoming, client::conn::{http1, http2}, - Response, Version, + Request, Response, StatusCode, Version, }; use hyper_util::rt::{TokioExecutor, TokioIo}; -use mirrord_protocol::tcp::HttpRequestFallback; +use mirrord_protocol::{ + tcp::{HttpRequest, HttpResponse, InternalHttpResponse}, + ConnectionId, Port, RequestId, +}; +use thiserror::Error; use tokio::net::TcpStream; use tracing::Level; -use super::interceptor::{InterceptorError, InterceptorResult}; +mod client_store; +mod response_mode; +mod streaming_body; -pub(super) const RETRY_ON_RESET_ATTEMPTS: u32 = 10; +pub use client_store::ClientStore; +pub use response_mode::ResponseMode; +pub use streaming_body::StreamingBody; -/// Handles the differences between hyper's HTTP/1 and HTTP/2 connections. -pub enum HttpSender { - V1(http1::SendRequest>), - V2(http2::SendRequest>), +/// An HTTP client used to pass requests to the user application. +pub struct LocalHttpClient { + /// Established HTTP connection with the user application. + sender: HttpSender, + /// Address of the user application's HTTP server. + local_server_address: SocketAddr, + /// Address of this client's TCP socket. + address: SocketAddr, } -/// Consumes the given [`TcpStream`] and performs an HTTP handshake, turning it into an HTTP -/// connection. -/// -/// # Returns -/// -/// [`HttpSender`] that can be used to send HTTP requests to the peer. -#[tracing::instrument(level = Level::TRACE, skip(target_stream), err(level = Level::WARN))] -pub async fn handshake( - version: Version, - target_stream: TcpStream, -) -> InterceptorResult { - match version { - Version::HTTP_2 => { - let (sender, connection) = - http2::handshake(TokioExecutor::default(), TokioIo::new(target_stream)).await?; - tokio::spawn(connection); - - Ok(HttpSender::V2(sender)) +impl LocalHttpClient { + /// Makes an HTTP connection with the given server and creates a new client. + #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN), ret)] + pub async fn new( + local_server_address: SocketAddr, + version: Version, + ) -> Result { + let stream = TcpStream::connect(local_server_address) + .await + .map_err(LocalHttpError::ConnectTcpFailed)?; + let local_server_address = stream + .peer_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; + let address = stream + .local_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; + let sender = HttpSender::handshake(version, stream).await?; + + Ok(Self { + sender, + local_server_address, + address, + }) + } + + /// Send the given `request` to the user application's HTTP server. + #[tracing::instrument(level = Level::DEBUG, err(level = Level::WARN), ret)] + pub async fn send_request( + &mut self, + request: HttpRequest, + ) -> Result, LocalHttpError> { + self.sender.send_request(request).await + } + + /// Returns the address of the local server to which this client is connected. + pub fn local_server_address(&self) -> SocketAddr { + self.local_server_address + } + + pub fn handles_version(&self, version: Version) -> bool { + match (&self.sender, version) { + (_, Version::HTTP_3) => false, + (HttpSender::V2(..), Version::HTTP_2) => true, + (HttpSender::V1(..), _) => true, + (HttpSender::V2(..), _) => false, } + } +} - Version::HTTP_3 => Err(InterceptorError::UnsupportedHttpVersion(version)), +impl fmt::Debug for LocalHttpClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LocalHttpClient") + .field("local_server_address", &self.local_server_address) + .field("address", &self.address) + .field("is_http_1", &matches!(self.sender, HttpSender::V1(..))) + .finish() + } +} + +/// Errors that can occur when sending an HTTP request to the user application. +#[derive(Error, Debug)] +pub enum LocalHttpError { + #[error("failed to make an HTTP handshake with the local application's HTTP server: {0}")] + HandshakeFailed(#[source] hyper::Error), - _http_v1 => { - let (sender, connection) = http1::handshake(TokioIo::new(target_stream)).await?; + #[error("{0:?} is not supported in the local HTTP proxy")] + UnsupportedHttpVersion(Version), - tokio::spawn(connection.with_upgrades()); + #[error("failed to send the request to the local application's HTTP server: {0}")] + SendFailed(#[source] hyper::Error), - Ok(HttpSender::V1(sender)) + #[error("failed to prepare a local TCP socket: {0}")] + SocketSetupFailed(#[source] io::Error), + + #[error("failed to make a TCP connection with the local application's HTTP server: {0}")] + ConnectTcpFailed(#[source] io::Error), + + #[error("failed to read the body of the local application's HTTP server response: {0}")] + ReadBodyFailed(#[source] hyper::Error), +} + +impl LocalHttpError { + /// Checks if we can retry sending the request, given that the previous attempt resulted in this + /// error. + pub fn can_retry(&self) -> bool { + match self { + Self::SocketSetupFailed(..) | Self::UnsupportedHttpVersion(..) => false, + Self::ConnectTcpFailed(..) => true, + Self::HandshakeFailed(err) | Self::SendFailed(err) | Self::ReadBodyFailed(err) => (err + .is_parse() + || err.is_parse_status() + || err.is_parse_too_large() + || err.is_user()) + .not(), } } } +/// Produces a mirrord-specific [`StatusCode::BAD_GATEWAY`] response. +pub fn mirrord_error_response( + message: M, + version: Version, + connection_id: ConnectionId, + request_id: RequestId, + port: Port, +) -> HttpResponse> { + HttpResponse { + connection_id, + port, + request_id, + internal_response: InternalHttpResponse { + status: StatusCode::BAD_GATEWAY, + version, + headers: Default::default(), + body: format!("mirrord: {message}\n").into_bytes(), + }, + } +} + +/// Holds either [`http1::SendRequest`] or [`http2::SendRequest`] and exposes a unified interface. +enum HttpSender { + V1(http1::SendRequest), + V2(http2::SendRequest), +} + impl HttpSender { - #[tracing::instrument(level = Level::TRACE, skip(self), err(level = Level::WARN))] - pub async fn send( + /// Performs an HTTP handshake over the given [`TcpStream`]. + async fn handshake(version: Version, target_stream: TcpStream) -> Result { + let local_addr = target_stream + .local_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; + let peer_addr = target_stream + .peer_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; + + match version { + Version::HTTP_2 => { + let (sender, connection) = + http2::handshake(TokioExecutor::default(), TokioIo::new(target_stream)) + .await + .map_err(LocalHttpError::HandshakeFailed)?; + + tokio::spawn(async move { + match connection.await { + Ok(()) => { + tracing::trace!(%local_addr, %peer_addr, "HTTP connection with the local application finished"); + } + Err(error) => { + tracing::warn!(%error, %local_addr, %peer_addr, "HTTP connection with the local application failed"); + } + } + }); + + Ok(HttpSender::V2(sender)) + } + + Version::HTTP_3 => Err(LocalHttpError::UnsupportedHttpVersion(version)), + + _http_v1 => { + let (sender, connection) = http1::handshake(TokioIo::new(target_stream)) + .await + .map_err(LocalHttpError::HandshakeFailed)?; + + tokio::spawn(async move { + match connection.with_upgrades().await { + Ok(()) => { + tracing::trace!(%local_addr, %peer_addr, "HTTP connection with the local application finished"); + } + Err(error) => { + tracing::warn!(%error, %local_addr, %peer_addr, "HTTP connection with the local application failed"); + } + } + }); + + Ok(HttpSender::V1(sender)) + } + } + } + + /// Tries to send the given [`HttpRequest`] to the server. + async fn send_request( &mut self, - req: HttpRequestFallback, - ) -> InterceptorResult, InterceptorError> { + request: HttpRequest, + ) -> Result, LocalHttpError> { match self { Self::V1(sender) => { // Solves a "connection was not ready" client error. // https://rust-lang.github.io/wg-async/vision/submitted_stories/status_quo/barbara_tries_unix_socket.html#the-single-magical-line - sender.ready().await?; + sender.ready().await.map_err(LocalHttpError::SendFailed)?; + sender - .send_request(req.into_hyper()) + .send_request(request.internal_request.into()) .await - .map_err(Into::into) + .map_err(LocalHttpError::SendFailed) } Self::V2(sender) => { - let mut req = req.into_hyper(); + let mut hyper_request: Request<_> = request.internal_request.into(); + // fixes https://github.com/metalbear-co/mirrord/issues/2497 // inspired by https://github.com/linkerd/linkerd2-proxy/blob/c5d9f1c1e7b7dddd9d75c0d1a0dca68188f38f34/linkerd/proxy/http/src/h2.rs#L175 - if req.uri().authority().is_none() { - *req.version_mut() = hyper::http::Version::HTTP_11; + if hyper_request.uri().authority().is_none() + && hyper_request.version() != Version::HTTP_11 + { + tracing::trace!( + original_version = ?hyper_request.version(), + "Request URI has no authority, changing HTTP version to {:?}", + Version::HTTP_11, + ); + + *hyper_request.version_mut() = Version::HTTP_11; } + // Solves a "connection was not ready" client error. // https://rust-lang.github.io/wg-async/vision/submitted_stories/status_quo/barbara_tries_unix_socket.html#the-single-magical-line - sender.ready().await?; - sender.send_request(req).await.map_err(Into::into) + sender.ready().await.map_err(LocalHttpError::SendFailed)?; + + sender + .send_request(hyper_request) + .await + .map_err(LocalHttpError::SendFailed) } } } diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs new file mode 100644 index 00000000000..69ee3ce512c --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -0,0 +1,232 @@ +use std::{ + cmp, fmt, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; + +use hyper::Version; +use tokio::{ + sync::Notify, + time::{self, Instant}, +}; +use tracing::Level; + +use super::{LocalHttpClient, LocalHttpError}; + +/// Idle [`LocalHttpClient`] caches in [`ClientStore`]. +struct IdleLocalClient { + client: LocalHttpClient, + last_used: Instant, +} + +impl fmt::Debug for IdleLocalClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IdleLocalClient") + .field("client", &self.client) + .field("idle_for_s", &self.last_used.elapsed().as_secs_f32()) + .finish() + } +} + +/// Cache for unused [`LocalHttpClient`]s. +/// +/// [`LocalHttpClient`] that have not been used for some time are dropped in the background by a +/// dedicated [`tokio::task`]. This timeout defaults to [`Self::IDLE_CLIENT_DEFAULT_TIMEOUT`]. +#[derive(Clone)] +pub struct ClientStore { + clients: Arc>>, + /// Used to notify other tasks when there is a new client in the store. + /// + /// Make sure to only call [`Notify::notify_waiters`] and [`Notify::notified`] when holding a + /// lock on [`Self::clients`]. Otherwise you'll have a race condition. + notify: Arc, +} + +impl Default for ClientStore { + fn default() -> Self { + Self::new_with_timeout(Self::IDLE_CLIENT_DEFAULT_TIMEOUT) + } +} + +impl ClientStore { + pub const IDLE_CLIENT_DEFAULT_TIMEOUT: Duration = Duration::from_secs(3); + + /// Creates a new store. + /// + /// The store will keep unused clients alive for at least the given time. + pub fn new_with_timeout(timeout: Duration) -> Self { + let store = Self { + clients: Default::default(), + notify: Default::default(), + }; + + tokio::spawn(cleanup_task(store.clone(), timeout)); + + store + } + + /// Reuses or creates a new [`LocalHttpClient`]. + #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::WARN))] + pub async fn get( + &self, + server_addr: SocketAddr, + version: Version, + ) -> Result { + let ready = { + let mut guard = self + .clients + .lock() + .expect("ClientStore mutex is poisoned, this is a bug"); + let position = guard.iter().position(|idle| { + idle.client.handles_version(version) + && idle.client.local_server_address() == server_addr + }); + position.map(|position| guard.swap_remove(position)) + }; + + if let Some(ready) = ready { + tracing::trace!(?ready, "Reused an idle client"); + return Ok(ready.client); + } + + let connect_task = tokio::spawn(LocalHttpClient::new(server_addr, version)); + + tokio::select! { + result = connect_task => result.expect("this task should not panic"), + ready = self.wait_for_ready(server_addr, version) => { + tracing::trace!(?ready, "Reused an idle client"); + Ok(ready) + }, + } + } + + /// Stores an unused [`LocalHttpClient`], so that it can be reused later. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub fn push_idle(&self, client: LocalHttpClient) { + let mut guard = self + .clients + .lock() + .expect("ClientStore mutex is poisoned, this is a bug"); + guard.push(IdleLocalClient { + client, + last_used: Instant::now(), + }); + self.notify.notify_waiters(); + } + + /// Waits until there is a ready unused client. + async fn wait_for_ready(&self, server_addr: SocketAddr, version: Version) -> LocalHttpClient { + loop { + let notified = { + let mut guard = self + .clients + .lock() + .expect("ClientStore mutex is poisoned, this is a bug"); + let position = guard.iter().position(|idle| { + idle.client.handles_version(version) + && idle.client.local_server_address() == server_addr + }); + + match position { + Some(position) => return guard.swap_remove(position).client, + None => self.notify.notified(), + } + }; + + notified.await; + } + } +} + +/// Cleans up stale [`LocalHttpClient`]s from the [`ClientStore`]. +async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { + let clients = Arc::downgrade(&store.clients); + let notify = store.notify.clone(); + std::mem::drop(store); + + loop { + let Some(clients) = clients.upgrade() else { + // Failed `upgrade` means that all `ClientStore` instances were dropped. + // This task is no longer needed. + break; + }; + + let now = Instant::now(); + let mut min_last_used = None; + let notified = { + let Ok(mut guard) = clients.lock() else { + tracing::error!("ClientStore mutex is poisoned, this is a bug"); + return; + }; + + guard.retain(|client| { + if client.last_used + idle_client_timeout > now { + // We determine how long to sleep before cleaning the store again. + min_last_used = min_last_used + .map(|previous| cmp::min(previous, client.last_used)) + .or(Some(client.last_used)); + + true + } else { + // We drop the idle clients that have gone beyond the timeout. + tracing::trace!(?client, "Dropping an idle client"); + false + } + }); + + // Acquire [`Notified`] while still holding the lock. + // Prevents missed updates. + notify.notified() + }; + + if let Some(min_last_used) = min_last_used { + time::sleep_until(min_last_used + idle_client_timeout).await; + } else { + notified.await; + } + } +} + +#[cfg(test)] +mod test { + use std::{convert::Infallible, time::Duration}; + + use bytes::Bytes; + use http_body_util::Empty; + use hyper::{ + body::Incoming, server::conn::http1, service::service_fn, Request, Response, Version, + }; + use hyper_util::rt::TokioIo; + use tokio::{net::TcpListener, time}; + + use super::ClientStore; + + /// Verifies that [`ClientStore`] cleans up unused connections. + #[tokio::test] + async fn cleans_up_unused_connections() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let service = service_fn(|_req: Request| { + std::future::ready(Ok::<_, Infallible>(Response::new(Empty::::new()))) + }); + + let (connection, _) = listener.accept().await.unwrap(); + std::mem::drop(listener); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let client_store = ClientStore::new_with_timeout(Duration::from_millis(10)); + let client = client_store.get(addr, Version::HTTP_11).await.unwrap(); + client_store.push_idle(client); + + time::sleep(Duration::from_millis(100)).await; + + assert!(client_store.clients.lock().unwrap().is_empty()); + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs new file mode 100644 index 00000000000..c6f4eb5a583 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs @@ -0,0 +1,31 @@ +use mirrord_protocol::tcp::{HTTP_CHUNKED_RESPONSE_VERSION, HTTP_FRAMED_VERSION}; + +/// Determines how [`IncomingProxy`](crate::proxies::incoming::IncomingProxy) should send HTTP +/// responses. +#[derive(Debug, Clone, Copy, Default)] +pub enum ResponseMode { + /// Agent supports + /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked) + /// and the previous variants. + Chunked, + /// Agent supports + /// [`LayerTcpSteal::HttpResponseFramed`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseFramed) + /// and the previous variant. + Framed, + /// Agent supports only + /// [`LayerTcpSteal::HttpResponse`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponse) + #[default] + Basic, +} + +impl From<&semver::Version> for ResponseMode { + fn from(value: &semver::Version) -> Self { + if HTTP_CHUNKED_RESPONSE_VERSION.matches(value) { + Self::Chunked + } else if HTTP_FRAMED_VERSION.matches(value) { + Self::Framed + } else { + Self::Basic + } + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs b/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs new file mode 100644 index 00000000000..06a614f1ba4 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs @@ -0,0 +1,140 @@ +use std::{ + convert::Infallible, + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use hyper::body::{Body, Frame}; +use mirrord_protocol::tcp::{InternalHttpBody, InternalHttpBodyFrame}; +use tokio::sync::mpsc::{self, Receiver}; + +/// Cheaply cloneable [`Body`] implementation that reads [`Frame`]s from an [`mpsc::channel`]. +/// +/// # Clone behavior +/// +/// All instances acquired via [`Clone`] share the [`mpsc::Receiver`] and a vector of previously +/// read frames. Each instance maintains its own position in the shared vector, and a new clone +/// starts at 0. +/// +/// When polled with [`Body::poll_frame`], an instance tries to return a cached frame. +/// +/// Thanks to this, each clone returns all frames from the start when polled with +/// [`Body::poll_frame`]. As you'd expect from a cloneable [`Body`] implementation. +pub struct StreamingBody { + /// Shared with instances acquired via [`Clone`]. + /// + /// Allows the clones to access previously fetched [`Frame`]s. + shared_state: Arc, Vec)>>, + /// Index of the next frame to return from the buffer, not shared with other instances acquired + /// via [`Clone`]. + /// + /// If outside of the buffer, we need to poll the stream to get the next frame. + idx: usize, +} + +impl StreamingBody { + /// Creates a new instance of this [`Body`]. + /// + /// It will first read all frames from the vector given as `first_frames`. + /// Following frames will be fetched from the given `rx`. + pub fn new( + rx: Receiver, + first_frames: Vec, + ) -> Self { + Self { + shared_state: Arc::new(Mutex::new((rx, first_frames))), + idx: 0, + } + } +} + +impl Clone for StreamingBody { + fn clone(&self) -> Self { + Self { + shared_state: self.shared_state.clone(), + // Setting idx to 0 in order to replay the previous frames. + idx: 0, + } + } +} + +impl Body for StreamingBody { + type Data = Bytes; + + type Error = Infallible; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + let mut guard = this.shared_state.lock().unwrap(); + + if let Some(frame) = guard.1.get(this.idx) { + this.idx += 1; + return Poll::Ready(Some(Ok(frame.clone().into()))); + } + + match std::task::ready!(guard.0.poll_recv(cx)) { + None => Poll::Ready(None), + Some(frame) => { + guard.1.push(frame.clone()); + this.idx += 1; + Poll::Ready(Some(Ok(frame.into()))) + } + } + } +} + +impl Default for StreamingBody { + fn default() -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + Self { + shared_state: Arc::new(Mutex::new((dummy_rx, Default::default()))), + idx: 0, + } + } +} + +impl From> for StreamingBody { + fn from(value: Vec) -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + let frames = vec![InternalHttpBodyFrame::Data(value)]; + Self::new(dummy_rx, frames) + } +} + +impl From for StreamingBody { + fn from(value: InternalHttpBody) -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + Self::new(dummy_rx, value.0.into_iter().collect()) + } +} + +impl From> for StreamingBody { + fn from(value: Receiver) -> Self { + Self::new(value, Default::default()) + } +} + +impl fmt::Debug for StreamingBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut s = f.debug_struct("StreamingBody"); + s.field("idx", &self.idx); + + match self.shared_state.try_lock() { + Ok(guard) => { + s.field("frame_rx_closed", &guard.0.is_closed()); + s.field("cached_frames", &guard.1); + } + Err(error) => { + s.field("lock_error", &error); + } + } + + s.finish() + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs new file mode 100644 index 00000000000..16f1b166820 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -0,0 +1,943 @@ +use std::{ + collections::VecDeque, + convert::Infallible, + error::Error, + fmt, + net::SocketAddr, + ops::ControlFlow, + time::{Duration, Instant}, +}; + +use exponential_backoff::Backoff; +use http_body_util::BodyExt; +use hyper::{body::Incoming, http::response::Parts, StatusCode}; +use mirrord_protocol::{ + batched_body::BatchedBody, + tcp::{ + ChunkedHttpBody, ChunkedHttpError, ChunkedResponse, HttpRequest, HttpResponse, + InternalHttpBody, InternalHttpBodyFrame, InternalHttpResponse, + }, +}; +use tokio::time; +use tracing::Level; + +use super::{ + http::{mirrord_error_response, ClientStore, LocalHttpError, ResponseMode, StreamingBody}, + tasks::{HttpOut, InProxyTaskMessage}, +}; +use crate::background_tasks::{BackgroundTask, MessageBus}; + +/// [`BackgroundTask`] used by the [`IncomingProxy`](super::IncomingProxy). +/// +/// Responsible for delivering a single HTTP request to the user application. +/// +/// Exits immediately when it's [`TaskSender`](crate::background_tasks::TaskSender) is dropped. +pub struct HttpGatewayTask { + /// Request to deliver. + request: HttpRequest, + /// Shared cache of [`LocalHttpClient`](super::http::LocalHttpClient)s. + client_store: ClientStore, + /// Determines response variant. + response_mode: ResponseMode, + /// Address of the HTTP server in the user application. + server_addr: SocketAddr, +} + +impl fmt::Debug for HttpGatewayTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpGatewayTask") + .field("request", &self.request) + .field("response_mode", &self.response_mode) + .field("server_addr", &self.server_addr) + .finish() + } +} + +impl HttpGatewayTask { + /// Creates a new gateway task. + pub fn new( + request: HttpRequest, + client_store: ClientStore, + response_mode: ResponseMode, + server_addr: SocketAddr, + ) -> Self { + Self { + request, + client_store, + response_mode, + server_addr, + } + } + + /// Handles the response if we operate in [`ResponseMode::Chunked`]. + /// + /// # Returns + /// + /// * An error if we failed before sending the [`ChunkedResponse::Start`] message through the + /// [`MessageBus`] (we can still retry the request) + /// * [`ControlFlow::Break`] if we failed after sending the [`ChunkedResponse::Start`] message + /// * [`ControlFlow::Continue`] if we succeeded + async fn handle_response_chunked( + &self, + parts: Parts, + mut body: Incoming, + message_bus: &mut MessageBus, + ) -> Result, LocalHttpError> { + let frames = body + .ready_frames() + .map_err(LocalHttpError::ReadBodyFailed)?; + + if frames.is_last { + let ready_frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + + tracing::trace!( + ?ready_frames, + "All response body frames were instantly ready, sending full response" + ); + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body: InternalHttpBody(ready_frames), + }, + }; + message_bus.send(HttpOut::ResponseFramed(response)).await; + + return Ok(ControlFlow::Continue(())); + } + + let ready_frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + tracing::trace!( + ?ready_frames, + "Some response body frames were instantly ready, \ + but response body may not be finished yet" + ); + + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body: ready_frames, + }, + }; + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Start(response))) + .await; + + loop { + let start = Instant::now(); + match body.next_frames().await { + Ok(frames) => { + let is_last = frames.is_last; + let frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + tracing::trace!( + ?frames, + is_last, + elapsed_ms = start.elapsed().as_millis(), + "Received a next batch of response body frames", + ); + + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Body( + ChunkedHttpBody { + frames, + is_last, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + }, + ))) + .await; + + if is_last { + break; + } + } + + // Do not return any error here, as it would later be transformed into an error + // response. We already send the request head to the agent. + Err(error) => { + tracing::warn!( + error = ?ErrorWithSources(&error), + elapsed_ms = start.elapsed().as_millis(), + gateway = ?self, + "Failed to read next response body frames", + ); + + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Error( + ChunkedHttpError { + connection_id: self.request.connection_id, + request_id: self.request.request_id, + }, + ))) + .await; + + return Ok(ControlFlow::Break(())); + } + } + } + + Ok(ControlFlow::Continue(())) + } + + /// Makes an attempt to send the request and read the whole response. + /// + /// [`Err`] is handled in the caller and, if we run out of send attempts, converted to an error + /// response. Because of this, this function should not return any error that happened after + /// sending [`ChunkedResponse::Start`]. The agent would get a duplicated response. + #[tracing::instrument(level = Level::TRACE, skip_all, err(level = Level::WARN))] + async fn send_attempt(&self, message_bus: &mut MessageBus) -> Result<(), LocalHttpError> { + let mut client = self + .client_store + .get(self.server_addr, self.request.version()) + .await?; + let mut response = client.send_request(self.request.clone()).await?; + let on_upgrade = (response.status() == StatusCode::SWITCHING_PROTOCOLS).then(|| { + tracing::trace!("Detected an HTTP upgrade"); + hyper::upgrade::on(&mut response) + }); + let (parts, body) = response.into_parts(); + + let flow = match self.response_mode { + ResponseMode::Basic => { + let start = Instant::now(); + let body: Vec = body + .collect() + .await + .map_err(LocalHttpError::ReadBodyFailed)? + .to_bytes() + .into(); + tracing::trace!( + body_len = body.len(), + elapsed_ms = start.elapsed().as_millis(), + "Collected the whole response body", + ); + + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body, + }, + }; + message_bus.send(HttpOut::ResponseBasic(response)).await; + + ControlFlow::Continue(()) + } + ResponseMode::Framed => { + let start = Instant::now(); + let body = InternalHttpBody::from_body(body) + .await + .map_err(LocalHttpError::ReadBodyFailed)?; + tracing::trace!( + ?body, + elapsed_ms = start.elapsed().as_millis(), + "Collected the whole response body", + ); + + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body, + }, + }; + message_bus.send(HttpOut::ResponseFramed(response)).await; + + ControlFlow::Continue(()) + } + ResponseMode::Chunked => { + self.handle_response_chunked(parts, body, message_bus) + .await? + } + }; + + if flow.is_break() { + return Ok(()); + } + + if let Some(on_upgrade) = on_upgrade { + message_bus.send(HttpOut::Upgraded(on_upgrade)).await; + } else { + // If there was no upgrade and no error, the client can be reused. + self.client_store.push_idle(client); + } + + Ok(()) + } +} + +impl BackgroundTask for HttpGatewayTask { + type Error = Infallible; + type MessageIn = Infallible; + type MessageOut = InProxyTaskMessage; + + #[tracing::instrument(level = Level::TRACE, name = "http_gateway_task_main_loop", skip(message_bus))] + async fn run(&mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + let mut backoffs = + Backoff::new(10, Duration::from_millis(50), Duration::from_millis(500)).into_iter(); + let guard = message_bus.closed(); + + let mut attempt = 0; + let error = loop { + attempt += 1; + tracing::trace!(attempt, "Starting send attempt"); + match guard.cancel_on_close(self.send_attempt(message_bus)).await { + None | Some(Ok(())) => return Ok(()), + Some(Err(error)) => { + let backoff = error + .can_retry() + .then(|| backoffs.next()) + .flatten() + .flatten(); + let Some(backoff) = backoff else { + tracing::warn!( + gateway = ?self, + failed_attempts = attempt, + error = ?ErrorWithSources(&error), + "Failed to send an HTTP request", + ); + + break error; + }; + + tracing::trace!( + backoff_ms = backoff.as_millis(), + failed_attempts = attempt, + error = ?ErrorWithSources(&error), + "Trying again after backoff", + ); + + if guard.cancel_on_close(time::sleep(backoff)).await.is_none() { + return Ok(()); + } + } + } + }; + + let response = mirrord_error_response( + error, + self.request.version(), + self.request.connection_id, + self.request.request_id, + self.request.port, + ); + message_bus.send(HttpOut::ResponseBasic(response)).await; + + Ok(()) + } +} + +/// Helper struct for tracing an [`Error`] along with all its sources, +/// down to the root cause. +/// +/// Might help when inspecting [`hyper`] errors. +struct ErrorWithSources<'a>(&'a dyn Error); + +impl fmt::Debug for ErrorWithSources<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut list = f.debug_list(); + list.entry(&self.0); + + let mut source = self.0.source(); + while let Some(error) = source { + list.entry(&error); + source = error.source(); + } + + list.finish() + } +} + +#[cfg(test)] +mod test { + use std::{io, sync::Arc}; + + use bytes::Bytes; + use http_body_util::{Empty, StreamBody}; + use hyper::{ + body::{Frame, Incoming}, + header::{self, HeaderValue, CONNECTION, UPGRADE}, + server::conn::http1, + service::service_fn, + upgrade::Upgraded, + Method, Request, Response, StatusCode, Version, + }; + use hyper_util::rt::TokioIo; + use mirrord_protocol::tcp::{HttpRequest, InternalHttpRequest}; + use rstest::rstest; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, + sync::{mpsc, watch, Semaphore}, + task, + }; + use tokio_stream::wrappers::ReceiverStream; + + use super::*; + use crate::{ + background_tasks::{BackgroundTasks, TaskUpdate}, + proxies::incoming::{ + tcp_proxy::{LocalTcpConnection, TcpProxyTask}, + InProxyTaskError, + }, + }; + + /// Binary protocol over TCP. + /// Server first sends bytes [`INITIAL_MESSAGE`], then echoes back all received data. + const TEST_PROTO: &str = "dummyecho"; + + const INITIAL_MESSAGE: &[u8] = &[0x4a, 0x50, 0x32, 0x47, 0x4d, 0x44]; + + /// Handles requests upgrading to the [`TEST_PROTO`] protocol. + async fn upgrade_req_handler( + mut req: Request, + ) -> hyper::Result>> { + async fn dummy_echo(upgraded: Upgraded) -> io::Result<()> { + let mut upgraded = TokioIo::new(upgraded); + let mut buf = [0_u8; 64]; + + upgraded.write_all(INITIAL_MESSAGE).await?; + + loop { + let bytes_read = upgraded.read(&mut buf[..]).await?; + if bytes_read == 0 { + break; + } + + let echo_back = buf.get(0..bytes_read).unwrap(); + upgraded.write_all(echo_back).await?; + } + + Ok(()) + } + + let mut res = Response::new(Empty::new()); + + let contains_expected_upgrade = req + .headers() + .get(UPGRADE) + .filter(|proto| *proto == TEST_PROTO) + .is_some(); + if !contains_expected_upgrade { + *res.status_mut() = StatusCode::BAD_REQUEST; + return Ok(res); + } + + task::spawn(async move { + match hyper::upgrade::on(&mut req).await { + Ok(upgraded) => { + if let Err(e) = dummy_echo(upgraded).await { + eprintln!("server foobar io error: {}", e) + }; + } + Err(e) => eprintln!("upgrade error: {}", e), + } + }); + + *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + res.headers_mut() + .insert(UPGRADE, HeaderValue::from_static(TEST_PROTO)); + res.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("upgrade")); + Ok(res) + } + + /// Runs a [`hyper`] server that accepts only requests upgrading to the [`TEST_PROTO`] protocol. + async fn dummy_echo_server(listener: TcpListener, mut shutdown: watch::Receiver) { + loop { + tokio::select! { + res = listener.accept() => { + let (stream, _) = res.expect("dummy echo server failed to accept connection"); + + let mut shutdown = shutdown.clone(); + + task::spawn(async move { + let conn = http1::Builder::new().serve_connection(TokioIo::new(stream), service_fn(upgrade_req_handler)); + let mut conn = conn.with_upgrades(); + let mut conn = Pin::new(&mut conn); + + tokio::select! { + res = &mut conn => { + res.expect("dummy echo server failed to serve connection"); + } + + _ = shutdown.changed() => { + conn.graceful_shutdown(); + } + } + }); + } + + _ = shutdown.changed() => break, + } + } + } + + /// Verifies that [`HttpGatewayTask`] and [`TcpProxyTask`] together correctly handle HTTP + /// upgrades. + #[tokio::test] + async fn handles_http_upgrades() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_destination = listener.local_addr().unwrap(); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let server_task = task::spawn(dummy_echo_server(listener, shutdown_rx)); + + let mut tasks: BackgroundTasks = + Default::default(); + let _gateway = { + let request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "dummyecho://www.mirrord.dev/".parse().unwrap(), + headers: [ + (CONNECTION, HeaderValue::from_static("upgrade")), + (UPGRADE, HeaderValue::from_static(TEST_PROTO)), + ] + .into_iter() + .collect(), + version: Version::HTTP_11, + body: Default::default(), + }, + }; + let gateway = HttpGatewayTask::new( + request, + Default::default(), + ResponseMode::Basic, + local_destination, + ); + tasks.register(gateway, 0, 8) + }; + + let message = tasks + .next() + .await + .expect("no task result") + .1 + .unwrap_message(); + match message { + InProxyTaskMessage::Http(HttpOut::ResponseBasic(res)) => { + assert_eq!( + res.internal_response.status, + StatusCode::SWITCHING_PROTOCOLS + ); + println!("Received response from the gateway: {res:?}"); + assert!(res + .internal_response + .headers + .get(CONNECTION) + .filter(|v| *v == "upgrade") + .is_some()); + assert!(res + .internal_response + .headers + .get(UPGRADE) + .filter(|v| *v == TEST_PROTO) + .is_some()); + } + other => panic!("unexpected task update: {other:?}"), + } + + let message = tasks + .next() + .await + .expect("not task result") + .1 + .unwrap_message(); + let on_upgrade = match message { + InProxyTaskMessage::Http(HttpOut::Upgraded(on_upgrade)) => on_upgrade, + other => panic!("unexpected task update: {other:?}"), + }; + let update = tasks.next().await.expect("no task result").1; + match update { + TaskUpdate::Finished(Ok(())) => {} + other => panic!("unexpected task update: {other:?}"), + } + + let proxy = tasks.register( + TcpProxyTask::new(LocalTcpConnection::AfterUpgrade(on_upgrade), false), + 1, + 8, + ); + + proxy.send(b"test test test".to_vec()).await; + + let message = tasks + .next() + .await + .expect("no task result") + .1 + .unwrap_message(); + match message { + InProxyTaskMessage::Tcp(bytes) => { + assert_eq!(bytes, INITIAL_MESSAGE); + } + _ => panic!("unexpected task update: {update:?}"), + } + + let message = tasks + .next() + .await + .expect("no task result") + .1 + .unwrap_message(); + match message { + InProxyTaskMessage::Tcp(bytes) => { + assert_eq!(bytes, b"test test test"); + } + _ => panic!("unexpected task update: {update:?}"), + } + + let _ = shutdown_tx.send(true); + server_task.await.expect("dummy echo server panicked"); + } + + /// Verifies that [`HttpGatewayTask`] produces correct variant of the [`HttpResponse`]. + /// + /// Verifies that body of + /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked) + /// is streamed. + #[rstest] + #[case::basic(ResponseMode::Basic)] + #[case::framed(ResponseMode::Framed)] + #[case::chunked(ResponseMode::Chunked)] + #[tokio::test] + async fn produces_correct_response_variant(#[case] response_mode: ResponseMode) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let semaphore: Arc = Arc::new(Semaphore::const_new(0)); + let semaphore_clone = semaphore.clone(); + + let conn_task = tokio::spawn(async move { + let service = service_fn(|_req: Request| { + let semaphore = semaphore_clone.clone(); + async move { + let (frame_tx, frame_rx) = mpsc::channel::>>(1); + + tokio::spawn(async move { + for _ in 0..2 { + semaphore.acquire().await.unwrap().forget(); + let _ = frame_tx + .send(Ok(Frame::data(Bytes::from_static(b"hello\n")))) + .await; + } + }); + + let body = StreamBody::new(ReceiverStream::new(frame_rx)); + let mut response = Response::new(body); + response + .headers_mut() + .insert(header::CONTENT_LENGTH, HeaderValue::from_static("12")); + + Ok::<_, Infallible>(response) + } + }); + + let (connection, _) = listener.accept().await.unwrap(); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body: StreamingBody::from(vec![]), + }, + }; + + let mut tasks: BackgroundTasks<(), InProxyTaskMessage, Infallible> = Default::default(); + let _gateway = tasks.register( + HttpGatewayTask::new(request, ClientStore::default(), response_mode, addr), + (), + 8, + ); + + match response_mode { + ResponseMode::Basic => { + semaphore.add_permits(2); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseBasic(response)) => { + assert_eq!(response.internal_response.body, b"hello\nhello\n"); + } + other => panic!("unexpected task message: {other:?}"), + } + } + + ResponseMode::Framed => { + semaphore.add_permits(2); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseFramed(response)) => { + let mut collected = vec![]; + for frame in response.internal_response.body.0 { + match frame { + InternalHttpBodyFrame::Data(data) => collected.extend(data), + InternalHttpBodyFrame::Trailers(trailers) => { + panic!("unexpected trailing headers: {trailers:?}"); + } + } + } + + assert_eq!(collected, b"hello\nhello\n"); + } + other => panic!("unexpected task message: {other:?}"), + } + } + + ResponseMode::Chunked => { + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseChunked(ChunkedResponse::Start( + response, + ))) => { + assert!(response.internal_response.body.is_empty()); + } + other => panic!("unexpected task message: {other:?}"), + } + + semaphore.add_permits(1); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseChunked(ChunkedResponse::Body( + body, + ))) => { + assert_eq!( + body.frames, + vec![InternalHttpBodyFrame::Data(b"hello\n".into())], + ); + assert!(!body.is_last); + } + other => panic!("unexpected task message: {other:?}"), + } + + semaphore.add_permits(1); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseChunked(ChunkedResponse::Body( + body, + ))) => { + assert_eq!( + body.frames, + vec![InternalHttpBodyFrame::Data(b"hello\n".into())], + ); + assert!(body.is_last); + } + other => panic!("unexpected task message: {other:?}"), + } + } + } + + match tasks.next().await.unwrap().1 { + TaskUpdate::Finished(Ok(())) => {} + other => panic!("unexpected task update: {other:?}"), + } + + conn_task.await.unwrap(); + } + + /// Verifies that [`HttpGateway`] sends request body frames to the server as soon as they are + /// available. + #[tokio::test] + async fn streams_request_body_frames() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let semaphore: Arc = Arc::new(Semaphore::const_new(0)); + let semaphore_clone = semaphore.clone(); + + let conn_task = tokio::spawn(async move { + let service = service_fn(|mut req: Request| { + let semaphore = semaphore_clone.clone(); + async move { + for _ in 0..2 { + semaphore.add_permits(1); + let frame = req + .body_mut() + .frame() + .await + .unwrap() + .unwrap() + .into_data() + .unwrap(); + assert_eq!(frame, "hello\n"); + } + + Ok::<_, Infallible>(Response::new(Empty::::new())) + } + }); + + let (connection, _) = listener.accept().await.unwrap(); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let (frame_tx, frame_rx) = mpsc::channel(1); + let body = StreamingBody::new(frame_rx, vec![]); + let mut request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body, + }, + }; + request + .internal_request + .headers + .insert(header::CONTENT_LENGTH, HeaderValue::from_static("12")); + + let mut tasks: BackgroundTasks<(), InProxyTaskMessage, Infallible> = Default::default(); + let client_store = ClientStore::default(); + let _gateway = tasks.register( + HttpGatewayTask::new(request, client_store.clone(), ResponseMode::Basic, addr), + (), + 8, + ); + + for _ in 0..2 { + semaphore.acquire().await.unwrap().forget(); + frame_tx + .send(InternalHttpBodyFrame::Data(b"hello\n".into())) + .await + .unwrap(); + } + std::mem::drop(frame_tx); + + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseBasic(response)) => { + assert_eq!(response.internal_response.status, StatusCode::OK); + } + other => panic!("unexpected message: {other:?}"), + } + + match tasks.next().await.unwrap().1 { + TaskUpdate::Finished(Ok(())) => {} + other => panic!("unexpected task update: {other:?}"), + } + + conn_task.await.unwrap(); + } + + /// Verifies that [`HttpGateway`] reuses already established HTTP connections. + #[tokio::test] + async fn reuses_client_connections() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let service = service_fn(|_req: Request| { + std::future::ready(Ok::<_, Infallible>(Response::new(Empty::::new()))) + }); + + let (connection, _) = listener.accept().await.unwrap(); + std::mem::drop(listener); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let mut request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body: Default::default(), + }, + }; + request + .internal_request + .headers + .insert(header::CONNECTION, HeaderValue::from_static("keep-alive")); + + let mut tasks: BackgroundTasks = Default::default(); + let client_store = ClientStore::new_with_timeout(Duration::from_secs(1337 * 21 * 37)); + let _gateway_1 = tasks.register( + HttpGatewayTask::new( + request.clone(), + client_store.clone(), + ResponseMode::Basic, + addr, + ), + 0, + 8, + ); + let _gateway_2 = tasks.register( + HttpGatewayTask::new( + request.clone(), + client_store.clone(), + ResponseMode::Basic, + addr, + ), + 1, + 8, + ); + + let mut finished = 0; + let mut responses = 0; + + while finished < 2 && responses < 2 { + match tasks.next().await.unwrap() { + (id, TaskUpdate::Finished(Ok(()))) => { + println!("gateway {id} finished"); + finished += 1; + } + ( + id, + TaskUpdate::Message(InProxyTaskMessage::Http(HttpOut::ResponseBasic(response))), + ) => { + println!("gateway {id} returned a response"); + assert_eq!(response.internal_response.status, StatusCode::OK); + responses += 1; + } + other => panic!("unexpected task update: {other:?}"), + } + } + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs deleted file mode 100644 index 31acfe82f21..00000000000 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ /dev/null @@ -1,916 +0,0 @@ -//! [`BackgroundTask`] used by [`Incoming`](super::IncomingProxy) to manage a single -//! intercepted connection. - -use std::{ - error::Error, - io::{self, ErrorKind}, - net::SocketAddr, - time::Duration, -}; - -use bytes::BytesMut; -use exponential_backoff::Backoff; -use hyper::{upgrade::OnUpgrade, StatusCode, Version}; -use hyper_util::rt::TokioIo; -use mirrord_protocol::tcp::{ - HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBody, ReceiverStreamBody, - HTTP_CHUNKED_RESPONSE_VERSION, -}; -use thiserror::Error; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpSocket, TcpStream}, - time::{self, sleep}, -}; -use tracing::Level; - -use super::http::HttpSender; -use crate::{ - background_tasks::{BackgroundTask, MessageBus}, - proxies::incoming::http::RETRY_ON_RESET_ATTEMPTS, -}; - -/// Messages consumed by the [`Interceptor`] when it runs as a [`BackgroundTask`]. -pub enum MessageIn { - /// Request to be sent to the user application. - Http(HttpRequestFallback), - /// Data to be sent to the user application. - Raw(Vec), -} - -/// Messages produced by the [`Interceptor`] when it runs as a [`BackgroundTask`]. -#[derive(Debug)] -pub enum MessageOut { - /// Response received from the user application. - Http(HttpResponseFallback), - /// Data received from the user application. - Raw(Vec), -} - -impl From for MessageIn { - fn from(value: HttpRequestFallback) -> Self { - Self::Http(value) - } -} - -impl From> for MessageIn { - fn from(value: Vec) -> Self { - Self::Raw(value) - } -} - -/// Errors that can occur when [`Interceptor`] runs as a [`BackgroundTask`]. -#[derive(Error, Debug)] -pub enum InterceptorError { - /// IO failed. - #[error("io failed: {0}")] - Io(#[from] io::Error), - /// Hyper failed. - #[error("hyper failed: {0}")] - Hyper(hyper::Error), - /// The layer closed connection too soon to send a request. - #[error("connection closed too soon")] - ConnectionClosedTooSoon(HttpRequestFallback), - - #[error("incomplete message")] - IncompleteMessage(HttpRequestFallback), - - /// Received a request with an unsupported HTTP version. - #[error("{0:?} is not supported")] - UnsupportedHttpVersion(Version), - /// Occurs when [`Interceptor`] receives [`MessageIn::Raw`], but it acts as an HTTP gateway and - /// there was no HTTP upgrade. - #[error("received raw bytes, but expected an HTTP request")] - UnexpectedRawData, - /// Occurs when [`Interceptor`] receives [`MessageIn::Http`], but it acts as a TCP proxy. - #[error("received an HTTP request, but expected raw bytes")] - UnexpectedHttpRequest, - - /// We dig into the [`hyper::Error`] to try and see if it's an [`h2::Error`], checking - /// for [`h2::Error::is_reset`]. - /// - /// [`hyper::Error`] mentions that `source` is not a guaranteed thing we can check for, - /// so if you see any weird behavior, check that the [`h2`] crate is in sync with - /// whatever hyper changed (for errors). - #[error("HTTP2 `RST_STREAM` received")] - Reset, - - /// We have reached the max number of attempts that we can retry our http connection, - /// due to a `RST_STREAM`, or when the connection has been closed too soon. - #[error("HTTP2 reached the maximum amount of retries!")] - MaxRetries, -} - -impl From for InterceptorError { - fn from(hyper_fail: hyper::Error) -> Self { - if hyper_fail - .source() - .and_then(|source| source.downcast_ref::()) - .is_some_and(h2::Error::is_reset) - { - Self::Reset - } else { - Self::Hyper(hyper_fail) - } - } -} - -pub type InterceptorResult = core::result::Result; - -/// Manages a single intercepted connection. -/// Multiple instances are run as [`BackgroundTask`]s by one [`IncomingProxy`](super::IncomingProxy) -/// to manage individual connections. -/// -/// This interceptor can proxy both raw TCP data and HTTP messages in the same TCP connection. -/// When it receives [`MessageIn::Raw`], it starts acting as a simple proxy. -/// When it received [`MessageIn::Http`], it starts acting as an HTTP gateway. -pub struct Interceptor { - /// Socket that should be used to make the first connection (should already be bound). - socket: Option, - /// Address of user app's listener. - peer: SocketAddr, - /// Version of [`mirrord_protocol`] negotiated with the agent. - agent_protocol_version: Option, -} - -impl Interceptor { - /// Creates a new instance. When run, this instance will use the given `socket` (must be already - /// bound) to communicate with the given `peer`. - /// - /// # Note - /// - /// The socket can be replaced when retrying HTTP requests. - pub fn new( - socket: TcpSocket, - peer: SocketAddr, - agent_protocol_version: Option, - ) -> Self { - Self { - socket: Some(socket), - peer, - agent_protocol_version, - } - } -} - -impl BackgroundTask for Interceptor { - type Error = InterceptorError; - type MessageIn = MessageIn; - type MessageOut = MessageOut; - - #[tracing::instrument(level = Level::TRACE, skip_all, err)] - async fn run( - &mut self, - message_bus: &mut MessageBus, - ) -> InterceptorResult<(), Self::Error> { - let Some(socket) = self.socket.take() else { - return Ok(()); - }; - - let mut stream = socket.connect(self.peer).await?; - - // First, we determine whether this is a raw TCP connection or an HTTP connection. - // If we receive an HTTP request from our parent task, this must be an HTTP connection. - // If we receive raw bytes or our peer starts sending some data, this must be raw TCP. - let request = tokio::select! { - message = message_bus.recv() => match message { - Some(MessageIn::Raw(data)) => { - if data.is_empty() { - tracing::trace!("incoming interceptor -> agent shutdown, shutting down connection with layer"); - stream.shutdown().await?; - } else { - stream.write_all(&data).await?; - } - - return RawConnection { stream }.run(message_bus).await; - } - Some(MessageIn::Http(request)) => request, - None => return Ok(()), - }, - - result = stream.readable() => { - result?; - return RawConnection { stream }.run(message_bus).await; - } - }; - - let sender = super::http::handshake(request.version(), stream).await?; - let mut http_conn = HttpConnection { - sender, - peer: self.peer, - agent_protocol_version: self.agent_protocol_version.clone(), - }; - let (response, on_upgrade) = http_conn.send(request).await.inspect_err(|fail| { - tracing::error!(?fail, "Failed getting a filtered http response!") - })?; - message_bus.send(MessageOut::Http(response)).await; - - let raw = if let Some(on_upgrade) = on_upgrade { - let upgraded = on_upgrade.await?; - let parts = upgraded - .downcast::>() - .expect("IO type is known"); - if !parts.read_buf.is_empty() { - message_bus - .send(MessageOut::Raw(parts.read_buf.into())) - .await; - } - - Some(RawConnection { - stream: parts.io.into_inner(), - }) - } else { - http_conn.run(message_bus).await? - }; - - if let Some(raw) = raw { - raw.run(message_bus).await - } else { - Ok(()) - } - } -} - -/// Utilized by the [`Interceptor`] when it acts as an HTTP gateway. -/// See [`HttpConnection::run`] for usage. -struct HttpConnection { - /// Server address saved to allow for reconnecting in case a retry is required. - peer: SocketAddr, - /// Handle to the HTTP connection between the [`Interceptor`] the server. - sender: HttpSender, - /// Version of [`mirrord_protocol`] negotiated with the agent. - /// Determines which variant of [`LayerTcpSteal`](mirrord_protocol::tcp::LayerTcpSteal) - /// we use when sending HTTP responses. - agent_protocol_version: Option, -} - -impl HttpConnection { - /// Returns whether the agent supports - /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked). - pub fn agent_supports_streaming_response(&self) -> bool { - self.agent_protocol_version - .as_ref() - .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) - .unwrap_or(false) - } - - /// Handles the result of sending an HTTP request. - /// Returns an [`HttpResponseFallback`] to be returned to the client or an [`InterceptorError`]. - /// - /// See [`HttpResponseFallback::response_from_request`] for notes on picking the correct - /// [`HttpResponseFallback`] variant. - #[tracing::instrument(level = Level::TRACE, skip(self, response), err(level = Level::WARN))] - async fn handle_response( - &self, - request: HttpRequestFallback, - response: InterceptorResult>, - ) -> InterceptorResult<(HttpResponseFallback, Option)> { - match response { - Err(InterceptorError::Hyper(e)) if e.is_closed() => { - tracing::warn!( - "Sending request to local application failed with: {e:?}. \ - Seems like the local application closed the connection too early, so \ - creating a new connection and trying again." - ); - tracing::trace!("The request to be retried: {request:?}."); - - Err(InterceptorError::ConnectionClosedTooSoon(request)) - } - Err(InterceptorError::Hyper(e)) if e.is_parse() => { - tracing::warn!( - "Could not parse HTTP response to filtered HTTP request, got error: {e:?}." - ); - let body_message = format!( - "mirrord: could not parse HTTP response from local application - {e:?}" - ); - Ok(( - HttpResponseFallback::response_from_request( - request, - StatusCode::BAD_GATEWAY, - &body_message, - self.agent_protocol_version.as_ref(), - ), - None, - )) - } - Err(InterceptorError::Hyper(e)) if e.is_incomplete_message() => { - tracing::warn!( - "Sending request to local application failed with: {e:?}. \ - Connection closed before the message could complete!" - ); - tracing::trace!( - ?request, - "Retrying the request, see \ - [https://github.com/hyperium/hyper/issues/2136] for more info." - ); - - Err(InterceptorError::IncompleteMessage(request)) - } - - Err(fail) => { - tracing::warn!(?fail, "Request to local application failed!"); - let body_message = format!( - "mirrord tried to forward the request to the local application and got {fail:?}" - ); - Ok(( - HttpResponseFallback::response_from_request( - request, - StatusCode::BAD_GATEWAY, - &body_message, - self.agent_protocol_version.as_ref(), - ), - None, - )) - } - - Ok(mut res) => { - let upgrade = if res.status() == StatusCode::SWITCHING_PROTOCOLS { - Some(hyper::upgrade::on(&mut res)) - } else { - None - }; - - let result = match &request { - HttpRequestFallback::Framed(..) => { - HttpResponse::::from_hyper_response( - res, - self.peer.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(HttpResponseFallback::Framed) - } - HttpRequestFallback::Fallback(..) => { - HttpResponse::>::from_hyper_response( - res, - self.peer.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(HttpResponseFallback::Fallback) - } - HttpRequestFallback::Streamed { .. } - if self.agent_supports_streaming_response() => - { - HttpResponse::::from_hyper_response( - res, - self.peer.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(|response| { - HttpResponseFallback::Streamed(response, Some(request.clone())) - }) - } - HttpRequestFallback::Streamed { .. } => { - HttpResponse::::from_hyper_response( - res, - self.peer.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(HttpResponseFallback::Framed) - } - }; - - Ok(result.map(|response| (response, upgrade))?) - } - } - } - - /// Sends the given [`HttpRequestFallback`] to the server. - /// - /// If we get a `RST_STREAM` error from the server, or the connection was closed too - /// soon starts a new connection and retries using a [`Backoff`] until we reach - /// [`RETRY_ON_RESET_ATTEMPTS`]. - /// - /// Returns [`HttpResponseFallback`] from the server. - #[tracing::instrument(level = Level::TRACE, skip(self), ret, err)] - async fn send( - &mut self, - request: HttpRequestFallback, - ) -> InterceptorResult<(HttpResponseFallback, Option)> { - let min = Duration::from_millis(10); - let max = Duration::from_millis(250); - - let mut backoffs = Backoff::new(RETRY_ON_RESET_ATTEMPTS, min, max) - .into_iter() - .flatten(); - - // Retry to handle this request a few times. - loop { - let response = self.sender.send(request.clone()).await; - - match self.handle_response(request.clone(), response).await { - Ok(response) => return Ok(response), - - Err(error @ InterceptorError::Reset) - | Err(error @ InterceptorError::ConnectionClosedTooSoon(_)) - | Err(error @ InterceptorError::IncompleteMessage(_)) => { - tracing::warn!( - ?request, - %error, - "Either the connection closed, the message is incomplete, \ - or we got a reset, retrying!" - ); - - let Some(backoff) = backoffs.next() else { - break; - }; - - sleep(backoff).await; - - // Create a new connection for the next attempt. - let socket = super::bind_similar(self.peer)?; - let stream = socket.connect(self.peer).await?; - let new_sender = super::http::handshake(request.version(), stream).await?; - self.sender = new_sender; - } - - Err(fail) => return Err(fail), - } - } - - Err(InterceptorError::MaxRetries) - } - - /// Proxies HTTP messages until an HTTP upgrade happens or the [`MessageBus`] closes. - /// Support retries (with reconnecting to the HTTP server). - /// - /// When an HTTP upgrade happens, the underlying [`TcpStream`] is reclaimed, wrapped - /// in a [`RawConnection`] and returned. When [`MessageBus`] closes, [`None`] is returned. - #[tracing::instrument(level = Level::TRACE, skip_all, ret, err)] - async fn run( - mut self, - message_bus: &mut MessageBus, - ) -> InterceptorResult> { - let upgrade = loop { - let Some(msg) = message_bus.recv().await else { - return Ok(None); - }; - - match msg { - MessageIn::Raw(..) => { - // We should not receive any raw data from the agent before sending a - //`101 SWITCHING PROTOCOLS` response. - return Err(InterceptorError::UnexpectedRawData); - } - - MessageIn::Http(req) => { - let (res, on_upgrade) = self.send(req).await.inspect_err(|fail| { - tracing::error!(?fail, "Failed getting a filtered http response!") - })?; - tracing::debug!("{} has upgrade: {}", res.request_id(), on_upgrade.is_some()); - message_bus.send(MessageOut::Http(res)).await; - - if let Some(on_upgrade) = on_upgrade { - break on_upgrade.await?; - } - } - } - }; - - let parts = upgrade - .downcast::>() - .expect("IO type is known"); - let stream = parts.io.into_inner(); - let read_buf = parts.read_buf; - - if !read_buf.is_empty() { - message_bus.send(MessageOut::Raw(read_buf.into())).await; - } - - Ok(Some(RawConnection { stream })) - } -} - -/// Utilized by the [`Interceptor`] when it acts as a TCP proxy. -/// See [`RawConnection::run`] for usage. -#[derive(Debug)] -struct RawConnection { - /// Connection between the [`Interceptor`] and the server. - stream: TcpStream, -} - -impl RawConnection { - /// Proxies raw TCP data until the [`MessageBus`] closes. - /// - /// # Notes - /// - /// 1. When the peer shuts down writing, a single 0-sized read is sent through the - /// [`MessageBus`]. This is to notify the agent about the shutdown condition. - /// - /// 2. A 0-sized read received from the [`MessageBus`] is treated as a shutdown on the agent - /// side. Connection with the peer is shut down as well. - /// - /// 3. This implementation exits only when an error is encountered or the [`MessageBus`] is - /// closed. - async fn run(mut self, message_bus: &mut MessageBus) -> InterceptorResult<()> { - let mut buf = BytesMut::with_capacity(64 * 1024); - let mut reading_closed = false; - let mut remote_closed = false; - - loop { - tokio::select! { - biased; - - res = self.stream.read_buf(&mut buf), if !reading_closed => match res { - Err(e) if e.kind() == ErrorKind::WouldBlock => {}, - Err(e) => break Err(e.into()), - Ok(..) => { - if buf.is_empty() { - tracing::trace!("incoming interceptor -> layer shutdown, sending a 0-sized read to inform the agent"); - reading_closed = true; - } - message_bus.send(MessageOut::Raw(buf.to_vec())).await; - buf.clear(); - } - }, - - msg = message_bus.recv(), if !remote_closed => match msg { - None => { - tracing::trace!("incoming interceptor -> message bus closed, waiting 1 second before exiting"); - remote_closed = true; - }, - Some(MessageIn::Raw(data)) => { - if data.is_empty() { - tracing::trace!("incoming interceptor -> agent shutdown, shutting down connection with layer"); - self.stream.shutdown().await?; - } else { - self.stream.write_all(&data).await?; - } - }, - Some(MessageIn::Http(..)) => break Err(InterceptorError::UnexpectedHttpRequest), - }, - - _ = time::sleep(Duration::from_secs(1)), if remote_closed => { - tracing::trace!("incoming interceptor -> layer silent for 1 second and message bus is closed, exiting"); - - break Ok(()); - }, - } - } - } -} - -#[cfg(test)] -mod test { - use std::{ - convert::Infallible, - sync::{Arc, Mutex}, - }; - - use bytes::Bytes; - use futures::future::FutureExt; - use http_body_util::{BodyExt, Empty, Full}; - use hyper::{ - body::Incoming, - header::{HeaderValue, CONNECTION, UPGRADE}, - server::conn::http1, - service::service_fn, - upgrade::Upgraded, - Method, Request, Response, - }; - use hyper_util::rt::{TokioExecutor, TokioIo}; - use mirrord_protocol::tcp::{HttpRequest, InternalHttpRequest, StreamingBody}; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpListener, - sync::{watch, Notify}, - task, - }; - - use super::*; - use crate::background_tasks::{BackgroundTasks, TaskUpdate}; - - /// Binary protocol over TCP. - /// Server first sends bytes [`INITIAL_MESSAGE`], then echoes back all received data. - const TEST_PROTO: &str = "dummyecho"; - - const INITIAL_MESSAGE: &[u8] = &[0x4a, 0x50, 0x32, 0x47, 0x4d, 0x44]; - - /// Handles requests upgrading to the [`TEST_PROTO`] protocol. - async fn upgrade_req_handler( - mut req: Request, - ) -> hyper::Result>> { - async fn dummy_echo(upgraded: Upgraded) -> io::Result<()> { - let mut upgraded = TokioIo::new(upgraded); - let mut buf = [0_u8; 64]; - - upgraded.write_all(INITIAL_MESSAGE).await?; - - loop { - let bytes_read = upgraded.read(&mut buf[..]).await?; - if bytes_read == 0 { - break; - } - - let echo_back = buf.get(0..bytes_read).unwrap(); - upgraded.write_all(echo_back).await?; - } - - Ok(()) - } - - let mut res = Response::new(Empty::new()); - - let contains_expected_upgrade = req - .headers() - .get(UPGRADE) - .filter(|proto| *proto == TEST_PROTO) - .is_some(); - if !contains_expected_upgrade { - *res.status_mut() = StatusCode::BAD_REQUEST; - return Ok(res); - } - - task::spawn(async move { - match hyper::upgrade::on(&mut req).await { - Ok(upgraded) => { - if let Err(e) = dummy_echo(upgraded).await { - eprintln!("server foobar io error: {}", e) - }; - } - Err(e) => eprintln!("upgrade error: {}", e), - } - }); - - *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - res.headers_mut() - .insert(UPGRADE, HeaderValue::from_static(TEST_PROTO)); - res.headers_mut() - .insert(CONNECTION, HeaderValue::from_static("upgrade")); - Ok(res) - } - - /// Runs a [`hyper`] server that accepts only requests upgrading to the [`TEST_PROTO`] protocol. - async fn dummy_echo_server(listener: TcpListener, mut shutdown: watch::Receiver) { - loop { - tokio::select! { - res = listener.accept() => { - let (stream, _) = res.expect("dummy echo server failed to accept connection"); - - let mut shutdown = shutdown.clone(); - - task::spawn(async move { - let conn = http1::Builder::new().serve_connection(TokioIo::new(stream), service_fn(upgrade_req_handler)); - let mut conn = conn.with_upgrades(); - let mut conn = Pin::new(&mut conn); - - tokio::select! { - res = &mut conn => { - res.expect("dummy echo server failed to serve connection"); - } - - _ = shutdown.changed() => { - conn.graceful_shutdown(); - } - } - }); - } - - _ = shutdown.changed() => break, - } - } - } - - #[tokio::test] - async fn upgrade_test() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let local_destination = listener.local_addr().unwrap(); - - let (shutdown_tx, shutdown_rx) = watch::channel(false); - let server_task = task::spawn(dummy_echo_server(listener, shutdown_rx)); - - let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); - let interceptor = { - let socket = TcpSocket::new_v4().unwrap(); - socket.bind("127.0.0.1:0".parse().unwrap()).unwrap(); - tasks.register( - Interceptor::new( - socket, - local_destination, - Some(mirrord_protocol::VERSION.clone()), - ), - (), - 8, - ) - }; - - interceptor - .send(HttpRequestFallback::Fallback(HttpRequest { - connection_id: 0, - request_id: 0, - port: 80, - internal_request: InternalHttpRequest { - method: Method::GET, - uri: "dummyecho://www.mirrord.dev/".parse().unwrap(), - headers: [ - (CONNECTION, HeaderValue::from_static("upgrade")), - (UPGRADE, HeaderValue::from_static(TEST_PROTO)), - ] - .into_iter() - .collect(), - version: Version::HTTP_11, - body: Default::default(), - }, - })) - .await; - - let (_, update) = tasks.next().await.expect("no task result"); - match update { - TaskUpdate::Message(MessageOut::Http(res)) => { - let res = res - .into_hyper::() - .expect("failed to convert into hyper response"); - assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS); - println!("{:?}", res.headers()); - assert!(res - .headers() - .get(CONNECTION) - .filter(|v| *v == "upgrade") - .is_some()); - assert!(res - .headers() - .get(UPGRADE) - .filter(|v| *v == TEST_PROTO) - .is_some()); - } - _ => panic!("unexpected task update: {update:?}"), - } - - interceptor.send(b"test test test".to_vec()).await; - - let (_, update) = tasks.next().await.expect("no task result"); - match update { - TaskUpdate::Message(MessageOut::Raw(bytes)) => { - assert_eq!(bytes, INITIAL_MESSAGE); - } - _ => panic!("unexpected task update: {update:?}"), - } - - let (_, update) = tasks.next().await.expect("no task result"); - match update { - TaskUpdate::Message(MessageOut::Raw(bytes)) => { - assert_eq!(bytes, b"test test test"); - } - _ => panic!("unexpected task update: {update:?}"), - } - - let _ = shutdown_tx.send(true); - server_task.await.expect("dummy echo server panicked"); - } - - /// Ensure that [`HttpRequestFallback::Streamed`] are received frame by frame - #[tokio::test] - async fn receive_request_as_frames() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let local_destination = listener.local_addr().unwrap(); - - let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); - let socket = TcpSocket::new_v4().unwrap(); - socket.bind("127.0.0.1:0".parse().unwrap()).unwrap(); - let interceptor = Interceptor::new( - socket, - local_destination, - Some(mirrord_protocol::VERSION.clone()), - ); - let sender = tasks.register(interceptor, (), 8); - - let (tx, rx) = tokio::sync::mpsc::channel(12); - sender - .send(MessageIn::Http(HttpRequestFallback::Streamed { - request: HttpRequest { - internal_request: InternalHttpRequest { - method: Method::POST, - uri: "/".parse().unwrap(), - headers: Default::default(), - version: Version::HTTP_11, - body: StreamingBody::new(rx), - }, - connection_id: 1, - request_id: 2, - port: 3, - }, - retries: 0, - })) - .await; - let (connection, _peer_addr) = listener.accept().await.unwrap(); - - let tx = Mutex::new(Some(tx)); - let notifier = Arc::new(Notify::default()); - let finished = notifier.notified(); - - let service = service_fn(|mut req: Request| { - let tx = tx.lock().unwrap().take().unwrap(); - let notifier = notifier.clone(); - async move { - let x = req.body_mut().frame().now_or_never(); - assert!(x.is_none()); - tx.send(mirrord_protocol::tcp::InternalHttpBodyFrame::Data( - b"string".to_vec(), - )) - .await - .unwrap(); - let x = req - .body_mut() - .frame() - .await - .unwrap() - .unwrap() - .into_data() - .unwrap(); - assert_eq!(x, b"string".to_vec()); - let x = req.body_mut().frame().now_or_never(); - assert!(x.is_none()); - - tx.send(mirrord_protocol::tcp::InternalHttpBodyFrame::Data( - b"another_string".to_vec(), - )) - .await - .unwrap(); - let x = req - .body_mut() - .frame() - .await - .unwrap() - .unwrap() - .into_data() - .unwrap(); - assert_eq!(x, b"another_string".to_vec()); - - drop(tx); - let x = req.body_mut().frame().await; - assert!(x.is_none()); - - notifier.notify_waiters(); - Ok::<_, hyper::Error>(Response::new(Empty::::new())) - } - }); - let conn = http1::Builder::new().serve_connection(TokioIo::new(connection), service); - - tokio::select! { - result = conn => { - result.unwrap() - } - _ = finished => { - - } - } - } - - /// Checks that [`hyper`] and [`h2`] crate versions are in sync with each other. - /// - /// As we use `source.downcast_ref::` to drill down on [`h2`] errors from - /// [`hyper`], we need these two crates to stay in sync, otherwise we could always - /// fail some of our checks that rely on this `downcast` working. - /// - /// Even though we're using [`h2::Error::is_reset`] in intproxy, this test can be - /// for any error, and thus here we do it for [`h2::Error::is_go_away`] which is - /// easier to trigger. - #[tokio::test] - async fn hyper_and_h2_versions_in_sync() { - let notify = Arc::new(Notify::new()); - let wait_notify = notify.clone(); - - tokio::spawn(async move { - let listener = TcpListener::bind("127.0.0.1:6666").await.unwrap(); - - notify.notify_waiters(); - let (io, _) = listener.accept().await.unwrap(); - - if let Err(fail) = hyper::server::conn::http2::Builder::new(TokioExecutor::default()) - .serve_connection( - TokioIo::new(io), - service_fn(|_| async move { - Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Heresy!")))) - }), - ) - .await - { - assert!(fail - .source() - .and_then(|source| source.downcast_ref::()) - .is_some_and(h2::Error::is_go_away)); - } else { - panic!( - r"The request is supposed to fail with `GO_AWAY`! - Something is wrong if it didn't! - - >> If you're seeing this error, the cause is likely that `hyper` and `h2` - versions are out of sync, and we can't have that due to our use of - `downcast_ref` on some `h2` errors!" - ); - } - }); - - // Wait for the listener to be ready for our connection. - wait_notify.notified().await; - - assert!(reqwest::get("https://127.0.0.1:6666").await.is_err()); - } -} diff --git a/mirrord/intproxy/src/proxies/incoming/metadata_store.rs b/mirrord/intproxy/src/proxies/incoming/metadata_store.rs new file mode 100644 index 00000000000..fc15d7e9d4b --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/metadata_store.rs @@ -0,0 +1,48 @@ +use std::collections::HashMap; + +use mirrord_intproxy_protocol::{ConnMetadataRequest, ConnMetadataResponse}; +use mirrord_protocol::ConnectionId; + +/// Maps local socket address pairs to remote. +/// +/// Allows for extracting the original socket addresses of peers of a remote connection. +#[derive(Default)] +pub struct MetadataStore { + prepared_responses: HashMap, + expected_requests: HashMap, +} + +impl MetadataStore { + /// Retrieves remote addresses for the given pair of local addresses. + /// + /// If the mapping is not found, returns the local addresses unchanged. + pub fn get(&mut self, req: ConnMetadataRequest) -> ConnMetadataResponse { + self.prepared_responses + .remove(&req) + .unwrap_or_else(|| ConnMetadataResponse { + remote_source: req.peer_address, + local_address: req.listener_address.ip(), + }) + } + + /// Adds a new `req`->`res` mapping to this struct. + /// + /// Marks that the mapping is related to the remote connection with the given id. + pub fn expect( + &mut self, + req: ConnMetadataRequest, + connection: ConnectionId, + res: ConnMetadataResponse, + ) { + self.expected_requests.insert(connection, req.clone()); + self.prepared_responses.insert(req, res); + } + + /// Clears mapping related to the remote connection with the given id. + pub fn no_longer_expect(&mut self, connection: ConnectionId) { + let Some(req) = self.expected_requests.remove(&connection) else { + return; + }; + self.prepared_responses.remove(&req); + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs index e928be69ace..58d92011076 100644 --- a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs +++ b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs @@ -3,7 +3,7 @@ use mirrord_intproxy_protocol::PortSubscription; use mirrord_protocol::{ tcp::{LayerTcp, LayerTcpSteal, StealType}, - ClientMessage, ConnectionId, Port, + ClientMessage, Port, }; /// Retrieves subscribed port from the given [`StealType`]. @@ -26,9 +26,6 @@ pub trait PortSubscriptionExt { /// Returns an unsubscribe request to be sent to the agent. fn wrap_agent_unsubscribe(&self) -> ClientMessage; - - /// Returns an unsubscribe connection request to be sent to the agent. - fn wrap_agent_unsubscribe_connection(&self, connection_id: ConnectionId) -> ClientMessage; } impl PortSubscriptionExt for PortSubscription { @@ -58,14 +55,4 @@ impl PortSubscriptionExt for PortSubscription { } } } - - /// [`LayerTcp::ConnectionUnsubscribe`] or [`LayerTcpSteal::ConnectionUnsubscribe`]. - fn wrap_agent_unsubscribe_connection(&self, connection_id: ConnectionId) -> ClientMessage { - match self { - Self::Mirror(..) => ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe(connection_id)), - Self::Steal(..) => { - ClientMessage::TcpSteal(LayerTcpSteal::ConnectionUnsubscribe(connection_id)) - } - } - } } diff --git a/mirrord/intproxy/src/proxies/incoming/subscriptions.rs b/mirrord/intproxy/src/proxies/incoming/subscriptions.rs index eb2e0a718c5..be9dcbdb315 100644 --- a/mirrord/intproxy/src/proxies/incoming/subscriptions.rs +++ b/mirrord/intproxy/src/proxies/incoming/subscriptions.rs @@ -251,6 +251,7 @@ impl SubscriptionsManager { Ok(subscription.confirm()) } + Err(ResponseError::PortAlreadyStolen(port)) => { let Some(subscription) = self.subscriptions.remove(&port) else { return Ok(vec![]); @@ -264,23 +265,30 @@ impl SubscriptionsManager { } } } + Err( - ref response_err @ ResponseError::Forbidden { - blocked_action: BlockedAction::Steal(ref steal_type), - .. + ref response_error @ ResponseError::Forbidden { + ref blocked_action, .. }, ) => { - tracing::warn!("Port subscribe blocked by policy: {response_err}"); - let Some(subscription) = self.subscriptions.remove(&steal_type.get_port()) else { + tracing::warn!(%response_error, "Port subscribe blocked by policy"); + + let port = match blocked_action { + BlockedAction::Steal(steal_type) => steal_type.get_port(), + BlockedAction::Mirror(port) => *port, + }; + let Some(subscription) = self.subscriptions.remove(&port) else { return Ok(vec![]); }; + subscription - .reject(response_err.clone()) - .map_err(|sub|{ - tracing::error!("Subscription {sub:?} was confirmed before, then requested again and blocked by a policy."); - IncomingProxyError::SubscriptionFailed(response_err.clone()) + .reject(response_error.clone()) + .map_err(|subscription|{ + tracing::error!(?subscription, "Subscription was confirmed before, then requested again and blocked by a policy."); + IncomingProxyError::SubscriptionFailed(response_error.clone()) }) } + Err(err) => Err(IncomingProxyError::SubscriptionFailed(err)), } } diff --git a/mirrord/intproxy/src/proxies/incoming/tasks.rs b/mirrord/intproxy/src/proxies/incoming/tasks.rs new file mode 100644 index 00000000000..49a315636c5 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/tasks.rs @@ -0,0 +1,113 @@ +use std::{convert::Infallible, fmt, io}; + +use hyper::{upgrade::OnUpgrade, Version}; +use mirrord_protocol::{ + tcp::{ChunkedResponse, HttpResponse, InternalHttpBody}, + ConnectionId, Port, RequestId, +}; +use thiserror::Error; + +/// Messages produced by the [`BackgroundTask`](crate::background_tasks::BackgroundTask)s used in +/// the [`IncomingProxy`](super::IncomingProxy). +pub enum InProxyTaskMessage { + /// Produced by the [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) in steal mode. + Tcp( + /// Data received from the local application. + Vec, + ), + /// Produced by the [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask). + Http( + /// HTTP spefiic message. + HttpOut, + ), +} + +impl fmt::Debug for InProxyTaskMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Tcp(data) => f + .debug_tuple("Tcp") + .field(&format_args!("{} bytes", data.len())) + .finish(), + Self::Http(msg) => f.debug_tuple("Http").field(msg).finish(), + } + } +} + +/// Messages produced by the [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask). +#[derive(Debug)] +pub enum HttpOut { + /// Response from the local application's HTTP server. + ResponseBasic(HttpResponse>), + /// Response from the local application's HTTP server. + ResponseFramed(HttpResponse), + /// Response from the local application's HTTP server. + ResponseChunked(ChunkedResponse), + /// Upgraded HTTP connection, to be handled as a remote connection stolen without any filter. + Upgraded(OnUpgrade), +} + +impl From> for InProxyTaskMessage { + fn from(value: Vec) -> Self { + Self::Tcp(value) + } +} + +impl From for InProxyTaskMessage { + fn from(value: HttpOut) -> Self { + Self::Http(value) + } +} + +/// Errors that can occur in the [`BackgroundTask`](crate::background_tasks::BackgroundTask)s used +/// in the [`IncomingProxy`](super::IncomingProxy). +/// +/// All of these can occur only in the [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) +/// and mean that the local connection is irreversibly broken. +/// The [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask) produces no errors +/// and instead responds with an error HTTP response to the agent. +/// +/// However, due to [`BackgroundTasks`](crate::background_tasks::BackgroundTasks) +/// type constraints, we need a common error type. +/// Thus, this type implements [`From`]. +#[derive(Error, Debug)] +pub enum InProxyTaskError { + #[error("io failed: {0}")] + IoError(#[from] io::Error), + #[error("local HTTP upgrade failed: {0}")] + UpgradeError(#[source] hyper::Error), +} + +impl From for InProxyTaskError { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +/// Types of [`BackgroundTask`](crate::background_tasks::BackgroundTask)s used in the +/// [`IncomingProxy`](super::IncomingProxy). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum InProxyTask { + /// [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) handling a mirrored connection. + MirrorTcpProxy(ConnectionId), + /// [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) handling a stolen connection. + StealTcpProxy(ConnectionId), + /// [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask) handling a stolen HTTP request. + HttpGateway(HttpGatewayId), +} + +/// Identifies a [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask). +/// +/// ([`ConnectionId`], [`RequestId`]) would suffice, but storing extra data allows us to produce an +/// error response in case the task somehow panics. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct HttpGatewayId { + /// Id of the remote connection. + pub connection_id: ConnectionId, + /// Id of the stolen request. + pub request_id: RequestId, + /// Remote port from which the request was stolen. + pub port: Port, + /// HTTP version of the stolen request. + pub version: Version, +} diff --git a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs new file mode 100644 index 00000000000..16330782a56 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs @@ -0,0 +1,202 @@ +use std::{io::ErrorKind, net::SocketAddr, time::Duration}; + +use bytes::BytesMut; +use hyper::upgrade::OnUpgrade; +use hyper_util::rt::TokioIo; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + time, +}; +use tracing::Level; + +use super::{ + bound_socket::BoundTcpSocket, + tasks::{InProxyTaskError, InProxyTaskMessage}, +}; +use crate::background_tasks::{BackgroundTask, MessageBus}; + +/// Local TCP connections between the [`TcpProxyTask`] and the user application. +#[derive(Debug)] +pub enum LocalTcpConnection { + /// Not yet established. Should be made by the [`TcpProxyTask`] from the given + /// [`BoundTcpSocket`]. + FromTheStart { + socket: BoundTcpSocket, + peer: SocketAddr, + }, + /// Upgraded HTTP connection from a previously stolen HTTP request. + AfterUpgrade(OnUpgrade), +} + +/// [`BackgroundTask`] of [`IncomingProxy`](super::IncomingProxy) that handles a remote +/// stolen/mirrored TCP connection. +/// +/// In steal mode, exits immediately when it's [`TaskSender`](crate::background_tasks::TaskSender) +/// is dropped. +/// +/// In mirror mode, when it's [`TaskSender`](crate::background_tasks::TaskSender) is dropped, +/// this proxy keeps reading data from the user application and exits after +/// [`Self::MIRROR_MODE_LINGER_TIMEOUT`] of silence. +#[derive(Debug)] +pub struct TcpProxyTask { + /// The local connection between this task and the user application. + connection: Option, + /// Whether this task should silently discard data coming from the user application. + /// + /// The data is discarded only when the remote connection is mirrored. + discard_data: bool, +} + +impl TcpProxyTask { + /// Mirror mode only: how long do we wait before exiting after the [`MessageBus`] is closed + /// and user application doesn't send any data. + pub const MIRROR_MODE_LINGER_TIMEOUT: Duration = Duration::from_secs(1); + + /// Creates a new task. + /// + /// * This task will talk with the user application using the given [`LocalTcpConnection`]. + /// * If `discard_data` is set, this task will silently discard all data coming from the user + /// application. + pub fn new(connection: LocalTcpConnection, discard_data: bool) -> Self { + Self { + connection: Some(connection), + discard_data, + } + } +} + +impl BackgroundTask for TcpProxyTask { + type Error = InProxyTaskError; + type MessageIn = Vec; + type MessageOut = InProxyTaskMessage; + + #[tracing::instrument(level = Level::TRACE, name = "tcp_proxy_task_main_loop", skip(message_bus), err(level = Level::WARN))] + async fn run(&mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + let mut stream = match self + .connection + .take() + .expect("task should have a valid connection before run") + { + LocalTcpConnection::FromTheStart { socket, peer } => { + let Some(stream) = message_bus + .closed() + .cancel_on_close(socket.connect(peer)) + .await + else { + return Ok(()); + }; + + stream? + } + + LocalTcpConnection::AfterUpgrade(on_upgrade) => { + let upgraded = on_upgrade.await.map_err(InProxyTaskError::UpgradeError)?; + let parts = upgraded + .downcast::>() + .expect("IO type is known"); + let stream = parts.io.into_inner(); + let read_buf = parts.read_buf; + + if !self.discard_data && !read_buf.is_empty() { + // We don't send empty data, + // because the agent recognizes it as a shutdown from the user application. + message_bus.send(Vec::from(read_buf)).await; + } + + stream + } + }; + + let peer_addr = stream.peer_addr()?; + let self_addr = stream.local_addr()?; + + let mut buf = BytesMut::with_capacity(64 * 1024); + let mut reading_closed = false; + let mut is_lingering = false; + + loop { + tokio::select! { + res = stream.read_buf(&mut buf), if !reading_closed => match res { + Err(e) if e.kind() == ErrorKind::WouldBlock => {}, + Err(e) => break Err(e.into()), + Ok(..) => { + if buf.is_empty() { + reading_closed = true; + + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "The user application shut down its side of the connection", + ) + } else { + tracing::trace!( + data_len = buf.len(), + peer_addr = %peer_addr, + self_addr = %self_addr, + "Received some data from the user application", + ); + } + + if !self.discard_data { + message_bus.send(buf.to_vec()).await; + } + + buf.clear(); + } + }, + + msg = message_bus.recv(), if !is_lingering => match msg { + None if self.discard_data => { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "Message bus closed, waiting until the connection is silent", + ); + + is_lingering = true; + } + None => { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "Message bus closed, exiting", + ); + + break Ok(()); + } + Some(data) => { + if data.is_empty() { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "The agent shut down its side of the connection", + ); + + stream.shutdown().await?; + } else { + tracing::trace!( + data_len = data.len(), + peer_addr = %peer_addr, + self_addr = %self_addr, + "Received some data from the agent", + ); + + stream.write_all(&data).await?; + } + }, + }, + + _ = time::sleep(Self::MIRROR_MODE_LINGER_TIMEOUT), if is_lingering => { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "Message bus is closed and the connection is silent, exiting", + ); + + break Ok(()); + } + } + } + } +} diff --git a/mirrord/intproxy/src/proxies/simple.rs b/mirrord/intproxy/src/proxies/simple.rs index 5b1433114c2..6efa5416865 100644 --- a/mirrord/intproxy/src/proxies/simple.rs +++ b/mirrord/intproxy/src/proxies/simple.rs @@ -5,9 +5,10 @@ use std::collections::HashMap; use mirrord_intproxy_protocol::{LayerId, MessageId, ProxyToLayerMessage}; use mirrord_protocol::{ - dns::{GetAddrInfoRequest, GetAddrInfoResponse}, + dns::{AddressFamily, GetAddrInfoRequestV2, GetAddrInfoResponse, ADDRINFO_V2_VERSION}, ClientMessage, DaemonMessage, GetEnvVarsRequest, RemoteResult, }; +use semver::Version; use thiserror::Error; use crate::{ @@ -20,10 +21,12 @@ use crate::{ #[derive(Debug)] pub enum SimpleProxyMessage { - AddrInfoReq(MessageId, LayerId, GetAddrInfoRequest), + AddrInfoReq(MessageId, LayerId, GetAddrInfoRequestV2), AddrInfoRes(GetAddrInfoResponse), GetEnvReq(MessageId, LayerId, GetEnvVarsRequest), GetEnvRes(RemoteResult>), + /// Protocol version was negotiated with the agent. + ProtocolVersion(Version), } #[derive(Error, Debug)] @@ -34,10 +37,27 @@ pub struct SimpleProxyError(#[from] UnexpectedAgentMessage); /// Run as a [`BackgroundTask`]. #[derive(Default)] pub struct SimpleProxy { - /// For [`GetAddrInfoRequest`]s. + /// For [`GetAddrInfoRequestV2`]s. addr_info_reqs: RequestQueue, /// For [`GetEnvVarsRequest`]s. get_env_reqs: RequestQueue, + /// [`mirrord_protocol`] version negotiated with the agent. + /// Determines whether we can use `GetAddrInfoRequestV2`. + protocol_version: Option, +} + +impl SimpleProxy { + #[tracing::instrument(skip(self), level = tracing::Level::TRACE)] + fn set_protocol_version(&mut self, version: Version) { + self.protocol_version.replace(version); + } + + /// Returns whether [`mirrord_protocol`] version allows for a V2 addrinfo request. + fn addr_info_v2(&self) -> bool { + self.protocol_version + .as_ref() + .is_some_and(|version| ADDRINFO_V2_VERSION.matches(version)) + } } impl BackgroundTask for SimpleProxy { @@ -52,9 +72,23 @@ impl BackgroundTask for SimpleProxy { match msg { SimpleProxyMessage::AddrInfoReq(message_id, session_id, req) => { self.addr_info_reqs.push_back(message_id, session_id); - message_bus - .send(ClientMessage::GetAddrInfoRequest(req)) - .await; + if self.addr_info_v2() { + message_bus + .send(ClientMessage::GetAddrInfoRequestV2(req)) + .await; + } else { + if matches!(req.family, AddressFamily::Ipv6Only) { + tracing::warn!( + "The agent version you're using does not support DNS\ + queries for IPv6 addresses. This version will only fetch IPv4\ + address. Please update to a newer agent image for better IPv6\ + support." + ) + } + message_bus + .send(ClientMessage::GetAddrInfoRequest(req.into())) + .await; + } } SimpleProxyMessage::AddrInfoRes(res) => { let (message_id, layer_id) = @@ -88,6 +122,7 @@ impl BackgroundTask for SimpleProxy { }) .await } + SimpleProxyMessage::ProtocolVersion(version) => self.set_protocol_version(version), } } diff --git a/mirrord/kube/src/api/container.rs b/mirrord/kube/src/api/container.rs index b87a088a412..a651dc13458 100644 --- a/mirrord/kube/src/api/container.rs +++ b/mirrord/kube/src/api/container.rs @@ -44,10 +44,16 @@ pub struct ContainerParams { /// the agent container. pub tls_cert: Option, pub pod_ips: Option, + /// Support IPv6-only clusters + pub support_ipv6: bool, } impl ContainerParams { - pub fn new(tls_cert: Option, pod_ips: Option) -> ContainerParams { + pub fn new( + tls_cert: Option, + pod_ips: Option, + support_ipv6: bool, + ) -> ContainerParams { let port: u16 = rand::thread_rng().gen_range(30000..=65535); let gid: u16 = rand::thread_rng().gen_range(3000..u16::MAX); @@ -64,6 +70,7 @@ impl ContainerParams { port, tls_cert, pod_ips, + support_ipv6, } } } diff --git a/mirrord/kube/src/api/container/job.rs b/mirrord/kube/src/api/container/job.rs index 907aefffeeb..d9958e6620b 100644 --- a/mirrord/kube/src/api/container/job.rs +++ b/mirrord/kube/src/api/container/job.rs @@ -241,12 +241,14 @@ mod test { fn targetless() -> Result<(), Box> { let mut config_context = ConfigContext::default(); let agent = AgentFileConfig::default().generate_config(&mut config_context)?; + let support_ipv6 = false; let params = ContainerParams { name: "foobar".to_string(), port: 3000, gid: 13, tls_cert: None, pod_ips: None, + support_ipv6, }; let update = JobVariant::new(&agent, ¶ms).as_update(); @@ -298,7 +300,8 @@ mod test { { "name": "RUST_LOG", "value": agent.log_level }, { "name": "MIRRORD_AGENT_STEALER_FLUSH_CONNECTIONS", "value": agent.flush_connections.to_string() }, { "name": "MIRRORD_AGENT_NFTABLES", "value": agent.nftables.to_string() }, - { "name": "MIRRORD_AGENT_JSON_LOG", "value": Some(agent.json_log.to_string()) } + { "name": "MIRRORD_AGENT_JSON_LOG", "value": Some(agent.json_log.to_string()) }, + { "name": "MIRRORD_AGENT_SUPPORT_IPV6", "value": Some(support_ipv6.to_string()) } ], "resources": // Add requests to avoid getting defaulted https://github.com/metalbear-co/mirrord/issues/579 @@ -330,12 +333,14 @@ mod test { fn targeted() -> Result<(), Box> { let mut config_context = ConfigContext::default(); let agent = AgentFileConfig::default().generate_config(&mut config_context)?; + let support_ipv6 = false; let params = ContainerParams { name: "foobar".to_string(), port: 3000, gid: 13, tls_cert: None, pod_ips: None, + support_ipv6, }; let update = JobTargetedVariant::new( @@ -432,7 +437,8 @@ mod test { { "name": "RUST_LOG", "value": agent.log_level }, { "name": "MIRRORD_AGENT_STEALER_FLUSH_CONNECTIONS", "value": agent.flush_connections.to_string() }, { "name": "MIRRORD_AGENT_NFTABLES", "value": agent.nftables.to_string() }, - { "name": "MIRRORD_AGENT_JSON_LOG", "value": Some(agent.json_log.to_string()) } + { "name": "MIRRORD_AGENT_JSON_LOG", "value": Some(agent.json_log.to_string()) }, + { "name": "MIRRORD_AGENT_SUPPORT_IPV6", "value": Some(support_ipv6.to_string()) } ], "resources": // Add requests to avoid getting defaulted https://github.com/metalbear-co/mirrord/issues/579 { diff --git a/mirrord/kube/src/api/container/util.rs b/mirrord/kube/src/api/container/util.rs index 77f917378ce..d40949d268c 100644 --- a/mirrord/kube/src/api/container/util.rs +++ b/mirrord/kube/src/api/container/util.rs @@ -4,7 +4,9 @@ use futures::{AsyncBufReadExt, TryStreamExt}; use k8s_openapi::api::core::v1::{EnvVar, Pod, Toleration}; use kube::{api::LogParams, Api}; use mirrord_config::agent::{AgentConfig, LinuxCapability}; -use mirrord_protocol::{AGENT_NETWORK_INTERFACE_ENV, AGENT_OPERATOR_CERT_ENV}; +use mirrord_protocol::{ + AGENT_IPV6_ENV, AGENT_METRICS_ENV, AGENT_NETWORK_INTERFACE_ENV, AGENT_OPERATOR_CERT_ENV, +}; use regex::Regex; use tracing::warn; @@ -59,6 +61,9 @@ pub(super) fn agent_env(agent: &AgentConfig, params: &&ContainerParams) -> Vec Vec, + support_ipv6: bool, ) -> Result<(ContainerParams, Option), KubeApiError> { let runtime_data = match target.path.as_ref().unwrap_or(&Target::Targetless) { Target::Targetless => None, @@ -187,7 +188,7 @@ impl KubernetesAPI { .join(",") }); - let params = ContainerParams::new(tls_cert, pod_ips); + let params = ContainerParams::new(tls_cert, pod_ips, support_ipv6); Ok((params, runtime_data)) } @@ -209,7 +210,12 @@ impl KubernetesAPI { where P: Progress + Send + Sync, { - let (params, runtime_data) = self.create_agent_params(target, tls_cert).await?; + let support_ipv6 = config + .map(|layer_conf| layer_conf.feature.network.ipv6) + .unwrap_or_default(); + let (params, runtime_data) = self + .create_agent_params(target, tls_cert, support_ipv6) + .await?; if let Some(RuntimeData { guessed_container: true, container_name, diff --git a/mirrord/kube/src/api/kubernetes/seeker.rs b/mirrord/kube/src/api/kubernetes/seeker.rs index b9429610f8e..e5208fcd8b5 100644 --- a/mirrord/kube/src/api/kubernetes/seeker.rs +++ b/mirrord/kube/src/api/kubernetes/seeker.rs @@ -1,6 +1,5 @@ use std::fmt; -use async_stream::stream; use futures::{stream, Stream, StreamExt, TryStreamExt}; use k8s_openapi::{ api::{ @@ -8,17 +7,17 @@ use k8s_openapi::{ batch::v1::{CronJob, Job}, core::v1::Pod, }, - Metadata, NamespaceResourceScope, + ClusterResourceScope, Metadata, NamespaceResourceScope, }; -use kube::{api::ListParams, Resource}; -use serde::de; +use kube::{api::ListParams, Api, Resource}; +use serde::de::{self, DeserializeOwned}; use crate::{ api::{ container::SKIP_NAMES, kubernetes::{get_k8s_resource_api, rollout::Rollout}, }, - error::Result, + error::{KubeApiError, Result}, }; pub struct KubeResourceSeeker<'a> { @@ -95,7 +94,7 @@ impl KubeResourceSeeker<'_> { Some((name, containers)) } - self.list_resource::(Some("status.phase=Running")) + self.list_all_namespaced(Some("status.phase=Running")) .try_filter(|pod| std::future::ready(check_pod_status(pod))) .try_filter_map(|pod| std::future::ready(Ok(create_pod_container_map(pod)))) .map_ok(|(pod, containers)| { @@ -111,6 +110,7 @@ impl KubeResourceSeeker<'_> { .try_flatten() .try_collect() .await + .map_err(KubeApiError::KubeError) } /// The list of deployments that have at least 1 `Replicas` and a deployment name. @@ -123,7 +123,7 @@ impl KubeResourceSeeker<'_> { .unwrap_or(false) } - self.list_resource::(None) + self.list_all_namespaced::(None) .filter(|response| std::future::ready(response.is_ok())) .try_filter(|deployment| std::future::ready(check_deployment_replicas(deployment))) .try_filter_map(|deployment| { @@ -134,60 +134,114 @@ impl KubeResourceSeeker<'_> { }) .try_collect() .await + .map_err(From::from) } - /// Helper to get the list of a resource type ([`Pod`], [`Deployment`], [`Rollout`], [`Job`], - /// [`CronJob`], [`StatefulSet`], or whatever satisfies `R`) through the kube api. - fn list_resource<'s, R>( - &self, - field_selector: Option<&'s str>, - ) -> impl Stream> + 's + async fn simple_list_resource<'s, R>(&self, prefix: &'s str) -> Result> where - R: Clone + fmt::Debug + for<'de> de::Deserialize<'de> + 's, - R: Resource, + R: 'static + + Clone + + fmt::Debug + + for<'de> de::Deserialize<'de> + + Resource + + Metadata + + Send, { - let Self { client, namespace } = self; - let resource_api = get_k8s_resource_api::(client, *namespace); + self.list_all_namespaced::(None) + .filter(|response| std::future::ready(response.is_ok())) + .try_filter_map(|rollout| { + std::future::ready(Ok(rollout + .meta() + .name + .as_ref() + .map(|name| format!("{prefix}/{name}")))) + }) + .try_collect() + .await + .map_err(From::from) + } + + /// Prepares [`ListParams`] that: + /// 1. Excludes our own resources + /// 2. Adds a limit for item count in a response + fn make_list_params(field_selector: Option<&str>) -> ListParams { + ListParams { + label_selector: Some("app!=mirrord,!operator.metalbear.co/owner".to_string()), + field_selector: field_selector.map(ToString::to_string), + limit: Some(500), + ..Default::default() + } + } - stream! { - let mut params = ListParams { - label_selector: Some("app!=mirrord,!operator.metalbear.co/owner".to_string()), - field_selector: field_selector.map(ToString::to_string), - limit: Some(500), - ..Default::default() - }; + /// Returns a [`Stream`] of all objects in this [`KubeResourceSeeker`]'s namespace. + /// + /// 1. `field_selector` can be used for filtering. + /// 2. Our own resources are excluded. + pub fn list_all_namespaced( + &self, + field_selector: Option<&str>, + ) -> impl 'static + Stream> + Send + where + R: 'static + + Resource + + fmt::Debug + + Clone + + DeserializeOwned + + Send, + { + let api = get_k8s_resource_api(self.client, self.namespace); + let mut params = Self::make_list_params(field_selector); + async_stream::stream! { loop { - let resource = resource_api.list(¶ms).await?; + let response = api.list(¶ms).await?; - for resource in resource.items { + for resource in response.items { yield Ok(resource); } - if let Some(continue_token) = resource.metadata.continue_ && !continue_token.is_empty() { - params = params.continue_token(&continue_token); - } else { + let continue_token = response.metadata.continue_.unwrap_or_default(); + if continue_token.is_empty() { break; } + params.continue_token.replace(continue_token); } } } - async fn simple_list_resource<'s, R>(&self, prefix: &'s str) -> Result> + /// Returns a [`Stream`] of all objects in the cluster. + /// + /// 1. `field_selector` can be used for filtering. + /// 2. Our own resources are excluded. + pub fn list_all_clusterwide( + &self, + field_selector: Option<&str>, + ) -> impl 'static + Stream> + Send where - R: Clone + fmt::Debug + for<'de> de::Deserialize<'de>, - R: Resource + Metadata, + R: 'static + + Resource + + fmt::Debug + + Clone + + DeserializeOwned + + Send, { - self.list_resource::(None) - .filter(|response| std::future::ready(response.is_ok())) - .try_filter_map(|rollout| { - std::future::ready(Ok(rollout - .meta() - .name - .as_ref() - .map(|name| format!("{prefix}/{name}")))) - }) - .try_collect() - .await + let api = Api::all(self.client.clone()); + let mut params = Self::make_list_params(field_selector); + + async_stream::stream! { + loop { + let response = api.list(¶ms).await?; + + for resource in response.items { + yield Ok(resource); + } + + let continue_token = response.metadata.continue_.unwrap_or_default(); + if continue_token.is_empty() { + break; + } + params.continue_token.replace(continue_token); + } + } } } diff --git a/mirrord/kube/src/api/runtime.rs b/mirrord/kube/src/api/runtime.rs index a77afd1b5b5..a431a3b4c03 100644 --- a/mirrord/kube/src/api/runtime.rs +++ b/mirrord/kube/src/api/runtime.rs @@ -3,7 +3,7 @@ use std::{ collections::BTreeMap, convert::Infallible, fmt::{self, Display, Formatter}, - net::Ipv4Addr, + net::IpAddr, ops::FromResidual, str::FromStr, }; @@ -71,7 +71,7 @@ impl Display for ContainerRuntime { #[derive(Debug)] pub struct RuntimeData { pub pod_name: String, - pub pod_ips: Vec, + pub pod_ips: Vec, pub pod_namespace: Option, pub node_name: String, pub container_id: String, @@ -128,9 +128,9 @@ impl RuntimeData { .filter_map(|pod_ip| { pod_ip .ip - .parse::() + .parse::() .inspect_err(|e| { - tracing::warn!("failed to parse pod IP {ip}: {e:?}", ip = pod_ip.ip); + tracing::warn!("failed to parse pod IP {ip}: {e:#?}", ip = pod_ip.ip); }) .ok() }) diff --git a/mirrord/layer/src/detour.rs b/mirrord/layer/src/detour.rs index a89e79ce0f6..92a014d20dd 100644 --- a/mirrord/layer/src/detour.rs +++ b/mirrord/layer/src/detour.rs @@ -215,6 +215,10 @@ pub(crate) enum Bypass { /// Useful for operations that are version gated, and we want to bypass when the protocol /// doesn't support them. NotImplemented, + + /// File `open` (any `open`-ish operation) was forced to be local, instead of remote, most + /// likely due to an operator fs policy. + OpenLocal, } impl Bypass { diff --git a/mirrord/layer/src/error.rs b/mirrord/layer/src/error.rs index ea0cbabe3c8..da4797a1916 100644 --- a/mirrord/layer/src/error.rs +++ b/mirrord/layer/src/error.rs @@ -248,6 +248,7 @@ impl From for i64 { HookError::BincodeEncode(_) => libc::EINVAL, HookError::ResponseError(response_fail) => match response_fail { ResponseError::IdsExhausted(_) => libc::ENOMEM, + ResponseError::OpenLocal => libc::ENOENT, ResponseError::NotFound(_) => libc::ENOENT, ResponseError::NotDirectory(_) => libc::ENOTDIR, ResponseError::NotFile(_) => libc::EISDIR, diff --git a/mirrord/layer/src/file/hooks.rs b/mirrord/layer/src/file/hooks.rs index 4de1e73577f..3fcfc3c1280 100644 --- a/mirrord/layer/src/file/hooks.rs +++ b/mirrord/layer/src/file/hooks.rs @@ -904,8 +904,9 @@ unsafe extern "C" fn fstatat_detour( }) } +/// Hook for `libc::fstatfs`. #[hook_guard_fn] -unsafe extern "C" fn fstatfs_detour(fd: c_int, out_stat: *mut statfs) -> c_int { +pub(crate) unsafe extern "C" fn fstatfs_detour(fd: c_int, out_stat: *mut statfs) -> c_int { if out_stat.is_null() { return HookError::BadPointer.into(); } @@ -919,6 +920,25 @@ unsafe extern "C" fn fstatfs_detour(fd: c_int, out_stat: *mut statfs) -> c_int { .unwrap_or_bypass_with(|_| FN_FSTATFS(fd, out_stat)) } +/// Hook for `libc::statfs`. +#[hook_guard_fn] +pub(crate) unsafe extern "C" fn statfs_detour( + raw_path: *const c_char, + out_stat: *mut statfs, +) -> c_int { + if out_stat.is_null() { + return HookError::BadPointer.into(); + } + + crate::file::ops::statfs(raw_path.checked_into()) + .map(|res| { + let res = res.metadata; + fill_statfs(out_stat, &res); + 0 + }) + .unwrap_or_bypass_with(|_| FN_STATFS(raw_path, out_stat)) +} + unsafe fn realpath_logic( source_path: *const c_char, output_path: *mut c_char, @@ -1088,6 +1108,43 @@ pub(crate) unsafe extern "C" fn mkdirat_detour( }) } +/// Hook for `libc::rmdir`. +#[hook_guard_fn] +pub(crate) unsafe extern "C" fn rmdir_detour(pathname: *const c_char) -> c_int { + rmdir(pathname.checked_into()) + .map(|()| 0) + .unwrap_or_bypass_with(|bypass| { + let raw_path = update_ptr_from_bypass(pathname, &bypass); + FN_RMDIR(raw_path) + }) +} + +/// Hook for `libc::unlink`. +#[hook_guard_fn] +pub(crate) unsafe extern "C" fn unlink_detour(pathname: *const c_char) -> c_int { + unlink(pathname.checked_into()) + .map(|()| 0) + .unwrap_or_bypass_with(|bypass| { + let raw_path = update_ptr_from_bypass(pathname, &bypass); + FN_UNLINK(raw_path) + }) +} + +/// Hook for `libc::unlinkat`. +#[hook_guard_fn] +pub(crate) unsafe extern "C" fn unlinkat_detour( + dirfd: c_int, + pathname: *const c_char, + flags: u32, +) -> c_int { + unlinkat(dirfd, pathname.checked_into(), flags) + .map(|()| 0) + .unwrap_or_bypass_with(|bypass| { + let raw_path = update_ptr_from_bypass(pathname, &bypass); + FN_UNLINKAT(dirfd, raw_path, flags) + }) +} + /// Convenience function to setup file hooks (`x_detour`) with `frida_gum`. pub(crate) unsafe fn enable_file_hooks(hook_manager: &mut HookManager) { replace!(hook_manager, "open", open_detour, FnOpen, FN_OPEN); @@ -1163,7 +1220,6 @@ pub(crate) unsafe fn enable_file_hooks(hook_manager: &mut HookManager) { ); replace!(hook_manager, "mkdir", mkdir_detour, FnMkdir, FN_MKDIR); - replace!( hook_manager, "mkdirat", @@ -1172,6 +1228,17 @@ pub(crate) unsafe fn enable_file_hooks(hook_manager: &mut HookManager) { FN_MKDIRAT ); + replace!(hook_manager, "rmdir", rmdir_detour, FnRmdir, FN_RMDIR); + + replace!(hook_manager, "unlink", unlink_detour, FnUnlink, FN_UNLINK); + replace!( + hook_manager, + "unlinkat", + unlinkat_detour, + FnUnlinkat, + FN_UNLINKAT + ); + replace!(hook_manager, "lseek", lseek_detour, FnLseek, FN_LSEEK); replace!(hook_manager, "write", write_detour, FnWrite, FN_WRITE); @@ -1286,6 +1353,8 @@ pub(crate) unsafe fn enable_file_hooks(hook_manager: &mut HookManager) { FnFstatfs, FN_FSTATFS ); + replace!(hook_manager, "statfs", statfs_detour, FnStatfs, FN_STATFS); + replace!( hook_manager, "fdopendir", @@ -1368,6 +1437,13 @@ pub(crate) unsafe fn enable_file_hooks(hook_manager: &mut HookManager) { FnFstatfs, FN_FSTATFS ); + replace!( + hook_manager, + "statfs$INODE64", + statfs_detour, + FnStatfs, + FN_STATFS + ); replace!( hook_manager, "fdopendir$INODE64", diff --git a/mirrord/layer/src/file/ops.rs b/mirrord/layer/src/file/ops.rs index 580f8eacc8a..8ca2401101f 100644 --- a/mirrord/layer/src/file/ops.rs +++ b/mirrord/layer/src/file/ops.rs @@ -4,17 +4,20 @@ use std::{env, ffi::CString, io::SeekFrom, os::unix::io::RawFd, path::PathBuf}; #[cfg(target_os = "linux")] use libc::{c_char, statx, statx_timestamp}; -use libc::{c_int, iovec, unlink, AT_FDCWD}; +use libc::{c_int, iovec, AT_FDCWD}; use mirrord_protocol::{ file::{ MakeDirAtRequest, MakeDirRequest, OpenFileRequest, OpenFileResponse, OpenOptionsInternal, - ReadFileResponse, ReadLinkFileRequest, ReadLinkFileResponse, SeekFileResponse, - WriteFileResponse, XstatFsResponse, XstatResponse, + ReadFileResponse, ReadLinkFileRequest, ReadLinkFileResponse, RemoveDirRequest, + SeekFileResponse, StatFsRequest, UnlinkAtRequest, WriteFileResponse, XstatFsResponse, + XstatResponse, }, ResponseError, }; use rand::distributions::{Alphanumeric, DistString}; -use tracing::{error, trace, Level}; +#[cfg(debug_assertions)] +use tracing::Level; +use tracing::{error, trace}; use super::{hooks::FN_OPEN, open_dirs::OPEN_DIRS, *}; #[cfg(target_os = "linux")] @@ -157,7 +160,7 @@ fn create_local_fake_file(remote_fd: u64) -> Detour { close_remote_file_on_failure(remote_fd)?; Detour::Error(HookError::LocalFileCreation(remote_fd, error.0)) } else { - unsafe { unlink(file_path_ptr) }; + unsafe { libc::unlink(file_path_ptr) }; Detour::Success(local_file_fd) } } @@ -206,7 +209,12 @@ pub(crate) fn open(path: Detour, open_options: OpenOptionsInternal) -> ensure_not_ignored!(path, open_options.is_write()); - let OpenFileResponse { fd: remote_fd } = RemoteFile::remote_open(path.clone(), open_options)?; + let OpenFileResponse { fd: remote_fd } = RemoteFile::remote_open(path.clone(), open_options) + .or_else(|fail| match fail { + // The operator has a policy that matches this `path` as local-only. + HookError::ResponseError(ResponseError::OpenLocal) => Detour::Bypass(Bypass::OpenLocal), + other => Detour::Error(other), + })?; // TODO: Need a way to say "open a directory", right now `is_dir` always returns false. // This requires having a fake directory name (`/fake`, for example), instead of just converting @@ -387,6 +395,69 @@ pub(crate) fn mkdirat(dirfd: RawFd, pathname: Detour, mode: u32) -> Det } } +#[mirrord_layer_macro::instrument(level = Level::TRACE, ret)] +pub(crate) fn rmdir(pathname: Detour) -> Detour<()> { + let pathname = pathname?; + + check_relative_paths!(pathname); + + let path = remap_path!(pathname); + + ensure_not_ignored!(path, false); + + let rmdir = RemoveDirRequest { pathname: path }; + + // `NotImplemented` error here means that the protocol doesn't support it. + match common::make_proxy_request_with_response(rmdir)? { + Ok(response) => Detour::Success(response), + Err(ResponseError::NotImplemented) => Detour::Bypass(Bypass::NotImplemented), + Err(fail) => Detour::Error(fail.into()), + } +} + +#[mirrord_layer_macro::instrument(level = Level::TRACE, ret)] +pub(crate) fn unlink(pathname: Detour) -> Detour<()> { + let pathname = pathname?; + + check_relative_paths!(pathname); + + let path = remap_path!(pathname); + + ensure_not_ignored!(path, false); + + let unlink = RemoveDirRequest { pathname: path }; + + // `NotImplemented` error here means that the protocol doesn't support it. + match common::make_proxy_request_with_response(unlink)? { + Ok(response) => Detour::Success(response), + Err(ResponseError::NotImplemented) => Detour::Bypass(Bypass::NotImplemented), + Err(fail) => Detour::Error(fail.into()), + } +} + +#[mirrord_layer_macro::instrument(level = Level::TRACE, ret)] +pub(crate) fn unlinkat(dirfd: RawFd, pathname: Detour, flags: u32) -> Detour<()> { + let pathname = pathname?; + + let optional_dirfd = match pathname.is_absolute() { + true => None, + false => Some(get_remote_fd(dirfd)?), + }; + + let unlink = UnlinkAtRequest { + dirfd: optional_dirfd, + pathname: pathname.clone(), + flags, + }; + + // `NotImplemented` error here means that the protocol doesn't support it. + match common::make_proxy_request_with_response(unlink)? { + Ok(response) => Detour::Success(response), + Err(ResponseError::NotImplemented) => Detour::Bypass(Bypass::NotImplemented), + Err(fail) => Detour::Error(fail.into()), + } +} + pub(crate) fn pwrite(local_fd: RawFd, buffer: &[u8], offset: u64) -> Detour { let remote_fd = get_remote_fd(local_fd)?; trace!("pwrite: local_fd {local_fd}"); @@ -666,6 +737,16 @@ pub(crate) fn xstatfs(fd: RawFd) -> Detour { Detour::Success(response) } +#[mirrord_layer_macro::instrument(level = "trace")] +pub(crate) fn statfs(path: Detour) -> Detour { + let path = path?; + let lstatfs = StatFsRequest { path }; + + let response = common::make_proxy_request_with_response(lstatfs)??; + + Detour::Success(response) +} + #[cfg(target_os = "linux")] #[mirrord_layer_macro::instrument(level = "trace")] pub(crate) fn getdents64(fd: RawFd, buffer_size: u64) -> Detour { diff --git a/mirrord/layer/src/go/linux_x64.rs b/mirrord/layer/src/go/linux_x64.rs index 5acaceb13b2..622a24383d7 100644 --- a/mirrord/layer/src/go/linux_x64.rs +++ b/mirrord/layer/src/go/linux_x64.rs @@ -340,10 +340,17 @@ unsafe extern "C" fn c_abi_syscall_handler( faccessat_detour(param1 as _, param2 as _, param3 as _, 0) as i64 } libc::SYS_fstat => fstat_detour(param1 as _, param2 as _) as i64, + libc::SYS_statfs => statfs_detour(param1 as _, param2 as _) as i64, + libc::SYS_fstatfs => fstatfs_detour(param1 as _, param2 as _) as i64, libc::SYS_getdents64 => getdents64_detour(param1 as _, param2 as _, param3 as _) as i64, #[cfg(all(target_os = "linux", not(target_arch = "aarch64")))] libc::SYS_mkdir => mkdir_detour(param1 as _, param2 as _) as i64, libc::SYS_mkdirat => mkdirat_detour(param1 as _, param2 as _, param3 as _) as i64, + #[cfg(all(target_os = "linux", not(target_arch = "aarch64")))] + libc::SYS_rmdir => rmdir_detour(param1 as _) as i64, + #[cfg(all(target_os = "linux", not(target_arch = "aarch64")))] + libc::SYS_unlink => unlink_detour(param1 as _) as i64, + libc::SYS_unlinkat => unlinkat_detour(param1 as _, param2 as _, param3 as _) as i64, _ => { let (Ok(result) | Err(result)) = syscalls::syscall!( syscalls::Sysno::from(syscall as i32), diff --git a/mirrord/layer/src/go/mod.rs b/mirrord/layer/src/go/mod.rs index 003eed8692c..6a28c3ebfd9 100644 --- a/mirrord/layer/src/go/mod.rs +++ b/mirrord/layer/src/go/mod.rs @@ -101,6 +101,8 @@ unsafe extern "C" fn c_abi_syscall6_handler( .into() } libc::SYS_fstat => fstat_detour(param1 as _, param2 as _) as i64, + libc::SYS_statfs => statfs_detour(param1 as _, param2 as _) as i64, + libc::SYS_fstatfs => fstatfs_detour(param1 as _, param2 as _) as i64, libc::SYS_fsync => fsync_detour(param1 as _) as i64, libc::SYS_fdatasync => fsync_detour(param1 as _) as i64, libc::SYS_openat => { @@ -113,6 +115,11 @@ unsafe extern "C" fn c_abi_syscall6_handler( #[cfg(all(target_os = "linux", not(target_arch = "aarch64")))] libc::SYS_mkdir => mkdir_detour(param1 as _, param2 as _) as i64, libc::SYS_mkdirat => mkdirat_detour(param1 as _, param2 as _, param3 as _) as i64, + #[cfg(all(target_os = "linux", not(target_arch = "aarch64")))] + libc::SYS_rmdir => rmdir_detour(param1 as _) as i64, + #[cfg(all(target_os = "linux", not(target_arch = "aarch64")))] + libc::SYS_unlink => unlink_detour(param1 as _) as i64, + libc::SYS_unlinkat => unlinkat_detour(param1 as _, param2 as _, param3 as _) as i64, _ => { let (Ok(result) | Err(result)) = syscalls::syscall!( syscalls::Sysno::from(syscall as i32), diff --git a/mirrord/layer/src/socket.rs b/mirrord/layer/src/socket.rs index 79c197d408d..5c2d03f8ace 100644 --- a/mirrord/layer/src/socket.rs +++ b/mirrord/layer/src/socket.rs @@ -439,10 +439,22 @@ impl ProtocolAndAddressFilterExt for ProtocolAndAddressFilter { return Ok(false); } + let family = if address.is_ipv4() { + libc::AF_INET + } else { + libc::AF_INET6 + }; + + let addr_protocol = if matches!(protocol, NetProtocol::Stream) { + libc::SOCK_STREAM + } else { + libc::SOCK_DGRAM + }; + match &self.address { AddressFilter::Name(name, port) => { let resolved_ips = if crate::setup().remote_dns_enabled() && !force_local_dns { - match remote_getaddrinfo(name.to_string()) { + match remote_getaddrinfo(name.to_string(), *port, 0, family, 0, addr_protocol) { Ok(res) => res.into_iter().map(|(_, ip)| ip).collect(), Err(HookError::ResponseError(ResponseError::DnsLookup( DnsLookupError { diff --git a/mirrord/layer/src/socket/ops.rs b/mirrord/layer/src/socket/ops.rs index 543a512629c..a8a6e49e79a 100644 --- a/mirrord/layer/src/socket/ops.rs +++ b/mirrord/layer/src/socket/ops.rs @@ -4,6 +4,7 @@ use std::{ collections::HashMap, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpStream}, + ops::Not, os::{ fd::{BorrowedFd, FromRawFd, IntoRawFd}, unix::io::RawFd, @@ -21,7 +22,7 @@ use mirrord_intproxy_protocol::{ OutgoingConnectResponse, PortSubscribe, }; use mirrord_protocol::{ - dns::{GetAddrInfoRequest, LookupRecord}, + dns::{AddressFamily, GetAddrInfoRequestV2, LookupRecord, SockType}, file::{OpenFileResponse, OpenOptionsInternal, ReadFileResponse}, }; use nix::sys::socket::{sockopt, SockaddrIn, SockaddrIn6, SockaddrLike, SockaddrStorage}; @@ -129,7 +130,7 @@ pub(super) fn socket(domain: c_int, type_: c_int, protocol: c_int) -> Detour bool { /// If the socket is not found in [`SOCKETS`], bypass. /// Otherwise, if it's not an ignored port, bind (possibly with a fallback to random port) and /// update socket state in [`SOCKETS`]. If it's an ignored port, remove the socket from [`SOCKETS`]. -#[mirrord_layer_macro::instrument(level = Level::TRACE, fields(pid = std::process::id()), ret, skip(raw_address))] +#[mirrord_layer_macro::instrument(level = Level::TRACE, fields(pid = std::process::id()), ret, skip(raw_address) +)] pub(super) fn bind( sockfd: c_int, raw_address: *const sockaddr, @@ -323,9 +325,9 @@ pub(super) fn bind( } }) } - .ok() - .and_then(|(_, address)| address.as_socket()) - .bypass(Bypass::AddressConversion)?; + .ok() + .and_then(|(_, address)| address.as_socket()) + .bypass(Bypass::AddressConversion)?; Arc::get_mut(&mut socket).unwrap().state = SocketState::Bound(Bound { requested_address, @@ -889,8 +891,33 @@ pub(super) fn dup(fd: c_int, dup_fd: i32) -> Result<(), /// /// This function updates the mapping in [`REMOTE_DNS_REVERSE_MAPPING`]. #[mirrord_layer_macro::instrument(level = Level::TRACE, ret, err)] -pub(super) fn remote_getaddrinfo(node: String) -> HookResult> { - let addr_info_list = common::make_proxy_request_with_response(GetAddrInfoRequest { node })?.0?; +pub(super) fn remote_getaddrinfo( + node: String, + service_port: u16, + flags: c_int, + family: c_int, + socktype: c_int, + protocol: c_int, +) -> HookResult> { + let family = match family { + libc::AF_INET => AddressFamily::Ipv4Only, + libc::AF_INET6 => AddressFamily::Ipv6Only, + _ => AddressFamily::Both, + }; + let socktype = match socktype { + libc::SOCK_STREAM => SockType::Stream, + libc::SOCK_DGRAM => SockType::Dgram, + _ => SockType::Any, + }; + let addr_info_list = common::make_proxy_request_with_response(GetAddrInfoRequestV2 { + node, + service_port, + flags, + family, + socktype, + protocol, + })? + .0?; let mut remote_dns_reverse_mapping = REMOTE_DNS_REVERSE_MAPPING.lock()?; addr_info_list.iter().for_each(|lookup| { @@ -945,29 +972,41 @@ pub(super) fn getaddrinfo( Bypass::CStrConversion })? + // TODO: according to the man page, service could also be a service name, it doesn't have to + // be a port number. .and_then(|service| service.parse::().ok()) .unwrap_or(0); - crate::setup().dns_selector().check_query(&node, service)?; + let setup = crate::setup(); + setup.dns_selector().check_query(&node, service)?; + let ipv6_enabled = setup.layer_config().feature.network.ipv6; let raw_hints = raw_hints .cloned() .unwrap_or_else(|| unsafe { mem::zeroed() }); - // TODO(alex): Use more fields from `raw_hints` to respect the user's `getaddrinfo` call. let libc::addrinfo { + ai_family, ai_socktype, ai_protocol, + ai_flags, .. } = raw_hints; // Some apps (gRPC on Python) use `::` to listen on all interfaces, and usually that just means - // resolve on unspecified. So we just return that in IpV4 because we don't support ipv6. - let resolved_addr = if node == "::" { + // resolve on unspecified. So we just return that in IPv4, if IPv6 support is disabled. + let resolved_addr = if ipv6_enabled.not() && (node == "::") { // name is "" because that's what happens in real flow. vec![("".to_string(), IpAddr::V4(Ipv4Addr::UNSPECIFIED))] } else { - remote_getaddrinfo(node.clone())? + remote_getaddrinfo( + node.clone(), + service, + ai_flags, + ai_family, + ai_socktype, + ai_protocol, + )? }; let mut managed_addr_info = MANAGED_ADDRINFO.lock()?; @@ -1066,7 +1105,7 @@ pub(super) fn gethostbyname(raw_name: Option<&CStr>) -> Detour<*mut hostent> { crate::setup().dns_selector().check_query(&name, 0)?; - let hosts_and_ips = remote_getaddrinfo(name.clone())?; + let hosts_and_ips = remote_getaddrinfo(name.clone(), 0, 0, 0, 0, 0)?; // We could `unwrap` here, as this would have failed on the previous conversion. let host_name = CString::new(name)?; diff --git a/mirrord/layer/tests/apps/fileops/go/main.go b/mirrord/layer/tests/apps/fileops/go/main.go index 5973e013d5e..69db336fb3e 100644 --- a/mirrord/layer/tests/apps/fileops/go/main.go +++ b/mirrord/layer/tests/apps/fileops/go/main.go @@ -7,10 +7,21 @@ import ( func main() { tempFile := "/tmp/test_file.txt" - syscall.Open(tempFile, syscall.O_CREAT|syscall.O_WRONLY, 0644) + fd, _ := syscall.Open(tempFile, syscall.O_CREAT|syscall.O_WRONLY, 0644) var stat syscall.Stat_t err := syscall.Stat(tempFile, &stat) if err != nil { panic(err) } + + var statfs syscall.Statfs_t + err = syscall.Statfs(tempFile, &statfs) + if err != nil { + panic(err) + } + + err = syscall.Fstatfs(fd, &statfs) + if err != nil { + panic(err) + } } diff --git a/mirrord/layer/tests/apps/mkdir/mkdir.c b/mirrord/layer/tests/apps/mkdir/mkdir.c index 4253b6c089d..8631ae7a26a 100644 --- a/mirrord/layer/tests/apps/mkdir/mkdir.c +++ b/mirrord/layer/tests/apps/mkdir/mkdir.c @@ -1,6 +1,5 @@ #include #include -#include #include #include diff --git a/mirrord/layer/tests/apps/rmdir/rmdir.c b/mirrord/layer/tests/apps/rmdir/rmdir.c new file mode 100644 index 00000000000..f839e5256e7 --- /dev/null +++ b/mirrord/layer/tests/apps/rmdir/rmdir.c @@ -0,0 +1,20 @@ +#include +#include +#include +#include + +/// Test `rmdir`. +/// +/// Creates a folder and then removes it. +/// +int main() +{ + char *test_dir = "/test_dir"; + int mkdir_result = mkdir(test_dir, 0777); + assert(mkdir_result == 0); + + int rmdir_result = rmdir(test_dir); + assert(rmdir_result == 0); + + return 0; +} diff --git a/mirrord/layer/tests/apps/statfs_fstatfs/statfs_fstatfs.c b/mirrord/layer/tests/apps/statfs_fstatfs/statfs_fstatfs.c new file mode 100644 index 00000000000..6c474cce4e3 --- /dev/null +++ b/mirrord/layer/tests/apps/statfs_fstatfs/statfs_fstatfs.c @@ -0,0 +1,52 @@ +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) && defined(__MACH__) +#include +#include +#else +#include +#endif + +/// Test `statfs / fstatfs`. +/// +/// Gets information about a mounted filesystem +/// +int main() +{ + char *tmp_test_path = "/statfs_fstatfs_test_path"; + mkdir(tmp_test_path, 0777); + + // statfs + struct statfs statfs_buf; + if (statfs(tmp_test_path, &statfs_buf) == -1) + { + perror("statfs failed"); + return EXIT_FAILURE; + } + + // fstatfs + int fd = open(tmp_test_path, O_RDONLY); + + if (fd == -1) + { + perror("Error opening tmp_test_path"); + return 1; + } + + struct statfs fstatfs_buf; + if (fstatfs(fd, &fstatfs_buf) == -1) + { + perror("fstatfs failed"); + close(fd); + return EXIT_FAILURE; + } + + close(fd); + return 0; +} diff --git a/mirrord/layer/tests/common/mod.rs b/mirrord/layer/tests/common/mod.rs index 22945790e4b..a4652433768 100644 --- a/mirrord/layer/tests/common/mod.rs +++ b/mirrord/layer/tests/common/mod.rs @@ -17,7 +17,7 @@ use mirrord_intproxy::{agent_conn::AgentConnection, IntProxy}; use mirrord_protocol::{ file::{ AccessFileRequest, AccessFileResponse, OpenFileRequest, OpenOptionsInternal, - ReadFileRequest, SeekFromInternal, XstatRequest, XstatResponse, + ReadFileRequest, SeekFromInternal, XstatFsResponse, XstatRequest, XstatResponse, }, tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, ClientMessage, DaemonCodec, DaemonMessage, FileRequest, FileResponse, @@ -44,7 +44,7 @@ pub const RUST_OUTGOING_LOCAL: &str = "4.4.4.4:4444"; /// /// We take advantage of how Rust's thread naming scheme for tests to create the log files, /// and if we have no thread name, then we just write the logs to `stderr`. -pub fn init_tracing() -> Result> { +pub fn init_tracing() -> DefaultGuard { let subscriber = tracing_subscriber::fmt() .with_env_filter(EnvFilter::new("mirrord=trace")) .without_time() @@ -61,7 +61,7 @@ pub fn init_tracing() -> Result> { .map(|name| name.replace(':', "_")) { Some(test_name) => { - let mut logs_file = PathBuf::from_str("/tmp/intproxy_logs")?; + let mut logs_file = PathBuf::from("/tmp/intproxy_logs"); #[cfg(target_os = "macos")] logs_file.push("macos"); @@ -71,26 +71,28 @@ pub fn init_tracing() -> Result> { let _ = std::fs::create_dir_all(&logs_file).ok(); logs_file.push(&test_name); - match File::create(logs_file) { + match File::create(&logs_file) { + // Writes the logs to the file. Ok(file) => { + println!("Created intproxy log file: {}", logs_file.display()); let subscriber = subscriber.with_writer(Arc::new(file)).finish(); - - // Writes the logs to a file. - Ok(tracing::subscriber::set_default(subscriber)) + tracing::subscriber::set_default(subscriber) } - Err(_) => { + // File creation failure makes the output go to `stderr`. + Err(error) => { + println!("Failed to create intproxy log file at {}: {error}. Intproxy logs will be flushed to stderr", logs_file.display()); let subscriber = subscriber.with_writer(io::stderr).finish(); - - // File creation failure makes the output go to `stderr`. - Ok(tracing::subscriber::set_default(subscriber)) + tracing::subscriber::set_default(subscriber) } } } + // No thread name makes the output go to `stderr`. None => { + println!( + "Failed to obtain current thread name, intproxy logs will be flushed to stderr" + ); let subscriber = subscriber.with_writer(io::stderr).finish(); - - // No thread name makes the output go to `stderr`. - Ok(tracing::subscriber::set_default(subscriber)) + tracing::subscriber::set_default(subscriber) } } } @@ -487,6 +489,67 @@ impl TestIntProxy { .unwrap(); } + /// Makes a [`FileRequest::Statefs`] and answers it. + pub async fn expect_statfs(&mut self, expected_path: &str) { + // Expecting `statfs` call with path. + assert_matches!( + self.recv().await, + ClientMessage::FileRequest(FileRequest::StatFs( + mirrord_protocol::file::StatFsRequest { path } + )) if path.to_str().unwrap() == expected_path + ); + + // Answer `statfs`. + self.codec + .send(DaemonMessage::File(FileResponse::XstatFs(Ok( + XstatFsResponse { + metadata: Default::default(), + }, + )))) + .await + .unwrap(); + } + + /// Makes a [`FileRequest::Xstatefs`] and answers it. + pub async fn expect_fstatfs(&mut self, expected_fd: u64) { + // Expecting `fstatfs` call with path. + assert_matches!( + self.recv().await, + ClientMessage::FileRequest(FileRequest::XstatFs( + mirrord_protocol::file::XstatFsRequest { fd } + )) if expected_fd == fd + ); + + // Answer `fstatfs`. + self.codec + .send(DaemonMessage::File(FileResponse::XstatFs(Ok( + XstatFsResponse { + metadata: Default::default(), + }, + )))) + .await + .unwrap(); + } + + /// Makes a [`FileRequest::RemoveDir`] and answers it. + pub async fn expect_remove_dir(&mut self, expected_dir_name: &str) { + // Expecting `rmdir` call with path. + assert_matches!( + self.recv().await, + ClientMessage::FileRequest(FileRequest::RemoveDir( + mirrord_protocol::file::RemoveDirRequest { pathname } + )) if pathname.to_str().unwrap() == expected_dir_name + ); + + // Answer `rmdir`. + self.codec + .send(DaemonMessage::File( + mirrord_protocol::FileResponse::RemoveDir(Ok(())), + )) + .await + .unwrap(); + } + /// Verify that the passed message (not the next message from self.codec!) is a file read. /// Return buffer size. pub async fn expect_message_file_read(message: ClientMessage, expected_fd: u64) -> u64 { @@ -763,6 +826,8 @@ pub enum Application { Fork, ReadLink, MakeDir, + StatfsFstatfs, + RemoveDir, OpenFile, CIssue2055, CIssue2178, @@ -782,9 +847,9 @@ pub enum Application { /// Mode to use when opening the file, accepted as `-m` param. mode: u32, }, - // For running applications with the executable and arguments determined at runtime. + /// For running applications with the executable and arguments determined at runtime. DynamicApp(String, Vec), - // Go app that only checks whether Linux pidfd syscalls are supported. + /// Go app that only checks whether Linux pidfd syscalls are supported. Go23Issue2988, } @@ -819,6 +884,8 @@ impl Application { Application::Fork => String::from("tests/apps/fork/out.c_test_app"), Application::ReadLink => String::from("tests/apps/readlink/out.c_test_app"), Application::MakeDir => String::from("tests/apps/mkdir/out.c_test_app"), + Application::StatfsFstatfs => String::from("tests/apps/statfs_fstatfs/out.c_test_app"), + Application::RemoveDir => String::from("tests/apps/rmdir/out.c_test_app"), Application::Realpath => String::from("tests/apps/realpath/out.c_test_app"), Application::NodeHTTP | Application::NodeIssue2283 | Application::NodeIssue2807 => { String::from("node") @@ -1057,6 +1124,8 @@ impl Application { | Application::Fork | Application::ReadLink | Application::MakeDir + | Application::StatfsFstatfs + | Application::RemoveDir | Application::Realpath | Application::RustFileOps | Application::RustIssue1123 @@ -1135,6 +1204,8 @@ impl Application { | Application::Fork | Application::ReadLink | Application::MakeDir + | Application::StatfsFstatfs + | Application::RemoveDir | Application::Realpath | Application::Go21Issue834 | Application::Go22Issue834 diff --git a/mirrord/layer/tests/dns_resolve.rs b/mirrord/layer/tests/dns_resolve.rs index 508f546ec03..9085b708d0a 100644 --- a/mirrord/layer/tests/dns_resolve.rs +++ b/mirrord/layer/tests/dns_resolve.rs @@ -8,7 +8,7 @@ use rstest::rstest; mod common; pub use common::*; use mirrord_protocol::{ - dns::{DnsLookup, GetAddrInfoRequest, GetAddrInfoResponse, LookupRecord}, + dns::{DnsLookup, GetAddrInfoRequestV2, GetAddrInfoResponse, LookupRecord}, ClientMessage, DaemonMessage, DnsLookupError, ResolveErrorKindInternal::NoRecordsFound, }; @@ -25,7 +25,7 @@ async fn test_dns_resolve( .await; let msg = intproxy.recv().await; - let ClientMessage::GetAddrInfoRequest(GetAddrInfoRequest { node }) = msg else { + let ClientMessage::GetAddrInfoRequestV2(GetAddrInfoRequestV2 { node, .. }) = msg else { panic!("Invalid message received from layer: {msg:?}"); }; @@ -39,7 +39,7 @@ async fn test_dns_resolve( .await; let msg = intproxy.recv().await; - let ClientMessage::GetAddrInfoRequest(GetAddrInfoRequest { node: _ }) = msg else { + let ClientMessage::GetAddrInfoRequestV2(GetAddrInfoRequestV2 { .. }) = msg else { panic!("Invalid message received from layer: {msg:?}"); }; diff --git a/mirrord/layer/tests/fileops.rs b/mirrord/layer/tests/fileops.rs index 5a9ee96f726..de26b318f40 100644 --- a/mirrord/layer/tests/fileops.rs +++ b/mirrord/layer/tests/fileops.rs @@ -44,7 +44,7 @@ async fn self_open( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, vec![], None) @@ -65,7 +65,7 @@ async fn self_open( #[tokio::test] #[timeout(Duration::from_secs(20))] async fn read_from_mirrord_bin(dylib_path: &Path) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let contents = "please don't flake"; let temp_dir = env::temp_dir(); @@ -108,7 +108,7 @@ async fn read_from_mirrord_bin(dylib_path: &Path) { #[tokio::test] #[timeout(Duration::from_secs(60))] async fn pwrite(#[values(Application::RustFileOps)] application: Application, dylib_path: &Path) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); // add rw override for the specific path let (mut test_process, mut intproxy) = application @@ -228,7 +228,7 @@ async fn node_close( #[values(Application::NodeFileOps)] application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer( @@ -295,7 +295,7 @@ async fn go_stat( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); // add rw override for the specific path let (mut test_process, mut intproxy) = application @@ -345,6 +345,9 @@ async fn go_stat( )))) .await; + intproxy.expect_statfs("/tmp/test_file.txt").await; + intproxy.expect_fstatfs(fd).await; + test_process.wait_assert_success().await; test_process.assert_no_error_in_stderr().await; } @@ -358,7 +361,7 @@ async fn go_dir( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer( @@ -478,7 +481,7 @@ async fn go_dir_on_linux( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer( @@ -575,7 +578,7 @@ async fn go_dir_bypass( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let tmp_dir = temp_dir().join("go_dir_bypass_test"); std::fs::create_dir_all(tmp_dir.clone()).unwrap(); @@ -616,7 +619,7 @@ async fn read_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, vec![("MIRRORD_FILE_MODE", "read")], None) @@ -658,7 +661,7 @@ async fn write_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut layer_connection) = application .start_process_with_layer(dylib_path, get_rw_test_file_env_vars(), None) @@ -687,7 +690,7 @@ async fn lseek_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, get_rw_test_file_env_vars(), None) @@ -722,7 +725,7 @@ async fn faccessat_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, get_rw_test_file_env_vars(), None) diff --git a/mirrord/layer/tests/http_mirroring.rs b/mirrord/layer/tests/http_mirroring.rs index e37d433a0e3..96f2230ed18 100644 --- a/mirrord/layer/tests/http_mirroring.rs +++ b/mirrord/layer/tests/http_mirroring.rs @@ -30,6 +30,8 @@ async fn mirroring_with_http( dylib_path: &Path, config_dir: &Path, ) { + let _guard = init_tracing(); + let (mut test_process, mut intproxy) = application .start_process_with_layer_and_port( dylib_path, diff --git a/mirrord/layer/tests/issue2055.rs b/mirrord/layer/tests/issue2055.rs index c34d5e13f25..d24b71fadce 100644 --- a/mirrord/layer/tests/issue2055.rs +++ b/mirrord/layer/tests/issue2055.rs @@ -2,7 +2,7 @@ use std::{net::IpAddr, path::Path, time::Duration}; use mirrord_protocol::{ - dns::{DnsLookup, GetAddrInfoRequest, GetAddrInfoResponse, LookupRecord}, + dns::{DnsLookup, GetAddrInfoRequestV2, GetAddrInfoResponse, LookupRecord}, ClientMessage, DaemonMessage, DnsLookupError, ResolveErrorKindInternal::NoRecordsFound, ResponseError, @@ -23,10 +23,10 @@ async fn issue_2055(dylib_path: &Path) { .start_process_with_layer(dylib_path, vec![("MIRRORD_REMOTE_DNS", "true")], None) .await; - println!("Application started, waiting for `GetAddrInfoRequest`."); + println!("Application started, waiting for `GetAddrInfoRequestV2`."); let msg = intproxy.recv().await; - let ClientMessage::GetAddrInfoRequest(GetAddrInfoRequest { node }) = msg else { + let ClientMessage::GetAddrInfoRequestV2(GetAddrInfoRequestV2 { node, .. }) = msg else { panic!("Invalid message received from layer: {msg:?}"); }; @@ -40,7 +40,7 @@ async fn issue_2055(dylib_path: &Path) { .await; let msg = intproxy.recv().await; - let ClientMessage::GetAddrInfoRequest(GetAddrInfoRequest { node: _ }) = msg else { + let ClientMessage::GetAddrInfoRequestV2(GetAddrInfoRequestV2 { .. }) = msg else { panic!("Invalid message received from layer: {msg:?}"); }; diff --git a/mirrord/layer/tests/issue2283.rs b/mirrord/layer/tests/issue2283.rs index f3d7a2a9b2b..6987bbdffbf 100644 --- a/mirrord/layer/tests/issue2283.rs +++ b/mirrord/layer/tests/issue2283.rs @@ -3,7 +3,7 @@ use std::{assert_matches::assert_matches, net::SocketAddr, path::Path, time::Duration}; use mirrord_protocol::{ - dns::{DnsLookup, GetAddrInfoRequest, GetAddrInfoResponse, LookupRecord}, + dns::{DnsLookup, GetAddrInfoRequestV2, GetAddrInfoResponse, LookupRecord}, outgoing::{ tcp::{DaemonTcpOutgoing, LayerTcpOutgoing}, DaemonConnect, DaemonRead, LayerConnect, SocketAddress, @@ -48,7 +48,7 @@ async fn test_issue2283( } let message = intproxy.recv().await; - assert_matches!(message, ClientMessage::GetAddrInfoRequest(GetAddrInfoRequest { node }) if node == "test-server"); + assert_matches!(message, ClientMessage::GetAddrInfoRequestV2(GetAddrInfoRequestV2 { node, .. }) if node == "test-server"); let address = "1.2.3.4:80".parse::().unwrap(); diff --git a/mirrord/layer/tests/mkdir.rs b/mirrord/layer/tests/mkdir.rs index 5f0fe3301b3..76fa3e8936b 100644 --- a/mirrord/layer/tests/mkdir.rs +++ b/mirrord/layer/tests/mkdir.rs @@ -17,10 +17,10 @@ async fn mkdir(dylib_path: &Path) { .start_process_with_layer(dylib_path, Default::default(), None) .await; - println!("waiting for file request."); + println!("waiting for MakeDirRequest."); intproxy.expect_make_dir("/mkdir_test_path", 0o777).await; - println!("waiting for file request."); + println!("waiting for MakeDirRequest."); intproxy.expect_make_dir("/mkdirat_test_path", 0o777).await; assert_eq!(intproxy.try_recv().await, None); diff --git a/mirrord/layer/tests/rmdir.rs b/mirrord/layer/tests/rmdir.rs new file mode 100644 index 00000000000..910e3d18ff9 --- /dev/null +++ b/mirrord/layer/tests/rmdir.rs @@ -0,0 +1,31 @@ +#![feature(assert_matches)] +use std::{path::Path, time::Duration}; + +use rstest::rstest; + +mod common; +pub use common::*; + +/// Test for the [`libc::rmdir`] function. +#[rstest] +#[tokio::test] +#[timeout(Duration::from_secs(60))] +async fn rmdir(dylib_path: &Path) { + let application = Application::RemoveDir; + + let (mut test_process, mut intproxy) = application + .start_process_with_layer(dylib_path, Default::default(), None) + .await; + + println!("waiting for MakeDirRequest."); + intproxy.expect_make_dir("/test_dir", 0o777).await; + + println!("waiting for RemoveDirRequest."); + intproxy.expect_remove_dir("/test_dir").await; + + assert_eq!(intproxy.try_recv().await, None); + + test_process.wait_assert_success().await; + test_process.assert_no_error_in_stderr().await; + test_process.assert_no_error_in_stdout().await; +} diff --git a/mirrord/layer/tests/statfs_fstatfs.rs b/mirrord/layer/tests/statfs_fstatfs.rs new file mode 100644 index 00000000000..38f48c8495f --- /dev/null +++ b/mirrord/layer/tests/statfs_fstatfs.rs @@ -0,0 +1,45 @@ +#![feature(assert_matches)] +use std::{path::Path, time::Duration}; + +use rstest::rstest; + +mod common; +pub use common::*; + +/// Test for the [`libc::statfs`] and [`libc::fstatfs`] functions. +#[rstest] +#[tokio::test] +#[timeout(Duration::from_secs(60))] +async fn mkdir(dylib_path: &Path) { + let application = Application::StatfsFstatfs; + + let (mut test_process, mut intproxy) = application + .start_process_with_layer(dylib_path, Default::default(), None) + .await; + + println!("waiting for file request (mkdir)."); + intproxy + .expect_make_dir("/statfs_fstatfs_test_path", 0o777) + .await; + + println!("waiting for file request (statfs)."); + intproxy.expect_statfs("/statfs_fstatfs_test_path").await; + + println!("waiting for file request (open)."); + let fd: u64 = 1; + intproxy + .expect_file_open_for_reading("/statfs_fstatfs_test_path", fd) + .await; + + println!("waiting for file request (fstatfs)."); + intproxy.expect_fstatfs(fd).await; + + println!("waiting for file request (close)."); + intproxy.expect_file_close(fd).await; + + assert_eq!(intproxy.try_recv().await, None); + + test_process.wait_assert_success().await; + test_process.assert_no_error_in_stderr().await; + test_process.assert_no_error_in_stdout().await; +} diff --git a/mirrord/operator/src/crd/policy.rs b/mirrord/operator/src/crd/policy.rs index 1ad9447d1e8..cf712606d3c 100644 --- a/mirrord/operator/src/crd/policy.rs +++ b/mirrord/operator/src/crd/policy.rs @@ -58,6 +58,11 @@ pub struct MirrordPolicySpec { /// target. #[serde(default)] pub env: EnvPolicy, + + /// Overrides fs ops behaviour, granting control over them to the operator policy, instead of + /// the user config. + #[serde(default)] + pub fs: FsPolicy, } /// Custom cluster-wide resource for policies that limit what mirrord features users can use. @@ -90,11 +95,16 @@ pub struct MirrordClusterPolicySpec { /// target. #[serde(default)] pub env: EnvPolicy, + + /// Overrides fs ops behaviour, granting control over them to the operator policy, instead of + /// the user config. + #[serde(default)] + pub fs: FsPolicy, } /// Policy for controlling environment variables access from mirrord instances. #[derive(Clone, Default, Debug, Deserialize, Eq, PartialEq, Serialize, JsonSchema)] -#[serde(rename_all = "kebab-case")] +#[serde(rename_all = "camelCase")] pub struct EnvPolicy { /// List of environment variables that should be excluded when using mirrord. /// @@ -104,9 +114,42 @@ pub struct EnvPolicy { /// Variable names can be matched using `*` and `?` where `?` matches exactly one occurrence of /// any character and `*` matches arbitrary many (including zero) occurrences of any character, /// e.g. `DATABASE_*` will match `DATABASE_URL` and `DATABASE_PORT`. + #[serde(default)] pub exclude: HashSet, } +/// File operations policy that mimics the mirrord fs config. +/// +/// Allows the operator control over remote file ops behaviour, overriding what the user has set in +/// their mirrord config file, if it matches something in one of the lists (regex sets) of this +/// struct. +/// +/// If the file path matches regexes in multiple sets, priority is as follows: +/// 1. `local` +/// 2. `notFound` +/// 3. `readOnly` +#[derive(Clone, Default, Debug, Deserialize, Eq, PartialEq, Serialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct FsPolicy { + /// Files that cannot be opened for writing. + /// + /// Opening the file for writing is rejected with an IO error. + #[serde(default)] + pub read_only: HashSet, + + /// Files that cannot be opened at all. + /// + /// Opening the file will be rejected and mirrord will open the file locally instead. + #[serde(default)] + pub local: HashSet, + + /// Files that cannot be opened at all. + /// + /// Opening the file is rejected with an IO error. + #[serde(default)] + pub not_found: HashSet, +} + #[test] fn check_one_api_group() { use kube::Resource; diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index 70f33186ba1..67eab572e62 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mirrord-protocol" -version = "1.13.1" +version = "1.16.0" authors.workspace = true description.workspace = true documentation.workspace = true @@ -18,12 +18,13 @@ workspace = true [dependencies] actix-codec.workspace = true +bincode.workspace = true bytes.workspace = true thiserror.workspace = true +futures.workspace = true hickory-resolver.workspace = true hickory-proto.workspace = true serde.workspace = true -bincode.workspace = true tracing.workspace = true hyper = { workspace = true, features = ["client"] } http-serde = "2" @@ -31,8 +32,6 @@ http-body-util = { workspace = true } fancy-regex = { workspace = true } socket2.workspace = true semver = { workspace = true, features = ["serde"] } -tokio-stream.workspace = true -tokio.workspace = true mirrord-macros = { path = "../macros" } diff --git a/mirrord/protocol/src/batched_body.rs b/mirrord/protocol/src/batched_body.rs new file mode 100644 index 00000000000..9f5780cf495 --- /dev/null +++ b/mirrord/protocol/src/batched_body.rs @@ -0,0 +1,86 @@ +use std::future::Future; + +use futures::FutureExt; +use http_body_util::BodyExt; +use hyper::body::{Body, Frame}; + +/// Utility extension trait for [`Body`]. +/// +/// Contains methods that allow for reading [`Frame`]s in batches. +pub trait BatchedBody: Body { + /// Reads all [`Frame`]s that are available without blocking. + fn ready_frames(&mut self) -> Result, Self::Error>; + + /// Waits for the next [`Frame`] then reads all [`Frame`]s that are available without blocking. + fn next_frames(&mut self) -> impl Future, Self::Error>>; +} + +impl BatchedBody for B +where + B: Body + Unpin, +{ + fn ready_frames(&mut self) -> Result, Self::Error> { + let mut frames = Frames { + frames: vec![], + is_last: false, + }; + extend_with_ready(self, &mut frames)?; + Ok(frames) + } + + async fn next_frames(&mut self) -> Result, Self::Error> { + let mut frames = Frames { + frames: vec![], + is_last: false, + }; + + match self.frame().await { + None => { + frames.is_last = true; + return Ok(frames); + } + Some(result) => { + frames.frames.push(result?); + } + } + + extend_with_ready(self, &mut frames)?; + + Ok(frames) + } +} + +/// Extends the given [`Frames`] instance with [`Frame`]s that are available without blocking. +fn extend_with_ready( + body: &mut B, + frames: &mut Frames, +) -> Result<(), B::Error> { + loop { + match body.frame().now_or_never() { + None => { + frames.is_last = false; + break; + } + Some(None) => { + frames.is_last = true; + break; + } + Some(Some(result)) => { + frames.frames.push(result?); + frames.is_last = false; + } + } + } + + Ok(()) +} + +/// A batch of body [`Frame`]s. +/// +/// `D` parameter determines [`Body::Data`] type. +pub struct Frames { + /// A batch of consecutive [`Frames`]. + pub frames: Vec>, + /// Whether the [`Body`] has finished and this is the last batch. + pub is_last: bool, +} diff --git a/mirrord/protocol/src/body_chunks.rs b/mirrord/protocol/src/body_chunks.rs deleted file mode 100644 index 19a42c78ae4..00000000000 --- a/mirrord/protocol/src/body_chunks.rs +++ /dev/null @@ -1,71 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use hyper::body::{Body, Frame}; - -pub trait BodyExt { - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B>; -} - -impl BodyExt for B -where - B: Body, -{ - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B> { - FramesFut { - body: self, - no_wait, - } - } -} - -pub struct FramesFut<'a, B> { - body: &'a mut B, - no_wait: bool, -} - -impl Future for FramesFut<'_, B> -where - B: Body + Unpin, -{ - type Output = hyper::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut frames = vec![]; - - loop { - let result = match Pin::new(&mut self.as_mut().body).poll_frame(cx) { - Poll::Ready(Some(Err(error))) => Poll::Ready(Err(error)), - Poll::Ready(Some(Ok(frame))) => { - frames.push(frame); - continue; - } - Poll::Ready(None) => Poll::Ready(Ok(Frames { - frames, - is_last: true, - })), - Poll::Pending => { - if frames.is_empty() && !self.no_wait { - Poll::Pending - } else { - Poll::Ready(Ok(Frames { - frames, - is_last: false, - })) - } - } - }; - - break result; - } - } -} - -pub struct Frames { - pub frames: Vec>, - pub is_last: bool, -} diff --git a/mirrord/protocol/src/codec.rs b/mirrord/protocol/src/codec.rs index 4071018fe6a..8e41d9acab1 100644 --- a/mirrord/protocol/src/codec.rs +++ b/mirrord/protocol/src/codec.rs @@ -12,7 +12,7 @@ use mirrord_macros::protocol_break; use semver::VersionReq; use crate::{ - dns::{GetAddrInfoRequest, GetAddrInfoResponse}, + dns::{GetAddrInfoRequest, GetAddrInfoRequestV2, GetAddrInfoResponse}, file::*, outgoing::{ tcp::{DaemonTcpOutgoing, LayerTcpOutgoing}, @@ -24,10 +24,16 @@ use crate::{ ResponseError, }; +/// Minimal mirrord-protocol version that that allows [`LogLevel::Info`]. +pub static INFO_LOG_VERSION: LazyLock = + LazyLock::new(|| ">=1.13.4".parse().expect("Bad Identifier")); + #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone, Copy)] pub enum LogLevel { Warn, Error, + /// Supported from [`INFO_LOG_VERSION`]. + Info, } #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] @@ -85,6 +91,10 @@ pub enum FileRequest { ReadDirBatch(ReadDirBatchRequest), MakeDir(MakeDirRequest), MakeDirAt(MakeDirAtRequest), + RemoveDir(RemoveDirRequest), + Unlink(UnlinkRequest), + UnlinkAt(UnlinkAtRequest), + StatFs(StatFsRequest), } /// Minimal mirrord-protocol version that allows `ClientMessage::ReadyForLogs` message. @@ -95,9 +105,27 @@ pub static CLIENT_READY_FOR_LOGS: LazyLock = #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum ClientMessage { Close, + /// TCP sniffer message. + /// + /// These are the messages used by the `mirror` feature, and handled by the + /// `TcpSnifferApi` in the agent. Tcp(LayerTcp), + + /// TCP stealer message. + /// + /// These are the messages used by the `steal` feature, and handled by the `TcpStealerApi` in + /// the agent. TcpSteal(LayerTcpSteal), + /// TCP outgoing message. + /// + /// These are the messages used by the `outgoing` feature (tcp), and handled by the + /// `TcpOutgoingApi` in the agent. TcpOutgoing(LayerTcpOutgoing), + + /// UDP outgoing message. + /// + /// These are the messages used by the `outgoing` feature (udp), and handled by the + /// `UdpOutgoingApi` in the agent. UdpOutgoing(LayerUdpOutgoing), FileRequest(FileRequest), GetEnvVarsRequest(GetEnvVarsRequest), @@ -108,6 +136,7 @@ pub enum ClientMessage { SwitchProtocolVersion(#[bincode(with_serde)] semver::Version), ReadyForLogs, Vpn(ClientVpn), + GetAddrInfoRequestV2(GetAddrInfoRequestV2), } /// Type alias for `Result`s that should be returned from mirrord-agent to mirrord-layer. @@ -130,6 +159,8 @@ pub enum FileResponse { ReadLink(RemoteResult), ReadDirBatch(RemoteResult), MakeDir(RemoteResult<()>), + RemoveDir(RemoteResult<()>), + Unlink(RemoteResult<()>), } /// `-agent` --> `-layer` messages. diff --git a/mirrord/protocol/src/dns.rs b/mirrord/protocol/src/dns.rs index 855f52b8af5..5958376cd94 100644 --- a/mirrord/protocol/src/dns.rs +++ b/mirrord/protocol/src/dns.rs @@ -1,12 +1,17 @@ extern crate alloc; use core::ops::Deref; -use std::net::IpAddr; +use std::{net::IpAddr, sync::LazyLock}; use bincode::{Decode, Encode}; use hickory_resolver::{lookup_ip::LookupIp, proto::rr::resource::RecordParts}; +use semver::VersionReq; use crate::RemoteResult; +/// Minimal mirrord-protocol version that allows [`GetAddrInfoRequestV2`]. +pub static ADDRINFO_V2_VERSION: LazyLock = + LazyLock::new(|| ">=1.15.0".parse().expect("Bad Identifier")); + #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct LookupRecord { pub name: String, @@ -73,3 +78,89 @@ impl Deref for GetAddrInfoResponse { pub struct GetAddrInfoRequest { pub node: String, } + +/// For when the new request is not supported, and we have to fall back to the old version. +impl From for GetAddrInfoRequest { + fn from(value: GetAddrInfoRequestV2) -> Self { + Self { node: value.node } + } +} + +#[derive( + serde::Serialize, serde::Deserialize, Encode, Decode, Debug, PartialEq, Eq, Copy, Clone, +)] +pub enum AddressFamily { + Ipv4Only, + Ipv6Only, + Both, + Any, + /// If we add a variant and a new client sends an old agent the new variant, the agent will see + /// this variant. + #[serde(other, skip_serializing)] + UnknownAddressFamilyFromNewerClient, +} + +#[derive(thiserror::Error, Debug)] +pub enum AddressFamilyError { + #[error( + "The agent received a GetAddrInfoRequestV2 with an address family that is not yet known \ + to this version of the agent." + )] + UnsupportedFamily, +} + +impl TryFrom for hickory_resolver::config::LookupIpStrategy { + type Error = AddressFamilyError; + + fn try_from(value: AddressFamily) -> Result { + match value { + AddressFamily::Ipv4Only => Ok(Self::Ipv4Only), + AddressFamily::Ipv6Only => Ok(Self::Ipv6Only), + AddressFamily::Both => Ok(Self::Ipv4AndIpv6), + AddressFamily::Any => Ok(Self::Ipv4thenIpv6), + AddressFamily::UnknownAddressFamilyFromNewerClient => { + Err(AddressFamilyError::UnsupportedFamily) + } + } + } +} + +#[derive(serde::Serialize, serde::Deserialize, Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub enum SockType { + Stream, + Dgram, + Any, + /// If we add a variant and a new client sends an old agent the new variant, the agent will see + /// this variant. + #[serde(other, skip_serializing)] + UnknownSockTypeFromNewerClient, +} + +/// Newer, advanced version of [`GetAddrInfoRequest`] +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct GetAddrInfoRequestV2 { + pub node: String, + /// Currently not respected by the agent, there for future use. + pub service_port: u16, + pub family: AddressFamily, + pub socktype: SockType, + /// Including these fields so we can use them in the future without introducing a new request + /// type. But note that the constants are different on macOS and Linux so they should be + /// converted to the linux values first (on the client, because the agent does not know the + /// client is macOS). + pub flags: i32, + pub protocol: i32, +} + +impl From for GetAddrInfoRequestV2 { + fn from(value: GetAddrInfoRequest) -> Self { + Self { + node: value.node, + service_port: 0, + flags: 0, + family: AddressFamily::Ipv4Only, + socktype: SockType::Any, + protocol: 0, + } + } +} diff --git a/mirrord/protocol/src/error.rs b/mirrord/protocol/src/error.rs index 20ac38a149d..67197d76843 100644 --- a/mirrord/protocol/src/error.rs +++ b/mirrord/protocol/src/error.rs @@ -44,7 +44,7 @@ pub enum ResponseError { #[error("Remote operation expected fd `{0}` to be a file, but it's a directory!")] NotFile(u64), - #[error("IO failed for remote operation with `{0}!")] + #[error("IO failed for remote operation: `{0}!")] RemoteIO(RemoteIOError), #[error(transparent)] @@ -67,6 +67,9 @@ pub enum ResponseError { #[error("Failed stripping path with `{0}`!")] StripPrefix(String), + + #[error("File has to be opened locally!")] + OpenLocal, } impl From for ResponseError { @@ -153,13 +156,26 @@ impl From for RemoteError { /// Our internal version of Rust's `std::io::Error` that can be passed between mirrord-layer and /// mirrord-agent. -#[derive(Encode, Decode, Debug, PartialEq, Clone, Eq, Error)] -#[error("Failed performing `getaddrinfo` with {raw_os_error:?} and kind {kind:?}!")] +/// +/// ### `Display` +/// +/// We manually implement `Display` as this error is mostly seen from a [`ResponseError`] context. +#[derive(Encode, Decode, Debug, PartialEq, Clone, Eq)] pub struct RemoteIOError { pub raw_os_error: Option, pub kind: ErrorKindInternal, } +impl core::fmt::Display for RemoteIOError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.kind)?; + if let Some(code) = self.raw_os_error { + write!(f, " (error code {code})")?; + } + Ok(()) + } +} + /// Our internal version of Rust's `std::io::Error` that can be passed between mirrord-layer and /// mirrord-agent. /// diff --git a/mirrord/protocol/src/file.rs b/mirrord/protocol/src/file.rs index b2e8e3773f8..4aa25069bb3 100644 --- a/mirrord/protocol/src/file.rs +++ b/mirrord/protocol/src/file.rs @@ -22,9 +22,21 @@ pub static READLINK_VERSION: LazyLock = pub static READDIR_BATCH_VERSION: LazyLock = LazyLock::new(|| ">=1.9.0".parse().expect("Bad Identifier")); +/// Minimal mirrord-protocol version that allows [`MakeDirRequest`] and [`MakeDirAtRequest`]. pub static MKDIR_VERSION: LazyLock = LazyLock::new(|| ">=1.13.0".parse().expect("Bad Identifier")); +/// Minimal mirrord-protocol version that allows [`RemoveDirRequest`], [`UnlinkRequest`] and +/// [`UnlinkAtRequest`].. +pub static RMDIR_VERSION: LazyLock = + LazyLock::new(|| ">=1.14.0".parse().expect("Bad Identifier")); + +pub static OPEN_LOCAL_VERSION: LazyLock = + LazyLock::new(|| ">=1.13.3".parse().expect("Bad Identifier")); + +pub static STATFS_VERSION: LazyLock = + LazyLock::new(|| ">=1.16.0".parse().expect("Bad Identifier")); + /// Internal version of Metadata across operating system (macOS, Linux) /// Only mutual attributes #[derive(Encode, Decode, Debug, PartialEq, Clone, Copy, Eq, Default)] @@ -282,6 +294,23 @@ pub struct MakeDirAtRequest { pub mode: u32, } +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct RemoveDirRequest { + pub pathname: PathBuf, +} + +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct UnlinkRequest { + pub pathname: PathBuf, +} + +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct UnlinkAtRequest { + pub dirfd: Option, + pub pathname: PathBuf, + pub flags: u32, +} + #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct ReadLimitedFileRequest { pub remote_fd: u64, @@ -387,6 +416,11 @@ pub struct XstatFsRequest { pub fd: u64, } +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct StatFsRequest { + pub path: PathBuf, +} + #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct XstatResponse { pub metadata: MetadataInternal, diff --git a/mirrord/protocol/src/lib.rs b/mirrord/protocol/src/lib.rs index ed719c5a9a4..f1a3cc1e5cc 100644 --- a/mirrord/protocol/src/lib.rs +++ b/mirrord/protocol/src/lib.rs @@ -3,7 +3,7 @@ #![warn(clippy::indexing_slicing)] #![deny(unused_crate_dependencies)] -pub mod body_chunks; +pub mod batched_body; pub mod codec; pub mod dns; pub mod error; @@ -111,3 +111,7 @@ impl FromStr for MeshVendor { pub const AGENT_OPERATOR_CERT_ENV: &str = "MIRRORD_AGENT_OPERATOR_CERT"; pub const AGENT_NETWORK_INTERFACE_ENV: &str = "MIRRORD_AGENT_INTERFACE"; + +pub const AGENT_METRICS_ENV: &str = "MIRRORD_AGENT_METRICS"; + +pub const AGENT_IPV6_ENV: &str = "MIRRORD_AGENT_SUPPORT_IPV6"; diff --git a/mirrord/protocol/src/outgoing/tcp.rs b/mirrord/protocol/src/outgoing/tcp.rs index e38fa0c44d0..877e0d2f6c0 100644 --- a/mirrord/protocol/src/outgoing/tcp.rs +++ b/mirrord/protocol/src/outgoing/tcp.rs @@ -3,14 +3,43 @@ use crate::RemoteResult; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerTcpOutgoing { + /// User is interested in connecting via tcp to some remote address, specified in + /// [`LayerConnect`]. + /// + /// The layer will get a mirrord managed address that it'll `connect` to, meanwhile + /// in the agent we `connect` to the actual remote address. Connect(LayerConnect), + + /// Write data to the remote address the agent is `connect`ed to. + /// + /// There's no `Read` message, as we're calling `read` in the agent, and we send + /// a [`DaemonTcpOutgoing::Read`] message in case we get some data from this connection. Write(LayerWrite), + + /// The layer closed the connection, this message syncs up the agent, closing it + /// over there as well. + /// + /// Connections in the agent may be closed in other ways, such as when an error happens + /// when reading or writing. Which means that this message is not the only way of + /// closing outgoing tcp connections. Close(LayerClose), } #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum DaemonTcpOutgoing { + /// The agent attempted a connection to the remote address specified by + /// [`LayerTcpOutgoing::Connect`], and it might've been successful or not. Connect(RemoteResult), + + /// Read data from the connection. + /// + /// There's no `Write` message, as `write`s come from the user (layer). The agent sending + /// a `write` to the layer like this would make no sense, since it could just `write` it + /// to the remote connection itself. Read(RemoteResult), + + /// Tell the layer that this connection has been `close`d, either by a request from + /// the user with [`LayerTcpOutgoing::Close`], or from some error in the agent when + /// writing or reading from the connection. Close(ConnectionId), } diff --git a/mirrord/protocol/src/outgoing/udp.rs b/mirrord/protocol/src/outgoing/udp.rs index 02b4d97f830..f58378beeea 100644 --- a/mirrord/protocol/src/outgoing/udp.rs +++ b/mirrord/protocol/src/outgoing/udp.rs @@ -3,14 +3,50 @@ use crate::RemoteResult; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerUdpOutgoing { + /// User is interested in connecting via udp to some remote address, specified in + /// [`LayerConnect`]. + /// + /// The layer will get a mirrord managed address that it'll `connect` to, meanwhile + /// in the agent we `connect` to the actual remote address. + /// + /// Saying that we have an _udp connection_ is a bit weird, considering it's a + /// _connectionless_ protocol, but in mirrord we use a _fakeish_ connection mechanism + /// when dealing with outgoing udp traffic. Connect(LayerConnect), + + /// Write data to the remote address the agent is `connect`ed to. + /// + /// There's no `Read` message, as we're calling `read` in the agent, and we send + /// a [`DaemonUdpOutgoing::Read`] message in case we get some data from this connection. Write(LayerWrite), + + /// The layer closed the connection, this message syncs up the agent, closing it + /// over there as well. + /// + /// Connections in the agent may be closed in other ways, such as when an error happens + /// when reading or writing. Which means that this message is not the only way of + /// closing outgoing udp connections. Close(LayerClose), } #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum DaemonUdpOutgoing { + /// The agent attempted a connection to the remote address specified by + /// [`LayerUdpOutgoing::Connect`], and it might've been successful or not. + /// + /// See the docs for [`LayerUdpOutgoing::Connect`] for a bit more information on the + /// weird idea of `connect` and udp in mirrord. Connect(RemoteResult), + + /// Read data from the connection. + /// + /// There's no `Write` message, as `write`s come from the user (layer). The agent sending + /// a `write` to the layer like this would make no sense, since it could just `write` it + /// to the remote connection itself. Read(RemoteResult), + + /// Tell the layer that this connection has been `close`d, either by a request from + /// the user with [`LayerUdpOutgoing::Close`], or from some error in the agent when + /// writing or reading from the connection. Close(ConnectionId), } diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index 023369129ad..e98077a62ec 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -5,27 +5,22 @@ use std::{ fmt, net::IpAddr, pin::Pin, - sync::{Arc, LazyLock, Mutex}, + sync::LazyLock, task::{Context, Poll}, }; use bincode::{Decode, Encode}; use bytes::Bytes; -use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; +use http_body_util::BodyExt; use hyper::{ - body::{Body, Frame, Incoming}, - http, - http::response::Parts, + body::{Body, Frame}, HeaderMap, Method, Request, Response, StatusCode, Uri, Version, }; use mirrord_macros::protocol_break; use semver::VersionReq; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc::Receiver; -use tokio_stream::wrappers::ReceiverStream; -use tracing::{error, Level}; -use crate::{body_chunks::BodyExt as _, ConnectionId, Port, RemoteResult, RequestId}; +use crate::{ConnectionId, Port, RemoteResult, RequestId}; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct NewTcpConnection { @@ -57,14 +52,31 @@ pub struct TcpClose { } /// Messages related to Tcp handler from client. +/// +/// Part of the `mirror` feature. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerTcp { + /// User is interested in mirroring traffic on this `Port`, so add it to the list of + /// ports that the sniffer is filtering. PortSubscribe(Port), + + /// User is not interested in the connection with `ConnectionId` anymore. + /// + /// This means that their app has closed the connection they were `listen`ning on. + /// + /// There is no `ConnectionSubscribe` counter-part of this variant, the subscription + /// happens when the sniffer receives an (agent) internal `SniffedConnection`. ConnectionUnsubscribe(ConnectionId), + + /// Removes this `Port` from the sniffer's filter, the traffic won't be cloned to mirrord + /// anymore. PortUnsubscribe(Port), } /// Messages related to Tcp handler from server. +/// +/// They are the same for both `steal` and `mirror` modes, even though their layer +/// counterparts ([`LayerTcpSteal`] and [`LayerTcp`]) are different. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum DaemonTcp { NewConnection(NewTcpConnection), @@ -120,7 +132,7 @@ pub struct Filter(String); impl Filter { pub fn new(filter_str: String) -> Result> { let _ = fancy_regex::Regex::new(&filter_str).inspect_err(|fail| { - error!( + tracing::error!( r" Something went wrong while creating a regex for [{filter_str:#?}]! @@ -219,10 +231,38 @@ impl StealType { } /// Messages related to Steal Tcp handler from client. +/// +/// `PortSubscribe`, `PortUnsubscribe`, and `ConnectionUnsubscribe` variants are similar +/// to what you'll find in the [`LayerTcp`], but they're handled by different tasks in +/// the agent. +/// +/// Stolen traffic might have an additional overhead when compared to mirrored traffic, as +/// we have an intermmediate HTTP server to handle filtering (based on HTTP headers, etc). #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum LayerTcpSteal { + /// User is interested in stealing traffic on this `Port`, so add it to the list of + /// ports that the stealer is filtering. + /// + /// The `TcpConnectionStealer` supports an [`HttpFilter`] granting the ability to steal + /// only traffic that matches the user configured filter. It's also possible to just steal + /// all traffic (which we refer as `Unfiltered`). For more info see [`StealType`]. + /// + /// This variant is somewhat related to [`LayerTcpSteal::ConnectionUnsubscribe`], since + /// we don't have a `ConnectionSubscribe` message anywhere, instead what we do is: when + /// a new connection comes in one of the ports we are subscribed to, we consider it a + /// connection subscription (so this mechanism represents the **non-existing** + /// `ConnectionSubscribe` variant). PortSubscribe(StealType), + + /// User has stopped stealing from this connection with [`ConnectionId`]. + /// + /// We do **not** have a `ConnectionSubscribe` variant/message. What happens instead is that we + /// call a _connection subscription_ the act of `accept`ing a new connection on one of the + /// ports we are subscribed to. See the [`LayerTcpSteal::PortSubscribe`] for more info. ConnectionUnsubscribe(ConnectionId), + + /// Removes this `Port` from the stealers's filter, the traffic won't be stolen by mirrord + /// anymore. PortUnsubscribe(Port), Data(TcpData), HttpResponse(HttpResponse>), @@ -239,7 +279,7 @@ pub enum ChunkedResponse { /// (De-)Serializable HTTP request. #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] -pub struct InternalHttpRequest { +pub struct InternalHttpRequest { #[serde(with = "http_serde::method")] pub method: Method, @@ -252,60 +292,34 @@ pub struct InternalHttpRequest { #[serde(with = "http_serde::version")] pub version: Version, - pub body: Body, + pub body: B, } -impl From> for Request> -where - E: From, -{ - fn from(value: InternalHttpRequest) -> Self { +impl InternalHttpRequest { + pub fn map_body(self, cb: F) -> InternalHttpRequest + where + F: FnOnce(B) -> T, + { let InternalHttpRequest { + version, + headers, method, uri, - headers, - version, body, - } = value; - let mut request = Request::new(BoxBody::new(body.map_err(|e| e.into()))); - *request.method_mut() = method; - *request.uri_mut() = uri; - *request.version_mut() = version; - *request.headers_mut() = headers; - - request - } -} + } = self; -impl From>> for Request> -where - E: From, -{ - fn from(value: InternalHttpRequest>) -> Self { - let InternalHttpRequest { + InternalHttpRequest { + version, + headers, method, uri, - headers, - version, - body, - } = value; - let mut request = Request::new(BoxBody::new( - Full::new(Bytes::from(body)).map_err(|e| e.into()), - )); - *request.method_mut() = method; - *request.uri_mut() = uri; - *request.version_mut() = version; - *request.headers_mut() = headers; - - request + body: cb(body), + } } } -impl From> for Request> -where - E: From, -{ - fn from(value: InternalHttpRequest) -> Self { +impl From> for Request { + fn from(value: InternalHttpRequest) -> Self { let InternalHttpRequest { method, uri, @@ -313,7 +327,8 @@ where version, body, } = value; - let mut request = Request::new(BoxBody::new(body.map_err(|e| e.into()))); + + let mut request = Request::new(body); *request.method_mut() = method; *request.uri_mut() = uri; *request.version_mut() = version; @@ -323,118 +338,6 @@ where } } -#[derive(Clone, Debug)] -pub enum HttpRequestFallback { - Framed(HttpRequest), - Fallback(HttpRequest>), - Streamed { - request: HttpRequest, - retries: u32, - }, -} - -#[derive(Debug)] -pub struct StreamingBody { - /// Shared with instances acquired via [`Clone`]. - /// Allows the clones to receive a copy of the data. - origin: Arc, Vec)>>, - /// Index of the next frame to return from the buffer. - /// If outside of the buffer, we need to poll the stream to get the next frame. - /// Local state of this instance, zeroed when cloning. - idx: usize, -} - -impl StreamingBody { - pub fn new(rx: Receiver) -> Self { - Self { - origin: Arc::new(Mutex::new((rx, vec![]))), - idx: 0, - } - } -} - -impl Clone for StreamingBody { - fn clone(&self) -> Self { - Self { - origin: self.origin.clone(), - idx: 0, - } - } -} - -impl Body for StreamingBody { - type Data = Bytes; - - type Error = Infallible; - - fn poll_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - let this = self.get_mut(); - let mut guard = this.origin.lock().unwrap(); - - if let Some(frame) = guard.1.get(this.idx) { - this.idx += 1; - return Poll::Ready(Some(Ok(frame.clone().into()))); - } - - match std::task::ready!(guard.0.poll_recv(cx)) { - None => Poll::Ready(None), - Some(frame) => { - guard.1.push(frame.clone()); - this.idx += 1; - Poll::Ready(Some(Ok(frame.into()))) - } - } - } -} - -impl HttpRequestFallback { - pub fn connection_id(&self) -> ConnectionId { - match self { - HttpRequestFallback::Framed(req) => req.connection_id, - HttpRequestFallback::Fallback(req) => req.connection_id, - HttpRequestFallback::Streamed { request: req, .. } => req.connection_id, - } - } - - pub fn port(&self) -> Port { - match self { - HttpRequestFallback::Framed(req) => req.port, - HttpRequestFallback::Fallback(req) => req.port, - HttpRequestFallback::Streamed { request: req, .. } => req.port, - } - } - - pub fn request_id(&self) -> RequestId { - match self { - HttpRequestFallback::Framed(req) => req.request_id, - HttpRequestFallback::Fallback(req) => req.request_id, - HttpRequestFallback::Streamed { request: req, .. } => req.request_id, - } - } - - pub fn version(&self) -> Version { - match self { - HttpRequestFallback::Framed(req) => req.version(), - HttpRequestFallback::Fallback(req) => req.version(), - HttpRequestFallback::Streamed { request: req, .. } => req.version(), - } - } - - pub fn into_hyper(self) -> Request> - where - E: From, - { - match self { - HttpRequestFallback::Framed(req) => req.internal_request.into(), - HttpRequestFallback::Fallback(req) => req.internal_request.into(), - HttpRequestFallback::Streamed { request: req, .. } => req.internal_request.into(), - } - } -} - /// Minimal mirrord-protocol version that allows [`DaemonTcp::HttpRequestFramed`] and /// [`LayerTcpSteal::HttpResponseFramed`]. pub static HTTP_FRAMED_VERSION: LazyLock = @@ -478,6 +381,24 @@ impl HttpRequest { pub fn version(&self) -> Version { self.internal_request.version } + + pub fn map_body(self, map: F) -> HttpRequest + where + F: FnOnce(B) -> T, + { + HttpRequest { + connection_id: self.connection_id, + request_id: self.request_id, + port: self.port, + internal_request: InternalHttpRequest { + method: self.internal_request.method, + uri: self.internal_request.uri, + headers: self.internal_request.headers, + version: self.internal_request.version, + body: map(self.internal_request.body), + }, + } + } } /// (De-)Serializable HTTP response. @@ -516,16 +437,28 @@ impl InternalHttpResponse { } } -#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Eq, Clone)] -pub struct InternalHttpBody(VecDeque); +impl From> for Response { + fn from(value: InternalHttpResponse) -> Self { + let InternalHttpResponse { + status, + version, + headers, + body, + } = value; -impl InternalHttpBody { - pub fn from_bytes(bytes: &[u8]) -> Self { - InternalHttpBody(VecDeque::from([InternalHttpBodyFrame::Data( - bytes.to_vec(), - )])) + let mut response = Response::new(body); + *response.status_mut() = status; + *response.version_mut() = version; + *response.headers_mut() = headers; + + response } +} + +#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Eq, Clone)] +pub struct InternalHttpBody(pub VecDeque); +impl InternalHttpBody { pub async fn from_body(mut body: B) -> Result where B: Body + Unpin, @@ -565,15 +498,11 @@ pub enum InternalHttpBodyFrame { impl From> for InternalHttpBodyFrame { fn from(frame: Frame) -> Self { - if frame.is_data() { - InternalHttpBodyFrame::Data(frame.into_data().expect("Malfromed data frame").to_vec()) - } else if frame.is_trailers() { - InternalHttpBodyFrame::Trailers( - frame.into_trailers().expect("Malfromed trailers frame"), - ) - } else { - panic!("Malfromed frame type") - } + frame + .into_data() + .map(|bytes| Self::Data(bytes.into())) + .or_else(|frame| frame.into_trailers().map(Self::Trailers)) + .expect("malformed frame type") } } @@ -591,463 +520,28 @@ impl fmt::Debug for InternalHttpBodyFrame { } } -pub type ReceiverStreamBody = StreamBody>>>; - -#[derive(Debug)] -pub enum HttpResponseFallback { - Framed(HttpResponse), - Fallback(HttpResponse>), - - /// Holds the [`HttpResponse`] that we're supposed to send back to the agent. - /// - /// It also holds the original http request [`HttpRequestFallback`], so we can retry - /// if our hyper server sent us a - /// [`RST_STREAM`](https://docs.rs/h2/latest/h2/struct.Error.html#method.is_reset). - Streamed( - HttpResponse, - Option, - ), -} - -impl HttpResponseFallback { - pub fn connection_id(&self) -> ConnectionId { - match self { - HttpResponseFallback::Framed(req) => req.connection_id, - HttpResponseFallback::Fallback(req) => req.connection_id, - HttpResponseFallback::Streamed(req, _) => req.connection_id, - } - } - - pub fn request_id(&self) -> RequestId { - match self { - HttpResponseFallback::Framed(req) => req.request_id, - HttpResponseFallback::Fallback(req) => req.request_id, - HttpResponseFallback::Streamed(req, _) => req.request_id, - } - } - - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub fn into_hyper(self) -> Result>, http::Error> - where - E: From, - { - match self { - HttpResponseFallback::Framed(req) => req.internal_response.try_into(), - HttpResponseFallback::Fallback(req) => req.internal_response.try_into(), - HttpResponseFallback::Streamed(req, _) => req.internal_response.try_into(), - } - } - - /// Produces an [`HttpResponseFallback`] to the given [`HttpRequestFallback`]. - /// - /// # Note on picking response variant - /// - /// Variant of returned [`HttpResponseFallback`] is picked based on the variant of given - /// [`HttpRequestFallback`] and agent protocol version. We need to consider both due - /// to: - /// 1. Old agent versions always responding with client's `mirrord_protocol` version to - /// [`ClientMessage::SwitchProtocolVersion`](super::ClientMessage::SwitchProtocolVersion), - /// 2. [`LayerTcpSteal::HttpResponseChunked`] being introduced after - /// [`DaemonTcp::HttpRequestChunked`]. - pub fn response_from_request( - request: HttpRequestFallback, - status: StatusCode, - message: &str, - agent_protocol_version: Option<&semver::Version>, - ) -> Self { - let agent_supports_streaming_response = agent_protocol_version - .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) - .unwrap_or(false); - - match request.clone() { - // We received `DaemonTcp::HttpRequestFramed` from the agent, - // so we know it supports `LayerTcpSteal::HttpResponseFramed` (both were introduced in - // the same `mirrord_protocol` version). - HttpRequestFallback::Framed(request) => HttpResponseFallback::Framed( - HttpResponse::::response_from_request(request, status, message), - ), - - // We received `DaemonTcp::HttpRequest` from the agent, so we assume it only supports - // `LayerTcpSteal::HttpResponse`. - HttpRequestFallback::Fallback(request) => HttpResponseFallback::Fallback( - HttpResponse::>::response_from_request(request, status, message), - ), - - // We received `DaemonTcp::HttpRequestChunked` and the agent supports - // `LayerTcpSteal::HttpResponseChunked`. - HttpRequestFallback::Streamed { - request: streamed_request, - .. - } if agent_supports_streaming_response => HttpResponseFallback::Streamed( - HttpResponse::::response_from_request( - streamed_request, - status, - message, - ), - Some(request), - ), - - // We received `DaemonTcp::HttpRequestChunked` from the agent, - // but the agent does not support `LayerTcpSteal::HttpResponseChunked`. - // However, it must support the older `LayerTcpSteal::HttpResponseFramed` - // variant (was introduced before `DaemonTcp::HttpRequestChunked`). - HttpRequestFallback::Streamed { request, .. } => HttpResponseFallback::Framed( - HttpResponse::::response_from_request(request, status, message), - ), - } - } -} - #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] -#[bincode(bounds = "for<'de> Body: Serialize + Deserialize<'de>")] -pub struct HttpResponse { +#[bincode(bounds = "for<'de> B: Serialize + Deserialize<'de>")] +pub struct HttpResponse { /// This is used to make sure the response is sent in its turn, after responses to all earlier /// requests were already sent. pub port: Port, pub connection_id: ConnectionId, pub request_id: RequestId, #[bincode(with_serde)] - pub internal_response: InternalHttpResponse, -} - -impl HttpResponse { - /// We cannot implement this with the [`From`] trait as it doesn't support `async` conversions, - /// and we also need some extra parameters. - /// - /// So this is our alternative implementation to `From>`. - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub async fn from_hyper_response( - response: Response, - port: Port, - connection_id: ConnectionId, - request_id: RequestId, - ) -> Result, hyper::Error> { - let ( - Parts { - status, - version, - headers, - .. - }, - body, - ) = response.into_parts(); - - let body = InternalHttpBody::from_body(body).await?; - - let internal_response = InternalHttpResponse { - status, - headers, - version, - body, - }; - - Ok(HttpResponse { - request_id, - port, - connection_id, - internal_response, - }) - } - - pub fn response_from_request( - request: HttpRequest, - status: StatusCode, - message: &str, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - let body = InternalHttpBody::from_bytes( - format!( - "{} {}\n{}\n", - status.as_str(), - status.canonical_reason().unwrap_or_default(), - message - ) - .as_bytes(), - ); - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body, - }, - } - } - - pub fn empty_response_from_request( - request: HttpRequest, - status: StatusCode, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body: Default::default(), - }, - } - } + pub internal_response: InternalHttpResponse, } -impl HttpResponse> { - /// We cannot implement this with the [`From`] trait as it doesn't support `async` conversions, - /// and we also need some extra parameters. - /// - /// So this is our alternative implementation to `From>`. - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub async fn from_hyper_response( - response: Response, - port: Port, - connection_id: ConnectionId, - request_id: RequestId, - ) -> Result>, hyper::Error> { - let ( - Parts { - status, - version, - headers, - .. - }, - body, - ) = response.into_parts(); - - let body = body.collect().await?.to_bytes().to_vec(); - - let internal_response = InternalHttpResponse { - status, - headers, - version, - body, - }; - - Ok(HttpResponse { - request_id, - port, - connection_id, - internal_response, - }) - } - - pub fn response_from_request( - request: HttpRequest>, - status: StatusCode, - message: &str, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - let body = format!( - "{} {}\n{}\n", - status.as_str(), - status.canonical_reason().unwrap_or_default(), - message - ) - .into_bytes(); - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body, - }, - } - } - - pub fn empty_response_from_request(request: HttpRequest>, status: StatusCode) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body: Default::default(), - }, - } - } -} - -impl HttpResponse { - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub async fn from_hyper_response( - response: Response, - port: Port, - connection_id: ConnectionId, - request_id: RequestId, - ) -> Result, hyper::Error> { - let ( - Parts { - status, - version, - headers, - .. - }, - mut body, - ) = response.into_parts(); - - let frames = body.next_frames(true).await?; - let (tx, rx) = tokio::sync::mpsc::channel(frames.frames.len().max(12)); - for frame in frames.frames { - tx.try_send(Ok(frame)) - .expect("Channel is open, capacity sufficient") - } - if !frames.is_last { - tokio::spawn(async move { - while let Some(frame) = body.frame().await { - if tx.send(frame).await.is_err() { - return; - } - } - }); - }; - - let body = StreamBody::new(ReceiverStream::from(rx)); - - let internal_response = InternalHttpResponse { - status, - headers, - version, - body, - }; - - Ok(HttpResponse { - request_id, - port, - connection_id, - internal_response, - }) - } - - #[tracing::instrument(level = Level::TRACE, ret)] - pub fn response_from_request( - request: HttpRequest, - status: StatusCode, - message: &str, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - let (tx, rx) = tokio::sync::mpsc::channel(1); - let frame = Frame::data(Bytes::copy_from_slice(message.as_bytes())); - tx.try_send(Ok(frame)) - .expect("channel is open, capacity is sufficient"); - let body = StreamBody::new(ReceiverStream::new(rx)); - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body, - }, - } - } -} - -impl TryFrom> for Response> { - type Error = http::Error; - - fn try_from(value: InternalHttpResponse) -> Result { - let InternalHttpResponse { - status, - version, - headers, - body, - } = value; - - let mut builder = Response::builder().status(status).version(version); - if let Some(h) = builder.headers_mut() { - *h = headers; - } - - builder.body(BoxBody::new(body.map_err(|_| unreachable!()))) - } -} - -impl TryFrom>> for Response> { - type Error = http::Error; - - fn try_from(value: InternalHttpResponse>) -> Result { - let InternalHttpResponse { - status, - version, - headers, - body, - } = value; - - let mut builder = Response::builder().status(status).version(version); - if let Some(h) = builder.headers_mut() { - *h = headers; - } - - builder.body(BoxBody::new( - Full::new(Bytes::from(body)).map_err(|_| unreachable!()), - )) - } -} - -impl TryFrom> for Response> -where - E: From, -{ - type Error = http::Error; - - fn try_from(value: InternalHttpResponse) -> Result { - let InternalHttpResponse { - status, - version, - headers, - body, - } = value; - - let mut builder = Response::builder().status(status).version(version); - if let Some(h) = builder.headers_mut() { - *h = headers; +impl HttpResponse { + pub fn map_body(self, cb: F) -> HttpResponse + where + F: FnOnce(B) -> T, + { + HttpResponse { + connection_id: self.connection_id, + request_id: self.request_id, + port: self.port, + internal_response: self.internal_response.map_body(cb), } - - builder.body(BoxBody::new(body.map_err(|e| e.into()))) } } diff --git a/mirrord/vpn/src/agent.rs b/mirrord/vpn/src/agent.rs index 5b7c6cb5f9e..3f145c421a0 100644 --- a/mirrord/vpn/src/agent.rs +++ b/mirrord/vpn/src/agent.rs @@ -136,6 +136,9 @@ impl Stream for VpnAgent { LogLevel::Warn => { tracing::warn!(message = %message.message, "agent sent warn message") } + LogLevel::Info => { + tracing::info!(message = %message.message, "agent sent info message") + } } self.poll_next(cx) diff --git a/tests/go-e2e-dir/main.go b/tests/go-e2e-dir/main.go index b608f01b53b..4a505a4dcc8 100644 --- a/tests/go-e2e-dir/main.go +++ b/tests/go-e2e-dir/main.go @@ -15,13 +15,20 @@ func main() { os.Exit(-1) } fmt.Printf("DirEntries: %s\n", dir) + // `os.ReadDir` does not include `.` and `..`. - if len(dir) != 2 { + if len(dir) < 2 { os.Exit(-1) } - // `os.ReadDir` sorts the result by file name. - if dir[0].Name() != "app.py" || dir[1].Name() != "test.txt" { - os.Exit(-1) + + // Iterate over the files in this dir, exiting if it's not an expected file name. + for i := 0; i < len(dir); i++ { + dirName := dir[i].Name() + + if dirName != "app.py" && dirName != "test.txt" && dirName != "file.local" && dirName != "file.not-found" && dirName != "file.read-only" && dirName != "file.read-write" { + os.Exit(-1) + } + } err = os.Mkdir("/app/test_mkdir", 0755) @@ -30,6 +37,12 @@ func main() { os.Exit(-1) } + err = os.Remove("/app/test_mkdir") + if err != nil { + fmt.Printf("Rmdir error: %s\n", err) + os.Exit(-1) + } + // let close requests be sent for test time.Sleep(1 * time.Second) os.Exit(0) diff --git a/tests/ipv6-app.yaml b/tests/ipv6-app.yaml new file mode 100644 index 00000000000..9a044cbde4d --- /dev/null +++ b/tests/ipv6-app.yaml @@ -0,0 +1,49 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: py-serv-deployment + labels: + app: py-serv +spec: + replicas: 1 + selector: + matchLabels: + app: py-serv + template: + metadata: + labels: + app: py-serv + spec: + containers: + - name: py-serv + image: ghcr.io/metalbear-co/mirrord-pytest:latest + ports: + - containerPort: 80 + env: + - name: MIRRORD_FAKE_VAR_FIRST + value: mirrord.is.running + - name: MIRRORD_FAKE_VAR_SECOND + value: "7777" + - name: HOST + value: "::" + +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app: py-serv + name: py-serv +spec: + ipFamilyPolicy: SingleStack + ipFamilies: + - IPv6 + ports: + - port: 80 + protocol: TCP + targetPort: 80 + nodePort: 30000 + selector: + app: py-serv + sessionAffinity: None + type: NodePort diff --git a/tests/kind-cluster-ipv6-config.yaml b/tests/kind-cluster-ipv6-config.yaml new file mode 100644 index 00000000000..29898284d86 --- /dev/null +++ b/tests/kind-cluster-ipv6-config.yaml @@ -0,0 +1,9 @@ +kind: Cluster +apiVersion: kind.x-k8s.io/v1alpha4 +networking: + ipFamily: ipv6 + apiServerAddress: 127.0.0.1 +containerdConfigPatches: + - |- + [plugins."io.containerd.grpc.v1.cri".registry] + config_path = "/etc/containerd/certs.d" diff --git a/tests/node-e2e/fspolicy/test_operator_fs_policy.mjs b/tests/node-e2e/fspolicy/test_operator_fs_policy.mjs new file mode 100644 index 00000000000..af84f17ee50 --- /dev/null +++ b/tests/node-e2e/fspolicy/test_operator_fs_policy.mjs @@ -0,0 +1,19 @@ +import fs from 'fs'; + +function test_open(path, mode) { + fs.open(path, mode, (fail, fd) => { + if (fd) { + console.log(`SUCCESS ${mode} ${path} ${fd}`); + } + + if (fail) { + console.log(`FAIL ${mode} ${path} ${fail}`); + } + }); +} + +test_open("/app/file.local", "r"); +test_open("/app/file.not-found", "r"); +test_open("/app/file.read-only", "r"); +test_open("/app/file.read-only", "r+"); +test_open("/app/file.read-write", "r+"); diff --git a/tests/python-e2e/files_ro.py b/tests/python-e2e/files_ro.py index ed99eab5d7f..c6e4e8d5631 100644 --- a/tests/python-e2e/files_ro.py +++ b/tests/python-e2e/files_ro.py @@ -3,7 +3,7 @@ import uuid import unittest -TEXT = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum." +TEXT = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.\n" class FileOpsTest(unittest.TestCase): @@ -22,4 +22,4 @@ def test_read_only(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/python-e2e/ops.py b/tests/python-e2e/ops.py index 8e83271628f..36c7ba5fb8c 100644 --- a/tests/python-e2e/ops.py +++ b/tests/python-e2e/ops.py @@ -2,7 +2,7 @@ import uuid import unittest -TEXT = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum." +TEXT = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.\n" class FileOpsTest(unittest.TestCase): @@ -87,7 +87,29 @@ def test_mkdir_errors(self): os.mkdir("test_mkdir_error_already_exists", dir_fd=dir) os.close(dir) + + def test_statfs_and_fstatvfs_sucess(self): + """ + Test statfs / fstatfs + """ + file_path, _ = self._create_new_tmp_file() + statvfs_result = os.statvfs(file_path) + self.assertIsNotNone(statvfs_result) + + fd = os.open(file_path, os.O_RDONLY) + fstatvfs_result = os.fstatvfs(fd) + self.assertIsNotNone(fstatvfs_result) + + def test_rmdir(self): + """ + Creates a new directory in "/tmp" and removes it using rmdir. + """ + os.mkdir("/tmp/test_rmdir") + self.assertTrue(os.path.isdir("/tmp/test_rmdir")) + os.rmdir("/tmp/test_rmdir") + self.assertFalse(os.path.isdir("/tmp/test_rmdir")) + def _create_new_tmp_file(self): """ Creates a new file in /tmp and returns the path and name of the file. @@ -97,6 +119,5 @@ def _create_new_tmp_file(self): w_file.write(TEXT) return file_path, file_name - if __name__ == "__main__": unittest.main() diff --git a/tests/src/env.rs b/tests/src/env.rs index 7e192ad9f94..8f4b4a8e919 100644 --- a/tests/src/env.rs +++ b/tests/src/env.rs @@ -36,7 +36,7 @@ mod env_tests { let service = service.await; let mut process = run_exec_with_target( application.command(), - &service.target, + &service.pod_container_target(), None, application.mirrord_args(), None, diff --git a/tests/src/file_ops.rs b/tests/src/file_ops.rs index 826b6b64ec9..4b67092b974 100644 --- a/tests/src/file_ops.rs +++ b/tests/src/file_ops.rs @@ -33,7 +33,7 @@ mod file_ops_tests { let env = vec![("MIRRORD_FILE_READ_WRITE_PATTERN", "/tmp/**")]; let mut process = run_exec_with_target( command, - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(args), Some(env), @@ -63,7 +63,7 @@ mod file_ops_tests { let mut process = run_exec_with_target( python_command, - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(args), Some(env), @@ -97,7 +97,7 @@ mod file_ops_tests { let mut process = run_exec_with_target( python_command, - &service.target, + &service.pod_container_target(), Some(&service.namespace), None, None, @@ -117,8 +117,14 @@ mod file_ops_tests { pub async fn bash_file_exists(#[future] service: KubeService) { let service = service.await; let bash_command = vec!["bash", "bash-e2e/file.sh", "exists"]; - let mut process = - run_exec_with_target(bash_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + bash_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); @@ -135,8 +141,14 @@ mod file_ops_tests { pub async fn bash_file_read(#[future] service: KubeService) { let service = service.await; let bash_command = vec!["bash", "bash-e2e/file.sh", "read"]; - let mut process = - run_exec_with_target(bash_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + bash_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); @@ -151,8 +163,14 @@ mod file_ops_tests { let service = service.await; let bash_command = vec!["bash", "bash-e2e/file.sh", "write"]; let args = vec!["--rw"]; - let mut process = - run_exec_with_target(bash_command, &service.target, None, Some(args), None).await; + let mut process = run_exec_with_target( + bash_command, + &service.pod_container_target(), + None, + Some(args), + None, + ) + .await; let res = process.wait().await; assert!(res.success()); @@ -183,7 +201,7 @@ mod file_ops_tests { let mut process = run_exec_with_target( command, - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(args), None, diff --git a/tests/src/http.rs b/tests/src/http.rs index 388bc6d2c89..e9e23911c3f 100644 --- a/tests/src/http.rs +++ b/tests/src/http.rs @@ -41,7 +41,12 @@ mod http_tests { let kube_client = kube_client.await; let url = get_service_url(kube_client.clone(), &service).await; let mut process = application - .run(&service.target, Some(&service.namespace), None, None) + .run( + &service.pod_container_target(), + Some(&service.namespace), + None, + None, + ) .await; process .wait_for_line(Duration::from_secs(120), "daemon subscribed") @@ -80,7 +85,12 @@ mod http_tests { let kube_client = kube_client.await; let url = get_service_url(kube_client.clone(), &service).await; let mut process = application - .run(&service.target, Some(&service.namespace), None, None) + .run( + &service.pod_container_target(), + Some(&service.namespace), + None, + None, + ) .await; process .wait_for_line(Duration::from_secs(300), "daemon subscribed") diff --git a/tests/src/issue1317.rs b/tests/src/issue1317.rs index 54b70cde154..28992e0fda2 100644 --- a/tests/src/issue1317.rs +++ b/tests/src/issue1317.rs @@ -81,7 +81,14 @@ mod issue1317_tests { .to_string_lossy() .to_string(); let executable = vec![app_path.as_str()]; - let mut process = run_exec_with_target(executable, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + executable, + &service.pod_container_target(), + None, + None, + None, + ) + .await; process .wait_for_line(Duration::from_secs(120), "daemon subscribed") diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 98d801f5d23..6e9663bc138 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -1,4 +1,5 @@ #![feature(stmt_expr_attributes)] +#![feature(ip)] #![warn(clippy::indexing_slicing)] #[cfg(feature = "cli")] diff --git a/tests/src/operator/concurrent_steal.rs b/tests/src/operator/concurrent_steal.rs index f7fcb44a952..8a142c9075e 100644 --- a/tests/src/operator/concurrent_steal.rs +++ b/tests/src/operator/concurrent_steal.rs @@ -31,7 +31,7 @@ pub async fn two_clients_steal_same_target( let mut client_a = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags.clone()), None, @@ -43,7 +43,12 @@ pub async fn two_clients_steal_same_target( .await; let mut client_b = application - .run(&service.target, Some(&service.namespace), Some(flags), None) + .run( + &service.pod_container_target(), + Some(&service.namespace), + Some(flags), + None, + ) .await; let res = client_b.child.wait().await.unwrap(); @@ -84,7 +89,7 @@ pub async fn two_clients_steal_same_target_pod_deployment( let mut client_a = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags.clone()), None, @@ -148,7 +153,7 @@ pub async fn two_clients_steal_with_http_filter( let mut client_a = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags.clone()), Some(vec![("MIRRORD_CONFIG_FILE", config_path.to_str().unwrap())]), @@ -164,7 +169,7 @@ pub async fn two_clients_steal_with_http_filter( let mut client_b = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![("MIRRORD_CONFIG_FILE", config_path.to_str().unwrap())]), diff --git a/tests/src/operator/policies.rs b/tests/src/operator/policies.rs index 61a17d90af5..2a0a8465bbe 100644 --- a/tests/src/operator/policies.rs +++ b/tests/src/operator/policies.rs @@ -18,6 +18,8 @@ use crate::utils::{ config_dir, kube_client, service, Application, KubeService, ResourceGuard, TestProcess, }; +mod fs; + /// Guard that deletes a mirrord policy when dropped. struct PolicyGuard { _inner: ResourceGuard, @@ -31,28 +33,18 @@ impl PolicyGuard { ) -> Self { let policy_api: Api = Api::namespaced(kube_client.clone(), namespace); PolicyGuard { - _inner: ResourceGuard::create( - policy_api, - policy.metadata.name.clone().unwrap(), - policy, - true, - ) - .await - .expect("Could not create policy in E2E test."), + _inner: ResourceGuard::create(policy_api, policy, true) + .await + .expect("Could not create policy in E2E test."), } } pub async fn clusterwide(kube_client: kube::Client, policy: &MirrordClusterPolicy) -> Self { let policy_api: Api = Api::all(kube_client.clone()); PolicyGuard { - _inner: ResourceGuard::create( - policy_api, - policy.metadata.name.clone().unwrap(), - policy, - true, - ) - .await - .expect("Could not create policy in E2E test."), + _inner: ResourceGuard::create(policy_api, policy, true) + .await + .expect("Could not create policy in E2E test."), } } } @@ -128,6 +120,7 @@ fn block_steal_without_qualifiers() -> PolicyTestCase { selector: None, block: vec![BlockedFeature::Steal], env: Default::default(), + fs: Default::default(), }, ), service_b_can_steal: No, @@ -147,6 +140,7 @@ fn block_steal_with_path_pattern() -> PolicyTestCase { selector: None, block: vec![BlockedFeature::Steal], env: Default::default(), + fs: Default::default(), }, ), service_b_can_steal: EvenWithoutFilter, @@ -166,6 +160,7 @@ fn block_unfiltered_steal_with_path_pattern() -> PolicyTestCase { selector: None, block: vec![BlockedFeature::StealWithoutFilter], env: Default::default(), + fs: Default::default(), }, ), service_b_can_steal: EvenWithoutFilter, @@ -185,6 +180,7 @@ fn block_unfiltered_steal_with_deployment_path_pattern() -> PolicyTestCase { selector: None, block: vec![BlockedFeature::StealWithoutFilter], env: Default::default(), + fs: Default::default(), }, ), service_a_can_steal: OnlyWithFilter, @@ -210,6 +206,7 @@ fn block_steal_with_label_selector() -> PolicyTestCase { }), block: vec![BlockedFeature::Steal], env: Default::default(), + fs: Default::default(), }, ), service_b_can_steal: EvenWithoutFilter, @@ -236,6 +233,7 @@ fn block_steal_with_unmatching_policy() -> PolicyTestCase { }), block: vec![BlockedFeature::Steal], env: Default::default(), + fs: Default::default(), }, ), service_b_can_steal: EvenWithoutFilter, @@ -274,7 +272,7 @@ async fn run_mirrord_and_verify_steal_result( let target = if target_deployment { format!("deploy/{}", kube_service.name) } else { - kube_service.target.clone() + kube_service.pod_container_target() }; let test_proc = application @@ -293,7 +291,7 @@ async fn run_mirrord_and_verify_steal_result( let test_proc = application .run( - &kube_service.target, + &kube_service.pod_container_target(), Some(&kube_service.namespace), Some(vec!["--config-file", config_path.to_str().unwrap()]), None, @@ -347,7 +345,7 @@ async fn run_mirrord_and_verify_mirror_result(kube_service: &KubeService, expect let test_proc = application .run( - &kube_service.target, + &kube_service.pod_container_target(), Some(&kube_service.namespace), Some(vec!["--fs-mode=local"]), None, @@ -377,6 +375,7 @@ pub async fn create_cluster_policy_and_try_to_mirror( selector: None, block: vec![BlockedFeature::Mirror], env: Default::default(), + fs: Default::default(), }, ), ) diff --git a/tests/src/operator/policies/fs.rs b/tests/src/operator/policies/fs.rs new file mode 100644 index 00000000000..79a1b7e7202 --- /dev/null +++ b/tests/src/operator/policies/fs.rs @@ -0,0 +1,92 @@ +use std::{collections::HashSet, time::Duration}; + +use mirrord_operator::crd::policy::{FsPolicy, MirrordPolicy, MirrordPolicySpec}; +use rstest::{fixture, rstest}; + +use crate::{ + operator::policies::PolicyGuard, + utils::{kube_client, service, Application, KubeService}, +}; + +#[fixture] +async fn fs_service(#[future] kube_client: kube::Client) -> KubeService { + let namespace = format!("e2e-tests-fs-policies-{}", crate::utils::random_string()); + + service( + &namespace, + "NodePort", + "ghcr.io/metalbear-co/mirrord-pytest:latest", + "fs-policy-e2e-test-service", + false, + kube_client, + ) + .await +} + +#[rstest] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[timeout(Duration::from_secs(60))] +pub async fn create_namespaced_fs_policy_and_try_file_open( + #[future] fs_service: KubeService, + #[future] kube_client: kube::Client, +) { + let kube_client = kube_client.await; + let service = fs_service.await; + + // Create policy, delete it when test exits. + let _policy_guard = PolicyGuard::namespaced( + kube_client, + &MirrordPolicy::new( + "e2e-test-fs-policy-with-path-pattern", + MirrordPolicySpec { + target_path: Some("*fs-policy-e2e-test-*".into()), + selector: None, + block: Default::default(), + env: Default::default(), + fs: FsPolicy { + read_only: HashSet::from_iter(vec!["file\\.read-only".to_string()]), + local: HashSet::from_iter(vec!["file\\.local".to_string()]), + not_found: HashSet::from_iter(vec!["file\\.not-found".to_string()]), + }, + }, + ), + &service.namespace, + ) + .await; + + let application = Application::NodeFsPolicy; + println!("Running mirrord {application:?} against {}", &service.name); + + let mut test_process = application + .run( + &service.pod_container_target(), + Some(&service.namespace), + Some(vec!["--fs-mode=write"]), + None, + ) + .await; + + test_process.wait_assert_success().await; + + let stdout = test_process.get_stdout().await; + + let reading_local_failed = stdout.contains("FAIL r /app/file.local"); + let reading_not_found_failed = stdout.contains("FAIL r /app/file.not-found"); + let reading_read_only_succeeded = stdout.contains("SUCCESS r /app/file.read-only"); + let writing_read_only_failed = stdout.contains("FAIL r+ /app/file.read-only"); + let writing_read_write_succeeded = stdout.contains("SUCCESS r+ /app/file.read-write"); + + assert!( + reading_local_failed + && reading_not_found_failed + && reading_read_only_succeeded + && writing_read_only_failed + && writing_read_write_succeeded, + "some file operations did not finish as expected:\n + \treading_local_failed={reading_local_failed}\n + \treading_not_found_failed={reading_not_found_failed}\n + \treading_read_only_succeeded={reading_read_only_succeeded} \n + \twriting_read_only_failed={writing_read_only_failed}\n + \twriting_read_write_succeeded={writing_read_write_succeeded}", + ) +} diff --git a/tests/src/traffic.rs b/tests/src/traffic.rs index 66d1df00238..c511eb12e10 100644 --- a/tests/src/traffic.rs +++ b/tests/src/traffic.rs @@ -16,8 +16,8 @@ mod traffic_tests { use tokio::{fs::File, io::AsyncWriteExt}; use crate::utils::{ - config_dir, hostname_service, kube_client, run_exec_with_target, service, - udp_logger_service, KubeService, CONTAINER_NAME, + config_dir, hostname_service, ipv6::ipv6_service, kube_client, run_exec_with_target, + service, udp_logger_service, Application, KubeService, CONTAINER_NAME, }; #[cfg_attr(not(feature = "job"), ignore)] @@ -30,8 +30,14 @@ mod traffic_tests { "node", "node-e2e/remote_dns/test_remote_dns_enabled_works.mjs", ]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); @@ -47,8 +53,14 @@ mod traffic_tests { "node", "node-e2e/remote_dns/test_remote_dns_lookup_google.mjs", ]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); @@ -66,8 +78,14 @@ mod traffic_tests { "node", "node-e2e/outgoing/test_outgoing_traffic_single_request.mjs", ]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); @@ -83,11 +101,54 @@ mod traffic_tests { "node", "node-e2e/outgoing/test_outgoing_traffic_single_request_ipv6.mjs", ]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; + + let res = process.wait().await; + assert!(res.success()); + } + + #[rstest] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[ignore] + pub async fn outgoing_traffic_single_request_ipv6_enabled(#[future] ipv6_service: KubeService) { + let service = ipv6_service.await; + let node_command = vec![ + "node", + "node-e2e/outgoing/test_outgoing_traffic_single_request_ipv6.mjs", + ]; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + Some(vec![("MIRRORD_ENABLE_IPV6", "true")]), + ) + .await; + + let res = process.wait().await; + assert!(res.success()); + } + #[rstest] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[timeout(Duration::from_secs(30))] + #[ignore] + pub async fn connect_to_kubernetes_api_service_over_ipv6() { + let app = Application::CurlToKubeApi; + let mut process = app + .run_targetless(None, None, Some(vec![("MIRRORD_ENABLE_IPV6", "true")])) + .await; let res = process.wait().await; assert!(res.success()); + let stdout = process.get_stdout().await; + assert!(stdout.contains(r#""apiVersion": "v1""#)) } #[cfg_attr(not(feature = "job"), ignore)] @@ -102,7 +163,7 @@ mod traffic_tests { let mirrord_args = vec!["--no-outgoing"]; let mut process = run_exec_with_target( node_command, - &service.target, + &service.pod_container_target(), None, Some(mirrord_args), None, @@ -122,8 +183,14 @@ mod traffic_tests { "node", "node-e2e/outgoing/test_outgoing_traffic_make_request_after_listen.mjs", ]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); } @@ -137,8 +204,14 @@ mod traffic_tests { "node", "node-e2e/outgoing/test_outgoing_traffic_make_request_localhost.mjs", ]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); } @@ -185,7 +258,7 @@ mod traffic_tests { let mirrord_no_outgoing = vec!["--no-outgoing"]; let mut process = run_exec_with_target( node_command.clone(), - &target_service.target, + &target_service.pod_container_target(), Some(&target_service.namespace), Some(mirrord_no_outgoing), None, @@ -193,18 +266,13 @@ mod traffic_tests { .await; let res = process.wait().await; assert!(res.success()); // The test does not fail, because UDP does not report dropped datagrams. - let stripped_target = internal_service - .target - .split('/') - .nth(1) - .expect("malformed target"); - let logs = pod_api.logs(stripped_target, &lp).await; + let logs = pod_api.logs(&internal_service.pod_name, &lp).await; assert_eq!(logs.unwrap(), ""); // Assert that the target service did not get the message. // Run mirrord with outgoing enabled. let mut process = run_exec_with_target( node_command, - &target_service.target, + &target_service.pod_container_target(), Some(&target_service.namespace), None, None, @@ -217,7 +285,7 @@ mod traffic_tests { lp.follow = true; // Follow log stream. let mut log_lines = pod_api - .log_stream(stripped_target, &lp) + .log_stream(&internal_service.pod_name, &lp) .await .unwrap() .lines(); @@ -278,7 +346,7 @@ mod traffic_tests { // If this verification fails, the test itself is invalid. let mut process = run_exec_with_target( node_command.clone(), - &target_service.target, + &target_service.pod_container_target(), Some(&target_service.namespace), None, Some(vec![("MIRRORD_CONFIG_FILE", config_path.to_str().unwrap())]), @@ -286,12 +354,7 @@ mod traffic_tests { .await; let res = process.wait().await; assert!(!res.success()); // Should fail because local process cannot reach service. - let stripped_target = internal_service - .target - .split('/') - .nth(1) - .expect("malformed target"); - let logs = pod_api.logs(stripped_target, &lp).await; + let logs = pod_api.logs(&internal_service.pod_name, &lp).await; assert_eq!(logs.unwrap(), ""); // Create remote filter file with service name so we can test DNS outgoing filter. @@ -326,7 +389,7 @@ mod traffic_tests { // Run mirrord with outgoing enabled. let mut process = run_exec_with_target( node_command, - &target_service.target, + &target_service.pod_container_target(), Some(&target_service.namespace), None, Some(vec![("MIRRORD_CONFIG_FILE", config_path.to_str().unwrap())]), @@ -339,7 +402,7 @@ mod traffic_tests { lp.follow = true; // Follow log stream. let mut log_lines = pod_api - .log_stream(stripped_target, &lp) + .log_stream(&internal_service.pod_name, &lp) .await .unwrap() .lines(); @@ -376,7 +439,7 @@ mod traffic_tests { let mirrord_args = vec!["--no-outgoing"]; let mut process = run_exec_with_target( node_command, - &service.target, + &service.pod_container_target(), None, Some(mirrord_args), None, @@ -395,7 +458,8 @@ mod traffic_tests { pub async fn test_go(service: impl Future, command: Vec<&str>) { let service = service.await; - let mut process = run_exec_with_target(command, &service.target, None, None, None).await; + let mut process = + run_exec_with_target(command, &service.pod_container_target(), None, None, None).await; let res = process.wait().await; assert!(res.success()); } @@ -458,8 +522,14 @@ mod traffic_tests { pub async fn listen_localhost(#[future] service: KubeService) { let service = service.await; let node_command = vec!["node", "node-e2e/listen/test_listen_localhost.mjs"]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); } @@ -471,8 +541,14 @@ mod traffic_tests { pub async fn gethostname_remote_result(#[future] hostname_service: KubeService) { let service = hostname_service.await; let node_command = vec!["python3", "-u", "python-e2e/hostname.py"]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; assert!(res.success()); @@ -511,7 +587,9 @@ mod traffic_tests { "MIRRORD_OUTGOING_REMOTE_UNIX_STREAMS", "/app/unix-socket-server.sock", )]); - let mut process = run_exec_with_target(executable, &service.target, None, None, env).await; + let mut process = + run_exec_with_target(executable, &service.pod_container_target(), None, None, env) + .await; let res = process.wait().await; // The test application panics if it does not successfully connect to the socket, send data, @@ -534,7 +612,14 @@ mod traffic_tests { .to_string(); let executable = vec![app_path.as_str()]; - let mut process = run_exec_with_target(executable, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + executable, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.wait().await; // The test application panics if it does not successfully connect to the socket, send data, @@ -551,8 +636,14 @@ mod traffic_tests { "node", "node-e2e/outgoing/test_outgoing_traffic_many_requests.mjs", ]; - let mut process = - run_exec_with_target(node_command, &service.target, None, None, None).await; + let mut process = run_exec_with_target( + node_command, + &service.pod_container_target(), + None, + None, + None, + ) + .await; let res = process.child.wait().await.unwrap(); assert!(res.success()); @@ -571,7 +662,7 @@ mod traffic_tests { let mirrord_args = vec!["--no-outgoing"]; let mut process = run_exec_with_target( node_command, - &service.target, + &service.pod_container_target(), None, Some(mirrord_args), None, diff --git a/tests/src/traffic/steal.rs b/tests/src/traffic/steal.rs index 518aa0bc13e..4b9fcbb2ddd 100644 --- a/tests/src/traffic/steal.rs +++ b/tests/src/traffic/steal.rs @@ -1,31 +1,30 @@ -#![allow(dead_code, unused)] #[cfg(test)] mod steal_tests { - use std::{ - io::{BufRead, BufReader, Read, Write}, - net::{SocketAddr, TcpStream}, - path::Path, - time::Duration, - }; + use std::{net::SocketAddr, path::Path, time::Duration}; use futures_util::{SinkExt, StreamExt}; - use kube::Client; + use k8s_openapi::api::core::v1::Pod; + use kube::{Api, Client}; use reqwest::{header::HeaderMap, Url}; use rstest::*; - use tokio::time::sleep; + use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + net::TcpStream, + time::sleep, + }; use tokio_tungstenite::{ connect_async, tungstenite::{client::IntoClientRequest, Message}, }; use crate::utils::{ - config_dir, get_service_host_and_port, get_service_url, http2_service, kube_client, - send_request, send_requests, service, tcp_echo_service, websocket_service, Application, - KubeService, + config_dir, get_service_host_and_port, get_service_url, http2_service, + ipv6::{ipv6_service, portforward_http_requests}, + kube_client, send_request, send_requests, service, tcp_echo_service, websocket_service, + Application, KubeService, }; #[cfg_attr(not(any(feature = "ephemeral", feature = "job")), ignore)] - #[cfg(target_os = "linux")] #[rstest] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[timeout(Duration::from_secs(240))] @@ -49,7 +48,12 @@ mod steal_tests { } let mut process = application - .run(&service.target, Some(&service.namespace), Some(flags), None) + .run( + &service.pod_container_target(), + Some(&service.namespace), + Some(flags), + None, + ) .await; process @@ -63,6 +67,47 @@ mod steal_tests { application.assert(&process).await; } + #[ignore] // Needs special cluster setup, so ignore by default. + #[rstest] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[timeout(Duration::from_secs(240))] + async fn steal_http_ipv6_traffic( + #[future] ipv6_service: KubeService, + #[future] kube_client: Client, + ) { + let application = Application::PythonFastApiHTTPIPv6; + let service = ipv6_service.await; + let kube_client = kube_client.await; + + let mut flags = vec!["--steal"]; + + if cfg!(feature = "ephemeral") { + flags.extend(["-e"].into_iter()); + } + + let mut process = application + .run( + &service.pod_container_target(), + Some(&service.namespace), + Some(flags), + Some(vec![("MIRRORD_ENABLE_IPV6", "true")]), + ) + .await; + + process + .wait_for_line(Duration::from_secs(40), "daemon subscribed") + .await; + + let api = Api::::namespaced(kube_client.clone(), &service.namespace); + portforward_http_requests(&api, service).await; + + tokio::time::timeout(Duration::from_secs(40), process.wait()) + .await + .unwrap(); + + application.assert(&process).await; + } + #[cfg_attr(not(any(feature = "ephemeral", feature = "job")), ignore)] #[cfg(target_os = "linux")] #[rstest] @@ -89,7 +134,7 @@ mod steal_tests { let mut process = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![("MIRRORD_AGENT_STEALER_FLUSH_CONNECTIONS", "true")]), @@ -125,7 +170,12 @@ mod steal_tests { } let mut process = application - .run(&service.target, Some(&service.namespace), Some(flags), None) + .run( + &service.pod_container_target(), + Some(&service.namespace), + Some(flags), + None, + ) .await; // Verify that we hooked the socket operations and the agent started stealing. @@ -208,7 +258,7 @@ mod steal_tests { let mut process = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![ @@ -226,7 +276,7 @@ mod steal_tests { .wait_for_line(Duration::from_secs(40), "daemon subscribed") .await; - let mut tcp_stream = TcpStream::connect((addr, port as u16)).unwrap(); + let mut tcp_stream = TcpStream::connect((addr, port as u16)).await.unwrap(); // Wait for the test app to close the socket and tell us about it. process @@ -235,10 +285,10 @@ mod steal_tests { const DATA: &[u8; 16] = b"upper me please\n"; - tcp_stream.write_all(DATA).unwrap(); + tcp_stream.write_all(DATA).await.unwrap(); let mut response = [0u8; DATA.len()]; - tcp_stream.read_exact(&mut response).unwrap(); + tcp_stream.read_exact(&mut response).await.unwrap(); process .write_to_stdin(b"Hey test app, please stop running and just exit successfuly.\n") @@ -288,7 +338,7 @@ mod steal_tests { let mut client = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![("MIRRORD_HTTP_HEADER_FILTER", "x-filter: yes")]), @@ -329,7 +379,7 @@ mod steal_tests { let mut client = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), None, Some(vec![("MIRRORD_CONFIG_FILE", config_path.to_str().unwrap())]), @@ -370,7 +420,7 @@ mod steal_tests { let mut client = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), None, Some(vec![("MIRRORD_CONFIG_FILE", config_path.to_str().unwrap())]), @@ -423,7 +473,7 @@ mod steal_tests { let mut mirrored_process = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![("MIRRORD_HTTP_HEADER_FILTER", "x-filter: yes")]), @@ -494,7 +544,7 @@ mod steal_tests { let mut mirrorded_process = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![("MIRRORD_HTTP_HEADER_FILTER", "x-filter: yes")]), @@ -559,7 +609,7 @@ mod steal_tests { let mut mirrorded_process = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![("MIRRORD_HTTP_HEADER_FILTER", "x-filter: yes")]), @@ -571,11 +621,11 @@ mod steal_tests { .await; let addr = SocketAddr::new(host.trim().parse().unwrap(), port as u16); - let mut stream = TcpStream::connect(addr).unwrap(); - stream.write_all(tcp_data.as_bytes()).unwrap(); + let mut stream = TcpStream::connect(addr).await.unwrap(); + stream.write_all(tcp_data.as_bytes()).await.unwrap(); let mut reader = BufReader::new(stream); let mut buf = String::new(); - reader.read_line(&mut buf).unwrap(); + reader.read_line(&mut buf).await.unwrap(); println!("Got response: {buf}"); // replace "remote: " with empty string, since the response can be split into frames // and we just need assert the final response @@ -629,7 +679,7 @@ mod steal_tests { let mut mirrorded_process = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(flags), Some(vec![("MIRRORD_HTTP_HEADER_FILTER", "x-filter: yes")]), @@ -708,7 +758,7 @@ mod steal_tests { let mut mirrorded_process = application .run( - &service.target, + &service.pod_container_target(), Some(&service.namespace), Some(vec!["--steal"]), Some(vec![("MIRRORD_HTTP_HEADER_FILTER", "x-filter: yes")]), diff --git a/tests/src/utils.rs b/tests/src/utils.rs index bf807337a09..7fa4b7c68e7 100644 --- a/tests/src/utils.rs +++ b/tests/src/utils.rs @@ -4,8 +4,9 @@ use std::{ collections::HashMap, fmt::Debug, - net::Ipv4Addr, + net::IpAddr, ops::Not, + os::unix::process::ExitStatusExt, path::PathBuf, process::{ExitStatus, Stdio}, sync::{Arc, Once}, @@ -24,7 +25,7 @@ use kube::{ api::{DeleteParams, ListParams, PostParams, WatchParams}, core::WatchEvent, runtime::wait::{await_condition, conditions::is_pod_running}, - Api, Client, Config, Error, + Api, Client, Config, Error, Resource, }; use rand::{distributions::Alphanumeric, Rng}; use reqwest::{RequestBuilder, StatusCode}; @@ -39,6 +40,7 @@ use tokio::{ task::JoinHandle, }; +pub(crate) mod ipv6; pub mod sqs_resources; const TEXT: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; @@ -90,16 +92,23 @@ fn format_time() -> String { pub enum Application { PythonFlaskHTTP, PythonFastApiHTTP, + PythonFastApiHTTPIPv6, NodeHTTP, NodeHTTP2, Go21HTTP, Go22HTTP, Go23HTTP, CurlToKubeApi, + CurlToKubeApiOverIpv6, PythonCloseSocket, PythonCloseSocketKeepConnection, RustWebsockets, RustSqs, + /// Tries to open files in the remote target, but these operations should succeed or fail based + /// on mirrord `FsPolicy`. + /// + /// - `node-e2e/fspolicy/test_operator_fs_policy.mjs` + NodeFsPolicy, } #[derive(Debug)] @@ -148,68 +157,111 @@ impl TestProcess { pub async fn assert_log_level(&self, stderr: bool, level: &str) { if stderr { - assert!(!self.stderr_data.read().await.contains(level)); + assert!( + self.stderr_data.read().await.contains(level).not(), + "application stderr should not contain `{level}`" + ); } else { - assert!(!self.stdout_data.read().await.contains(level)); + assert!( + self.stdout_data.read().await.contains(level).not(), + "application stdout should not contain `{level}`" + ); } } pub async fn assert_python_fileops_stderr(&self) { - assert!(!self.stderr_data.read().await.contains("FAILED")); + assert!( + self.stderr_data.read().await.contains("FAILED").not(), + "application stderr should not contain `FAILED`" + ); } pub async fn wait_assert_success(&mut self) { let output = self.wait().await; - assert!(output.success()); + assert!( + output.success(), + "application unexpectedly failed: exit code {:?}, signal code {:?}", + output.code(), + output.signal(), + ); } pub async fn wait_assert_fail(&mut self) { let output = self.wait().await; - assert!(!output.success()); + assert!( + output.success().not(), + "application unexpectedly succeeded: exit code {:?}, signal code {:?}", + output.code(), + output.signal() + ); } pub async fn assert_stdout_contains(&self, string: &str) { - assert!(self.get_stdout().await.contains(string)); + assert!( + self.get_stdout().await.contains(string), + "application stdout should contain `{string}`", + ); } pub async fn assert_stdout_doesnt_contain(&self, string: &str) { - assert!(!self.get_stdout().await.contains(string)); + assert!( + self.get_stdout().await.contains(string).not(), + "application stdout should not contain `{string}`", + ); } pub async fn assert_stderr_contains(&self, string: &str) { - assert!(self.get_stderr().await.contains(string)); + assert!( + self.get_stderr().await.contains(string), + "application stderr should contain `{string}`", + ); } pub async fn assert_stderr_doesnt_contain(&self, string: &str) { - assert!(!self.get_stderr().await.contains(string)); + assert!( + self.get_stderr().await.contains(string).not(), + "application stderr should not contain `{string}`", + ); } pub async fn assert_no_error_in_stdout(&self) { - assert!(!self - .error_capture - .is_match(&self.stdout_data.read().await) - .unwrap()); + assert!( + self.error_capture + .is_match(&self.stdout_data.read().await) + .unwrap() + .not(), + "application stdout contains an error" + ); } pub async fn assert_no_error_in_stderr(&self) { - assert!(!self - .error_capture - .is_match(&self.stderr_data.read().await) - .unwrap()); + assert!( + self.error_capture + .is_match(&self.stderr_data.read().await) + .unwrap() + .not(), + "application stderr contains an error" + ); } pub async fn assert_no_warn_in_stdout(&self) { - assert!(!self - .warn_capture - .is_match(&self.stdout_data.read().await) - .unwrap()); + assert!( + self.warn_capture + .is_match(&self.stdout_data.read().await) + .unwrap() + .not(), + "application stdout contains a warning" + ); } pub async fn assert_no_warn_in_stderr(&self) { - assert!(!self - .warn_capture - .is_match(&self.stderr_data.read().await) - .unwrap()); + assert!( + self.warn_capture + .is_match(&self.stderr_data.read().await) + .unwrap() + .not(), + "application stderr contains a warning" + ); } pub async fn wait_for_line(&self, timeout: Duration, line: &str) { @@ -394,6 +446,15 @@ impl Application { "app_fastapi:app", ] } + Application::PythonFastApiHTTPIPv6 => { + vec![ + "uvicorn", + "--port=80", + "--host=::", + "--app-dir=./python-e2e/", + "app_fastapi:app", + ] + } Application::PythonCloseSocket => { vec!["python3", "-u", "python-e2e/close_socket.py"] } @@ -408,12 +469,18 @@ impl Application { Application::NodeHTTP2 => { vec!["node", "node-e2e/http2/test_http2_traffic_steal.mjs"] } + Application::NodeFsPolicy => { + vec!["node", "node-e2e/fspolicy/test_operator_fs_policy.mjs"] + } Application::Go21HTTP => vec!["go-e2e/21.go_test_app"], Application::Go22HTTP => vec!["go-e2e/22.go_test_app"], Application::Go23HTTP => vec!["go-e2e/23.go_test_app"], Application::CurlToKubeApi => { vec!["curl", "https://kubernetes/api", "--insecure"] } + Application::CurlToKubeApiOverIpv6 => { + vec!["curl", "-6", "https://kubernetes/api", "--insecure"] + } Application::RustWebsockets => vec!["../target/debug/rust-websockets"], Application::RustSqs => vec!["../target/debug/rust-sqs-printer"], } @@ -439,7 +506,7 @@ impl Application { } pub async fn assert(&self, process: &TestProcess) { - if let Application::PythonFastApiHTTP = self { + if matches!(self, Self::PythonFastApiHTTP | Self::PythonFastApiHTTPIPv6) { process.assert_log_level(true, "ERROR").await; process.assert_log_level(false, "ERROR").await; process.assert_log_level(true, "CRITICAL").await; @@ -570,11 +637,17 @@ pub async fn run_exec( .into_iter() .chain(process_cmd.into_iter()) .collect(); + let agent_image_env = "MIRRORD_AGENT_IMAGE"; + let agent_image_from_devs_env = std::env::var(agent_image_env); // used by the CI, to load the image locally: // docker build -t test . -f mirrord/agent/Dockerfile // minikube load image test:latest let mut base_env = HashMap::new(); - base_env.insert("MIRRORD_AGENT_IMAGE", "test"); + base_env.insert( + agent_image_env, + // Let devs running the test specify an agent image per env var. + agent_image_from_devs_env.as_deref().unwrap_or("test"), + ); base_env.insert("MIRRORD_CHECK_VERSION", "false"); base_env.insert("MIRRORD_AGENT_RUST_LOG", "warn,mirrord=debug"); base_env.insert("MIRRORD_AGENT_COMMUNICATION_TIMEOUT", "180"); @@ -654,23 +727,27 @@ pub(crate) struct ResourceGuard { impl ResourceGuard { /// Create a kube resource and spawn a task to delete it when this guard is dropped. /// Return [`Error`] if creating the resource failed. - pub async fn create( + pub async fn create< + K: Resource + Debug + Clone + DeserializeOwned + Serialize + 'static, + >( api: Api, - name: String, data: &K, delete_on_fail: bool, ) -> Result { + let name = data.meta().name.clone().unwrap(); + println!("Creating {} `{name}`: {data:?}", K::kind(&())); api.create(&PostParams::default(), data).await?; + println!("Created {} `{name}`", K::kind(&())); let deleter = async move { - println!("Deleting resource `{name}`"); + println!("Deleting {} `{name}`", K::kind(&())); let delete_params = DeleteParams { grace_period_seconds: Some(0), ..Default::default() }; let res = api.delete(&name, &delete_params).await; if let Err(e) = res { - println!("Failed to delete resource `{name}`: {e:?}"); + println!("Failed to delete {} `{name}`: {e:?}", K::kind(&())); } }; @@ -720,15 +797,19 @@ impl Drop for ResourceGuard { pub struct KubeService { pub name: String, pub namespace: String, - pub target: String, guards: Vec, namespace_guard: Option, + pub pod_name: String, } impl KubeService { pub fn deployment_target(&self) -> String { format!("deployment/{}", self.name) } + + pub fn pod_container_target(&self) -> String { + format!("pod/{}/container/{CONTAINER_NAME}", self.pod_name) + } } impl Drop for KubeService { @@ -809,6 +890,17 @@ fn deployment_from_json(name: &str, image: &str, env: Value) -> Deployment { .expect("Failed creating `deployment` from json spec!") } +/// Change the `ipFamilies` and `ipFamilyPolicy` fields to make the service IPv6-only. +/// +/// # Panics +/// +/// Will panic if the given service does not have a spec. +fn set_ipv6_only(service: &mut Service) { + let spec = service.spec.as_mut().unwrap(); + spec.ip_families = Some(vec!["IPv6".to_string()]); + spec.ip_family_policy = Some("SingleStack".to_string()); +} + fn service_from_json(name: &str, service_type: &str) -> Service { serde_json::from_value(json!({ "apiVersion": "v1", @@ -1057,6 +1149,7 @@ pub async fn service( randomize_name, kube_client.await, default_env(), + false, ) .await } @@ -1085,6 +1178,7 @@ pub async fn service_with_env( randomize_name, kube_client, env, + false, ) .await } @@ -1100,6 +1194,7 @@ pub async fn service_with_env( /// This behavior can be changed, see [`PRESERVE_FAILED_ENV_NAME`]. /// * `randomize_name` - whether a random suffix should be added to the end of the resource names /// * `env` - `Value`, should be `Value::Array` of kubernetes container env var definitions. +#[allow(clippy::too_many_arguments)] async fn internal_service( namespace: &str, service_type: &str, @@ -1108,6 +1203,7 @@ async fn internal_service( randomize_name: bool, kube_client: Client, env: Value, + ipv6_only: bool, ) -> KubeService { let delete_after_fail = std::env::var_os(PRESERVE_FAILED_ENV_NAME).is_none(); @@ -1133,7 +1229,7 @@ async fn internal_service( }; println!( - "{} creating service {name:?} in namespace {namespace:?}", + "{} creating service {name} in namespace {namespace}", format_time() ); @@ -1151,7 +1247,6 @@ async fn internal_service( // Create namespace and wrap it in ResourceGuard if it does not yet exist. let namespace_guard = ResourceGuard::create( namespace_api.clone(), - namespace.to_string(), &namespace_resource, delete_after_fail, ) @@ -1163,47 +1258,40 @@ async fn internal_service( // `Deployment` let deployment = deployment_from_json(&name, image, env); - let pod_guard = ResourceGuard::create( - deployment_api.clone(), - name.to_string(), - &deployment, - delete_after_fail, - ) - .await - .unwrap(); + let pod_guard = ResourceGuard::create(deployment_api.clone(), &deployment, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&deployment_api, &name).await; // `Service` - let service = service_from_json(&name, service_type); - let service_guard = ResourceGuard::create( - service_api.clone(), - name.clone(), - &service, - delete_after_fail, - ) - .await - .unwrap(); + let mut service = service_from_json(&name, service_type); + if ipv6_only { + set_ipv6_only(&mut service); + } + let service_guard = ResourceGuard::create(service_api.clone(), &service, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&service_api, "default").await; - let target = get_instance_name::(kube_client.clone(), &name, namespace) + let pod_name = get_instance_name::(kube_client.clone(), &name, namespace) .await .unwrap(); let pod_api: Api = Api::namespaced(kube_client.clone(), namespace); - await_condition(pod_api, &target, is_pod_running()) + await_condition(pod_api, &pod_name, is_pod_running()) .await .unwrap(); println!( - "{:?} done creating service {name:?} in namespace {namespace:?}", - Utc::now() + "{} done creating service {name} in namespace {namespace}", + format_time(), ); KubeService { name, namespace: namespace.to_string(), - target: format!("pod/{target}/container/{CONTAINER_NAME}"), + pod_name, guards: vec![pod_guard, service_guard], namespace_guard, } @@ -1262,7 +1350,6 @@ pub async fn service_for_mirrord_ls( // Create namespace and wrap it in ResourceGuard if it does not yet exist. let namespace_guard = ResourceGuard::create( namespace_api.clone(), - namespace.to_string(), &namespace_resource, delete_after_fail, ) @@ -1274,35 +1361,25 @@ pub async fn service_for_mirrord_ls( // `Deployment` let deployment = deployment_from_json(&name, image, default_env()); - let pod_guard = ResourceGuard::create( - deployment_api.clone(), - name.to_string(), - &deployment, - delete_after_fail, - ) - .await - .unwrap(); + let pod_guard = ResourceGuard::create(deployment_api.clone(), &deployment, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&deployment_api, &name).await; // `Service` let service = service_from_json(&name, service_type); - let service_guard = ResourceGuard::create( - service_api.clone(), - name.clone(), - &service, - delete_after_fail, - ) - .await - .unwrap(); + let service_guard = ResourceGuard::create(service_api.clone(), &service, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&service_api, "default").await; - let target = get_instance_name::(kube_client.clone(), &name, namespace) + let pod_name = get_instance_name::(kube_client.clone(), &name, namespace) .await .unwrap(); let pod_api: Api = Api::namespaced(kube_client.clone(), namespace); - await_condition(pod_api, &target, is_pod_running()) + await_condition(pod_api, &pod_name, is_pod_running()) .await .unwrap(); @@ -1314,7 +1391,7 @@ pub async fn service_for_mirrord_ls( KubeService { name, namespace: namespace.to_string(), - target: format!("pod/{target}/container/{CONTAINER_NAME}"), + pod_name, guards: vec![pod_guard, service_guard], namespace_guard, } @@ -1384,7 +1461,6 @@ pub async fn service_for_mirrord_ls( // Create namespace and wrap it in ResourceGuard if it does not yet exist. let namespace_guard = ResourceGuard::create( namespace_api.clone(), - namespace.to_string(), &namespace_resource, delete_after_fail, ) @@ -1396,67 +1472,47 @@ pub async fn service_for_mirrord_ls( // `Deployment` let deployment = deployment_from_json(&name, image, default_env()); - let pod_guard = ResourceGuard::create( - deployment_api.clone(), - name.to_string(), - &deployment, - delete_after_fail, - ) - .await - .unwrap(); + let pod_guard = ResourceGuard::create(deployment_api.clone(), &deployment, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&deployment_api, &name).await; // `Service` let service = service_from_json(&name, service_type); - let service_guard = ResourceGuard::create( - service_api.clone(), - name.clone(), - &service, - delete_after_fail, - ) - .await - .unwrap(); + let service_guard = ResourceGuard::create(service_api.clone(), &service, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&service_api, "default").await; // `StatefulSet` let stateful_set = stateful_set_from_json(&name, image); - let stateful_set_guard = ResourceGuard::create( - stateful_set_api.clone(), - name.to_string(), - &stateful_set, - delete_after_fail, - ) - .await - .unwrap(); + let stateful_set_guard = + ResourceGuard::create(stateful_set_api.clone(), &stateful_set, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&stateful_set_api, &name).await; // `CronJob` let cron_job = cron_job_from_json(&name, image); - let cron_job_guard = ResourceGuard::create( - cron_job_api.clone(), - name.to_string(), - &cron_job, - delete_after_fail, - ) - .await - .unwrap(); + let cron_job_guard = ResourceGuard::create(cron_job_api.clone(), &cron_job, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&cron_job_api, &name).await; // `Job` let job = job_from_json(&name, image); - let job_guard = - ResourceGuard::create(job_api.clone(), name.to_string(), &job, delete_after_fail) - .await - .unwrap(); + let job_guard = ResourceGuard::create(job_api.clone(), &job, delete_after_fail) + .await + .unwrap(); watch_resource_exists(&job_api, &name).await; - let target = get_instance_name::(kube_client.clone(), &name, namespace) + let pod_name = get_instance_name::(kube_client.clone(), &name, namespace) .await .unwrap(); let pod_api: Api = Api::namespaced(kube_client.clone(), namespace); - await_condition(pod_api, &target, is_pod_running()) + await_condition(pod_api, &pod_name, is_pod_running()) .await .unwrap(); @@ -1468,7 +1524,6 @@ pub async fn service_for_mirrord_ls( KubeService { name, namespace: namespace.to_string(), - target: format!("pod/{target}/container/{CONTAINER_NAME}"), guards: vec![ pod_guard, service_guard, @@ -1477,6 +1532,7 @@ pub async fn service_for_mirrord_ls( job_guard, ], namespace_guard, + pod_name, } } @@ -1599,12 +1655,15 @@ async fn get_pod_or_node_host(kube_client: Client, name: &str, namespace: &str) .next() .and_then(|pod| pod.status) .and_then(|status| status.host_ip) - .and_then(|ip| { - ip.parse::() - .unwrap() - .is_private() - .not() - .then_some(ip) + .filter(|ip| { + // use this IP only if it's a public one. + match ip.parse::().unwrap() { + IpAddr::V4(ip4) => ip4.is_private(), + IpAddr::V6(ip6) => { + ip6.is_unicast_link_local() || ip6.is_unique_local() || ip6.is_loopback() + } + } + .not() }) .unwrap_or_else(resolve_node_host) } diff --git a/tests/src/utils/ipv6.rs b/tests/src/utils/ipv6.rs new file mode 100644 index 00000000000..4b766e7f42d --- /dev/null +++ b/tests/src/utils/ipv6.rs @@ -0,0 +1,100 @@ +#![cfg(test)] + +use http_body_util::{BodyExt, Empty}; +use hyper::{ + client::{conn, conn::http1::SendRequest}, + Request, +}; +use k8s_openapi::api::core::v1::Pod; +use kube::{Api, Client}; +use rstest::fixture; + +use crate::utils::{internal_service, kube_client, KubeService}; + +/// Create a new [`KubeService`] and related Kubernetes resources. The resources will be deleted +/// when the returned service is dropped, unless it is dropped during panic. +/// This behavior can be changed, see +/// [`PRESERVE_FAILED_ENV_NAME`](crate::utils::PRESERVE_FAILED_ENV_NAME). +/// +/// * `randomize_name` - whether a random suffix should be added to the end of the resource names +#[fixture] +pub async fn ipv6_service( + #[default("default")] namespace: &str, + #[default("NodePort")] service_type: &str, + #[default("ghcr.io/metalbear-co/mirrord-pytest:latest")] image: &str, + #[default("http-echo")] service_name: &str, + #[default(true)] randomize_name: bool, + #[future] kube_client: Client, +) -> KubeService { + internal_service( + namespace, + service_type, + image, + service_name, + randomize_name, + kube_client.await, + serde_json::json!([ + { + "name": "HOST", + "value": "::" + } + ]), + true, + ) + .await +} + +/// Send an HTTP request using the referenced `request_sender`, with the provided `method`, +/// then verify a success status code, and a response body that is the used method. +/// +/// # Panics +/// - If the request cannot be sent. +/// - If the response's code is not OK +/// - If the response's body is not the method's name. +pub async fn send_request_with_method( + method: &str, + request_sender: &mut SendRequest>, +) { + let req = Request::builder() + .method(method) + .header("Host", "::") + .body(Empty::::new()) + .unwrap(); + + println!("Request: {:?}", req); + + let res = request_sender.send_request(req).await.unwrap(); + println!("Response: {:?}", res); + assert_eq!(res.status(), hyper::StatusCode::OK); + let bytes = res.collect().await.unwrap().to_bytes(); + let response_string = String::from_utf8(bytes.to_vec()).unwrap(); + assert_eq!(response_string, method); +} + +/// Create a portforward to the pod of the test service, and send HTTP requests over it. +/// Send four HTTP request (GET, POST, PUT, DELETE), using the referenced `request_sender`, with the +/// provided `method`, verify OK status, and a response body that is the used method. +/// +/// # Panics +/// - If a request cannot be sent. +/// - If a response's code is not OK +/// - If a response's body is not the method's name. +pub async fn portforward_http_requests(api: &Api, service: KubeService) { + let mut portforwarder = api + .portforward(&service.pod_name, &[80]) + .await + .expect("Failed to start portforward to test pod"); + + let stream = portforwarder.take_stream(80).unwrap(); + let stream = hyper_util::rt::TokioIo::new(stream); + + let (mut request_sender, connection) = conn::http1::handshake(stream).await.unwrap(); + tokio::spawn(async move { + if let Err(err) = connection.await { + eprintln!("Error in connection from test function to deployed test app {err:#?}"); + } + }); + for method in ["GET", "POST", "PUT", "DELETE"] { + send_request_with_method(method, &mut request_sender).await; + } +}