From c4e709a8fade7f7c762e0040f89a3540110cb4a5 Mon Sep 17 00:00:00 2001 From: AbdullahSharif Date: Thu, 9 Jan 2025 18:25:20 +0000 Subject: [PATCH 1/4] containerized the magemaker with all dependencies --- .dockerignore | 10 +++ ...k--opt-125m-202410251736-202501091758.yaml | 15 ++++ Dockerfile-server | 51 ++++++++---- magemaker/cli.py | 32 ++++++++ magemaker/docker/entrypoint.sh | 14 ++++ magemaker/huggingface/test_hf_hub_api.py | 42 ++++++++++ magemaker/sagemaker/test_create_model.py | 64 +++++++++++++++ magemaker/sagemaker/test_delete_model.py | 23 ++++++ magemaker/sagemaker/test_fine_tune_model.py | 80 +++++++++++++++++++ magemaker/sagemaker/test_query_endpoint.py | 34 ++++++++ pyproject.toml | 5 +- tests/test_cli.py | 0 12 files changed, 354 insertions(+), 16 deletions(-) create mode 100644 .dockerignore create mode 100644 .magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml create mode 100644 magemaker/cli.py create mode 100755 magemaker/docker/entrypoint.sh create mode 100644 magemaker/huggingface/test_hf_hub_api.py create mode 100644 magemaker/sagemaker/test_create_model.py create mode 100644 magemaker/sagemaker/test_delete_model.py create mode 100644 magemaker/sagemaker/test_fine_tune_model.py create mode 100644 magemaker/sagemaker/test_query_endpoint.py create mode 100644 tests/test_cli.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..0008d33 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +echo """ +venv/ +*.pyc +__pycache__ +.git +.env +*.egg-info +dist/ +build/ +""" > .dockerignore \ No newline at end of file diff --git a/.magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml b/.magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml new file mode 100644 index 0000000..77bc415 --- /dev/null +++ b/.magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml @@ -0,0 +1,15 @@ +deployment: !Deployment + destination: aws + endpoint_name: facebook--opt-125m-202410251736-202501091758 + instance_count: 1 + instance_type: ml.t2.medium + num_gpus: null + quantization: null +models: +- !Model + id: facebook/opt-125m + location: null + predict: null + source: huggingface + task: text-generation + version: null diff --git a/Dockerfile-server b/Dockerfile-server index 79bb9a6..5fde15f 100644 --- a/Dockerfile-server +++ b/Dockerfile-server @@ -1,22 +1,43 @@ -# Use an official Python runtime as a parent image -FROM python:3.12-slim +FROM python:3.11-slim -# Set environment variables -ENV PYTHONDONTWRITEBYTECODE 1 -ENV PYTHONUNBUFFERED 1 +# Install system dependencies +RUN apt-get update && apt-get install -y \ + git \ + curl \ + build-essential \ + unzip \ + nano \ + vim \ + && rm -rf /var/lib/apt/lists/* -# Set the working directory in the container +# Set working directory WORKDIR /app -# Copy the project files to the working directory -COPY . . +# Copy your magemaker package +COPY . /app/ -# Install PDM and use it to install dependencies -RUN pip install --no-cache-dir pdm \ - && pdm install --no-interactive +# Install package and dependencies +RUN pip install --no-cache-dir -e . -# Expose port 8000 to the outside world -EXPOSE 8000 +# Install AWS CLI +RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" \ + && unzip awscliv2.zip \ + && ./aws/install \ + && rm awscliv2.zip \ + && rm -rf aws -# Run uvicorn when the container launches -CMD ["pdm", "run", "uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"] +# Install Google Cloud SDK +RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-458.0.0-linux-x86_64.tar.gz -o google-cloud-sdk.tar.gz \ + && tar -xf google-cloud-sdk.tar.gz \ + && ./google-cloud-sdk/install.sh --quiet \ + && rm google-cloud-sdk.tar.gz + +# Add Google Cloud SDK to PATH +ENV PATH $PATH:/app/google-cloud-sdk/bin + +# Copy and setup entrypoint +COPY magemaker/docker/entrypoint.sh /usr/local/bin/ +RUN chmod +x /usr/local/bin/entrypoint.sh + +ENTRYPOINT ["entrypoint.sh"] +CMD ["bash"] \ No newline at end of file diff --git a/magemaker/cli.py b/magemaker/cli.py new file mode 100644 index 0000000..578ee8e --- /dev/null +++ b/magemaker/cli.py @@ -0,0 +1,32 @@ +import click +import yaml +from rich.console import Console +from magemaker.schemas.deployment import Deployment +from magemaker.schemas.model import Model +from magemaker.sagemaker.create_model import deploy_huggingface_model_to_sagemaker +from magemaker.gcp.create_model import deploy_huggingface_model_to_vertexai + +console = Console() + +@click.group() +def main(): + """Magemaker CLI for model deployment""" + pass + +@main.command() +@click.option('--deploy', type=click.Path(exists=True), help='Path to deployment YAML file') +def deploy(deploy): + """Deploy a model using configuration from YAML file""" + with open(deploy, 'r') as f: + config = yaml.safe_load(f) + + deployment = Deployment(**config['deployment']) + model = Model(**config['model']) + + if deployment.destination == "sagemaker": + deploy_huggingface_model_to_sagemaker(deployment, model) + elif deployment.destination == "gcp": + deploy_huggingface_model_to_vertexai(deployment, model) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/magemaker/docker/entrypoint.sh b/magemaker/docker/entrypoint.sh new file mode 100755 index 0000000..b291f38 --- /dev/null +++ b/magemaker/docker/entrypoint.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +# Setup AWS credentials if mounted +if [ -f "/root/.aws/credentials" ]; then + export AWS_SHARED_CREDENTIALS_FILE="/root/.aws/credentials" +fi + +# Setup GCP credentials if mounted +if [ -f "/root/.config/gcloud/application_default_credentials.json" ]; then + export GOOGLE_APPLICATION_CREDENTIALS="/root/.config/gcloud/application_default_credentials.json" +fi + +exec "$@" \ No newline at end of file diff --git a/magemaker/huggingface/test_hf_hub_api.py b/magemaker/huggingface/test_hf_hub_api.py new file mode 100644 index 0000000..35bebe7 --- /dev/null +++ b/magemaker/huggingface/test_hf_hub_api.py @@ -0,0 +1,42 @@ +import pytest +from unittest.mock import patch, MagicMock +from magemaker.schemas.model import Model +from magemaker.huggingface.hf_hub_api import get_hf_task + +def test_get_hf_task_successful(): + with patch('magemaker.huggingface.hf_hub_api.hf_api') as mock_hf_api: + mock_model_info = MagicMock() + mock_model_info.pipeline_tag = "text-classification" + mock_model_info.transformers_info = None + mock_hf_api.model_info.return_value = mock_model_info + + model = Model(id="test-model", source="huggingface") + task = get_hf_task(model) + + assert task == "text-classification" + +def test_get_hf_task_with_transformers_info(): + with patch('magemaker.huggingface.hf_hub_api.hf_api') as mock_hf_api: + mock_model_info = MagicMock() + mock_model_info.pipeline_tag = "old-task" + mock_model_info.transformers_info = MagicMock(pipeline_tag="text-classification") + mock_hf_api.model_info.return_value = mock_model_info + + model = Model(id="test-model", source="huggingface") + task = get_hf_task(model) + + assert task == "text-classification" + +def test_get_hf_task_exception(): + with patch('magemaker.huggingface.hf_hub_api.hf_api') as mock_hf_api, \ + patch('magemaker.huggingface.hf_hub_api.console') as mock_console, \ + patch('magemaker.huggingface.hf_hub_api.print_error') as mock_print_error: + + mock_hf_api.model_info.side_effect = Exception("API error") + + model = Model(id="test-model", source="huggingface") + task = get_hf_task(model) + + assert task is None + mock_console.print_exception.assert_called_once() + mock_print_error.assert_called_once() \ No newline at end of file diff --git a/magemaker/sagemaker/test_create_model.py b/magemaker/sagemaker/test_create_model.py new file mode 100644 index 0000000..15e5d32 --- /dev/null +++ b/magemaker/sagemaker/test_create_model.py @@ -0,0 +1,64 @@ +import pytest +import tempfile +import os +from unittest.mock import MagicMock, patch +from magemaker.sagemaker.create_model import ( + deploy_huggingface_model_to_sagemaker, + deploy_custom_huggingface_model, + create_and_deploy_jumpstart_model +) +from magemaker.schemas.deployment import Deployment +from magemaker.schemas.model import Model, ModelSource + +@pytest.fixture +def sample_huggingface_model(): + return Model(id="google-bert/bert-base-uncased", source=ModelSource.HuggingFace) + +@pytest.fixture +def sample_deployment(): + return Deployment(destination="aws", instance_type="ml.m5.xlarge", instance_count=1) + +@patch('magemaker.sagemaker.create_model.S3Uploader.upload') +@patch('magemaker.sagemaker.create_model.HuggingFaceModel') +def test_custom_model_deployment(mock_hf_model, mock_s3_upload, sample_deployment, tmp_path): + # Create a mock model file + test_model_file = tmp_path / "model.pt" + test_model_file.write_text("dummy model content") + + # Mock S3 upload and model deployment + mock_s3_upload.return_value = "s3://test-bucket/models/test-custom-model" + mock_predictor = MagicMock() + mock_predictor.endpoint_name = "test-endpoint-001" + mock_hf_model_return = mock_hf_model.return_value + mock_hf_model_return.deploy.return_value = mock_predictor + + custom_model = Model( + id="test-custom-model", + source=ModelSource.Custom, + location=str(test_model_file) + ) + + predictor = deploy_custom_huggingface_model(sample_deployment, custom_model) + + assert predictor.endpoint_name == "test-endpoint-001" + mock_s3_upload.assert_called_once() + mock_hf_model_return.deploy.assert_called_once() + +@patch('magemaker.sagemaker.create_model.JumpStartModel') +def test_jumpstart_model_deployment(mock_jumpstart_model, sample_deployment): + # Use a valid JumpStart model ID + jumpstart_model = Model( + id="jumpstart-dft-bert-base-uncased-text-classification", + source=ModelSource.Sagemaker + ) + + # Mock the JumpStart model deployment + mock_predictor = MagicMock() + mock_predictor.endpoint_name = "test-jumpstart-endpoint" + mock_jumpstart_model_return = mock_jumpstart_model.return_value + mock_jumpstart_model_return.deploy.return_value = mock_predictor + + predictor = create_and_deploy_jumpstart_model(sample_deployment, jumpstart_model) + + assert predictor.endpoint_name == "test-jumpstart-endpoint" + mock_jumpstart_model.assert_called_once() \ No newline at end of file diff --git a/magemaker/sagemaker/test_delete_model.py b/magemaker/sagemaker/test_delete_model.py new file mode 100644 index 0000000..de3ee38 --- /dev/null +++ b/magemaker/sagemaker/test_delete_model.py @@ -0,0 +1,23 @@ +import pytest +from unittest.mock import patch, MagicMock +from magemaker.sagemaker.delete_model import delete_sagemaker_model + +def test_delete_sagemaker_model(): + # Test deleting multiple endpoints + with patch('boto3.client') as mock_boto_client: + mock_sagemaker_client = MagicMock() + mock_boto_client.return_value = mock_sagemaker_client + + endpoints = ['endpoint1', 'endpoint2'] + delete_sagemaker_model(endpoints) + + # Check that delete_endpoint was called for each endpoint + assert mock_sagemaker_client.delete_endpoint.call_count == 2 + mock_sagemaker_client.delete_endpoint.assert_any_call(EndpointName='endpoint1') + mock_sagemaker_client.delete_endpoint.assert_any_call(EndpointName='endpoint2') + +def test_delete_empty_endpoints(): + # Test deleting with empty list + with patch('magemaker.sagemaker.delete_model.print_success') as mock_print_success: + delete_sagemaker_model([]) + mock_print_success.assert_called_once_with("No Endpoints to delete!") \ No newline at end of file diff --git a/magemaker/sagemaker/test_fine_tune_model.py b/magemaker/sagemaker/test_fine_tune_model.py new file mode 100644 index 0000000..0c9dd81 --- /dev/null +++ b/magemaker/sagemaker/test_fine_tune_model.py @@ -0,0 +1,80 @@ +import pytest +from unittest.mock import patch, MagicMock +import sys + +# Mock required modules before importing +sys.modules['datasets'] = MagicMock() +sys.modules['transformers'] = MagicMock() + +from magemaker.schemas.training import Training +from magemaker.schemas.model import Model, ModelSource +from magemaker.sagemaker.fine_tune_model import fine_tune_model + +@pytest.fixture +def sample_sagemaker_training(): + return Training( + destination="aws", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path="s3://test-bucket/output", + training_input_path="s3://test-bucket/train" + ) + +@pytest.fixture +def sample_sagemaker_model(): + return Model( + id="jumpstart-dft-bert-base-uncased-text-classification", + source=ModelSource.Sagemaker, + version="1.0" + ) + +@patch('magemaker.sagemaker.fine_tune_model.sagemaker.hyperparameters.retrieve_default') +@patch('magemaker.sagemaker.fine_tune_model.train_model') +@patch('magemaker.sagemaker.fine_tune_model.JumpStartEstimator') +def test_fine_tune_sagemaker_model( + mock_jumpstart_estimator, + mock_train_model, + mock_retrieve_default, + sample_sagemaker_training, + sample_sagemaker_model +): + # Mock hyperparameters retrieval + mock_retrieve_default.return_value = {"param1": "value1"} + + # Setup mock estimator + mock_estimator = MagicMock() + mock_jumpstart_estimator.return_value = mock_estimator + + # Call fine_tune_model + fine_tune_model(sample_sagemaker_training, sample_sagemaker_model) + + # Verify method calls + mock_retrieve_default.assert_called_once() + mock_jumpstart_estimator.assert_called_once() + mock_train_model.assert_called_once() + +@patch('magemaker.sagemaker.fine_tune_model.train_model') +def test_fine_tune_unsupported_model_sources(mock_train_model): + # Test HuggingFace model source + huggingface_model = Model( + id="google-bert/bert-base-uncased", + source=ModelSource.HuggingFace + ) + training = Training( + destination="aws", + instance_type="ml.m5.xlarge", + instance_count=1, + training_input_path="s3://test-bucket/train" + ) + + with pytest.raises(NotImplementedError): + fine_tune_model(training, huggingface_model) + + # Test Custom model source + custom_model = Model( + id="custom-model", + source=ModelSource.Custom + ) + + with pytest.raises(NotImplementedError): + fine_tune_model(training, custom_model) \ No newline at end of file diff --git a/magemaker/sagemaker/test_query_endpoint.py b/magemaker/sagemaker/test_query_endpoint.py new file mode 100644 index 0000000..ad1f1bf --- /dev/null +++ b/magemaker/sagemaker/test_query_endpoint.py @@ -0,0 +1,34 @@ +import pytest +from unittest.mock import patch, MagicMock +import sys + +# Mock problematic imports +sys.modules['inquirer'] = MagicMock() +sys.modules['InquirerPy'] = MagicMock() + +from magemaker.schemas.query import Query +from magemaker.sagemaker.query_endpoint import make_query_request +from magemaker.schemas.deployment import Deployment +from magemaker.schemas.model import Model + +def test_make_query_request(): + with patch('magemaker.sagemaker.query_endpoint.is_sagemaker_model') as mock_is_sagemaker, \ + patch('magemaker.sagemaker.query_endpoint.query_sagemaker_endpoint') as mock_sagemaker_query, \ + patch('magemaker.sagemaker.query_endpoint.query_hugging_face_endpoint') as mock_hf_query: + + # Test Sagemaker model + mock_is_sagemaker.return_value = True + mock_sagemaker_query.return_value = "Sagemaker result" + + query = Query(query="Test query") + config = (MagicMock(), MagicMock()) + + result = make_query_request("test-endpoint", query, config) + assert result == "Sagemaker result" + + # Test HuggingFace model + mock_is_sagemaker.return_value = False + mock_hf_query.return_value = "HuggingFace result" + + result = make_query_request("test-endpoint", query, config) + assert result == "HuggingFace result" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5a0b4a5..b5b841c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ 'azure-identity==1.19.0', 'azure-mgmt-resource==23.2.0', 'marshmallow==3.23.2', + 'click>=8.0.0', + 'docker>=6.1.0', ] requires-python = ">=3.11" readme = "README.md" @@ -34,6 +36,7 @@ license = {text = "MIT"} [project.scripts] magemaker = "magemaker.runner:runner" +magemaker-docker = "magemaker.docker.cli:main" [build-system] @@ -45,7 +48,7 @@ distribution = true package-dir = "magemaker" [tool.pdm.build] -includes = ["magemaker", "magemaker/scripts/preflight.sh", "magemaker/scripts/setup_role.sh"] # Include setup.sh in the package distribution +includes = ["magemaker", "magemaker/scripts/preflight.sh", "magemaker/scripts/setup_role.sh", "magemaker/docker/Dockerfile", "magemaker/docker/entrypoint.sh"] # Include setup.sh in the package distribution [tool.pdm.dev-dependencies] diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..e69de29 From 584b5dea1d38b9ec946233a7d26656534326581a Mon Sep 17 00:00:00 2001 From: HamzaAyoub033 Date: Thu, 16 Jan 2025 13:04:27 +0000 Subject: [PATCH 2/4] Stablize the Docker Image remove unnecessary files --- .gitignore | 4 +- Dockerfile-server | 12 ++ magemaker/cli.py | 32 ---- pyproject.toml | 1 - scripts/preflight.sh | 384 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 398 insertions(+), 35 deletions(-) delete mode 100644 magemaker/cli.py create mode 100755 scripts/preflight.sh diff --git a/.gitignore b/.gitignore index e3db042..b3fbd5a 100644 --- a/.gitignore +++ b/.gitignore @@ -204,5 +204,5 @@ cython_debug/ models/ configs/ .pdm-python -.magemaker_configs -.env.save \ No newline at end of file +.env.save +.magemaker_configs \ No newline at end of file diff --git a/Dockerfile-server b/Dockerfile-server index 5fde15f..6351096 100644 --- a/Dockerfile-server +++ b/Dockerfile-server @@ -8,6 +8,9 @@ RUN apt-get update && apt-get install -y \ unzip \ nano \ vim \ + gnupg \ + lsb-release \ + ca-certificates \ && rm -rf /var/lib/apt/lists/* # Set working directory @@ -32,6 +35,15 @@ RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud && ./google-cloud-sdk/install.sh --quiet \ && rm google-cloud-sdk.tar.gz +# Install Azure CLI +RUN mkdir -p /etc/apt/keyrings && \ + curl -sLS https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor | tee /etc/apt/keyrings/microsoft.gpg > /dev/null && \ + chmod go+r /etc/apt/keyrings/microsoft.gpg && \ + echo "deb [arch=`dpkg --print-architecture` signed-by=/etc/apt/keyrings/microsoft.gpg] https://packages.microsoft.com/repos/azure-cli/ $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/azure-cli.list && \ + apt-get update && \ + apt-get install -y azure-cli && \ + rm -rf /var/lib/apt/lists/* + # Add Google Cloud SDK to PATH ENV PATH $PATH:/app/google-cloud-sdk/bin diff --git a/magemaker/cli.py b/magemaker/cli.py deleted file mode 100644 index 578ee8e..0000000 --- a/magemaker/cli.py +++ /dev/null @@ -1,32 +0,0 @@ -import click -import yaml -from rich.console import Console -from magemaker.schemas.deployment import Deployment -from magemaker.schemas.model import Model -from magemaker.sagemaker.create_model import deploy_huggingface_model_to_sagemaker -from magemaker.gcp.create_model import deploy_huggingface_model_to_vertexai - -console = Console() - -@click.group() -def main(): - """Magemaker CLI for model deployment""" - pass - -@main.command() -@click.option('--deploy', type=click.Path(exists=True), help='Path to deployment YAML file') -def deploy(deploy): - """Deploy a model using configuration from YAML file""" - with open(deploy, 'r') as f: - config = yaml.safe_load(f) - - deployment = Deployment(**config['deployment']) - model = Model(**config['model']) - - if deployment.destination == "sagemaker": - deploy_huggingface_model_to_sagemaker(deployment, model) - elif deployment.destination == "gcp": - deploy_huggingface_model_to_vertexai(deployment, model) - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b5b841c..34ab6c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ license = {text = "MIT"} [project.scripts] magemaker = "magemaker.runner:runner" -magemaker-docker = "magemaker.docker.cli:main" [build-system] diff --git a/scripts/preflight.sh b/scripts/preflight.sh new file mode 100755 index 0000000..971c2f0 --- /dev/null +++ b/scripts/preflight.sh @@ -0,0 +1,384 @@ +#!/bin/sh + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Get the directory where the script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +# Logging functions +log_info() { + echo "[INFO] $1" +} + +log_debug() { + echo "[DEBUG] $1" +} + +log_error() { + echo "[ERROR] $1" >&2 +} + +# Configuration functions +configure_aws() { +echo "Configuring AWS..." +echo "you need to create an aws user with access to Sagemaker" +echo "if you don't know how to do that follow this doc https://docs.google.com/document/d/1NvA6uZmppsYzaOdkcgNTRl7Nb4LbpP9Koc4H_t5xNSg/edit?usp=sharing" + + +# green +if ! command -v aws &> /dev/null +then + OS="$(uname -s)" + case "${OS}" in + Linux*) + curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" + unzip awscliv2.zip + sudo ./aws/install + ;; + Darwin*) + curl "https://awscli.amazonaws.com/AWSCLIV2.pkg" -o "AWSCLIV2.pkg" + sudo installer -pkg AWSCLIV2.pkg -target / + ;; + *) + echo "Unsupported OS: ${OS}. See https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html" + exit 1 + ;; + esac +fi + +# echo green that press enter if you have already done this +echo -e "${GREEN}Press enter in the following configuration steps if you have already done this${NC}" + +aws configure set region us-east-1 && aws configure +touch .env + + +if ! grep -q "SAGEMAKER_ROLE" .env +then + # bash ./setup_role.sh + bash "$SCRIPT_DIR/setup_role.sh" +fi +} + + +# GCP + +configure_gcp() { + echo "Configuring GCP..." +echo "you need to create a GCP service account with access to GCS and vertex ai" +echo "if you don't know how to do that follow this doc https://docs.google.com/document/d/1NvA6uZmppsYzaOdkcgNTRl7Nb4LbpP9Koc4H_t5xNSg/edit?usp=sharing" + +if ! command -v gcloud &> /dev/null +then + echo "you need to install gcloud sdk for the terminal" + echo "https://cloud.google.com/sdk/docs/install" +fi + +# only run this if the credentials are not set + + +echo "Checking for gcloud installation..." + +# Check if gcloud is installed +if ! command -v gcloud &> /dev/null; then + echo -e "${RED}Error: gcloud CLI is not installed${NC}" + echo "Please install the Google Cloud SDK first" + exit 1 +fi + +echo "Checking for active gcloud accounts..." + +# Get list of active accounts +ACCOUNTS=$(gcloud auth list --filter=status:ACTIVE --format="value(account)" 2>/dev/null) + +# Check if command was successful +if [ $? -ne 0 ]; then + echo -e "${RED}Error: Failed to retrieve account information${NC}" + echo "Please check your gcloud installation" + exit 1 +fi + +# Check if any accounts are found +if [ -z "$ACCOUNTS" ]; then + echo -e "${YELLOW}No active gcloud accounts found${NC}" + # echo "To login, use: gcloud auth login" + gcloud auth login + exit 0 +fi + +# echo "Setting up application default credentials..." +# gcloud auth application-default login --no-launch-browser + +# if [ $? -ne 0 ]; then +# echo -e "${RED}Failed to set application default credentials${NC}" +# exit 1 +# fi + +# Get current project ID +if ! grep -q "PROJECT_ID" .env +then + PROJECT_ID=$(gcloud config get-value project 2>/dev/null) + if [ -n "$PROJECT_ID" ]; then + export PROJECT_ID="$PROJECT_ID" + echo "PROJECT_ID=$PROJECT_ID" >> .env + echo -e "${GREEN}Exported PROJECT_ID=${NC}${PROJECT_ID}" + else + echo -e "${YELLOW}No project currently set${NC}" + fi +fi + +if ! grep -q "GCLOUD_REGION" .env +then + CURRENT_REGION=$(gcloud config get-value compute/region 2>/dev/null) + if [ -n "$CURRENT_REGION" ]; then + echo "GCLOUD_REGION=$CURRENT_REGION" >> .env + export GCLOUD_REGION="$CURRENT_REGION" + echo -e "${GREEN}Exported GCLOUD_REGION=${NC}${CURRENT_REGION}" + else + echo -e "${YELLOW}No compute region currently set${NC}" + fi +fi +} + +# AZURE +configure_azure() { +echo "Configuring Azure..." +echo "Checking for Azure CLI installation..." +if ! command -v az &> /dev/null +then + echo "Azure CLI not found. Installing..." + OS="$(uname -s)" + case "${OS}" in + Linux*) + curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash + ;; + Darwin*) + brew update && brew install azure-cli + ;; + *) + echo "Unsupported OS: ${OS}. See https://docs.microsoft.com/en-us/cli/azure/install-azure-cli" + exit 1 + ;; + esac +fi + +# Check Azure login status +echo "Checking Azure login status..." +if ! az account show &> /dev/null; then + echo "Not logged into Azure. Please log in..." + az login + if [ $? -ne 0 ]; then + echo "Azure login failed. Please try again." + exit 1 + fi +fi + +# Get and set subscription +if ! grep -q "AZURE_SUBSCRIPTION_ID" .env; then + SUBSCRIPTION_ID=$(az account show --query id -o tsv) + if [ -n "$SUBSCRIPTION_ID" ]; then + echo "AZURE_SUBSCRIPTION_ID=$SUBSCRIPTION_ID" >> .env + export AZURE_SUBSCRIPTION_ID="$SUBSCRIPTION_ID" + echo "Exported AZURE_SUBSCRIPTION_ID=${SUBSCRIPTION_ID}" + else + echo "No Azure subscription found" + exit 1 + fi +fi + +# Get and set resource group +if ! grep -q "AZURE_RESOURCE_GROUP" .env; then + echo "Listing resource groups..." + az group list -o table + echo "Please enter the resource group name to use:" + read RESOURCE_GROUP + if [ -n "$RESOURCE_GROUP" ]; then + echo "AZURE_RESOURCE_GROUP=$RESOURCE_GROUP" >> .env + export AZURE_RESOURCE_GROUP="$RESOURCE_GROUP" + echo "Exported AZURE_RESOURCE_GROUP=${RESOURCE_GROUP}" + else + echo "No resource group specified" + exit 1 + fi +fi + +# Get and set region +if ! grep -q "AZURE_REGION" .env; then + CURRENT_REGION=$(az group show --name $AZURE_RESOURCE_GROUP --query location -o tsv) + if [ -n "$CURRENT_REGION" ]; then + echo "AZURE_REGION=$CURRENT_REGION" >> .env + export AZURE_REGION="$CURRENT_REGION" + echo "Exported AZURE_REGION=${CURRENT_REGION}" + else + echo "Available Azure regions:" + az account list-locations --query "[].{Region:name}" -o table + echo "Please enter the Azure region to use:" + read AZURE_REGION + echo "AZURE_REGION=$AZURE_REGION" >> .env + export AZURE_REGION="$AZURE_REGION" + echo "Exported AZURE_REGION=${AZURE_REGION}" + fi +fi + +# Check Azure ML workspace +echo "Checking Azure ML workspace..." +if ! grep -q "AZURE_WORKSPACE_NAME" .env; then + # List available workspaces + echo "Available Azure ML workspaces in resource group $AZURE_RESOURCE_GROUP:" + az ml workspace list --resource-group $AZURE_RESOURCE_GROUP -o table + + echo "Please enter the Azure ML workspace name to use:" + read WORKSPACE_NAME + + if [ -n "$WORKSPACE_NAME" ]; then + # Verify workspace exists + if az ml workspace show --name $WORKSPACE_NAME --resource-group $AZURE_RESOURCE_GROUP &> /dev/null; then + echo "AZURE_WORKSPACE_NAME=$WORKSPACE_NAME" >> .env + export AZURE_ML_WORKSPACE="$WORKSPACE_NAME" + echo "Exported AZURE_WORKSPACE_NAME=${WORKSPACE_NAME}" + else + echo "Workspace $WORKSPACE_NAME not found in resource group $AZURE_RESOURCE_GROUP" + exit 1 + fi + else + echo "No workspace specified" + exit 1 + fi +fi + +# Function to check and register Azure resource providers +check_and_register_providers() { + local providers=( + "Microsoft.MachineLearningServices" + "Microsoft.ContainerRegistry" + "Microsoft.KeyVault" + "Microsoft.Storage" + "Microsoft.Insights" + "Microsoft.ContainerService" + "Microsoft.PolicyInsights" + "Microsoft.Cdn" + ) + + echo "Checking Azure resource providers..." + for provider in "${providers[@]}"; do + echo "Checking registration status for: $provider" + + # Get the registration state + state=$(az provider show --namespace $provider --query registrationState -o tsv 2>/dev/null) + + if [ "$state" != "Registered" ]; then + echo "$provider is not registered. Registering now..." + az provider register --namespace $provider + + # Wait for registration to complete + echo "Waiting for $provider registration to complete..." + while true; do + state=$(az provider show --namespace $provider --query registrationState -o tsv) + if [ "$state" == "Registered" ]; then + echo "$provider registration completed" + break + fi + echo "Registration in progress... waiting 10 seconds" + sleep 10 + done + else + echo "$provider is already registered" + fi + done + + echo "All required resource providers are registered" +} + +# Add this line after the Azure login check +echo "Checking and registering required Azure resource providers..." +check_and_register_providers + + +# Verify all required Azure environment variables are set +echo "Verifying Azure environment variables..." +REQUIRED_VARS=("AZURE_SUBSCRIPTION_ID" "AZURE_RESOURCE_GROUP" "AZURE_REGION" "AZURE_WORKSPACE_NAME") +for var in "${REQUIRED_VARS[@]}"; do + if ! grep -q "$var" .env; then + echo "Missing required environment variable: $var" + exit 1 + fi +done + +echo "Azure environment setup completed successfully!" + +# touch .env +} +configure_all_providers() { + log_info "Performing comprehensive multi-cloud configuration..." + + + # Detailed configuration for each cloud + configure_aws + configure_gcp + configure_azure + log_info "Multi-cloud configuration completed successfully" +} + +# Argument parsing +CLOUD="" +while [ $# -gt 0 ]; do + case "$1" in + --cloud) + shift + CLOUD="$1" + break + ;; + --cloud=*) + CLOUD="${1#*=}" + break + ;; + esac + shift +done + +# log_debug "Raw arguments: $@" +# log_debug "Cloud argument received: '$CLOUD'" + +# Validate cloud argument +# validate_cloud_arg() { +# case "$1" in +# aws|gcp|azure) +# return 0 +# ;; +# *) +# log_error "Invalid cloud provider: '$1'" +# log_error "Supported providers: aws, gcp, azure" +# exit 1 +# ;; +# esac +# } + +# Main configuration logic +main_configuration() { + # Validate cloud argument + # validate_cloud_arg "$CLOUD" + + # Configure specific cloud provider + case "$CLOUD" in + aws) + configure_aws + ;; + gcp) + configure_gcp + ;; + azure) + configure_azure + ;; + all) + configure_all_providers + ;; + *) + esac +} + +# Execute main configuration +main_configuration \ No newline at end of file From 02d4569064b90f00d41752e7991ee8916977676c Mon Sep 17 00:00:00 2001 From: HamzaAyoub033 Date: Mon, 20 Jan 2025 22:18:03 +0000 Subject: [PATCH 3/4] Remove tests files from Docker PR --- .gitignore | 3 +- ...k--opt-125m-202410251736-202501091758.yaml | 15 - magemaker/huggingface/test_hf_hub_api.py | 42 -- magemaker/sagemaker/create_model.py | 218 ---------- magemaker/sagemaker/delete_model.py | 17 - magemaker/sagemaker/fine_tune_model.py | 119 ------ magemaker/sagemaker/query_endpoint.py | 205 ---------- magemaker/sagemaker/test_create_model.py | 64 --- magemaker/sagemaker/test_delete_model.py | 23 -- magemaker/sagemaker/test_fine_tune_model.py | 80 ---- magemaker/sagemaker/test_query_endpoint.py | 34 -- scripts/preflight.sh | 384 ------------------ 12 files changed, 1 insertion(+), 1203 deletions(-) delete mode 100644 .magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml delete mode 100644 magemaker/huggingface/test_hf_hub_api.py delete mode 100644 magemaker/sagemaker/create_model.py delete mode 100644 magemaker/sagemaker/delete_model.py delete mode 100644 magemaker/sagemaker/fine_tune_model.py delete mode 100644 magemaker/sagemaker/query_endpoint.py delete mode 100644 magemaker/sagemaker/test_create_model.py delete mode 100644 magemaker/sagemaker/test_delete_model.py delete mode 100644 magemaker/sagemaker/test_fine_tune_model.py delete mode 100644 magemaker/sagemaker/test_query_endpoint.py delete mode 100755 scripts/preflight.sh diff --git a/.gitignore b/.gitignore index b3fbd5a..65d4606 100644 --- a/.gitignore +++ b/.gitignore @@ -204,5 +204,4 @@ cython_debug/ models/ configs/ .pdm-python -.env.save -.magemaker_configs \ No newline at end of file +.env.save \ No newline at end of file diff --git a/.magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml b/.magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml deleted file mode 100644 index 77bc415..0000000 --- a/.magemaker_configs/facebook--opt-125m-202410251736-202501091758.yaml +++ /dev/null @@ -1,15 +0,0 @@ -deployment: !Deployment - destination: aws - endpoint_name: facebook--opt-125m-202410251736-202501091758 - instance_count: 1 - instance_type: ml.t2.medium - num_gpus: null - quantization: null -models: -- !Model - id: facebook/opt-125m - location: null - predict: null - source: huggingface - task: text-generation - version: null diff --git a/magemaker/huggingface/test_hf_hub_api.py b/magemaker/huggingface/test_hf_hub_api.py deleted file mode 100644 index 35bebe7..0000000 --- a/magemaker/huggingface/test_hf_hub_api.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -from magemaker.schemas.model import Model -from magemaker.huggingface.hf_hub_api import get_hf_task - -def test_get_hf_task_successful(): - with patch('magemaker.huggingface.hf_hub_api.hf_api') as mock_hf_api: - mock_model_info = MagicMock() - mock_model_info.pipeline_tag = "text-classification" - mock_model_info.transformers_info = None - mock_hf_api.model_info.return_value = mock_model_info - - model = Model(id="test-model", source="huggingface") - task = get_hf_task(model) - - assert task == "text-classification" - -def test_get_hf_task_with_transformers_info(): - with patch('magemaker.huggingface.hf_hub_api.hf_api') as mock_hf_api: - mock_model_info = MagicMock() - mock_model_info.pipeline_tag = "old-task" - mock_model_info.transformers_info = MagicMock(pipeline_tag="text-classification") - mock_hf_api.model_info.return_value = mock_model_info - - model = Model(id="test-model", source="huggingface") - task = get_hf_task(model) - - assert task == "text-classification" - -def test_get_hf_task_exception(): - with patch('magemaker.huggingface.hf_hub_api.hf_api') as mock_hf_api, \ - patch('magemaker.huggingface.hf_hub_api.console') as mock_console, \ - patch('magemaker.huggingface.hf_hub_api.print_error') as mock_print_error: - - mock_hf_api.model_info.side_effect = Exception("API error") - - model = Model(id="test-model", source="huggingface") - task = get_hf_task(model) - - assert task is None - mock_console.print_exception.assert_called_once() - mock_print_error.assert_called_once() \ No newline at end of file diff --git a/magemaker/sagemaker/create_model.py b/magemaker/sagemaker/create_model.py deleted file mode 100644 index 8612d33..0000000 --- a/magemaker/sagemaker/create_model.py +++ /dev/null @@ -1,218 +0,0 @@ -import json -from dotenv import dotenv_values -from rich.table import Table -from sagemaker import image_uris, model_uris, script_uris -from sagemaker.huggingface import get_huggingface_llm_image_uri -from sagemaker.huggingface.model import HuggingFaceModel -from sagemaker.jumpstart.model import JumpStartModel -from sagemaker.jumpstart.estimator import JumpStartEstimator -from sagemaker.model import Model -from sagemaker.predictor import Predictor -from sagemaker.s3 import S3Uploader -from magemaker.config import write_config -from magemaker.schemas.model import Model, ModelSource -from magemaker.schemas.deployment import Deployment -from magemaker.session import session, sagemaker_session -from magemaker.console import console -from magemaker.utils.aws_utils import construct_s3_uri, is_s3_uri -from magemaker.utils.rich_utils import print_error, print_success -from magemaker.utils.model_utils import get_unique_endpoint_name, get_model_and_task -from magemaker.huggingface import HuggingFaceTask -from magemaker.huggingface.hf_hub_api import get_hf_task - - - -def deploy_huggingface_model_to_sagemaker(deployment, model): - HUGGING_FACE_HUB_TOKEN = dotenv_values(".env").get("HUGGING_FACE_HUB_KEY") - SAGEMAKER_ROLE = dotenv_values(".env").get("SAGEMAKER_ROLE") - - region_name = session.region_name - task = get_hf_task(model) - model.task = task - env = { - 'HF_MODEL_ID': model.id, - 'HF_TASK': task, - } - - if HUGGING_FACE_HUB_TOKEN is not None: - env['HUGGING_FACE_HUB_TOKEN'] = HUGGING_FACE_HUB_TOKEN - - image_uri = None - if deployment.num_gpus: - env['SM_NUM_GPUS'] = json.dumps(deployment.num_gpus) - - if deployment.quantization: - env['HF_MODEL_QUANTIZE'] = deployment.quantization - - if task == HuggingFaceTask.TextGeneration: - # use TGI imageq if llm. - image_uri = get_huggingface_llm_image_uri( - "huggingface", - version="1.4.2" - ) - - huggingface_model = HuggingFaceModel( - env=env, - role=SAGEMAKER_ROLE, - transformers_version="4.37", - pytorch_version="2.1", - py_version="py310", - image_uri=image_uri - ) - - endpoint_name = get_unique_endpoint_name( - model.id, deployment.endpoint_name) - - deployment.endpoint_name = endpoint_name - - console.log( - "Deploying model to AWS. [magenta]This may take up to 10 minutes for very large models.[/magenta] See full logs here:") - console.print( - f"https://{region_name}.console.aws.amazon.com/cloudwatch/home#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FEndpoints$252F{endpoint_name}") - - with console.status("[bold green]Deploying model...") as status: - table = Table(show_header=False, header_style="magenta") - table.add_column("Resource", style="dim") - table.add_column("Value", style="blue") - table.add_row("model", model.id) - table.add_row("EC2 instance type", deployment.instance_type) - table.add_row("Number of instances", str( - deployment.instance_count)) - table.add_row("task", task) - console.print(table) - - try: - predictor = huggingface_model.deploy( - initial_instance_count=deployment.instance_count, - instance_type=deployment.instance_type, - endpoint_name=endpoint_name, - ) - except Exception: - console.print_exception() - quit() - - print_success( - f"{model.id} is now up and running at the endpoint [blue]{predictor.endpoint_name}") - - write_config(deployment, model) - return predictor - -def deploy_huggingface_model_to_vertexai(deployment, model): - pass - - - -def deploy_custom_huggingface_model(deployment: Deployment, model: Model): - SAGEMAKER_ROLE = dotenv_values(".env").get("SAGEMAKER_ROLE") - - region_name = session.region_name - if model.location is None: - print_error("Missing model source location.") - return - - s3_path = model.location - if not is_s3_uri(model.location): - # Local file. Upload to s3 before deploying - bucket = sagemaker_session.default_bucket() - s3_path = construct_s3_uri(bucket, f"models/{model.id}") - with console.status(f"[bold green]Uploading custom {model.id} model to S3 at {s3_path}...") as status: - try: - s3_path = S3Uploader.upload( - model.location, s3_path) - except Exception: - print_error("[red] Model failed to upload to S3") - - endpoint_name = get_unique_endpoint_name( - model.id, deployment.endpoint_name) - - deployment.endpoint_name = endpoint_name - model.task = get_model_and_task(model.id)['task'] - - console.log( - "Deploying model to AWS. [magenta]This may take up to 10 minutes for very large models.[/magenta] See full logs here:") - console.print( - f"https://{region_name}.console.aws.amazon.com/cloudwatch/home#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FEndpoints$252F{endpoint_name}") - - # create Hugging Face Model Class - huggingface_model = HuggingFaceModel( - # path to your trained sagemaker model - model_data=s3_path, - role=SAGEMAKER_ROLE, # iam role with permissions to create an Endpoint - transformers_version="4.37", - pytorch_version="2.1", - py_version="py310", - ) - - with console.status("[bold green]Deploying model...") as status: - table = Table(show_header=False, header_style="magenta") - table.add_column("Resource", style="dim") - table.add_column("Value", style="blue") - table.add_row("S3 Path", s3_path) - table.add_row("EC2 instance type", deployment.instance_type) - table.add_row("Number of instances", str( - deployment.instance_count)) - console.print(table) - - try: - predictor = huggingface_model.deploy( - initial_instance_count=deployment.instance_count, - instance_type=deployment.instance_type, - endpoint_name=endpoint_name - ) - except Exception: - console.print_exception() - quit() - - print_success( - f"Custom {model.id} is now up and running at the endpoint [blue]{predictor.endpoint_name}") - - write_config(deployment, model) - return predictor - - -def create_and_deploy_jumpstart_model(deployment: Deployment, model: Model): - SAGEMAKER_ROLE = dotenv_values(".env").get("SAGEMAKER_ROLE") - - region_name = session.region_name - endpoint_name = get_unique_endpoint_name( - model.id, deployment.endpoint_name) - deployment.endpoint_name = endpoint_name - model.task = get_model_and_task(model.id)['task'] - - console.log( - "Deploying model to AWS. [magenta]This may take up to 10 minutes for very large models.[/magenta] See full logs here:") - - console.print( - f"https://{region_name}.console.aws.amazon.com/cloudwatch/home#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FEndpoints$252F{endpoint_name}") - - with console.status("[bold green]Deploying model...") as status: - table = Table(show_header=False, header_style="magenta") - table.add_column("Resource", style="dim") - table.add_column("Value", style="blue") - table.add_row("model", model.id) - table.add_row("EC2 instance type", deployment.instance_type) - table.add_row("Number of instances", str( - deployment.instance_count)) - console.print(table) - - jumpstart_model = JumpStartModel( - model_id=model.id, instance_type=deployment.instance_type, role=SAGEMAKER_ROLE) - - # Attempt to deploy to AWS - try: - predictor = jumpstart_model.deploy( - initial_instance_count=deployment.instance_count, - instance_type=deployment.instance_type, - endpoint_name=endpoint_name, - accept_eula=True - ) - pass - except Exception: - console.print_exception() - quit() - - write_config(deployment, model) - print_success( - f"{model.id} is now up and running at the endpoint [blue]{predictor.endpoint_name}") - - return predictor diff --git a/magemaker/sagemaker/delete_model.py b/magemaker/sagemaker/delete_model.py deleted file mode 100644 index 7c10078..0000000 --- a/magemaker/sagemaker/delete_model.py +++ /dev/null @@ -1,17 +0,0 @@ -import boto3 -from rich import print -from magemaker.utils.rich_utils import print_success -from typing import List - - -def delete_sagemaker_model(endpoint_names: List[str] = None): - sagemaker_client = boto3.client('sagemaker') - - if len(endpoint_names) == 0: - print_success("No Endpoints to delete!") - return - - # Add validation / error handling - for endpoint in endpoint_names: - print(f"Deleting [blue]{endpoint}") - sagemaker_client.delete_endpoint(EndpointName=endpoint) diff --git a/magemaker/sagemaker/fine_tune_model.py b/magemaker/sagemaker/fine_tune_model.py deleted file mode 100644 index 4fd597d..0000000 --- a/magemaker/sagemaker/fine_tune_model.py +++ /dev/null @@ -1,119 +0,0 @@ -import logging -import os -import sagemaker -from botocore.exceptions import ClientError -from datasets import load_dataset -from rich import print -from rich.table import Table -from sagemaker.jumpstart.estimator import JumpStartEstimator -from magemaker.console import console -from magemaker.schemas.model import Model, ModelSource -from magemaker.schemas.training import Training -from magemaker.session import sagemaker_session -from magemaker.utils.aws_utils import is_s3_uri -from magemaker.utils.rich_utils import print_success, print_error -from transformers import AutoTokenizer - -from dotenv import load_dotenv -load_dotenv() - - -def prep_hf_data(s3_bucket: str, dataset_name_or_path: str, model: Model): - train_dataset, test_dataset = load_dataset( - dataset_name_or_path, split=["train", "test"]) - tokenizer = AutoTokenizer.from_pretrained(model.id) - - def tokenize(batch): - return tokenizer(batch["text"], padding="max_length", truncation=True) - - # tokenize train and test datasets - train_dataset = train_dataset.map(tokenize, batched=True) - test_dataset = test_dataset.map(tokenize, batched=True) - - # set dataset format for PyTorch - train_dataset = train_dataset.rename_column("label", "labels") - train_dataset.set_format( - "torch", columns=["input_ids", "attention_mask", "labels"]) - test_dataset = test_dataset.rename_column("label", "labels") - test_dataset.set_format( - "torch", columns=["input_ids", "attention_mask", "labels"]) - - # save train_dataset to s3 - training_input_path = f's3://{s3_bucket}/datasets/train' - train_dataset.save_to_disk(training_input_path) - - # save test_dataset to s3 - test_input_path = f's3://{s3_bucket}/datasets/test' - test_dataset.save_to_disk(test_input_path) - - return training_input_path, test_input_path - - -def train_model(training: Training, model: Model, estimator): - # TODO: Accept hf datasets or local paths to upload to s3 - if not is_s3_uri(training.training_input_path): - raise Exception("Training data needs to be uploaded to s3") - - # TODO: Implement training, validation, and test split or accept a directory of files - training_dataset_s3_path = training.training_input_path - - table = Table(show_header=False, header_style="magenta") - table.add_column("Resource", style="dim") - table.add_column("Value", style="blue") - table.add_row("model", model.id) - table.add_row("model_version", model.version) - table.add_row("base_model_uri", estimator.model_uri) - table.add_row("image_uri", estimator.image_uri) - table.add_row("EC2 instance type", training.instance_type) - table.add_row("Number of instances", str(training.instance_count)) - console.print(table) - - estimator.fit({"training": training_dataset_s3_path}) - - predictor = estimator.deploy( - initial_instance_count=training.instance_count, instance_type=training.instance_type) - - print_success( - f"Trained model {model.id} is now up and running at the endpoint [blue]{predictor.endpoint_name}") - - -def fine_tune_model(training: Training, model: Model): - SAGEMAKER_ROLE = os.environ.get("SAGEMAKER_ROLE") - - estimator = None - match model.source: - case ModelSource.Sagemaker: - hyperparameters = get_hyperparameters_for_model(training, model) - estimator = JumpStartEstimator( - model_id=model.id, - model_version=model.version, - instance_type=training.instance_type, - instance_count=training.instance_count, - output_path=training.output_path, - environment={"accept_eula": "true"}, - role=SAGEMAKER_ROLE, - sagemaker_session=sagemaker_session, - hyperparameters=hyperparameters - ) - case ModelSource.HuggingFace: - raise NotImplementedError - case ModelSource.Custom: - raise NotImplementedError - - try: - print_success("Enqueuing training job") - res = train_model(training, model, estimator) - except ClientError as e: - logging.error(e) - print_error("Training job enqueue fail") - return False - - -def get_hyperparameters_for_model(training: Training, model: Model): - hyperparameters = sagemaker.hyperparameters.retrieve_default( - model_id=model.id, model_version=model.version) - - if training.hyperparameters is not None: - hyperparameters.update( - (k, v) for k, v in training.hyperparameters.model_dump().items() if v is not None) - return hyperparameters diff --git a/magemaker/sagemaker/query_endpoint.py b/magemaker/sagemaker/query_endpoint.py deleted file mode 100644 index 8c425a4..0000000 --- a/magemaker/sagemaker/query_endpoint.py +++ /dev/null @@ -1,205 +0,0 @@ -import boto3 -import json -import inquirer -from InquirerPy import prompt -from sagemaker.huggingface.model import HuggingFacePredictor -from magemaker.config import ModelDeployment -from magemaker.console import console -from magemaker.sagemaker import SagemakerTask -from magemaker.huggingface import HuggingFaceTask -from magemaker.utils.model_utils import get_model_and_task, is_sagemaker_model, get_text_generation_hyperpameters -from magemaker.utils.rich_utils import print_error -from magemaker.schemas.deployment import Deployment -from magemaker.schemas.model import Model -from magemaker.schemas.query import Query -from magemaker.session import sagemaker_session -from typing import Dict, Tuple, Optional - - -def make_query_request(endpoint_name: str, query: Query, config: Tuple[Deployment, Model]): - if is_sagemaker_model(endpoint_name, config): - return query_sagemaker_endpoint(endpoint_name, query, config) - else: - return query_hugging_face_endpoint(endpoint_name, query, config) - - -def parse_response(query_response): - model_predictions = json.loads(query_response['Body'].read()) - probabilities, labels, predicted_label = model_predictions[ - 'probabilities'], model_predictions['labels'], model_predictions['predicted_label'] - return probabilities, labels, predicted_label - - -def query_hugging_face_endpoint(endpoint_name: str, user_query: Query, config: Tuple[Deployment, Model]): - task = get_model_and_task(endpoint_name, config)['task'] - predictor = HuggingFacePredictor(endpoint_name=endpoint_name, - sagemaker_session=sagemaker_session) - - query = user_query.query - context = user_query.context - - input = {"inputs": query} - if task is not None and task == HuggingFaceTask.QuestionAnswering: - if context is None: - questions = [{ - "type": "input", "message": "What context would you like to provide?:", "name": "context"}] - answers = prompt(questions) - context = answers.get('context', '') - - if not context: - raise Exception("Must provide context for question-answering") - - input = {} - input['context'] = answers['context'] - input['question'] = query - - if task is not None and task == HuggingFaceTask.TextGeneration: - parameters = get_text_generation_hyperpameters(config, user_query) - input['parameters'] = parameters - - if task is not None and task == HuggingFaceTask.ZeroShotClassification: - if context is None: - questions = [ - inquirer.Text('labels', - message="What labels would you like to use? (comma separated values)?", - ) - ] - answers = inquirer.prompt(questions) - context = answers.get('labels', '') - - if not context: - raise Exception( - "Must provide labels for zero shot text classification") - - labels = context.split(',') - input = json.dumps({ - "sequences": query, - "candidate_labels": labels - }) - - try: - result = predictor.predict(input) - except Exception: - console.print_exception() - quit() - - print(result) - return result - - -def query_sagemaker_endpoint(endpoint_name: str, user_query: Query, config: Tuple[Deployment, Model]): - client = boto3.client('runtime.sagemaker') - task = get_model_and_task(endpoint_name, config)['task'] - - if task not in [ - SagemakerTask.ExtractiveQuestionAnswering, - SagemakerTask.TextClassification, - SagemakerTask.SentenceSimilarity, - SagemakerTask.SentencePairClassification, - SagemakerTask.Summarization, - SagemakerTask.NamedEntityRecognition, - SagemakerTask.TextEmbedding, - SagemakerTask.TcEmbedding, - SagemakerTask.TextGeneration, - SagemakerTask.TextGeneration1, - SagemakerTask.TextGeneration2, - SagemakerTask.Translation, - SagemakerTask.FillMask, - SagemakerTask.ZeroShotTextClassification - ]: - print_error(""" -Querying this model type inside of Model Manager isn’t yet supported. -You can query it directly through the API endpoint - see here for documentation on how to do this: -https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html - """) - raise Exception("Unsupported") - - # MIME content type varies per deployment - content_type = "application/x-text" - accept_type = "application/json;verbose" - - # Depending on the task, input needs to be formatted differently. - # e.g. question-answering needs to have {question: , context: } - query = user_query.query - context = user_query.context - input = query.encode("utf-8") - match task: - case SagemakerTask.ExtractiveQuestionAnswering: - if context is None: - questions = [ - { - 'type': 'input', - 'name': 'context', - 'message': "What context would you like to provide?", - } - ] - answers = prompt(questions) - context = answers.get("context", '') - - if not context: - raise Exception("Must provide context for question-answering") - - content_type = "application/list-text" - input = json.dumps([query, context]).encode("utf-8") - - case SagemakerTask.SentencePairClassification: - if context is None: - questions = [ - inquirer.Text('context', - message="What sentence would you like to compare against?", - ) - ] - answers = inquirer.prompt(questions) - context = answers.get("context", '') - if not context: - raise Exception( - "Must provide a second sentence for sentence pair classification") - - content_type = "application/list-text" - input = json.dumps([query, context]).encode("utf-8") - case SagemakerTask.ZeroShotTextClassification: - if context is None: - questions = [ - inquirer.Text('labels', - message="What labels would you like to use? (comma separated values)?", - ) - ] - answers = inquirer.prompt(questions) - context = answers.get('labels', '') - - if not context: - raise Exception( - "must provide labels for zero shot text classification") - labels = context.split(',') - - content_type = "application/json" - input = json.dumps({ - "sequences": query, - "candidate_labels": labels, - }).encode("utf-8") - case SagemakerTask.TextGeneration: - parameters = get_text_generation_hyperpameters(config, user_query) - input = json.dumps({ - "inputs": query, - "parameters": parameters, - }).encode("utf-8") - content_type = "application/json" - - try: - response = client.invoke_endpoint( - EndpointName=endpoint_name, ContentType=content_type, Body=input, Accept=accept_type) - except Exception: - console.print_exception() - quit() - - model_predictions = json.loads(response['Body'].read()) - print(model_predictions) - return model_predictions - - -def test(endpoint_name: str): - text1 = 'astonishing ... ( frames ) profound ethical and philosophical questions in the form of dazzling pop entertainment' - text2 = 'simply stupid , irrelevant and deeply , truly , bottomlessly cynical ' - - for text in [text1, text2]: - query_sagemaker_endpoint(endpoint_name, text.encode('utf-8')) diff --git a/magemaker/sagemaker/test_create_model.py b/magemaker/sagemaker/test_create_model.py deleted file mode 100644 index 15e5d32..0000000 --- a/magemaker/sagemaker/test_create_model.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import tempfile -import os -from unittest.mock import MagicMock, patch -from magemaker.sagemaker.create_model import ( - deploy_huggingface_model_to_sagemaker, - deploy_custom_huggingface_model, - create_and_deploy_jumpstart_model -) -from magemaker.schemas.deployment import Deployment -from magemaker.schemas.model import Model, ModelSource - -@pytest.fixture -def sample_huggingface_model(): - return Model(id="google-bert/bert-base-uncased", source=ModelSource.HuggingFace) - -@pytest.fixture -def sample_deployment(): - return Deployment(destination="aws", instance_type="ml.m5.xlarge", instance_count=1) - -@patch('magemaker.sagemaker.create_model.S3Uploader.upload') -@patch('magemaker.sagemaker.create_model.HuggingFaceModel') -def test_custom_model_deployment(mock_hf_model, mock_s3_upload, sample_deployment, tmp_path): - # Create a mock model file - test_model_file = tmp_path / "model.pt" - test_model_file.write_text("dummy model content") - - # Mock S3 upload and model deployment - mock_s3_upload.return_value = "s3://test-bucket/models/test-custom-model" - mock_predictor = MagicMock() - mock_predictor.endpoint_name = "test-endpoint-001" - mock_hf_model_return = mock_hf_model.return_value - mock_hf_model_return.deploy.return_value = mock_predictor - - custom_model = Model( - id="test-custom-model", - source=ModelSource.Custom, - location=str(test_model_file) - ) - - predictor = deploy_custom_huggingface_model(sample_deployment, custom_model) - - assert predictor.endpoint_name == "test-endpoint-001" - mock_s3_upload.assert_called_once() - mock_hf_model_return.deploy.assert_called_once() - -@patch('magemaker.sagemaker.create_model.JumpStartModel') -def test_jumpstart_model_deployment(mock_jumpstart_model, sample_deployment): - # Use a valid JumpStart model ID - jumpstart_model = Model( - id="jumpstart-dft-bert-base-uncased-text-classification", - source=ModelSource.Sagemaker - ) - - # Mock the JumpStart model deployment - mock_predictor = MagicMock() - mock_predictor.endpoint_name = "test-jumpstart-endpoint" - mock_jumpstart_model_return = mock_jumpstart_model.return_value - mock_jumpstart_model_return.deploy.return_value = mock_predictor - - predictor = create_and_deploy_jumpstart_model(sample_deployment, jumpstart_model) - - assert predictor.endpoint_name == "test-jumpstart-endpoint" - mock_jumpstart_model.assert_called_once() \ No newline at end of file diff --git a/magemaker/sagemaker/test_delete_model.py b/magemaker/sagemaker/test_delete_model.py deleted file mode 100644 index de3ee38..0000000 --- a/magemaker/sagemaker/test_delete_model.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -from magemaker.sagemaker.delete_model import delete_sagemaker_model - -def test_delete_sagemaker_model(): - # Test deleting multiple endpoints - with patch('boto3.client') as mock_boto_client: - mock_sagemaker_client = MagicMock() - mock_boto_client.return_value = mock_sagemaker_client - - endpoints = ['endpoint1', 'endpoint2'] - delete_sagemaker_model(endpoints) - - # Check that delete_endpoint was called for each endpoint - assert mock_sagemaker_client.delete_endpoint.call_count == 2 - mock_sagemaker_client.delete_endpoint.assert_any_call(EndpointName='endpoint1') - mock_sagemaker_client.delete_endpoint.assert_any_call(EndpointName='endpoint2') - -def test_delete_empty_endpoints(): - # Test deleting with empty list - with patch('magemaker.sagemaker.delete_model.print_success') as mock_print_success: - delete_sagemaker_model([]) - mock_print_success.assert_called_once_with("No Endpoints to delete!") \ No newline at end of file diff --git a/magemaker/sagemaker/test_fine_tune_model.py b/magemaker/sagemaker/test_fine_tune_model.py deleted file mode 100644 index 0c9dd81..0000000 --- a/magemaker/sagemaker/test_fine_tune_model.py +++ /dev/null @@ -1,80 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import sys - -# Mock required modules before importing -sys.modules['datasets'] = MagicMock() -sys.modules['transformers'] = MagicMock() - -from magemaker.schemas.training import Training -from magemaker.schemas.model import Model, ModelSource -from magemaker.sagemaker.fine_tune_model import fine_tune_model - -@pytest.fixture -def sample_sagemaker_training(): - return Training( - destination="aws", - instance_type="ml.m5.xlarge", - instance_count=1, - output_path="s3://test-bucket/output", - training_input_path="s3://test-bucket/train" - ) - -@pytest.fixture -def sample_sagemaker_model(): - return Model( - id="jumpstart-dft-bert-base-uncased-text-classification", - source=ModelSource.Sagemaker, - version="1.0" - ) - -@patch('magemaker.sagemaker.fine_tune_model.sagemaker.hyperparameters.retrieve_default') -@patch('magemaker.sagemaker.fine_tune_model.train_model') -@patch('magemaker.sagemaker.fine_tune_model.JumpStartEstimator') -def test_fine_tune_sagemaker_model( - mock_jumpstart_estimator, - mock_train_model, - mock_retrieve_default, - sample_sagemaker_training, - sample_sagemaker_model -): - # Mock hyperparameters retrieval - mock_retrieve_default.return_value = {"param1": "value1"} - - # Setup mock estimator - mock_estimator = MagicMock() - mock_jumpstart_estimator.return_value = mock_estimator - - # Call fine_tune_model - fine_tune_model(sample_sagemaker_training, sample_sagemaker_model) - - # Verify method calls - mock_retrieve_default.assert_called_once() - mock_jumpstart_estimator.assert_called_once() - mock_train_model.assert_called_once() - -@patch('magemaker.sagemaker.fine_tune_model.train_model') -def test_fine_tune_unsupported_model_sources(mock_train_model): - # Test HuggingFace model source - huggingface_model = Model( - id="google-bert/bert-base-uncased", - source=ModelSource.HuggingFace - ) - training = Training( - destination="aws", - instance_type="ml.m5.xlarge", - instance_count=1, - training_input_path="s3://test-bucket/train" - ) - - with pytest.raises(NotImplementedError): - fine_tune_model(training, huggingface_model) - - # Test Custom model source - custom_model = Model( - id="custom-model", - source=ModelSource.Custom - ) - - with pytest.raises(NotImplementedError): - fine_tune_model(training, custom_model) \ No newline at end of file diff --git a/magemaker/sagemaker/test_query_endpoint.py b/magemaker/sagemaker/test_query_endpoint.py deleted file mode 100644 index ad1f1bf..0000000 --- a/magemaker/sagemaker/test_query_endpoint.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import sys - -# Mock problematic imports -sys.modules['inquirer'] = MagicMock() -sys.modules['InquirerPy'] = MagicMock() - -from magemaker.schemas.query import Query -from magemaker.sagemaker.query_endpoint import make_query_request -from magemaker.schemas.deployment import Deployment -from magemaker.schemas.model import Model - -def test_make_query_request(): - with patch('magemaker.sagemaker.query_endpoint.is_sagemaker_model') as mock_is_sagemaker, \ - patch('magemaker.sagemaker.query_endpoint.query_sagemaker_endpoint') as mock_sagemaker_query, \ - patch('magemaker.sagemaker.query_endpoint.query_hugging_face_endpoint') as mock_hf_query: - - # Test Sagemaker model - mock_is_sagemaker.return_value = True - mock_sagemaker_query.return_value = "Sagemaker result" - - query = Query(query="Test query") - config = (MagicMock(), MagicMock()) - - result = make_query_request("test-endpoint", query, config) - assert result == "Sagemaker result" - - # Test HuggingFace model - mock_is_sagemaker.return_value = False - mock_hf_query.return_value = "HuggingFace result" - - result = make_query_request("test-endpoint", query, config) - assert result == "HuggingFace result" \ No newline at end of file diff --git a/scripts/preflight.sh b/scripts/preflight.sh deleted file mode 100755 index 971c2f0..0000000 --- a/scripts/preflight.sh +++ /dev/null @@ -1,384 +0,0 @@ -#!/bin/sh - -# Colors for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -NC='\033[0m' # No Color - -# Get the directory where the script is located -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -# Logging functions -log_info() { - echo "[INFO] $1" -} - -log_debug() { - echo "[DEBUG] $1" -} - -log_error() { - echo "[ERROR] $1" >&2 -} - -# Configuration functions -configure_aws() { -echo "Configuring AWS..." -echo "you need to create an aws user with access to Sagemaker" -echo "if you don't know how to do that follow this doc https://docs.google.com/document/d/1NvA6uZmppsYzaOdkcgNTRl7Nb4LbpP9Koc4H_t5xNSg/edit?usp=sharing" - - -# green -if ! command -v aws &> /dev/null -then - OS="$(uname -s)" - case "${OS}" in - Linux*) - curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" - unzip awscliv2.zip - sudo ./aws/install - ;; - Darwin*) - curl "https://awscli.amazonaws.com/AWSCLIV2.pkg" -o "AWSCLIV2.pkg" - sudo installer -pkg AWSCLIV2.pkg -target / - ;; - *) - echo "Unsupported OS: ${OS}. See https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html" - exit 1 - ;; - esac -fi - -# echo green that press enter if you have already done this -echo -e "${GREEN}Press enter in the following configuration steps if you have already done this${NC}" - -aws configure set region us-east-1 && aws configure -touch .env - - -if ! grep -q "SAGEMAKER_ROLE" .env -then - # bash ./setup_role.sh - bash "$SCRIPT_DIR/setup_role.sh" -fi -} - - -# GCP - -configure_gcp() { - echo "Configuring GCP..." -echo "you need to create a GCP service account with access to GCS and vertex ai" -echo "if you don't know how to do that follow this doc https://docs.google.com/document/d/1NvA6uZmppsYzaOdkcgNTRl7Nb4LbpP9Koc4H_t5xNSg/edit?usp=sharing" - -if ! command -v gcloud &> /dev/null -then - echo "you need to install gcloud sdk for the terminal" - echo "https://cloud.google.com/sdk/docs/install" -fi - -# only run this if the credentials are not set - - -echo "Checking for gcloud installation..." - -# Check if gcloud is installed -if ! command -v gcloud &> /dev/null; then - echo -e "${RED}Error: gcloud CLI is not installed${NC}" - echo "Please install the Google Cloud SDK first" - exit 1 -fi - -echo "Checking for active gcloud accounts..." - -# Get list of active accounts -ACCOUNTS=$(gcloud auth list --filter=status:ACTIVE --format="value(account)" 2>/dev/null) - -# Check if command was successful -if [ $? -ne 0 ]; then - echo -e "${RED}Error: Failed to retrieve account information${NC}" - echo "Please check your gcloud installation" - exit 1 -fi - -# Check if any accounts are found -if [ -z "$ACCOUNTS" ]; then - echo -e "${YELLOW}No active gcloud accounts found${NC}" - # echo "To login, use: gcloud auth login" - gcloud auth login - exit 0 -fi - -# echo "Setting up application default credentials..." -# gcloud auth application-default login --no-launch-browser - -# if [ $? -ne 0 ]; then -# echo -e "${RED}Failed to set application default credentials${NC}" -# exit 1 -# fi - -# Get current project ID -if ! grep -q "PROJECT_ID" .env -then - PROJECT_ID=$(gcloud config get-value project 2>/dev/null) - if [ -n "$PROJECT_ID" ]; then - export PROJECT_ID="$PROJECT_ID" - echo "PROJECT_ID=$PROJECT_ID" >> .env - echo -e "${GREEN}Exported PROJECT_ID=${NC}${PROJECT_ID}" - else - echo -e "${YELLOW}No project currently set${NC}" - fi -fi - -if ! grep -q "GCLOUD_REGION" .env -then - CURRENT_REGION=$(gcloud config get-value compute/region 2>/dev/null) - if [ -n "$CURRENT_REGION" ]; then - echo "GCLOUD_REGION=$CURRENT_REGION" >> .env - export GCLOUD_REGION="$CURRENT_REGION" - echo -e "${GREEN}Exported GCLOUD_REGION=${NC}${CURRENT_REGION}" - else - echo -e "${YELLOW}No compute region currently set${NC}" - fi -fi -} - -# AZURE -configure_azure() { -echo "Configuring Azure..." -echo "Checking for Azure CLI installation..." -if ! command -v az &> /dev/null -then - echo "Azure CLI not found. Installing..." - OS="$(uname -s)" - case "${OS}" in - Linux*) - curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash - ;; - Darwin*) - brew update && brew install azure-cli - ;; - *) - echo "Unsupported OS: ${OS}. See https://docs.microsoft.com/en-us/cli/azure/install-azure-cli" - exit 1 - ;; - esac -fi - -# Check Azure login status -echo "Checking Azure login status..." -if ! az account show &> /dev/null; then - echo "Not logged into Azure. Please log in..." - az login - if [ $? -ne 0 ]; then - echo "Azure login failed. Please try again." - exit 1 - fi -fi - -# Get and set subscription -if ! grep -q "AZURE_SUBSCRIPTION_ID" .env; then - SUBSCRIPTION_ID=$(az account show --query id -o tsv) - if [ -n "$SUBSCRIPTION_ID" ]; then - echo "AZURE_SUBSCRIPTION_ID=$SUBSCRIPTION_ID" >> .env - export AZURE_SUBSCRIPTION_ID="$SUBSCRIPTION_ID" - echo "Exported AZURE_SUBSCRIPTION_ID=${SUBSCRIPTION_ID}" - else - echo "No Azure subscription found" - exit 1 - fi -fi - -# Get and set resource group -if ! grep -q "AZURE_RESOURCE_GROUP" .env; then - echo "Listing resource groups..." - az group list -o table - echo "Please enter the resource group name to use:" - read RESOURCE_GROUP - if [ -n "$RESOURCE_GROUP" ]; then - echo "AZURE_RESOURCE_GROUP=$RESOURCE_GROUP" >> .env - export AZURE_RESOURCE_GROUP="$RESOURCE_GROUP" - echo "Exported AZURE_RESOURCE_GROUP=${RESOURCE_GROUP}" - else - echo "No resource group specified" - exit 1 - fi -fi - -# Get and set region -if ! grep -q "AZURE_REGION" .env; then - CURRENT_REGION=$(az group show --name $AZURE_RESOURCE_GROUP --query location -o tsv) - if [ -n "$CURRENT_REGION" ]; then - echo "AZURE_REGION=$CURRENT_REGION" >> .env - export AZURE_REGION="$CURRENT_REGION" - echo "Exported AZURE_REGION=${CURRENT_REGION}" - else - echo "Available Azure regions:" - az account list-locations --query "[].{Region:name}" -o table - echo "Please enter the Azure region to use:" - read AZURE_REGION - echo "AZURE_REGION=$AZURE_REGION" >> .env - export AZURE_REGION="$AZURE_REGION" - echo "Exported AZURE_REGION=${AZURE_REGION}" - fi -fi - -# Check Azure ML workspace -echo "Checking Azure ML workspace..." -if ! grep -q "AZURE_WORKSPACE_NAME" .env; then - # List available workspaces - echo "Available Azure ML workspaces in resource group $AZURE_RESOURCE_GROUP:" - az ml workspace list --resource-group $AZURE_RESOURCE_GROUP -o table - - echo "Please enter the Azure ML workspace name to use:" - read WORKSPACE_NAME - - if [ -n "$WORKSPACE_NAME" ]; then - # Verify workspace exists - if az ml workspace show --name $WORKSPACE_NAME --resource-group $AZURE_RESOURCE_GROUP &> /dev/null; then - echo "AZURE_WORKSPACE_NAME=$WORKSPACE_NAME" >> .env - export AZURE_ML_WORKSPACE="$WORKSPACE_NAME" - echo "Exported AZURE_WORKSPACE_NAME=${WORKSPACE_NAME}" - else - echo "Workspace $WORKSPACE_NAME not found in resource group $AZURE_RESOURCE_GROUP" - exit 1 - fi - else - echo "No workspace specified" - exit 1 - fi -fi - -# Function to check and register Azure resource providers -check_and_register_providers() { - local providers=( - "Microsoft.MachineLearningServices" - "Microsoft.ContainerRegistry" - "Microsoft.KeyVault" - "Microsoft.Storage" - "Microsoft.Insights" - "Microsoft.ContainerService" - "Microsoft.PolicyInsights" - "Microsoft.Cdn" - ) - - echo "Checking Azure resource providers..." - for provider in "${providers[@]}"; do - echo "Checking registration status for: $provider" - - # Get the registration state - state=$(az provider show --namespace $provider --query registrationState -o tsv 2>/dev/null) - - if [ "$state" != "Registered" ]; then - echo "$provider is not registered. Registering now..." - az provider register --namespace $provider - - # Wait for registration to complete - echo "Waiting for $provider registration to complete..." - while true; do - state=$(az provider show --namespace $provider --query registrationState -o tsv) - if [ "$state" == "Registered" ]; then - echo "$provider registration completed" - break - fi - echo "Registration in progress... waiting 10 seconds" - sleep 10 - done - else - echo "$provider is already registered" - fi - done - - echo "All required resource providers are registered" -} - -# Add this line after the Azure login check -echo "Checking and registering required Azure resource providers..." -check_and_register_providers - - -# Verify all required Azure environment variables are set -echo "Verifying Azure environment variables..." -REQUIRED_VARS=("AZURE_SUBSCRIPTION_ID" "AZURE_RESOURCE_GROUP" "AZURE_REGION" "AZURE_WORKSPACE_NAME") -for var in "${REQUIRED_VARS[@]}"; do - if ! grep -q "$var" .env; then - echo "Missing required environment variable: $var" - exit 1 - fi -done - -echo "Azure environment setup completed successfully!" - -# touch .env -} -configure_all_providers() { - log_info "Performing comprehensive multi-cloud configuration..." - - - # Detailed configuration for each cloud - configure_aws - configure_gcp - configure_azure - log_info "Multi-cloud configuration completed successfully" -} - -# Argument parsing -CLOUD="" -while [ $# -gt 0 ]; do - case "$1" in - --cloud) - shift - CLOUD="$1" - break - ;; - --cloud=*) - CLOUD="${1#*=}" - break - ;; - esac - shift -done - -# log_debug "Raw arguments: $@" -# log_debug "Cloud argument received: '$CLOUD'" - -# Validate cloud argument -# validate_cloud_arg() { -# case "$1" in -# aws|gcp|azure) -# return 0 -# ;; -# *) -# log_error "Invalid cloud provider: '$1'" -# log_error "Supported providers: aws, gcp, azure" -# exit 1 -# ;; -# esac -# } - -# Main configuration logic -main_configuration() { - # Validate cloud argument - # validate_cloud_arg "$CLOUD" - - # Configure specific cloud provider - case "$CLOUD" in - aws) - configure_aws - ;; - gcp) - configure_gcp - ;; - azure) - configure_azure - ;; - all) - configure_all_providers - ;; - *) - esac -} - -# Execute main configuration -main_configuration \ No newline at end of file From a4266311300136baf9d4661cd04f13124de95146 Mon Sep 17 00:00:00 2001 From: HamzaAyoub033 Date: Mon, 20 Jan 2025 22:28:28 +0000 Subject: [PATCH 4/4] Remove tests/test_cli.py from git tracking --- tests/test_cli.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/test_cli.py diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index e69de29..0000000