Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metadata for spot termination for k8s #2207

Merged
merged 8 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
("argo-workflows", ".argo.argo_workflows_cli.cli"),
("card", ".cards.card_cli.cli"),
("tag", ".tag_cli.cli"),
("spot-metadata", ".kubernetes.spot_metadata_cli.cli"),
("logs", ".logs_cli.cli"),
]

Expand Down Expand Up @@ -104,6 +105,10 @@
"save_logs_periodically",
"..mflog.save_logs_periodically.SaveLogsPeriodicallySidecar",
),
(
"spot_termination_monitor",
".kubernetes.spot_monitor_sidecar.SpotTerminationMonitorSidecar",
),
("heartbeat", "metaflow.metadata_provider.heartbeat.MetadataHeartBeat"),
]

Expand Down
1 change: 1 addition & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,7 @@ def _container_templates(self):
},
**{
# Some optional values for bookkeeping
"METAFLOW_FLOW_FILENAME": os.path.basename(sys.argv[0]),
"METAFLOW_FLOW_NAME": self.flow.name,
"METAFLOW_STEP_NAME": node.name,
"METAFLOW_RUN_ID": run_id,
Expand Down
2 changes: 1 addition & 1 deletion metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def echo(msg, stream="stderr", job_id=None, **kwargs):
executable = ctx.obj.environment.executable(step_name, executable)

# Set environment
env = {}
env = {"METAFLOW_FLOW_FILENAME": os.path.basename(sys.argv[0])}
env_deco = [deco for deco in node.decorators if deco.name == "environment"]
if env_deco:
env = env_deco[0].attributes["vars"]
Expand Down
8 changes: 8 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,13 @@ def task_pre_step(
self._save_logs_sidecar = Sidecar("save_logs_periodically")
self._save_logs_sidecar.start()

# Start spot termination monitor sidecar.
current._update_env(
{"spot_termination_notice": "/tmp/spot_termination_notice"}
)
self._spot_monitor_sidecar = Sidecar("spot_termination_monitor")
self._spot_monitor_sidecar.start()

num_parallel = None
if hasattr(flow, "_parallel_ubf_iter"):
num_parallel = flow._parallel_ubf_iter.num_parallel
Expand Down Expand Up @@ -605,6 +612,7 @@ def task_finished(

try:
self._save_logs_sidecar.terminate()
self._spot_monitor_sidecar.terminate()
except:
# Best effort kill
pass
Expand Down
69 changes: 69 additions & 0 deletions metaflow/plugins/kubernetes/spot_metadata_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from metaflow._vendor import click
from datetime import datetime, timezone
from metaflow.tagging_util import validate_tags
from metaflow.metadata_provider import MetaDatum


@click.group()
def cli():
pass


@cli.group(help="Commands related to spot metadata.")
def spot_metadata():
pass


@spot_metadata.command(help="Record spot termination metadata for a task.")
@click.option(
"--run-id",
required=True,
help="Run ID for which metadata is to be recorded.",
)
@click.option(
"--step-name",
required=True,
help="Step Name for which metadata is to be recorded.",
)
@click.option(
"--task-id",
required=True,
help="Task ID for which metadata is to be recorded.",
)
@click.option(
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
"--termination-notice-time",
required=True,
help="Spot termination notice time.",
)
@click.option(
"--tag",
"tags",
multiple=True,
required=False,
default=None,
help="List of tags.",
)
@click.pass_obj
def record(obj, run_id, step_name, task_id, termination_notice_time, tags=None):
validate_tags(tags)

tag_list = list(tags) if tags else []

entries = [
MetaDatum(
field="spot-termination-received-at",
value=datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
type="spot-termination-received-at",
tags=tag_list,
),
MetaDatum(
field="spot-termination-time",
value=termination_notice_time,
type="spot-termination-time",
tags=tag_list,
),
]

obj.metadata.register_metadata(
run_id=run_id, step_name=step_name, task_id=task_id, metadata=entries
)
109 changes: 109 additions & 0 deletions metaflow/plugins/kubernetes/spot_monitor_sidecar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
import sys
import time
import signal
import requests
import subprocess
from multiprocessing import Process
from datetime import datetime, timezone
from metaflow.sidecar import MessageTypes


class SpotTerminationMonitorSidecar(object):
EC2_TYPE_URL = "http://169.254.169.254/latest/meta-data/instance-life-cycle"
METADATA_URL = "http://169.254.169.254/latest/meta-data/spot/termination-time"
TOKEN_URL = "http://169.254.169.254/latest/api/token"
POLL_INTERVAL = 5 # seconds

def __init__(self):
self.is_alive = True
self._process = None
self._token = None
self._token_expiry = 0

if self._is_aws_spot_instance():
self._process = Process(target=self._monitor_loop)
self._process.start()

def process_message(self, msg):
if msg.msg_type == MessageTypes.SHUTDOWN:
self.is_alive = False
if self._process:
self._process.terminate()

@classmethod
def get_worker(cls):
return cls

def _get_imds_token(self):
current_time = time.time()
if current_time >= self._token_expiry - 60: # Refresh 60s before expiry
try:
response = requests.put(
url=self.TOKEN_URL,
headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"},
timeout=1,
)
if response.status_code == 200:
self._token = response.text
self._token_expiry = current_time + 240 # Slightly less than TTL
except requests.exceptions.RequestException:
pass
return self._token

def _make_ec2_request(self, url, timeout):
token = self._get_imds_token()
headers = {"X-aws-ec2-metadata-token": token} if token else {}
response = requests.get(url=url, headers=headers, timeout=timeout)
return response

def _is_aws_spot_instance(self):
try:
response = self._make_ec2_request(url=self.EC2_TYPE_URL, timeout=1)
return response.status_code == 200 and response.text == "spot"
except (requests.exceptions.RequestException, requests.exceptions.Timeout):
return False

def _monitor_loop(self):
while self.is_alive:
try:
response = self._make_ec2_request(url=self.METADATA_URL, timeout=1)
if response.status_code == 200:
termination_time = response.text
self._emit_termination_metadata(termination_time)
os.kill(os.getppid(), signal.SIGTERM)
break
except (requests.exceptions.RequestException, requests.exceptions.Timeout):
pass
time.sleep(self.POLL_INTERVAL)

def _emit_termination_metadata(self, termination_time):
flow_filename = os.getenv("METAFLOW_FLOW_FILENAME")
pathspec = os.getenv("MF_PATHSPEC")
_, run_id, step_name, task_id = pathspec.split("/")
retry_count = os.getenv("MF_ATTEMPT")

with open("/tmp/spot_termination_notice", "w") as fp:
fp.write(termination_time)

command = [
sys.executable,
f"/metaflow/{flow_filename}",
"spot-metadata",
"record",
"--run-id",
run_id,
"--step-name",
step_name,
"--task-id",
task_id,
"--termination-notice-time",
termination_time,
"--tag",
"attempt_id:{}".format(retry_count),
]

result = subprocess.run(command, capture_output=True, text=True)

if result.returncode != 0:
print(f"Failed to record spot termination metadata: {result.stderr}")
Loading