Skip to content

Commit

Permalink
Rework
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitryDodzin committed Jul 15, 2024
1 parent 4fe9012 commit 9be23ce
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 33 deletions.
1 change: 1 addition & 0 deletions changelog.d/2572.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update loopback detection to include pod ip's
3 changes: 3 additions & 0 deletions mirrord/agent/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ pub struct Args {
/// If not given, the agent will not use TLS.
#[arg(long, env = AGENT_OPERATOR_CERT_ENV)]
pub operator_tls_cert_pem: Option<String>,

#[arg(long)]
pub pod_ips: Option<String>,
}

#[derive(Clone, Debug, Default, Subcommand)]
Expand Down
16 changes: 9 additions & 7 deletions mirrord/agent/src/entrypoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,15 @@ async fn start_agent(args: Args) -> Result<()> {
let cancellation_token = cancellation_token.clone();
let watched_task = WatchedTask::new(
TcpConnectionStealer::TASK_NAME,
TcpConnectionStealer::new(stealer_command_rx).and_then(|stealer| async move {
let res = stealer.start(cancellation_token).await;
if let Err(err) = res.as_ref() {
error!("Stealer failed: {err}");
}
res
}),
TcpConnectionStealer::new(stealer_command_rx, args.pod_ips).and_then(
|stealer| async move {
let res = stealer.start(cancellation_token).await;
if let Err(err) = res.as_ref() {
error!("Stealer failed: {err}");
}
res
},
),
);
let status = watched_task.status();
let task = run_thread_in_namespace(
Expand Down
7 changes: 5 additions & 2 deletions mirrord/agent/src/steal/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,16 @@ impl TcpConnectionStealer {
/// Initializes a new [`TcpConnectionStealer`], but doesn't start the actual work.
/// You need to call [`TcpConnectionStealer::start`] to do so.
#[tracing::instrument(level = "trace")]
pub(crate) async fn new(command_rx: Receiver<StealerCommand>) -> Result<Self, AgentError> {
pub(crate) async fn new(
command_rx: Receiver<StealerCommand>,
pod_ips: Option<String>,
) -> Result<Self, AgentError> {
let port_subscriptions = {
let flush_connections = std::env::var("MIRRORD_AGENT_STEALER_FLUSH_CONNECTIONS")
.ok()
.and_then(|var| var.parse::<bool>().ok())
.unwrap_or_default();
let redirector = IpTablesRedirector::new(flush_connections).await?;
let redirector = IpTablesRedirector::new(flush_connections, pod_ips).await?;

PortSubscriptions::new(redirector, 4)
};
Expand Down
14 changes: 9 additions & 5 deletions mirrord/agent/src/steal/ip_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,17 @@ impl<IPT> SafeIpTables<IPT>
where
IPT: IPTables + Send + Sync,
{
pub(super) async fn create(ipt: IPT, flush_connections: bool) -> Result<Self> {
pub(super) async fn create(
ipt: IPT,
flush_connections: bool,
pod_ips: Option<&str>,
) -> Result<Self> {
let ipt = Arc::new(ipt);

let mut redirect = if let Some(vendor) = MeshVendor::detect(ipt.as_ref())? {
Redirects::Mesh(MeshRedirect::create(ipt.clone(), vendor)?)
Redirects::Mesh(MeshRedirect::create(ipt.clone(), vendor, pod_ips)?)
} else {
match StandardRedirect::create(ipt.clone()) {
match StandardRedirect::create(ipt.clone(), pod_ips) {
Err(err) => {
warn!("Unable to create StandardRedirect chain: {err}");

Expand Down Expand Up @@ -416,7 +420,7 @@ mod tests {
.times(1)
.returning(|_| Ok(()));

let ipt = SafeIpTables::create(mock, false)
let ipt = SafeIpTables::create(mock, false, None)
.await
.expect("Create Failed");

Expand Down Expand Up @@ -549,7 +553,7 @@ mod tests {
.times(1)
.returning(|_| Ok(()));

let ipt = SafeIpTables::create(mock, false)
let ipt = SafeIpTables::create(mock, false, None)
.await
.expect("Create Failed");

Expand Down
4 changes: 2 additions & 2 deletions mirrord/agent/src/steal/ip_tables/mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ impl<IPT> MeshRedirect<IPT>
where
IPT: IPTables,
{
pub fn create(ipt: Arc<IPT>, vendor: MeshVendor) -> Result<Self> {
pub fn create(ipt: Arc<IPT>, vendor: MeshVendor, pod_ips: Option<&str>) -> Result<Self> {
let prerouteing = PreroutingRedirect::create(ipt.clone())?;

for port in Self::get_skip_ports(&ipt, &vendor)? {
prerouteing.add_rule(&format!("-m multiport -p tcp ! --dports {port} -j RETURN"))?;
}

let output = OutputRedirect::create(ipt, IPTABLE_MESH.to_string())?;
let output = OutputRedirect::create(ipt, IPTABLE_MESH.to_string(), pod_ips)?;

Ok(MeshRedirect {
prerouteing,
Expand Down
12 changes: 9 additions & 3 deletions mirrord/agent/src/steal/ip_tables/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ where
{
const ENTRYPOINT: &'static str = "OUTPUT";

pub fn create(ipt: Arc<IPT>, chain_name: String) -> Result<Self> {
pub fn create(ipt: Arc<IPT>, chain_name: String, pod_ips: Option<&str>) -> Result<Self> {
let managed = IPTableChain::create(ipt, chain_name)?;

let exclude_source_ips = pod_ips
.map(|pod_ips| format!("! -s {pod_ips}"))
.unwrap_or_default();

let gid = getgid();
managed
.add_rule(&format!("-m owner --gid-owner {gid} -p tcp -j RETURN"))
.add_rule(&format!(
"-m owner --gid-owner {gid} -p tcp {exclude_source_ips} -j RETURN"
))
.inspect_err(|_| {
warn!("Unable to create iptable rule with \"--gid-owner {gid}\" filter")
})?;
Expand All @@ -34,7 +40,7 @@ where
}

pub fn load(ipt: Arc<IPT>, chain_name: String) -> Result<Self> {
let managed = IPTableChain::create(ipt, chain_name)?;
let managed = IPTableChain::load(ipt, chain_name)?;

Ok(OutputRedirect { managed })
}
Expand Down
4 changes: 2 additions & 2 deletions mirrord/agent/src/steal/ip_tables/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ impl<IPT> StandardRedirect<IPT>
where
IPT: IPTables,
{
pub fn create(ipt: Arc<IPT>) -> Result<Self> {
pub fn create(ipt: Arc<IPT>, pod_ips: Option<&str>) -> Result<Self> {
let prerouteing = PreroutingRedirect::create(ipt.clone())?;
let output = OutputRedirect::create(ipt, IPTABLE_STANDARD.to_string())?;
let output = OutputRedirect::create(ipt, IPTABLE_STANDARD.to_string(), pod_ips)?;

Ok(StandardRedirect {
prerouteing,
Expand Down
15 changes: 13 additions & 2 deletions mirrord/agent/src/steal/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub(crate) struct IpTablesRedirector {
redirect_to: Port,
/// Listener to which redirect all connections.
listener: TcpListener,

pod_ips: Option<String>,
}

impl IpTablesRedirector {
Expand All @@ -73,7 +75,10 @@ impl IpTablesRedirector {
///
/// * `flush_connections` - whether exisitng connections should be flushed when adding new
/// redirects
pub(crate) async fn new(flush_connections: bool) -> Result<Self, AgentError> {
pub(crate) async fn new(
flush_connections: bool,
pod_ips: Option<String>,
) -> Result<Self, AgentError> {
let listener = TcpListener::bind((Ipv4Addr::UNSPECIFIED, 0)).await?;
let redirect_to = listener.local_addr()?.port();

Expand All @@ -82,6 +87,7 @@ impl IpTablesRedirector {
flush_connections,
redirect_to,
listener,
pod_ips,
})
}
}
Expand All @@ -95,7 +101,12 @@ impl PortRedirector for IpTablesRedirector {
Some(iptables) => iptables,
None => {
let iptables = new_iptables();
let safe = SafeIpTables::create(iptables.into(), self.flush_connections).await?;
let safe = SafeIpTables::create(
iptables.into(),
self.flush_connections,
self.pod_ips.as_deref(),
)
.await?;
self.iptables.insert(safe)
}
};
Expand Down
12 changes: 4 additions & 8 deletions mirrord/kube/src/api/container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ pub struct ContainerParams {
/// Value for [`AGENT_OPERATOR_CERT_ENV`](mirrord_protocol::AGENT_OPERATOR_CERT_ENV) set in
/// the agent container.
pub tls_cert: Option<String>,
pub pod_ips: Option<String>,
}

impl ContainerParams {
pub fn new() -> ContainerParams {
pub fn new(tls_cert: Option<String>, pod_ips: Option<String>) -> ContainerParams {
let port: u16 = rand::thread_rng().gen_range(30000..=65535);
let gid: u16 = rand::thread_rng().gen_range(3000..u16::MAX);

Expand All @@ -57,17 +58,12 @@ impl ContainerParams {
name,
gid,
port,
tls_cert: None,
tls_cert,
pod_ips,
}
}
}

impl Default for ContainerParams {
fn default() -> Self {
Self::new()
}
}

pub trait ContainerVariant {
type Update;

Expand Down
5 changes: 5 additions & 0 deletions mirrord/kube/src/api/container/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ pub(super) fn base_command_line(agent: &AgentConfig, params: &ContainerParams) -
command_line.push("--test-error".to_owned());
}

if let Some(pod_ips) = params.pod_ips.clone() {
command_line.push("--pod-ips".to_owned());
command_line.push(pod_ips);
}

command_line
}

Expand Down
7 changes: 5 additions & 2 deletions mirrord/kube/src/api/kubernetes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,11 @@ impl KubernetesAPI {
.into(),
};

let mut params = ContainerParams::new();
params.tls_cert = tls_cert;
let pod_ips = runtime_data
.as_ref()
.map(|runtime_data| runtime_data.pod_ips.join(","));

let params = ContainerParams::new(tls_cert, pod_ips);

Ok((params, runtime_data))
}
Expand Down
12 changes: 12 additions & 0 deletions mirrord/kube/src/api/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl Display for ContainerRuntime {
#[derive(Debug)]
pub struct RuntimeData {
pub pod_name: String,
pub pod_ips: Vec<String>,
pub pod_namespace: Option<String>,
pub node_name: String,
pub container_id: String,
Expand Down Expand Up @@ -109,6 +110,16 @@ impl RuntimeData {
.ok_or_else(|| KubeApiError::missing_field(pod, ".spec.nodeName"))?
.to_owned();

let pod_ips = pod
.status
.as_ref()
.and_then(|spec| spec.pod_ips.as_ref())
.ok_or_else(|| KubeApiError::missing_field(pod, ".status.podIPs"))?
.iter()
.filter_map(|pod_ip| pod_ip.ip.as_ref())
.cloned()
.collect();

let container_statuses = pod
.status
.as_ref()
Expand Down Expand Up @@ -155,6 +166,7 @@ impl RuntimeData {
};

Ok(RuntimeData {
pod_ips,
pod_name,
pod_namespace: pod.metadata.namespace.clone(),
node_name,
Expand Down
22 changes: 22 additions & 0 deletions tests/python-e2e/app_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import sys
import threading
import requests

log = logging.getLogger("werkzeug")
log.disabled = True
Expand All @@ -24,6 +25,27 @@ def kill_thread():
threading.Thread(target=kill_thread).start()


@app.route("/foobar", methods=["GET"])
def get_foobar():
print("GET: Request completed")
return "GET"

@app.route("/foobar", methods=["POST"])
def post_foobar():
print("POST: Request completed")

x = requests.get('http://10.1.62.94/foobar')
return x.text


@app.route("/foobar", methods=["PUT"])
def put_foobar():
print("PUT: Request completed")

x = requests.get('http://10.99.79.117/foobar')
return x.text


@app.route("/", methods=["GET"])
def get():
print("GET: Request completed")
Expand Down

0 comments on commit 9be23ce

Please sign in to comment.