Skip to content

Update Skypilot orchestrator settings and features #3612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/book/component-guide/orchestrators/skypilot-vm.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,17 @@ To use the SkyPilot VM Orchestrator, you need:

{% tabs %}
{% tab title="AWS" %}
We need first to install the SkyPilot integration for AWS and the AWS connectors extra, using the following two commands:
We need first to install the SkyPilot integration for AWS and the AWS connectors extra, using the following command:

```shell
pip install "zenml[connectors-aws]"
zenml integration install aws skypilot_aws
# Installs dependencies for Skypilot AWS, AWS Container Registry, and S3 Artifact Store
pip install "zenml[connectors-aws]" "skypilot[lambda]~=0.9.2" "aws-profile-manager" boto3 argparse fsspec
```

{% hint style="warning" %}
Please note that currently the ZenML AWS and Skypilot integration are pip-incompatible therefore executing `zenml integration install aws skypilot_aws` will not work. Please install the requirements of AWS components like the container registry and artifact store directly with pip to avoid any installation problems.
{% endhint %}

To provision VMs on AWS, your VM Orchestrator stack component needs to be configured to authenticate with [AWS Service Connector](https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/aws-service-connector). To configure the AWS Service Connector, you need to register a new service connector configured with AWS credentials that have at least the minimum permissions required by SkyPilot as documented [here](https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/aws.html).

First, check that the AWS service connector type is available using the following command:
Expand Down
2 changes: 1 addition & 1 deletion scripts/install-zenml-dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ install_integrations() {
# figure out the python version
python_version=$(python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))")

ignore_integrations="feast label_studio bentoml seldon pycaret skypilot_aws skypilot_gcp skypilot_azure pigeon prodigy argilla"
ignore_integrations="feast label_studio bentoml seldon pycaret skypilot_aws skypilot_gcp skypilot_azure skypilot_kubernetes skypilot_lambda pigeon prodigy argilla"

# Ignore tensorflow and deepchecks only on Python 3.12
if [ "$python_version" = "3.12" ]; then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""Skypilot orchestrator base config and settings."""

from typing import Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import Field

Expand Down Expand Up @@ -67,6 +67,14 @@ class SkypilotBaseOrchestratorSettings(BaseSettings):
disk_size: the size of the OS disk in GiB.
disk_tier: the disk performance tier to use. If None, defaults to
``'medium'``.
ports: Ports to expose. Could be an integer, a range, or a list of
integers and ranges. All ports will be exposed to the public internet.
labels: Labels to apply to instances as key-value pairs. These are
mapped to cloud-specific implementations (instance tags in AWS,
instance labels in GCP, etc.)
any_of: List of candidate resources to try in order of preference based on
cost (determined by the optimizer).
ordered: List of candidate resources to try in the specified order.

cluster_name: name of the cluster to create/reuse. If None,
auto-generate a name.
Expand All @@ -88,6 +96,20 @@ class SkypilotBaseOrchestratorSettings(BaseSettings):
stream_logs: if True, show the logs in the terminal.
docker_run_args: Optional arguments to pass to the `docker run` command
running inside the VM.
workdir: Working directory to sync to the VM. Synced to ~/sky_workdir.
task_name: Task name used for display purposes.
num_nodes: Number of nodes to launch (including the head node).
file_mounts: File and storage mounts configuration for remote cluster.
envs: Environment variables for the task. Accessible in setup/run.
task_settings: Dictionary of arbitrary settings to pass to sky.Task().
This allows passing future parameters added by SkyPilot without
requiring updates to ZenML.
resources_settings: Dictionary of arbitrary settings to pass to
sky.Resources(). This allows passing future parameters added
by SkyPilot without requiring updates to ZenML.
launch_settings: Dictionary of arbitrary settings to pass to
sky.launch(). This allows passing future parameters added
by SkyPilot without requiring updates to ZenML.
"""

# Resources
Expand All @@ -98,29 +120,50 @@ class SkypilotBaseOrchestratorSettings(BaseSettings):
memory: Union[None, int, float, str] = Field(
default=None, union_mode="left_to_right"
)
accelerators: Union[None, str, Dict[str, int]] = Field(
accelerators: Union[None, str, Dict[str, int], List[str]] = Field(
default=None, union_mode="left_to_right"
)
accelerator_args: Optional[Dict[str, str]] = None
accelerator_args: Optional[Dict[str, Any]] = None
use_spot: Optional[bool] = None
job_recovery: Optional[str] = None
job_recovery: Union[None, str, Dict[str, Any]] = Field(
default=None, union_mode="left_to_right"
)
region: Optional[str] = None
zone: Optional[str] = None
image_id: Union[Dict[str, str], str, None] = Field(
default=None, union_mode="left_to_right"
)
disk_size: Optional[int] = None
disk_tier: Optional[Literal["high", "medium", "low"]] = None
disk_tier: Optional[Literal["high", "medium", "low", "ultra", "best"]] = (
None
)

# Run settings
cluster_name: Optional[str] = None
retry_until_up: bool = False
idle_minutes_to_autostop: Optional[int] = 30
down: bool = True
stream_logs: bool = True

docker_run_args: List[str] = []

# Additional SkyPilot features
ports: Union[None, int, str, List[Union[int, str]]] = Field(
default=None, union_mode="left_to_right"
)
labels: Optional[Dict[str, str]] = None
any_of: Optional[List[Dict[str, Any]]] = None
ordered: Optional[List[Dict[str, Any]]] = None
workdir: Optional[str] = None
task_name: Optional[str] = None
num_nodes: Optional[int] = None
file_mounts: Optional[Dict[str, Any]] = None
envs: Optional[Dict[str, str]] = None

# Future-proofing settings dictionaries
task_settings: Dict[str, Any] = {}
resources_settings: Dict[str, Any] = {}
launch_settings: Dict[str, Any] = {}


class SkypilotBaseOrchestratorConfig(
BaseOrchestratorConfig, SkypilotBaseOrchestratorSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Implementation of the Skypilot base VM orchestrator."""

import os
import re
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast
from uuid import uuid4
Expand All @@ -31,6 +30,14 @@
from zenml.integrations.skypilot.orchestrators.skypilot_orchestrator_entrypoint_configuration import (
SkypilotOrchestratorEntrypointConfiguration,
)
from zenml.integrations.skypilot.utils import (
create_docker_run_command,
prepare_docker_setup,
prepare_launch_kwargs,
prepare_resources_kwargs,
prepare_task_kwargs,
sanitize_cluster_name,
)
from zenml.logger import get_logger
from zenml.orchestrators import (
ContainerizedOrchestrator,
Expand Down Expand Up @@ -252,32 +259,24 @@ def prepare_or_run_pipeline(
entrypoint_str = " ".join(command)
arguments_str = " ".join(args)

task_envs = environment
docker_environment_str = " ".join(
f"-e {k}={v}" for k, v in environment.items()
)
custom_run_args = " ".join(settings.docker_run_args)
if custom_run_args:
custom_run_args += " "

instance_type = settings.instance_type or self.DEFAULT_INSTANCE_TYPE
task_envs = environment.copy()

# Set up credentials
self.setup_credentials()

# Guaranteed by stack validation
assert stack is not None and stack.container_registry is not None

if docker_creds := stack.container_registry.credentials:
docker_username, docker_password = docker_creds
setup = (
f"sudo docker login --username $DOCKER_USERNAME --password "
f"$DOCKER_PASSWORD {stack.container_registry.config.uri}"
)
task_envs["DOCKER_USERNAME"] = docker_username
task_envs["DOCKER_PASSWORD"] = docker_password
else:
setup = None
# Prepare Docker setup
setup, docker_creds_envs = prepare_docker_setup(
container_registry_uri=stack.container_registry.config.uri,
credentials=stack.container_registry.credentials,
use_sudo=True, # Base orchestrator uses sudo
)

# Update task_envs with Docker credentials
if docker_creds_envs:
task_envs.update(docker_creds_envs)

# Run the entire pipeline

Expand All @@ -291,39 +290,43 @@ def prepare_or_run_pipeline(
down = False
idle_minutes_to_autostop = None
else:
run_command = f"sudo docker run --rm {custom_run_args}{docker_environment_str} {image} {entrypoint_str} {arguments_str}"
run_command = create_docker_run_command(
image=image,
entrypoint_str=entrypoint_str,
arguments_str=arguments_str,
environment=task_envs,
docker_run_args=settings.docker_run_args,
use_sudo=True, # Base orchestrator uses sudo
)
down = settings.down
idle_minutes_to_autostop = settings.idle_minutes_to_autostop
task = sky.Task(
run=run_command,

# Create the Task with all parameters and task settings
task_kwargs = prepare_task_kwargs(
settings=settings,
run_command=run_command,
setup=setup,
envs=task_envs,
task_envs=task_envs,
task_name=f"{orchestrator_run_name}",
)

task = sky.Task(**task_kwargs)
logger.debug(f"Running run: {run_command}")

task = task.set_resources(
sky.Resources(
cloud=self.cloud,
instance_type=instance_type,
cpus=settings.cpus,
memory=settings.memory,
accelerators=settings.accelerators,
accelerator_args=settings.accelerator_args,
use_spot=settings.use_spot,
job_recovery=settings.job_recovery,
region=settings.region,
zone=settings.zone,
image_id=image
if isinstance(self.cloud, sky.clouds.Kubernetes)
else settings.image_id,
disk_size=settings.disk_size,
disk_tier=settings.disk_tier,
)
# Set resources with all parameters and resource settings
resources_kwargs = prepare_resources_kwargs(
cloud=self.cloud,
settings=settings,
default_instance_type=self.DEFAULT_INSTANCE_TYPE,
kubernetes_image=image
if isinstance(self.cloud, sky.clouds.Kubernetes)
else None,
)
# Do not detach run if logs are being streamed
# Otherwise, the logs will not be streamed after the task is submitted
# Could also be a parameter in the settings to control this behavior
detach_run = not settings.stream_logs

task = task.set_resources(sky.Resources(**resources_kwargs))

# Use num_nodes from settings or default to 1
num_nodes = settings.num_nodes or 1

launch_new_cluster = True
if settings.cluster_name:
Expand All @@ -342,26 +345,43 @@ def prepare_or_run_pipeline(
)
cluster_name = settings.cluster_name
else:
cluster_name = self.sanitize_cluster_name(
cluster_name = sanitize_cluster_name(
f"{orchestrator_run_name}"
)
logger.info(
f"No cluster name provided. Launching a new cluster with name {cluster_name}..."
)

if launch_new_cluster:
# Prepare launch parameters with additional launch settings
launch_kwargs = prepare_launch_kwargs(
settings=settings,
stream_logs=settings.stream_logs,
down=down,
idle_minutes_to_autostop=idle_minutes_to_autostop,
num_nodes=num_nodes,
)

sky.launch(
task,
cluster_name,
retry_until_up=settings.retry_until_up,
idle_minutes_to_autostop=idle_minutes_to_autostop,
down=down,
stream_logs=settings.stream_logs,
backend=None,
detach_setup=True,
detach_run=detach_run,
**launch_kwargs,
)
else:
# Prepare exec parameters with additional launch settings
exec_kwargs = {
"down": down,
"stream_logs": settings.stream_logs,
"backend": None,
"detach_run": not settings.stream_logs, # detach_run is opposite of stream_logs
**settings.launch_settings, # Can reuse same settings for exec
}

# Remove None values to avoid overriding SkyPilot defaults
exec_kwargs = {
k: v for k, v in exec_kwargs.items() if v is not None
}

# Make sure the cluster is up -
# If the cluster is already up, this will not do anything
sky.start(
Expand All @@ -373,10 +393,7 @@ def prepare_or_run_pipeline(
sky.exec(
task,
settings.cluster_name,
down=down,
stream_logs=settings.stream_logs,
backend=None,
detach_run=detach_run,
**exec_kwargs,
)

except Exception as e:
Expand All @@ -386,19 +403,3 @@ def prepare_or_run_pipeline(
finally:
# Unset the service connector AWS profile ENV variable
self.prepare_environment_variable(set=False)

def sanitize_cluster_name(self, name: str) -> str:
"""Sanitize the value to be used in a cluster name.

Args:
name: Arbitrary input cluster name.

Returns:
Sanitized cluster name.
"""
name = re.sub(
r"[^a-z0-9-]", "-", name.lower()
) # replaces any character that is not a lowercase letter, digit, or hyphen with a hyphen
name = re.sub(r"^[-]+", "", name) # trim leading hyphens
name = re.sub(r"[-]+$", "", name) # trim trailing hyphens
return name
Loading
Loading