diff --git a/.gitignore b/.gitignore index bd3b0822..28e4946e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ pyenv/ .env generation-*.pb.json Pipfile* +.DS_Store diff --git a/nbs/Fine_Tuning_REST_API_External_Test.ipynb b/nbs/Fine_Tuning_REST_API_External_Test.ipynb new file mode 100644 index 00000000..a1e9f4fe --- /dev/null +++ b/nbs/Fine_Tuning_REST_API_External_Test.ipynb @@ -0,0 +1,992 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "GM-nO4oBegPm" + }, + "source": [ + "# Stability Fine-Tuning REST API Demo\n", + "\n", + "Thank you for trying the first external beta of the Stability Fine-Tuning service! Note that this is a **developer beta** - bugs and quality issues with the generated fine-tunes may occur. Please reach out to Stability if this is the case - and share what you've made as well!\n", + "\n", + "The code below hits the Stability REST API. This REST API contract is rather solid, so it's unlikely to see large changes before the production release of fine-tuning.\n", + "\n", + "Known issues:\n", + "\n", + "* Style fine-tunes may result in overfitting - if this is the case, uncomment the `# weight=1.0` field of `DiffusionFineTune` in the diffusion section(s) and provide a value between -1 and 1. You may need to go as low as 0.2 or 0.1.\n", + "* We will be exposing test parameters soon - please reach out with examples of datasets that produce overfitting or errors if you have them.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "T9ma2X7bhH8y", + "cellView": "form" + }, + "outputs": [], + "source": [ + "#@title Add your API key\n", + "import getpass\n", + "\n", + "#@markdown Execute this step and paste your API key in the box that appears.

Visit https://platform.stability.ai/account/keys to get your API key!
Note: If you are not on the fine-tuning whitelist you will receive an error during training!\n", + "API_KEY = getpass.getpass('Paste your Stability API Key here and press Enter: ')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5LM5SUUOhH8z" + }, + "outputs": [], + "source": [ + "#@title Install Dependencies & Helper Class\n", + "\n", + "#@markdown ## Install Dependencies & the Helper class\n", + "#@markdown To simplify implementing this fine-tuning service in your application, a small helper class is available for you to copy to your environment.

While this helper class will get you started, it is missing components that would make it more robust (e.g. retries) and you are encouraged to heavily modify it before using it to power your own fine-tuning applications.\n", + "\n", + "import io\n", + "import logging\n", + "import requests\n", + "import os\n", + "import shutil\n", + "import sys\n", + "import time\n", + "import json\n", + "import base64\n", + "from enum import Enum\n", + "from dataclasses import dataclass, is_dataclass, field, asdict\n", + "from typing import List, Optional, Any\n", + "from IPython.display import clear_output\n", + "from pathlib import Path\n", + "from PIL import Image\n", + "from zipfile import ZipFile\n", + "from google.colab import files\n", + "\n", + "\n", + "ENGINE_ID = \"stable-diffusion-xl-1024-v1-0\"\n", + "\n", + "class Printable:\n", + " \"\"\" Helper class for printing a class to the console. \"\"\"\n", + "\n", + " @staticmethod\n", + " def to_json(obj: Any) -> Any:\n", + " if isinstance(obj, Enum):\n", + " return obj.value\n", + " if is_dataclass(obj):\n", + " return asdict(obj)\n", + "\n", + " return obj\n", + "\n", + " def __str__(self):\n", + " return f\"{self.__class__.__name__}: {json.dumps(self, default=self.to_json, indent=4)}\"\n", + "\n", + "\n", + "class ToDict:\n", + " \"\"\" Helper class to simplify converting dataclasses to dicts. \"\"\"\n", + "\n", + " def to_dict(self):\n", + " return {k: v for k, v in asdict(self).items() if v is not None}\n", + "\n", + "\n", + "@dataclass\n", + "class FineTune(Printable):\n", + " id: str\n", + " user_id: str\n", + " name: str\n", + " mode: str\n", + " engine_id: str\n", + " training_set_id: str\n", + " status: str\n", + " failure_reason: Optional[str] = field(default=None)\n", + " duration_seconds: Optional[int] = field(default=None)\n", + " object_prompt: Optional[str] = field(default=None)\n", + "\n", + "\n", + "@dataclass\n", + "class DiffusionFineTune(Printable, ToDict):\n", + " id: str\n", + " token: str\n", + " weight: Optional[float] = field(default=None)\n", + "\n", + "\n", + "@dataclass\n", + "class TextPrompt(Printable, ToDict):\n", + " text: str\n", + " weight: Optional[float] = field(default=None)\n", + "\n", + "\n", + "class Sampler(Enum):\n", + " DDIM = \"DDIM\"\n", + " DDPM = \"DDPM\"\n", + " K_DPMPP_2M = \"K_DPMPP_2M\"\n", + " K_DPMPP_2S_ANCESTRAL = \"K_DPMPP_2S_ANCESTRAL\"\n", + " K_DPM_2 = \"K_DPM_2\"\n", + " K_DPM_2_ANCESTRAL = \"K_DPM_2_ANCESTRAL\"\n", + " K_EULER = \"K_EULER\"\n", + " K_EULER_ANCESTRAL = \"K_EULER_ANCESTRAL\"\n", + " K_HEUN = \"K_HEUN\"\n", + " K_LMS = \"K_LMS\"\n", + "\n", + " @staticmethod\n", + " def from_string(val) -> Enum or None:\n", + " for sampler in Sampler:\n", + " if sampler.value == val:\n", + " return sampler\n", + " raise Exception(f\"Unknown Sampler: {val}\")\n", + "\n", + "\n", + "@dataclass\n", + "class TextToImageParams(Printable):\n", + " fine_tunes: List[DiffusionFineTune]\n", + " text_prompts: List[TextPrompt]\n", + " samples: int\n", + " sampler: Sampler\n", + " engine_id: str\n", + " steps: int\n", + " seed: Optional[int] = field(default=0)\n", + " cfg_scale: Optional[int] = field(default=7)\n", + " width: Optional[int] = field(default=1024)\n", + " height: Optional[int] = field(default=1024)\n", + "\n", + "\n", + "@dataclass\n", + "class DiffusionResult:\n", + " base64: str\n", + " seed: int\n", + " finish_reason: str\n", + "\n", + " def __str__(self):\n", + " return f\"DiffusionResult(base64='too long to print', seed='{self.seed}', finish_reason='{self.finish_reason}')\"\n", + "\n", + " def __repr__(self):\n", + " return self.__str__()\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSetBase(Printable):\n", + " id: str\n", + " name: str\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSetImage(Printable):\n", + " id: str\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSet(TrainingSetBase):\n", + " images: List[TrainingSetImage]\n", + "\n", + "\n", + "class FineTuningRESTWrapper:\n", + " \"\"\"\n", + " Helper class to simplify interacting with the fine-tuning service via\n", + " Stability's REST API.\n", + "\n", + " While this class can be copied to your local environment, it is not likely\n", + " robust enough for your needs and does not support all of the features that\n", + " the REST API offers.\n", + " \"\"\"\n", + "\n", + " def __init__(self, api_key: str, api_host: str):\n", + " self.api_key = api_key\n", + " self.api_host = api_host\n", + "\n", + " def create_fine_tune(self,\n", + " name: str,\n", + " images: List[str],\n", + " engine_id: str,\n", + " mode: str,\n", + " object_prompt: Optional[str] = None) -> FineTune:\n", + " print(f\"Creating {mode} fine-tune called '{name}' using {len(images)} images...\")\n", + "\n", + " payload = {\"name\": name, \"engine_id\": engine_id, \"mode\": mode}\n", + " if object_prompt is not None:\n", + " payload[\"object_prompt\"] = object_prompt\n", + "\n", + " # Create a training set\n", + " training_set_id = self.create_training_set(name=name)\n", + " payload[\"training_set_id\"] = training_set_id\n", + " print(f\"\\tCreated training set {training_set_id}\")\n", + "\n", + " # Add images to the training set\n", + " for image in images:\n", + " print(f\"\\t\\tAdding {os.path.basename(image)}\")\n", + " self.add_image_to_training_set(\n", + " training_set_id=training_set_id,\n", + " image=image\n", + " )\n", + "\n", + " # Create the fine-tune\n", + " print(f\"\\tCreating a fine-tune from the training set\")\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/fine-tunes\",\n", + " json=payload,\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + " raise_on_non200(response)\n", + " print(f\"\\tCreated fine-tune {response.json()['id']}\")\n", + "\n", + " print(f\"Success\")\n", + " return FineTune(**response.json())\n", + "\n", + " def get_fine_tune(self, fine_tune_id: str) -> FineTune:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def list_fine_tunes(self) -> List[FineTune]:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/fine-tunes\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [FineTune(**ft) for ft in response.json()]\n", + "\n", + " def rename_fine_tune(self, fine_tune_id: str, name: str) -> FineTune:\n", + " response = requests.patch(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " json={\"operation\": \"RENAME\", \"name\": name},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def retrain_fine_tune(self, fine_tune_id: str) -> FineTune:\n", + " response = requests.patch(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " json={\"operation\": \"RETRAIN\"},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def delete_fine_tune(self, fine_tune: FineTune):\n", + " # Delete the underlying training set\n", + " self.delete_training_set(fine_tune.training_set_id)\n", + "\n", + " # Delete the fine-tune\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune.id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def create_training_set(self, name: str) -> str:\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/training-sets\",\n", + " json={\"name\": name},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return response.json().get('id')\n", + "\n", + " def get_training_set(self, training_set_id: str) -> TrainingSet:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return TrainingSet(**response.json())\n", + "\n", + " def list_training_sets(self) -> List[TrainingSetBase]:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/training-sets\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [TrainingSetBase(**tsb) for tsb in response.json()]\n", + "\n", + " def add_image_to_training_set(self, training_set_id: str, image: str) -> str:\n", + " with open(image, 'rb') as image_file:\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}/images\",\n", + " headers={\"Authorization\": self.api_key},\n", + " files={'image': image_file}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return response.json().get('id')\n", + "\n", + " def remove_image_from_training_set(self, training_set_id: str, image_id: str) -> None:\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}/images/{image_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def delete_training_set(self, training_set_id: str) -> None:\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def text_to_image(self, params: TextToImageParams) -> List[DiffusionResult]:\n", + " payload = {\n", + " \"fine_tunes\": [ft.to_dict() for ft in params.fine_tunes],\n", + " \"text_prompts\": [tp.to_dict() for tp in params.text_prompts],\n", + " \"samples\": params.samples,\n", + " \"sampler\": params.sampler.value,\n", + " \"steps\": params.steps,\n", + " \"seed\": params.seed,\n", + " \"width\": params.width,\n", + " \"height\": params.height,\n", + " \"cfg_scale\": params.cfg_scale,\n", + " }\n", + "\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/generation/{params.engine_id}/text-to-image\",\n", + " json=payload,\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Accept\": \"application/json\",\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [\n", + " DiffusionResult(base64=item[\"base64\"], seed=item[\"seed\"], finish_reason=item[\"finishReason\"])\n", + " for item in response.json()[\"artifacts\"]\n", + " ]\n", + "\n", + "\n", + "def raise_on_non200(response):\n", + " if 200 <= response.status_code < 300:\n", + " return\n", + " raise Exception(f\"Status code {response.status_code}: {json.dumps(response.json(), indent=4)}\")\n", + "\n", + "\n", + "# Redirect logs to print statements so we can see them in the notebook\n", + "class PrintHandler(logging.Handler):\n", + " def emit(self, record):\n", + " print(self.format(record))\n", + "logging.getLogger().addHandler(PrintHandler())\n", + "logging.getLogger().setLevel(logging.INFO)\n", + "\n", + "# Initialize the fine-tune service\n", + "rest_api = FineTuningRESTWrapper(API_KEY, \"https://preview-api.stability.ai\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t_gOf0i_gmCd" + }, + "source": [ + "## Add Training Images\n", + "\n", + "For training, we need a dataset of images in a `.zip` file.\n", + "\n", + "Please only upload images that you have the permission to use.\n", + "\n", + "\n", + "### Dataset image dimensions\n", + "\n", + "- Images **cannot** have any side less than 328px\n", + "- Images **cannot** be larger than 10MB\n", + "\n", + "There is no upper-bound for what we'll accept for an image's dimensions, but any side above 1024px will be scaled down to 1024px, while preserving aspect ratio. For example:\n", + "- `3024x4032` will be scaled down to `768x1024`\n", + "- `1118x1118` will be scaled down to `1024x1024`\n", + "\n", + "\n", + "### Dataset size\n", + "\n", + "- Datasets **cannot** have fewer than 3 images\n", + "- Datasets **cannot** have more than 64 images\n", + "\n", + "A larger dataset often tends to result in a more accurate fine-tune, but will also take longer to train.\n", + "\n", + "While each mode can accept up to 64 images, we have a few suggestions for a starter dataset based on the mode you are using:\n", + "* `FACE`: 6 or more images.\n", + "* `OBJECT`: 6 - 10 images.\n", + "* `STYLE`: 20 - 30 images.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9C1YOFxIhTJp" + }, + "outputs": [], + "source": [ + "#@title Upload ZIP file of images\n", + "training_dir = \"./train\"\n", + "Path(training_dir).mkdir(exist_ok=True)\n", + "try:\n", + " from google.colab import files\n", + "\n", + " upload_res = files.upload()\n", + " print(upload_res.keys())\n", + " extracted_dir = list(upload_res.keys())[0]\n", + " print(f\"Received {extracted_dir}\")\n", + " if not extracted_dir.endswith(\".zip\"):\n", + " raise ValueError(\"Uploaded file must be a zip file\")\n", + "\n", + " zf = ZipFile(io.BytesIO(upload_res[extracted_dir]), \"r\")\n", + " extracted_dir = Path(extracted_dir).stem\n", + " print(f\"Extracting to {extracted_dir}\")\n", + " zf.extractall(extracted_dir)\n", + "\n", + " for root, dirs, files in os.walk(extracted_dir):\n", + " for file in files:\n", + " source_path = os.path.join(root, file)\n", + " target_path = os.path.join(training_dir, file)\n", + "\n", + " if 'MACOSX' in source_path or 'DS' in source_path:\n", + " continue\n", + " print('Copying', source_path, '==>', target_path)\n", + " # Move the file to the target directory\n", + " shutil.move(source_path, target_path)\n", + "\n", + "\n", + "except ImportError:\n", + " pass\n", + "\n", + "print(f\"Using training images from: {training_dir}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8xJLYZ4fgoU9" + }, + "source": [ + "## Train a Fine-Tune\n", + "\n", + "Now we're ready to train our fine-tune. Use the parameters below to configure the name and the kind of fine-tune\n", + "\n", + "Please note that the training duration will vary based on:\n", + "- The number of images in your dataset\n", + "- The `training_mode` used\n", + "- The `engine_id` that is being fine-tuned on\n", + "\n", + "However, the following are some rough estimates for the training duration for each mode based on our recommended dataset sizes:\n", + "\n", + "* `FACE`: 4 - 5 minutes.\n", + "* `OBJECT`: 5 - 10 minutes.\n", + "* `STYLE`: 20 - 30 minutes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VLyYQVM3hH8z" + }, + "outputs": [], + "source": [ + "#@title Create a fine-tune\n", + "fine_tune_name = \"my dog spot\" #@param {type:\"string\"}\n", + "#@markdown > Requirements: \n", + "training_mode = \"OBJECT\" #@param [\"FACE\", \"STYLE\", \"OBJECT\"] {type:\"string\"}\n", + "#@markdown > Determines the kind of fine-tune you're creating: \n", + "object_prompt = \"dog\" #@param {type:\"string\"}\n", + "#@markdown > (This field is ignored if `training_mode` is `FACE` or `STYLE`).
Used for segmenting out your subject when the `training_mode` is `OBJECT`. If you want to fine tune on a cat, use `cat` - for a bottle of liquor, use `bottle`. In general, it's best to use the most general word you can to describe your object.\n", + "\n", + "# Gather training images\n", + "images = []\n", + "for filename in os.listdir(training_dir):\n", + " if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg', '.heic']:\n", + " images.append(os.path.join(training_dir, filename))\n", + "\n", + "# Create the fine-tune\n", + "fine_tune = rest_api.create_fine_tune(\n", + " name=fine_tune_name,\n", + " images=images,\n", + " mode=training_mode,\n", + " object_prompt=object_prompt if training_mode == \"OBJECT\" else None,\n", + " engine_id=ENGINE_ID,\n", + ")\n", + "\n", + "print()\n", + "print(fine_tune)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yEKyO3-bhH8z" + }, + "outputs": [], + "source": [ + "#@title Check on training status\n", + "start_time = time.time()\n", + "while fine_tune.status != \"COMPLETED\" and fine_tune.status != \"FAILED\":\n", + " fine_tune = rest_api.get_fine_tune(fine_tune.id)\n", + " elapsed = time.time() - start_time\n", + " clear_output(wait=True)\n", + " print(f\"Training '{fine_tune.name}' ({fine_tune.id}) status: {fine_tune.status} for {elapsed:.0f} seconds\")\n", + " time.sleep(10)\n", + "\n", + "clear_output(wait=True)\n", + "status_message = \"completed\" if fine_tune.status == \"COMPLETED\" else \"failed\"\n", + "print(f\"Training '{fine_tune.name}' ({fine_tune.id}) {status_message} after {elapsed:.0f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Qr4jBHX7hH8z" + }, + "outputs": [], + "source": [ + "#@title (Optional) Retrain if training failed\n", + "if fine_tune.status == \"FAILED\":\n", + " print(f\"Training failed, due to \\\"{fine_tune.failure_reason}\\\". Retraining...\")\n", + " fine_tune = rest_api.retrain_fine_tune(fine_tune.id)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Diffuse with your new Fine-Tune\n", + "\n", + "The example below uses the fine-tune you just finished training in the steps above. \n", + "If you want to diffuse using an _existing_ fine-tune please jump to the next section entitled [\"Diffuse with Existing Fine-Tunes\"](#scrollTo=L5LC8VWhsTVD)." + ], + "metadata": { + "id": "hHr0rq2Xo7Cz" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-Ugkjgy2hH8z" + }, + "outputs": [], + "source": [ + "#@title Generate Images\n", + "#@markdown ## Diffusion parameters\n", + "\n", + "fine_tune_alias=\"$my-dog\" #@param {type:\"string\"}\n", + "#@markdown > This is an alias for your fine-tune, allowing you to refer to the fine-tune directly in the prompt. This token is ephemeral and can be any valid text (though we recommend starting it with a `$` and using dashes instead of spaces e.g. `$my-dog`). This is _not_ the fine_tune_name you assigned in a prior step, this is just a short-hand we use to determine where to apply your fine-tune in your prompt.

For example, if your token was `$my-dog` you might use a prompt like: `a picture of $my-dog` or `$my-dog chasing a rabbit`. This syntax really shine when you have more than one fine-tune too! Given some fine-tune of film noir images you could use a prompt like `$my-dog in the style of $film-noir`.\n", + "prompt=\"a photo of $my-dog\" #@param {type:\"string\"}\n", + "#@markdown > The prompt to diffuse with. Must contain the `fine_tune_alias` at least once.\n", + "dimensions=\"1024x1024\" #@param ['1024x1024', '1152x896', '1216x832', '1344x768', '1536x640', '640x1536', '768x1344', '832x1216', '896x1152']\n", + "#@markdown > The dimensions of the image to generate, in pixels, and in the format width x height.\n", + "samples=2 #@param {type:\"slider\", min:1, max:10, step:1}\n", + "#@markdown > The number of images to generate. The higher the value the longer the generation times.\n", + "steps=32 #@param {type:\"slider\", min:30, max:60, step:1}\n", + "#@markdown > The number of iterations or stages a diffusion model goes through in the process of generating an image from a given text prompt. Lower steps will generate more quickly, but if steps are lowered too much, image quality will suffer. Images with higher steps take longer to generate, but often give more detailed results.\n", + "cfg_scale=7 #@param {type:\"slider\", min:0, max:35, step:1}\n", + "#@markdown > CFG (Classifier Free Guidance) scale determines how strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt).\n", + "seed=0 #@param {type:\"number\"}\n", + "#@markdown > The noise seed to use during diffusion. Using `0` means a random seed will be generated for each image. If you provide a non-zero value, images will be far less random.\n", + "download_results = False # @param {type:\"boolean\"}\n", + "#@markdown > Results are displayed inline below this section. By checking this box, the generated images will also be downloaded to your local machine.\n", + "\n", + "params = TextToImageParams(\n", + " fine_tunes=[\n", + " DiffusionFineTune(\n", + " id=fine_tune.id,\n", + " token=fine_tune_alias,\n", + " # Uncomment the following to provide a weight for the fine-tune\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " text_prompts=[\n", + " TextPrompt(\n", + " text=prompt,\n", + " # Uncomment the following to provide a weight for the prompt\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " engine_id=ENGINE_ID,\n", + " samples=samples,\n", + " steps=steps,\n", + " seed=seed,\n", + " cfg_scale=cfg_scale,\n", + " width=int(dimensions.split(\"x\")[0]),\n", + " height=int(dimensions.split(\"x\")[1]),\n", + " sampler=Sampler.K_DPMPP_2S_ANCESTRAL\n", + ")\n", + "\n", + "start_time = time.time()\n", + "\n", + "images = rest_api.text_to_image(params)\n", + "\n", + "elapsed = time.time() - start_time\n", + "\n", + "print(f\"Diffusion completed in {elapsed:.0f} seconds!\")\n", + "print(f\"The {len(images)} result{'s' if len(images) > 1 else ''} will be displayed below momentarily (depending on the speed of Colab).\\n\")\n", + "\n", + "for image in images:\n", + " display(Image.open(io.BytesIO(base64.b64decode(image.base64))))\n", + "\n", + "if download_results:\n", + " print(f\"Downloading {len(images)} images to disk.\")\n", + " from google.colab import files\n", + " Path('./out').mkdir(parents=True, exist_ok=True)\n", + " for index, image in enumerate(images):\n", + " with open(f'./out/txt2img_{image.seed}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image.base64))\n", + " files.download(f.name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3BZLVniihH8z" + }, + "outputs": [], + "source": [ + "#@title (Optional) Rename your new fine-tune\n", + "#@markdown Running this section to rename the fine-tune you just trained. The same naming rules from [this step](#scrollTo=VLyYQVM3hH8z) still apply.\n", + "\n", + "name = \"\" #@param {type:\"string\"}\n", + "rest_api.rename_fine_tune(fine_tune.id, name=name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eUFTMZOvhH80" + }, + "outputs": [], + "source": [ + "#@title (Optional) Delete your new fine-tune\n", + "#@markdown Running this section will **delete** the fine-tune you just trained. To prevent accidental deletions you need to check the box below before running this section in order for the delete to occur.\n", + "\n", + "confirm_deletion=False #@param {type:\"boolean\"}\n", + "if confirm_deletion:\n", + " rest_api.delete_fine_tune(fine_tune)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Diffuse with Existing Fine-Tunes\n", + "\n", + "Here you can diffuse with an existing fine-tune or multiple fine-tunes at once! Using multiple fine-tunes together is where this service really shines, so train a few and give the examples below a try!" + ], + "metadata": { + "id": "L5LC8VWhsTVD" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iQL5dFsfhH8z" + }, + "outputs": [], + "source": [ + "#@title List all of your Fine-Tunes\n", + "fine_tunes = rest_api.list_fine_tunes()\n", + "print(f\"Found {len(fine_tunes)} models\")\n", + "for fine_tune in fine_tunes:\n", + " print(f\" Model {fine_tune.id} {fine_tune.status:<9} {fine_tune.name}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ADUHK3Y9yIhr" + }, + "outputs": [], + "source": [ + "#@title Generate using 1 Fine-Tune\n", + "#@markdown # Diffusion parameters\n", + "\n", + "fine_tune_id=\"\" #@param {type:\"string\"}\n", + "#@markdown > The ID of the fine-tune to diffuse with. Run [this step](#scrollTo=iQL5dFsfhH8z) to list all of your existing fine-tunes.\n", + "fine_tune_alias=\"\" #@param {type:\"string\"}\n", + "#@markdown > This token acts as an alias for your fine-tune, allowing you to refer to the fine-tune directly in the prompt. This token is ephemeral and can be any valid text (though we recommend starting it with a `$` and using dashes instead of spaces).\n", + "prompt=\"\" #@param {type:\"string\"}\n", + "#@markdown > The prompt to diffuse with. Must contain the `fine_tune_alias` at least once.\n", + "dimensions=\"1024x1024\" #@param ['1024x1024', '1152x896', '1216x832', '1344x768', '1536x640', '640x1536', '768x1344', '832x1216', '896x1152']\n", + "#@markdown > The dimensions of the image to generate, in pixels, and in the format width x height.\n", + "samples=1 #@param {type:\"slider\", min:1, max:10, step:1}\n", + "#@markdown > The number of images to generate. Requesting a large number of images may negatively response time.\n", + "steps=32 #@param {type:\"slider\", min:30, max:60, step:1}\n", + "#@markdown > The number of iterations or stages a diffusion model goes through in the process of generating an image from a given text prompt.\n", + "cfg_scale=7 #@param {type:\"slider\", min:0, max:35, step:1}\n", + "#@markdown > CFG (Classifier Free Guidance) scale determines how strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt).\n", + "seed=0 #@param {type:\"number\"}\n", + "#@markdown > The noise seed to use during diffusion. Using `0` means a random seed will be generated for each image. If you provide a non-zero value, images will be far less random.\n", + "download_results = False # @param {type:\"boolean\"}\n", + "#@markdown > Results are displayed inline below this section. By checking this box, the generated images will also be downloaded to your local machine.\n", + "\n", + "params = TextToImageParams(\n", + " fine_tunes=[\n", + " DiffusionFineTune(\n", + " id=fine_tune_id,\n", + " token=fine_tune_alias,\n", + " # Uncomment the following to provide a weight for the fine-tune\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " text_prompts=[\n", + " TextPrompt(\n", + " text=prompt,\n", + " # Uncomment the following to provide a weight for the prompt\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " engine_id=ENGINE_ID,\n", + " samples=samples,\n", + " steps=steps,\n", + " seed=seed,\n", + " cfg_scale=cfg_scale,\n", + " width=int(dimensions.split(\"x\")[0]),\n", + " height=int(dimensions.split(\"x\")[1]),\n", + " sampler=Sampler.K_DPMPP_2S_ANCESTRAL\n", + ")\n", + "\n", + "start_time = time.time()\n", + "\n", + "images = rest_api.text_to_image(params)\n", + "\n", + "elapsed = time.time() - start_time\n", + "print(f\"Diffusion completed in {elapsed:.0f} seconds!\")\n", + "print(f\"The {len(images)} result{'s' if len(images) > 1 else ''} will be displayed below momentarily (depending on the speed of Colab).\\n\")\n", + "\n", + "for image in images:\n", + " display(Image.open(io.BytesIO(base64.b64decode(image.base64))))\n", + "\n", + "if download_results:\n", + " print(f\"Downloading {len(images)} images to disk.\")\n", + " from google.colab import files\n", + " Path('./out').mkdir(parents=True, exist_ok=True)\n", + " for index, image in enumerate(images):\n", + " with open(f'./out/txt2img_{image.seed}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image.base64))\n", + " files.download(f.name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XL_tKpuUKnXP" + }, + "outputs": [], + "source": [ + "#@title Generate using 2 Fine-Tunes\n", + "#@markdown # Diffusion parameters\n", + "#@markdown Documentation of some of the fields below has been omitted for brevity. See [this step](#scrollTo=ADUHK3Y9yIhr) for more detailed documentation around these diffusion parameters.\n", + "\n", + "first_fine_tune_id=\"\" #@param {type:\"string\"}\n", + "first_fine_tune_alias=\"\" #@param {type:\"string\"}\n", + "second_fine_tune_id=\"\" #@param {type:\"string\"}\n", + "second_fine_tune_alias=\"\" #@param {type:\"string\"}\n", + "prompt=\"\" #@param {type:\"string\"}\n", + "#@markdown > The prompt to diffuse with. Must contain both the `first_fine_tune_alias` and `second_fine_tune_alias` at least once.\n", + "dimensions=\"1024x1024\" #@param ['1024x1024', '1152x896', '1216x832', '1344x768', '1536x640', '640x1536', '768x1344', '832x1216', '896x1152']\n", + "#@markdown > The dimensions of the image to generate, in pixels, and in the format width x height.\n", + "samples=10 #@param {type:\"slider\", min:1, max:10, step:1}\n", + "#@markdown > The number of images to generate. Requesting a large number of images may negatively response time.\n", + "steps=32 #@param {type:\"slider\", min:30, max:60, step:1}\n", + "#@markdown > The number of iterations or stages a diffusion model goes through in the process of generating an image from a given text prompt.\n", + "cfg_scale=7 #@param {type:\"slider\", min:0, max:35, step:1}\n", + "#@markdown > CFG (Classifier Free Guidance) scale determines how strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt).\n", + "seed=0 #@param {type:\"number\"}\n", + "#@markdown > The noise seed to use during diffusion. Using `0` means a random seed will be generated for each image. If you provide a non-zero value, images will be far less random.\n", + "download_results = False # @param {type:\"boolean\"}\n", + "#@markdown > Results are displayed inline below this section. By checking this box, the generated images will also be downloaded to your local machine.\n", + "\n", + "params = TextToImageParams(\n", + " fine_tunes=[\n", + " DiffusionFineTune(\n", + " id=first_fine_tune_id,\n", + " token=first_fine_tune_alias,\n", + " # Uncomment the following to provide a weight for this fine-tune\n", + " # weight=1.0\n", + " ),\n", + " DiffusionFineTune(\n", + " id=second_fine_tune_id,\n", + " token=second_fine_tune_alias,\n", + " # Uncomment the following to provide a weight for this fine-tune\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " text_prompts=[TextPrompt(text=prompt)],\n", + " engine_id=ENGINE_ID,\n", + " samples=samples,\n", + " steps=steps,\n", + " seed=seed,\n", + " cfg_scale=cfg_scale,\n", + " width=int(dimensions.split(\"x\")[0]),\n", + " height=int(dimensions.split(\"x\")[1]),\n", + " sampler=Sampler.K_DPMPP_2S_ANCESTRAL\n", + ")\n", + "\n", + "start_time = time.time()\n", + "\n", + "images = rest_api.text_to_image(params)\n", + "\n", + "elapsed = time.time() - start_time\n", + "print(f\"Diffusion completed in {elapsed:.0f} seconds!\")\n", + "print(f\"The {len(images)} result{'s' if len(images) > 1 else ''} will be displayed below momentarily (depending on the speed of Colab).\\n\")\n", + "\n", + "for image in images:\n", + " display(Image.open(io.BytesIO(base64.b64decode(image.base64))))\n", + "\n", + "if download_results:\n", + " print(f\"Downloading {len(images)} images to disk.\")\n", + " from google.colab import files\n", + " Path('./out').mkdir(parents=True, exist_ok=True)\n", + " for index, image in enumerate(images):\n", + " with open(f'./out/txt2img_{image.seed}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image.base64))\n", + " files.download(f.name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UBGM2vbFKqbM" + }, + "outputs": [], + "source": [ + "#@title Generate using 3 Fine-Tunes\n", + "#@markdown # Diffusion parameters\n", + "#@markdown Documentation of some of the fields below has been omitted for brevity. See [this step](#scrollTo=ADUHK3Y9yIhr) for more detailed documentation around these diffusion parameters.
When using this many fine-tunes together at once, you may need to tweak the weights for each fine-tune individually. The code for this section contains comments for how to assign an individual weight for each fine-tune.\n", + "\n", + "first_fine_tune_id=\"\" #@param {type:\"string\"}\n", + "first_fine_tune_alias=\"\" #@param {type:\"string\"}\n", + "second_fine_tune_id=\"\" #@param {type:\"string\"}\n", + "second_fine_tune_alias=\"\" #@param {type:\"string\"}\n", + "third_fine_tune_id=\"\" #@param {type:\"string\"}\n", + "third_fine_tune_alias=\"\" #@param {type:\"string\"}\n", + "prompt=\"\" #@param {type:\"string\"}\n", + "#@markdown > The prompt to diffuse with. Must contain the `first_fine_tune_alias`, `second_fine_tune_alias`, and `third_fine_tune_alias` at least once.\n", + "dimensions=\"1024x1024\" #@param ['1024x1024', '1152x896', '1216x832', '1344x768', '1536x640', '640x1536', '768x1344', '832x1216', '896x1152']\n", + "#@markdown > The dimensions of the image to generate, in pixels, and in the format width x height.\n", + "samples=1 #@param {type:\"slider\", min:1, max:10, step:1}\n", + "#@markdown > The number of images to generate. Requesting a large number of images may negatively response time.\n", + "steps=32 #@param {type:\"slider\", min:30, max:60, step:1}\n", + "#@markdown > The number of iterations or stages a diffusion model goes through in the process of generating an image from a given text prompt.\n", + "cfg_scale=7 #@param {type:\"slider\", min:0, max:35, step:1}\n", + "#@markdown > CFG (Classifier Free Guidance) scale determines how strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt).\n", + "seed=0 #@param {type:\"number\"}\n", + "#@markdown > The noise seed to use during diffusion. Using `0` means a random seed will be generated for each image. If you provide a non-zero value, images will be far less random.\n", + "download_results = False # @param {type:\"boolean\"}\n", + "#@markdown > Results are displayed inline below this section. By checking this box, the generated images will also be downloaded to your local machine.\n", + "\n", + "params = TextToImageParams(\n", + " fine_tunes=[\n", + " DiffusionFineTune(\n", + " id=first_fine_tune_id,\n", + " token=first_fine_tune_alias,\n", + " # Uncomment the following to provide a weight for this fine-tune\n", + " # weight=1.0\n", + " ),\n", + " DiffusionFineTune(\n", + " id=second_fine_tune_id,\n", + " token=second_fine_tune_alias,\n", + " # Uncomment the following to provide a weight for this fine-tune\n", + " # weight=1.0\n", + " ),\n", + " DiffusionFineTune(\n", + " id=third_fine_tune_id,\n", + " token=third_fine_tune_alias,\n", + " # Uncomment the following to provide a weight for this fine-tune\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " text_prompts=[TextPrompt(text=prompt)],\n", + " engine_id=ENGINE_ID,\n", + " samples=samples,\n", + " steps=steps,\n", + " seed=seed,\n", + " cfg_scale=cfg_scale,\n", + " width=int(dimensions.split(\"x\")[0]),\n", + " height=int(dimensions.split(\"x\")[1]),\n", + " sampler=Sampler.K_DPMPP_2S_ANCESTRAL\n", + ")\n", + "\n", + "start_time = time.time()\n", + "images = rest_api.text_to_image(params)\n", + "elapsed = time.time() - start_time\n", + "print(f\"Diffusion completed in {elapsed:.0f} seconds!\")\n", + "print(f\"The {len(images)} result{'s' if len(images) > 1 else ''} will be displayed below momentarily (depending on the speed of Colab).\\n\")\n", + "\n", + "for image in images:\n", + " display(Image.open(io.BytesIO(base64.b64decode(image.base64))))\n", + "\n", + "if download_results:\n", + " print(f\"Downloading {len(images)} images to disk.\")\n", + " from google.colab import files\n", + " Path('./out').mkdir(parents=True, exist_ok=True)\n", + " for index, image in enumerate(images):\n", + " with open(f'./out/txt2img_{image.seed}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image.base64))\n", + " files.download(f.name)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/nbs/Fine_tuning_SDK_external_test.ipynb b/nbs/Fine_tuning_SDK_external_test.ipynb new file mode 100644 index 00000000..554e3e81 --- /dev/null +++ b/nbs/Fine_tuning_SDK_external_test.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "GM-nO4oBegPm" + }, + "source": [ + "# Stability Fine-Tuning SDK - Dev Test\n", + "\n", + "Thank you for trying the first external beta of the Stability Fine Tuning SDK! Please reach out to us if you have any questions or run into issues using the service. Note that this is a **developer beta** - bugs and quality issues with the generated fine-tunes may occur. Please reach out to Stability if this is the case - and share what you've made as well!\n", + "\n", + "Feel free to implement the gRPC SDK below in your own code, though be warned that the API below is subject to change before public release. A REST API will also be available in the near future.\n", + "\n", + "Known issues:\n", + "\n", + "* Style fine-tunes may result in overfitting - if this is the case, lower the model strength in the prompt - i.e. the `0.7` in `` within the prompt. You may need to go as low as 0.2 or 0.1.\n", + "* We will be exposing test parameters soon - please reach out with examples of datasets that produce overfitting or errors if you have them.\n", + "* Current input image limits are 3 minimum for all modes, 128 maximum for style fine-tuning, and 64 maximum for all other modes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "T9ma2X7bhH8y", + "outputId": "857cf355-b95c-43c6-c630-d2e680ee201b" + }, + "outputs": [], + "source": [ + "#@title Install Stability SDK with fine-tuning support\n", + "import getpass\n", + "import io\n", + "import logging\n", + "import os\n", + "import shutil\n", + "import sys\n", + "import time\n", + "from IPython.display import clear_output\n", + "from pathlib import Path\n", + "from zipfile import ZipFile\n", + "\n", + "if os.path.exists(\"../src/stability_sdk\"):\n", + " sys.path.append(\"../src\") # use local SDK src\n", + "else:\n", + " path = Path('stability-sdk')\n", + " if path.exists():\n", + " shutil.rmtree(path)\n", + " !pip uninstall -y stability-sdk\n", + " !git clone -b \"PLATFORM-339\" --recurse-submodules https://github.com/Stability-AI/stability-sdk\n", + " !pip install ./stability-sdk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5LM5SUUOhH8z", + "outputId": "ace3ee80-7405-40f0-f18c-476101200900" + }, + "outputs": [], + "source": [ + "#@title Connect to the Stability API\n", + "from stability_sdk.api import Context, generation\n", + "from stability_sdk.finetune import (\n", + " create_model, delete_model, get_model, list_models, resubmit_model, update_model,\n", + " FineTuneMode, FineTuneParameters, FineTuneStatus\n", + ")\n", + "\n", + "# @markdown To get your API key visit https://dreamstudio.ai/account. Ensure you are added to the whitelist during external test!\n", + "STABILITY_HOST = \"grpc.stability.ai:443\"\n", + "STABILITY_KEY = \"\"\n", + "\n", + "engine_id = \"stable-diffusion-xl-1024-v1-0\"\n", + "\n", + "# Create API context to query user info and generate images\n", + "context = Context(STABILITY_HOST, STABILITY_KEY, generate_engine_id=engine_id)\n", + "(balance, pfp) = context.get_user_info()\n", + "print(f\"Logged in org:{context._user_organization_id} with balance:{balance}\")\n", + "\n", + "# Redirect logs to print statements so we can see them in the notebook\n", + "class PrintHandler(logging.Handler):\n", + " def emit(self, record):\n", + " print(self.format(record))\n", + "logging.getLogger().addHandler(PrintHandler())\n", + "logging.getLogger().setLevel(logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "iQL5dFsfhH8z", + "outputId": "bfc71c3d-e235-4e10-e352-14add4a100c9" + }, + "outputs": [], + "source": [ + "#@title List fine-tuned models for this user / organization.\n", + "models = list_models(context, org_id=context._user_organization_id)\n", + "print(f\"Found {len(models)} models\")\n", + "for model in models:\n", + " print(f\" Model {model.id} {model.name} {model.status}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t_gOf0i_gmCd" + }, + "source": [ + "For training, we need a dataset of images. Please only upload images that you have the permission to use. This can be a folder of images or a .zip file containing your images. Images can be of any aspect ratio, as long as they obey a minimum size of 384px on the shortest side, and a maximum size of 1024px on the longest side. Datasets can range from a minimum of 4 images to a maximum of 128 images.\n", + "\n", + "A larger dataset often tends to result in a more accurate model, but will also take longer to train.\n", + "\n", + "While each mode can accept up to 128 images, we have a few suggestions for a starter dataset based on the mode you are using:\n", + "\n", + "\n", + "\n", + "* Face: 6 or more images.\n", + "* Object: 6 - 10 images.\n", + "* Style: 20 - 30 images.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 395 + }, + "id": "9C1YOFxIhTJp", + "outputId": "f65c6762-6bd3-4cf6-f3e4-a736c206e32f" + }, + "outputs": [], + "source": [ + "#@title Upload ZIP file of images.\n", + "training_dir = \"./train\"\n", + "Path(training_dir).mkdir(exist_ok=True)\n", + "try:\n", + " from google.colab import files\n", + "\n", + " upload_res = files.upload()\n", + " extracted_dir = list(upload_res.keys())[0]\n", + " print(f\"Received {extracted_dir}\")\n", + " if not extracted_dir.endswith(\".zip\"):\n", + " raise ValueError(\"Uploaded file must be a zip file\")\n", + "\n", + " zf = ZipFile(io.BytesIO(upload_res[extracted_dir]), \"r\")\n", + " extracted_dir = Path(extracted_dir).stem\n", + " print(f\"Extracting to {extracted_dir}\")\n", + " zf.extractall(extracted_dir)\n", + "\n", + " for root, dirs, files in os.walk(extracted_dir):\n", + " for file in files:\n", + "\n", + " source_path = os.path.join(root, file)\n", + " target_path = os.path.join(training_dir, file)\n", + "\n", + " if 'MACOSX' in source_path or 'DS' in source_path:\n", + " continue\n", + " print('Adding input image: ', source_path, target_path)\n", + " # Move the file to the target directory\n", + " shutil.move(source_path, target_path)\n", + "\n", + "\n", + "except ImportError:\n", + " pass\n", + "\n", + "print(f\"Using training images from: {training_dir}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8xJLYZ4fgoU9" + }, + "source": [ + "Now we're ready to train our model. Specify parameters like the name of your model, the training mode, and the guiding prompt for object mode training.\n", + "\n", + "Please note that the training duration will vary based on the size of your dataset, the training mode or the engine that is being fine-tuned on.\n", + "\n", + "However, the following are some rough estimates for the training duration for each mode based on our recommended dataset sizes:\n", + "\n", + "* Face: 4 - 5 minutes.\n", + "* Object: 5 - 10 minutes.\n", + "* Style: 20 - 30 minutes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VLyYQVM3hH8z", + "outputId": "ab994fbd-71cc-49fe-9829-d1b456dc15c4" + }, + "outputs": [], + "source": [ + "#@title Perform fine-tuning\n", + "model_name = \"elliot-dev\" #@param {type:\"string\"}\n", + "#@markdown > Model names are unique, and may only contain numbers, letters, and hyphens.\n", + "training_mode = \"face\" #@param [\"face\", \"style\", \"object\"] {type:\"string\"}\n", + "#@markdown > The Face training_mode expects pictures containing a face, and automatically crops and centers on the face detected in the input photos. Object segments out the object specified with the prompt below; and Style simply crops the images and filters for image quality.\n", + "object_prompt = \"cat\" #@param {type:\"string\"}\n", + "#@markdown > The Object Prompt is used for segmenting out your subject in the Object fine tuning mode - i.e. if you want to fine tune on a cat, put `cat` - for a bottle of liquor, use `bottle`. In general, it's best to use the most general word you can to describe your object.\n", + "\n", + "print(training_dir)\n", + "print(len(os.listdir(training_dir)))\n", + "# Gather training images\n", + "images = []\n", + "for filename in os.listdir(training_dir):\n", + " if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:\n", + " images.append(os.path.join(training_dir, filename))\n", + "\n", + "# Create the fine-tune model\n", + "params = FineTuneParameters(\n", + " name=model_name,\n", + " mode=FineTuneMode(training_mode),\n", + " object_prompt=object_prompt,\n", + " engine_id=engine_id,\n", + ")\n", + "model = create_model(context, params, images)\n", + "print(f\"Model {model_name} created.\")\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yEKyO3-bhH8z", + "outputId": "50b70059-833a-45ca-ae49-66fa6564ed2c" + }, + "outputs": [], + "source": [ + "#@title Check on training status\n", + "start_time = time.time()\n", + "while model.status != FineTuneStatus.COMPLETED and model.status != FineTuneStatus.FAILED:\n", + " model = get_model(context, model.id)\n", + " elapsed = time.time() - start_time\n", + " clear_output(wait=True)\n", + " print(f\"Model {model.name} ({model.id}) status: {model.status} for {elapsed:.0f} seconds\")\n", + " time.sleep(5)\n", + "\n", + "clear_output(wait=True)\n", + "status_message = \"completed\" if model.status == FineTuneStatus.COMPLETED else \"failed\"\n", + "print(f\"Model {model.name} ({model.id}) {status_message} after {elapsed:.0f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Qr4jBHX7hH8z" + }, + "outputs": [], + "source": [ + "#@title If fine-tuning fails for some reason, you can resubmit the model\n", + "if model.status == FineTuneStatus.FAILED:\n", + " print(\"Training failed, resubmitting\")\n", + " model = resubmit_model(context, model.id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "-Ugkjgy2hH8z", + "outputId": "93517776-9578-4d1c-ec84-68dba4b75085" + }, + "outputs": [], + "source": [ + "#@title 9. Generate images from your fine-tuned model\n", + "results = context.generate(\n", + " prompts=[f\"Illustration of <{model.id}:1> as a wizard\"],\n", + " weights=[1],\n", + " width=1024,\n", + " height=1024,\n", + " seed=42,\n", + " steps=40,\n", + " sampler=generation.SAMPLER_DDIM,\n", + " preset=\"photographic\",\n", + ")\n", + "image = results[generation.ARTIFACT_IMAGE][0]\n", + "display(image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3BZLVniihH8z" + }, + "outputs": [], + "source": [ + "#@title Models can be updated to change settings before a resubmit or after training to rename\n", + "update_model(context, model.id, name=\"cat-ft-01-renamed\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eUFTMZOvhH80" + }, + "outputs": [], + "source": [ + "#@title Delete the model when it's no longer needed\n", + "delete_model(context, model.id)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/nbs/Quick_REST_API_DEMO.ipynb b/nbs/Quick_REST_API_DEMO.ipynb new file mode 100644 index 00000000..c1801a44 --- /dev/null +++ b/nbs/Quick_REST_API_DEMO.ipynb @@ -0,0 +1,1153 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Stability API Standard Feature Demo\n", + "\n", + "This notebook showcases a few key features available through our API.\n", + "You will need to obtain keys, and may need to be whitelisted for some features\n", + "\n", + "* Stability SDXL Keys are available here: https://platform.stability.ai/account/keys\n", + "\n", + "*For a complete reference of the Stability API, please visit https://platform.stability.ai/docs/api-reference*
\n", + "Please note that a REST API and gRPC API are available." + ], + "metadata": { + "id": "ej6SVLpa7Jtz" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Install Dependencies\n", + "import requests\n", + "import shutil\n", + "import getpass\n", + "import os\n", + "import base64\n", + "from google.colab import files\n", + "from PIL import Image" + ], + "metadata": { + "id": "I0zR1GllPIRe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Load in Sample Images\n", + "#Feel free to replace these with your own images\n", + "url_mappings = {\"dog_with_armor\": \"https://i.imgur.com/4nnSP8q.png\",\n", + " \"dog_with_armor_inpaint\": \"https://i.imgur.com/eu44gJe.png\",\n", + " \"dog_with_armor_inpaint_just_armor\": \"https://i.imgur.com/Mw6QU6P.png\",\n", + " \"dog_outpaint_example\": \"https://i.imgur.com/yv9RxjQ.png\",\n", + " \"outpaint_mask_1024_1024\": \"https://i.imgur.com/L1lqrXm.png\"\n", + " }\n", + "for name in url_mappings:\n", + " response = requests.get(url_mappings[name], stream=True)\n", + " with open(f'/content/{name}.png', 'wb') as out_file:\n", + " response.raw.decode_content = True\n", + " shutil.copyfileobj(response.raw, out_file)\n", + " del response" + ], + "metadata": { + "id": "5EL8d-S7jSdM", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Stability API Key\n", + "# You will be prompted to enter your api keys after running this code\n", + "# You can view your API key here: https://next.platform.stability.ai/account/keys\n", + "api_key = getpass.getpass('Enter your API Key')" + ], + "metadata": { + "id": "_8RqX3BgQUSr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Text To Image Example\n", + "url = \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image\"\n", + "\n", + "body = {\n", + " \"steps\": 30,\n", + " \"width\": 1024,\n", + " \"height\": 1024,\n", + " \"seed\": 0,\n", + " \"cfg_scale\": 5,\n", + " \"samples\": 1,\n", + " \"text_prompts\": [\n", + " {\n", + " \"text\": \"A painting of a cat wearing armor, intricate filigree, cinematic masterpiece digital art\",\n", + " \"weight\": 1\n", + " },\n", + " {\n", + " \"text\": \"blurry, bad\",\n", + " \"weight\": -1\n", + " }\n", + " ],\n", + "}\n", + "\n", + "headers = {\n", + " \"Accept\": \"application/json\",\n", + " \"Content-Type\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\",\n", + "}\n", + "\n", + "response = requests.post(\n", + " url,\n", + " headers=headers,\n", + " json=body,\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/txt2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "vXlxmX1_MnNw" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Inpainting Example\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " #replace init image and mask image with your image and mask\n", + " \"init_image\": open(\"/content/dog_with_armor.png\", \"rb\"),\n", + " \"mask_image\": open(\"/content/dog_with_armor_inpaint_just_armor.png\", \"rb\")\n", + " },\n", + " data={\n", + " \"mask_source\": \"MASK_IMAGE_BLACK\",\n", + "\t\t\"steps\": 40,\n", + "\t\t\"seed\": 0,\n", + "\t\t\"cfg_scale\": 5,\n", + "\t\t\"samples\": 1,\n", + "\t\t\"text_prompts[0][text]\": 'Dog Armor made of chocolate',\n", + "\t\t\"text_prompts[0][weight]\": 1,\n", + "\t\t\"text_prompts[1][text]\": 'blurry, bad',\n", + "\t\t\"text_prompts[1][weight]\": -1,\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "Sq14bHVOO__i", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Inpainting - Change Background Example\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " \"init_image\": open(\"/content/dog_with_armor.png\", \"rb\"),\n", + " \"mask_image\": open(\"/content/dog_with_armor_inpaint.png\", \"rb\")\n", + " },\n", + " data={\n", + " # Flipping to white will make it inpaint but remove background, even though dog is masked black\n", + " \"mask_source\": \"MASK_IMAGE_WHITE\",\n", + "\t\t\"steps\": 40,\n", + "\t\t\"seed\": 0,\n", + "\t\t\"cfg_scale\": 5,\n", + "\t\t\"samples\": 1,\n", + "\t\t\"text_prompts[0][text]\": 'Medieval castle',\n", + "\t\t\"text_prompts[0][weight]\": 1,\n", + "\t\t\"text_prompts[1][text]\": 'blurry, bad',\n", + "\t\t\"text_prompts[1][weight]\": -1,\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "ahInbhG4Y5oM", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Outpainting Example\n", + "\n", + "# Init image has to be the same size as mask image\n", + "# Paste the smaller init image onto the mask\n", + "initial_init_image = Image.open(\"/content/dog_outpaint_example.png\")\n", + "# The mask is already blurred, which will improve coherence\n", + "mask = Image.open(\"/content/outpaint_mask_1024_1024.png\")\n", + "mask.paste(initial_init_image)\n", + "mask.save('/content/dog_outpaint_init_image.png', quality=95)\n", + "\n", + "\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " \"init_image\": open(\"/content/dog_outpaint_init_image.png\", \"rb\"),\n", + " \"mask_image\": open(\"/content/outpaint_mask_1024_1024.png\", \"rb\")\n", + " },\n", + " data={\n", + " \"mask_source\": \"MASK_IMAGE_BLACK\",\n", + "\t\t\"steps\": 40,\n", + "\t\t\"seed\": 0,\n", + "\t\t\"cfg_scale\": 5,\n", + "\t\t\"samples\": 1,\n", + "\t\t\"text_prompts[0][text]\": 'Medieval castle',\n", + "\t\t\"text_prompts[0][weight]\": 1,\n", + "\t\t\"text_prompts[1][text]\": 'blurry, bad',\n", + "\t\t\"text_prompts[1][weight]\": -1,\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "URaWbM6xdXZw", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Image Upscaling Example\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/esrgan-v1-x2plus/image-to-image/upscale\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " \"image\": open(\"/content/dog_with_armor.png\", \"rb\")\n", + " },\n", + " data={\n", + " \"width\": 2048\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "vAWBk84DezoP", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Stability SDXL Enterprise API Demo\n", + "\n", + "Stability provides enterprise-grade features for customers that require faster speeds and dedicated managed services with support. These nodes can have significantly faster speeds based on the Stability Supercomputer, and may include prototype / preview models.
\n", + "The below section will leverage a demo node that is prepared at request. If you would like to try an enterprise node, please reach out to the Stability team." + ], + "metadata": { + "id": "nmHWiBNZAi4G" + } + }, + { + "cell_type": "code", + "source": [ + "# This demo notebook is designed to help illustrate the latency on prototype nodes with early access models\n", + "# This REST implementation will hit the special demo node. Please note that this is a demo only and results in production should exceed these speeds.\n", + "# Make sure to enable batch downloads from your browser if you want to see the images that will be downloaded locally\n", + "# OUTPUT: You will see Average and Total time for image generation. Note that in the first call there may be a warm-up time of up to 2 seconds, and tha colab will add an additional ~1.5 seconds\n", + "\n", + "import base64\n", + "import requests\n", + "import os\n", + "import time\n", + "from google.colab import files\n", + "\n", + "\n", + "def make_request(index):\n", + " #replace the with the name of the node and module provided to you\n", + " url = \"https://test.api.stability.ai/v1/generation//\"\n", + "#Steps: Increasing can improve quality, and increase latency\n", + " body = {\n", + " \"steps\": 22,\n", + " \"width\": 1024,\n", + " \"height\": 1024,\n", + " \"seed\": 0,\n", + " \"cfg_scale\": 6,\n", + " \"samples\": 1,\n", + " \"text_prompts\": [\n", + " {\n", + " \"text\": \"octane render of a barabaric software engineer\",\n", + " \"weight\": 1\n", + " },\n", + " {\n", + " \"text\": \"blurry, bad\",\n", + " \"weight\": -1\n", + " }\n", + " ],\n", + " }\n", + "\n", + " headers = {\n", + " \"Accept\": \"application/json\",\n", + " \"Content-Type\": \"application/json\",\n", + " #insert your Key\n", + " \"Authorization\": \"Bearer \",\n", + " }\n", + "\n", + " response = requests.post(\n", + " url,\n", + " headers=headers,\n", + " json=body,\n", + " )\n", + "\n", + " if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + " data = response.json()\n", + "\n", + " if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + " for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/txt2img_{image[\"seed\"]}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + "\n", + " #Please comment the below line to execute pure benchmarking without downloading the images\n", + " files.download(f.name)\n", + "\n", + "total_time = 0\n", + "\n", + "#Adjust num to change the number of images to get as batch\n", + "num = 2\n", + "for i in range(num):\n", + " print(i)\n", + " start = time.time()\n", + " make_request(i)\n", + " end = time.time()\n", + " total_time += (end - start)\n", + "\n", + "print(\"Average: \", total_time/num)\n", + "print(\"Total_Time: \", total_time)\n", + "print(\"Num Iterations: \", num)" + ], + "metadata": { + "id": "yDeYgKNXAiQH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# SDXL Finetuning REST API Demo\n", + "\n", + "Stability is offering a private beta of its fine-tuning service to select customers.
\n", + "\n", + "Note that this is a **developer beta** - bugs and quality issues with the generated fine-tunes may occur. Please reach out to Stability if this is the case - and share what you've made as well!\n", + "\n", + "The code below hits the Stability REST API. This REST API contract is rather solid, so it's unlikely to see large changes before the production release of fine-tuning.\n", + "\n", + "Known issues:\n", + "\n", + "* Style fine-tunes may result in overfitting - if this is the case, uncomment the `# weight=1.0` field of `DiffusionFineTune` in the diffusion section and provide a value between -1 and 1. You may need to go as low as 0.2 or 0.1.\n", + "* We will be exposing test parameters soon - please reach out with examples of datasets that produce overfitting or errors if you have them." + ], + "metadata": { + "id": "-xGp9o-iTc8e" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Stability API key\n", + "import getpass\n", + "\n", + "#@markdown Execute this step and paste your API key in the box that appears.
Visit https://platform.stability.ai/account/keys to get your API key!

Note: If you are not on the fine-tuning whitelist you will receive an error during training.\n", + "\n", + "API_KEY = getpass.getpass('Paste your Stability API Key here and press Enter: ')\n", + "\n", + "API_HOST = \"https://preview-api.stability.ai\"\n", + "\n", + "ENGINE_ID = \"stable-diffusion-xl-1024-v1-0\"" + ], + "metadata": { + "id": "t7910RqrlFJc", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Initialize the REST API wrapper\n", + "import io\n", + "import logging\n", + "import requests\n", + "import os\n", + "import shutil\n", + "import sys\n", + "import time\n", + "import json\n", + "import base64\n", + "from enum import Enum\n", + "from dataclasses import dataclass, is_dataclass, field, asdict\n", + "from typing import List, Optional, Any\n", + "from IPython.display import clear_output\n", + "from pathlib import Path\n", + "from PIL import Image\n", + "from zipfile import ZipFile\n", + "\n", + "\n", + "class Printable:\n", + " \"\"\" Helper class for printing a class to the console. \"\"\"\n", + "\n", + " @staticmethod\n", + " def to_json(obj: Any) -> Any:\n", + " if isinstance(obj, Enum):\n", + " return obj.value\n", + " if is_dataclass(obj):\n", + " return asdict(obj)\n", + "\n", + " return obj\n", + "\n", + " def __str__(self):\n", + " return f\"{self.__class__.__name__}: {json.dumps(self, default=self.to_json, indent=4)}\"\n", + "\n", + "\n", + "class ToDict:\n", + " \"\"\" Helper class to simplify converting dataclasses to dicts. \"\"\"\n", + "\n", + " def to_dict(self):\n", + " return {k: v for k, v in asdict(self).items() if v is not None}\n", + "\n", + "\n", + "@dataclass\n", + "class FineTune(Printable):\n", + " id: str\n", + " user_id: str\n", + " name: str\n", + " mode: str\n", + " engine_id: str\n", + " training_set_id: str\n", + " status: str\n", + " failure_reason: Optional[str] = field(default=None)\n", + " duration_seconds: Optional[int] = field(default=None)\n", + " object_prompt: Optional[str] = field(default=None)\n", + "\n", + "\n", + "@dataclass\n", + "class DiffusionFineTune(Printable, ToDict):\n", + " id: str\n", + " token: str\n", + " weight: Optional[float] = field(default=None)\n", + "\n", + "\n", + "@dataclass\n", + "class TextPrompt(Printable, ToDict):\n", + " text: str\n", + " weight: Optional[float] = field(default=None)\n", + "\n", + "\n", + "class Sampler(Enum):\n", + " DDIM = \"DDIM\"\n", + " DDPM = \"DDPM\"\n", + " K_DPMPP_2M = \"K_DPMPP_2M\"\n", + " K_DPMPP_2S_ANCESTRAL = \"K_DPMPP_2S_ANCESTRAL\"\n", + " K_DPM_2 = \"K_DPM_2\"\n", + " K_DPM_2_ANCESTRAL = \"K_DPM_2_ANCESTRAL\"\n", + " K_EULER = \"K_EULER\"\n", + " K_EULER_ANCESTRAL = \"K_EULER_ANCESTRAL\"\n", + " K_HEUN = \"K_HEUN\"\n", + " K_LMS = \"K_LMS\"\n", + "\n", + " @staticmethod\n", + " def from_string(val) -> Enum or None:\n", + " for sampler in Sampler:\n", + " if sampler.value == val:\n", + " return sampler\n", + " raise Exception(f\"Unknown Sampler: {val}\")\n", + "\n", + "\n", + "@dataclass\n", + "class TextToImageParams(Printable):\n", + " fine_tunes: List[DiffusionFineTune]\n", + " text_prompts: List[TextPrompt]\n", + " samples: int\n", + " sampler: Sampler\n", + " engine_id: str\n", + " steps: int\n", + " seed: Optional[int] = field(default=0)\n", + " cfg_scale: Optional[int] = field(default=7)\n", + " width: Optional[int] = field(default=1024)\n", + " height: Optional[int] = field(default=1024)\n", + "\n", + "\n", + "@dataclass\n", + "class DiffusionResult:\n", + " base64: str\n", + " seed: int\n", + " finish_reason: str\n", + "\n", + " def __str__(self):\n", + " return f\"DiffusionResult(base64='too long to print', seed='{self.seed}', finish_reason='{self.finish_reason}')\"\n", + "\n", + " def __repr__(self):\n", + " return self.__str__()\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSetBase(Printable):\n", + " id: str\n", + " name: str\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSetImage(Printable):\n", + " id: str\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSet(TrainingSetBase):\n", + " images: List[TrainingSetImage]\n", + "\n", + "\n", + "class FineTuningRESTWrapper:\n", + " \"\"\"\n", + " Helper class to simplify interacting with the fine-tuning service via\n", + " Stability's REST API.\n", + "\n", + " While this class can be copied to your local environment, it is not likely\n", + " robust enough for your needs and does not support all of the features that\n", + " the REST API offers.\n", + " \"\"\"\n", + "\n", + " def __init__(self, api_key: str, api_host: str):\n", + " self.api_key = api_key\n", + " self.api_host = api_host\n", + "\n", + " def create_fine_tune(self,\n", + " name: str,\n", + " images: List[str],\n", + " engine_id: str,\n", + " mode: str,\n", + " object_prompt: Optional[str] = None) -> FineTune:\n", + " print(f\"Creating {mode} fine-tune called '{name}' using {len(images)} images...\")\n", + "\n", + " payload = {\"name\": name, \"engine_id\": engine_id, \"mode\": mode}\n", + " if object_prompt is not None:\n", + " payload[\"object_prompt\"] = object_prompt\n", + "\n", + " # Create a training set\n", + " training_set_id = self.create_training_set(name=name)\n", + " payload[\"training_set_id\"] = training_set_id\n", + " print(f\"\\tCreated training set {training_set_id}\")\n", + "\n", + " # Add images to the training set\n", + " for image in images:\n", + " print(f\"\\t\\tAdding {os.path.basename(image)}\")\n", + " self.add_image_to_training_set(\n", + " training_set_id=training_set_id,\n", + " image=image\n", + " )\n", + "\n", + " # Create the fine-tune\n", + " print(f\"\\tCreating a fine-tune from the training set\")\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/fine-tunes\",\n", + " json=payload,\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + " raise_on_non200(response)\n", + " print(f\"\\tCreated fine-tune {response.json()['id']}\")\n", + "\n", + " print(f\"Success\")\n", + " return FineTune(**response.json())\n", + "\n", + " def get_fine_tune(self, fine_tune_id: str) -> FineTune:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def list_fine_tunes(self) -> List[FineTune]:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/fine-tunes\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [FineTune(**ft) for ft in response.json()]\n", + "\n", + " def rename_fine_tune(self, fine_tune_id: str, name: str) -> FineTune:\n", + " response = requests.patch(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " json={\"operation\": \"RENAME\", \"name\": name},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def retrain_fine_tune(self, fine_tune_id: str) -> FineTune:\n", + " response = requests.patch(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " json={\"operation\": \"RETRAIN\"},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def delete_fine_tune(self, fine_tune: FineTune):\n", + " # Delete the underlying training set\n", + " self.delete_training_set(fine_tune.training_set_id)\n", + "\n", + " # Delete the fine-tune\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune.id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def create_training_set(self, name: str) -> str:\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/training-sets\",\n", + " json={\"name\": name},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return response.json().get('id')\n", + "\n", + " def get_training_set(self, training_set_id: str) -> TrainingSet:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return TrainingSet(**response.json())\n", + "\n", + " def list_training_sets(self) -> List[TrainingSetBase]:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/training-sets\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [TrainingSetBase(**tsb) for tsb in response.json()]\n", + "\n", + " def add_image_to_training_set(self, training_set_id: str, image: str) -> str:\n", + " with open(image, 'rb') as image_file:\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}/images\",\n", + " headers={\"Authorization\": self.api_key},\n", + " files={'image': image_file}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return response.json().get('id')\n", + "\n", + " def remove_image_from_training_set(self, training_set_id: str, image_id: str) -> None:\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}/images/{image_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def delete_training_set(self, training_set_id: str) -> None:\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def text_to_image(self, params: TextToImageParams) -> List[DiffusionResult]:\n", + " payload = {\n", + " \"fine_tunes\": [ft.to_dict() for ft in params.fine_tunes],\n", + " \"text_prompts\": [tp.to_dict() for tp in params.text_prompts],\n", + " \"samples\": params.samples,\n", + " \"sampler\": params.sampler.value,\n", + " \"steps\": params.steps,\n", + " \"seed\": params.seed,\n", + " \"width\": params.width,\n", + " \"height\": params.height,\n", + " \"cfg_scale\": params.cfg_scale,\n", + " }\n", + "\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/generation/{params.engine_id}/text-to-image\",\n", + " json=payload,\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Accept\": \"application/json\",\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [\n", + " DiffusionResult(base64=item[\"base64\"], seed=item[\"seed\"], finish_reason=item[\"finishReason\"])\n", + " for item in response.json()[\"artifacts\"]\n", + " ]\n", + "\n", + "\n", + "def raise_on_non200(response):\n", + " if 200 <= response.status_code < 300:\n", + " return\n", + " raise Exception(f\"Status code {response.status_code}: {json.dumps(response.json(), indent=4)}\")\n", + "\n", + "\n", + "# Redirect logs to print statements so we can see them in the notebook\n", + "class PrintHandler(logging.Handler):\n", + " def emit(self, record):\n", + " print(self.format(record))\n", + "logging.getLogger().addHandler(PrintHandler())\n", + "logging.getLogger().setLevel(logging.INFO)\n", + "\n", + "# Initialize the fine-tune service\n", + "rest_api = FineTuningRESTWrapper(API_KEY, API_HOST)" + ], + "metadata": { + "id": "Dr38OlbKTb7Q" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title List your existing fine-tunes\n", + "\n", + "fine_tunes = rest_api.list_fine_tunes()\n", + "print(f\"Found {len(fine_tunes)} models\")\n", + "for fine_tune in fine_tunes:\n", + " print(f\" Model {fine_tune.id} {fine_tune.status:<9} {fine_tune.name}\")" + ], + "metadata": { + "id": "ZqIc2d8FAIW0", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Add Training Images\n", + "\n", + "For training, we need a dataset of images in a `.zip` file.\n", + "\n", + "Please only upload images that you have the permission to use.\n", + "\n", + "\n", + "### Image Dimensions\n", + "\n", + "- Images **cannot** have any side less than 328px\n", + "- Images **cannot** be larger than 10MB\n", + "\n", + "There is no upper-bound for what we'll accept for an image's dimensions, but any side above 1024px will be scaled down to 1024px, while preserving aspect ratio. For example:\n", + "- `3024x4032` will be scaled down to `768x1024`\n", + "- `1118x1118` will be scaled down to `1024x1024`\n", + "\n", + "\n", + "### Image Quantity\n", + "\n", + "- Datasets **cannot** have fewer than 3 images\n", + "- Datasets **cannot** have more than 64 images\n", + "\n", + "A larger dataset often tends to result in a more accurate fine-tune, but will also take longer to train.\n", + "\n", + "While each mode can accept up to 64 images, we have a few suggestions for a starter dataset based on the mode you are using:\n", + "* `FACE`: 6 or more images.\n", + "* `OBJECT`: 6 - 10 images.\n", + "* `STYLE`: 20 - 30 images." + ], + "metadata": { + "id": "vnAPh8ydc3SG" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Upload ZIP file of images\n", + "training_dir = \"./train\"\n", + "Path(training_dir).mkdir(exist_ok=True)\n", + "try:\n", + " from google.colab import files\n", + "\n", + " upload_res = files.upload()\n", + " extracted_dir = list(upload_res.keys())[0]\n", + " print(f\"Received {extracted_dir}\")\n", + " if not extracted_dir.endswith(\".zip\"):\n", + " raise ValueError(\"Uploaded file must be a zip file\")\n", + "\n", + " zf = ZipFile(io.BytesIO(upload_res[extracted_dir]), \"r\")\n", + " extracted_dir = Path(extracted_dir).stem\n", + " print(f\"Extracting to {extracted_dir}\")\n", + " zf.extractall(extracted_dir)\n", + "\n", + " for root, dirs, files in os.walk(extracted_dir):\n", + " for file in files:\n", + " source_path = os.path.join(root, file)\n", + " target_path = os.path.join(training_dir, file)\n", + "\n", + " # Ignore Mac-specific files\n", + " if 'MACOSX' in source_path or 'DS' in source_path:\n", + " continue\n", + "\n", + " # Move the file to the target directory\n", + " print('Copying', source_path, '==>', target_path)\n", + " shutil.move(source_path, target_path)\n", + "\n", + "\n", + "except ImportError:\n", + " pass\n", + "\n", + "print(f\"Using training images from: {training_dir}\")" + ], + "metadata": { + "id": "YKQXWltHANju" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Train a Fine-Tune\n", + "\n", + "Now we're ready to train our fine-tune. Use the parameters below to configure the name and the kind of fine-tune\n", + "\n", + "Please note that the training duration will vary based on:\n", + "- The number of images in your dataset\n", + "- The `training_mode` used\n", + "- The `engine_id` that is being fine-tuned on\n", + "\n", + "The following are some rough estimates for the training duration for each mode based on our recommended dataset sizes:\n", + "\n", + "* `FACE`: 4 - 5 minutes.\n", + "* `OBJECT`: 5 - 10 minutes.\n", + "* `STYLE`: 20 - 30 minutes." + ], + "metadata": { + "id": "UXAn59XibFv5" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Begin Training\n", + "fine_tune_name = \"my dog spot\" #@param {type:\"string\"}\n", + "#@markdown > Requirements:
  • Must be unique (only across your account, not globally)
  • Must be between 3 and 64 characters (inclusive)
  • Must only contain letters, numbers, spaces, or hyphens
\n", + "training_mode = \"OBJECT\" #@param [\"FACE\", \"STYLE\", \"OBJECT\"] {type:\"string\"}\n", + "#@markdown > Determines the kind of fine-tune you're creating:
  • FACE - a fine-tune on faces; expects pictures containing a face; automatically crops and centers on the face detected in the input photos.
  • OBJECT - a fine-tune on a particular object (e.g. a bottle); segments out the object using the `object_prompt` below
  • STYLE - a fine-tune on a particular style (e.g. satellite photos of earth); crops the images and filters for image quality.
\n", + "object_prompt = \"dog\" #@param {type:\"string\"}\n", + "#@markdown > Used for segmenting out your subject when the `training_mode` is `OBJECT`. (i.e. if you want to fine tune on a cat, put `cat` - for a bottle of liquor, use `bottle`. In general, it's best to use the most general word you can to describe your object.)\n", + "\n", + "# Gather training images\n", + "images = []\n", + "for filename in os.listdir(training_dir):\n", + " if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg', '.heic']:\n", + " images.append(os.path.join(training_dir, filename))\n", + "\n", + "# Create the fine-tune\n", + "fine_tune = rest_api.create_fine_tune(\n", + " name=fine_tune_name,\n", + " images=images,\n", + " mode=training_mode,\n", + " object_prompt=object_prompt if training_mode == \"OBJECT\" else None,\n", + " engine_id=ENGINE_ID,\n", + ")\n", + "\n", + "print()\n", + "print(fine_tune)" + ], + "metadata": { + "id": "DMK3yOrGDLw8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Wait For Training to Finish\n", + "start_time = time.time()\n", + "while fine_tune.status != \"COMPLETED\" and fine_tune.status != \"FAILED\":\n", + " fine_tune = rest_api.get_fine_tune(fine_tune.id)\n", + " elapsed = time.time() - start_time\n", + " clear_output(wait=True)\n", + " print(f\"Training '{fine_tune.name}' ({fine_tune.id}) status: {fine_tune.status} for {elapsed:.0f} seconds\")\n", + " time.sleep(10)\n", + "\n", + "clear_output(wait=True)\n", + "status_message = \"completed\" if fine_tune.status == \"COMPLETED\" else \"failed\"\n", + "print(f\"Training '{fine_tune.name}' ({fine_tune.id}) {status_message} after {elapsed:.0f} seconds\")" + ], + "metadata": { + "id": "8-iAUX_ODwU6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title (Optional) Retrain if Training Failed\n", + "if fine_tune.status == \"FAILED\":\n", + " print(f\"Training failed, due to {fine_tune.failure_reason}. Retraining...\")\n", + " fine_tune = rest_api.retrain_fine_tune(fine_tune.id)" + ], + "metadata": { + "id": "eZaWJT_CDyrb" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Use your Fine-Tune\n", + "\n", + "Time to diffuse! The example below uses a single fine-tune, but using multiple fine-tunes is where this process really shines. While this Colab doesn't directly support diffusing with multiple fine-tunes, you can still try it out by commenting out the" + ], + "metadata": { + "id": "vaBl4zuQfO20" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Generate Images\n", + "\n", + "prompt_token=\"$my-dog\" #@param {type:\"string\"}\n", + "#@markdown > This token is an alias for your fine-tune, allowing you to reference your fine-tune directly in your prompt. Each fine-tune you want to diffuse with must provide a unique alias.

For example, if your token was `$my-dog` you might use a prompt like: `a picture of $my-dog` or `$my-dog chasing a rabbit`.

If you have more than one fine-tune you can combine them! Given some fine-tune of film noir images you could use a prompt like `$my-dog in the style of $film-noir`.\n", + "prompt=\"a photo of $my-dog\" #@param {type:\"string\"}\n", + "#@markdown > The prompt to diffuse with. Must contain the `prompt_token` at least once.\n", + "dimensions=\"1024x1024\" #@param ['1024x1024', '1152x896', '1216x832', '1344x768', '1536x640', '640x1536', '768x1344', '832x1216', '896x1152']\n", + "#@markdown > The dimensions of the image to generate (width x height).\n", + "samples=4 #@param {type:\"slider\", min:1, max:10, step:1}\n", + "#@markdown > The number of images to generate. Requesting a large number of images may negatively response time.\n", + "steps=32 #@param {type:\"slider\", min:30, max:60, step:1}\n", + "#@markdown > The number of iterations or stages a diffusion model goes through in the process of generating an image from a given text prompt. Lower steps will generate more quickly, but if steps are lowered too much, image quality will suffer. Images with higher steps take longer to generate, but often give more detailed results.\n", + "cfg_scale=7 #@param {type:\"slider\", min:0, max:35, step:1}\n", + "#@markdown > CFG (Classifier Free Guidance) scale determines how strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt).\n", + "seed=0 #@param {type:\"number\"}\n", + "#@markdown > The noise seed to use during diffusion. Using `0` means a random seed will be generated for each image. If you provide a non-zero value, images will be far less random.\n", + "\n", + "params = TextToImageParams(\n", + " fine_tunes=[\n", + " DiffusionFineTune(\n", + " id=fine_tune.id,\n", + " token=prompt_token,\n", + " # Uncomment the following to provide a weight for the fine-tune\n", + " # weight=1.0\n", + " ),\n", + "\n", + " # Uncomment the following to use multiple fine-tunes at once\n", + " # DiffusionFineTune(\n", + " # id=\"\",\n", + " # token=\"\",\n", + " # # weight=1.0\n", + " # ),\n", + " ],\n", + " text_prompts=[\n", + " TextPrompt(\n", + " text=prompt,\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " engine_id=ENGINE_ID,\n", + " samples=samples,\n", + " steps=steps,\n", + " seed=0,\n", + " cfg_scale=cfg_scale,\n", + " width=int(dimensions.split(\"x\")[0]),\n", + " height=int(dimensions.split(\"x\")[1]),\n", + " sampler=Sampler.K_DPMPP_2S_ANCESTRAL\n", + ")\n", + "\n", + "start_time = time.time()\n", + "images = rest_api.text_to_image(params)\n", + "\n", + "elapsed = time.time() - start_time\n", + "print(f\"Diffusion completed in {elapsed:.0f} seconds!\")\n", + "print(f\"{len(images)} result{'s' if len(images) > 1 else ''} will be displayed below momentarily (depending on the speed of Colab).\\n\")\n", + "\n", + "for image in images:\n", + " display(Image.open(io.BytesIO(base64.b64decode(image.base64))))" + ], + "metadata": { + "id": "sy1HcYqLEBXu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title (Optional) Download Images\n", + "from google.colab import files\n", + "\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for index, image in enumerate(images):\n", + " with open(f'./out/txt2img_{image.seed}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image.base64))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "7P-bBnScfaQQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title (Optional) Rename Fine-Tune\n", + "\n", + "name = \"\" #@param {type:\"string\"}\n", + "rest_api.rename_fine_tune(fine_tune.id, name=name)" + ], + "metadata": { + "id": "tg2gkvlDn4Dm" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/nbs/finetune.ipynb b/nbs/finetune.ipynb new file mode 100644 index 00000000..287e1312 --- /dev/null +++ b/nbs/finetune.ipynb @@ -0,0 +1,295 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "cellView": "form", + "id": "T9ma2X7bhH8y" + }, + "outputs": [], + "source": [ + "#@title Install Stability SDK with fine-tuning support\n", + "import getpass\n", + "import io\n", + "import logging\n", + "import os\n", + "import shutil\n", + "import sys\n", + "import time\n", + "from IPython.display import clear_output\n", + "from pathlib import Path\n", + "from zipfile import ZipFile\n", + "\n", + "if os.path.exists(\"../src/stability_sdk\"):\n", + " sys.path.append(\"../src\") # use local SDK src\n", + "else:\n", + " path = Path('stability-sdk')\n", + " if path.exists():\n", + " shutil.rmtree(path)\n", + " !pip uninstall -y stability-sdk\n", + " !git clone -b \"PLATFORM-339\" --recurse-submodules https://github.com/Stability-AI/stability-sdk\n", + " !pip install ./stability-sdk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "5LM5SUUOhH8z" + }, + "outputs": [], + "source": [ + "#@title Connect to the Stability API\n", + "from stability_sdk.api import Context, generation\n", + "from stability_sdk.finetune import (\n", + " create_model, delete_model, get_model, list_models, resubmit_model, update_model,\n", + " FineTuneMode, FineTuneParameters, FineTuneStatus\n", + ")\n", + "\n", + "# @markdown To get your API key visit https://dreamstudio.ai/account\n", + "STABILITY_HOST = \"grpc.stability.ai:443\" #@param [\"grpc.stability.ai:443\", \"grpc-staging.stability.ai:443\"] {type:\"string\"}\n", + "STABILITY_KEY = getpass.getpass('Enter your API Key')\n", + "\n", + "engine_id = \"stable-diffusion-xl-1024-v1-0\" #@param [\"stable-diffusion-xl-1024-v0-9\", \"stable-diffusion-xl-1024-v1-0\"] {type:\"string\"}\n", + "\n", + "# Create API context to query user info and generate images\n", + "context = Context(STABILITY_HOST, STABILITY_KEY, generate_engine_id=engine_id)\n", + "(balance, pfp) = context.get_user_info()\n", + "print(f\"Logged in org:{context._user_organization_id} with balance:{balance}\")\n", + "\n", + "# Redirect logs to print statements so we can see them in the notebook\n", + "class PrintHandler(logging.Handler):\n", + " def emit(self, record):\n", + " print(self.format(record))\n", + "logging.getLogger().addHandler(PrintHandler())\n", + "logging.getLogger().setLevel(logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iQL5dFsfhH8z" + }, + "outputs": [], + "source": [ + "# List fine-tuned models for this user / organization\n", + "models = list_models(context, org_id=context._user_organization_id)\n", + "print(f\"Found {len(models)} models\")\n", + "for model in models:\n", + " print(f\" Model {model.id} {model.name} {model.status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9C1YOFxIhTJp" + }, + "outputs": [], + "source": [ + "#@title Upload ZIP file of images.\n", + "training_dir = \"./train\"\n", + "Path(training_dir).mkdir(exist_ok=True)\n", + "try:\n", + " from google.colab import files\n", + "\n", + " upload_res = files.upload()\n", + " extracted_dir = list(upload_res.keys())[0]\n", + " print(f\"Received {extracted_dir}\")\n", + " if not extracted_dir.endswith(\".zip\"):\n", + " raise ValueError(\"Uploaded file must be a zip file\")\n", + "\n", + " zf = ZipFile(io.BytesIO(upload_res[extracted_dir]), \"r\")\n", + " extracted_dir = Path(extracted_dir).stem\n", + " print(f\"Extracting to {extracted_dir}\")\n", + " zf.extractall(extracted_dir)\n", + "\n", + " for root, dirs, files in os.walk(extracted_dir):\n", + " for file in files:\n", + " \n", + " source_path = os.path.join(root, file)\n", + " target_path = os.path.join(training_dir, file)\n", + " print('Adding input image: ', source_path, target_path)\n", + " # Move the file to the target directory\n", + " shutil.move(source_path, target_path)\n", + "\n", + "\n", + "except ImportError:\n", + " pass\n", + "\n", + "print(f\"Using training images from: {training_dir}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VLyYQVM3hH8z" + }, + "outputs": [], + "source": [ + "#@title Perform fine-tuning\n", + "model_name = \"cat-ft-01\" #@param {type:\"string\"}\n", + "training_mode = \"object\" #@param [\"face\", \"style\", \"object\"] {type:\"string\"}\n", + "object_prompt = \"cat\" #@param {type:\"string\"}\n", + "\n", + "# Gather training images\n", + "images = []\n", + "for filename in os.listdir(training_dir):\n", + " if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:\n", + " images.append(os.path.join(training_dir, filename))\n", + "\n", + "# Create the fine-tune model\n", + "params = FineTuneParameters(\n", + " name=model_name,\n", + " mode=FineTuneMode(training_mode),\n", + " object_prompt=object_prompt,\n", + " engine_id=engine_id,\n", + ")\n", + "model = create_model(context, params, images)\n", + "print(f\"Model {model_name} created.\")\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "yEKyO3-bhH8z" + }, + "outputs": [], + "source": [ + "# Check on training status\n", + "start_time = time.time()\n", + "while model.status != FineTuneStatus.COMPLETED and model.status != FineTuneStatus.FAILED:\n", + " model = get_model(context, model.id)\n", + " elapsed = time.time() - start_time\n", + " clear_output(wait=True)\n", + " print(f\"Model {model.name} ({model.id}) status: {model.status} for {elapsed:.0f} seconds\")\n", + " time.sleep(5)\n", + "\n", + "clear_output(wait=True)\n", + "status_message = \"completed\" if model.status == FineTuneStatus.COMPLETED else \"failed\"\n", + "print(f\"Model {model.name} ({model.id}) {status_message} after {elapsed:.0f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "Qr4jBHX7hH8z" + }, + "outputs": [], + "source": [ + "# If fine-tuning fails for some reason, you can resubmit the model\n", + "if model.status == FineTuneStatus.FAILED:\n", + " print(\"Training failed, resubmitting\")\n", + " model = resubmit_model(context, model.id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-Ugkjgy2hH8z" + }, + "outputs": [], + "source": [ + "# Generate an image using the fine-tuned model\n", + "results = context.generate(\n", + " prompts=[f\"Illustration of <{model.id}:0.7> as a wizard\"],\n", + " weights=[1],\n", + " width=1024,\n", + " height=1024,\n", + " seed=42,\n", + " sampler=generation.SAMPLER_DDIM,\n", + " preset=\"photographic\",\n", + ")\n", + "image = results[generation.ARTIFACT_IMAGE][0]\n", + "display(image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3BZLVniihH8z" + }, + "outputs": [], + "source": [ + "# Models can be updated to change settings before a resubmit or after training to rename\n", + "update_model(context, model.id, name=\"cat-ft-01-renamed\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eUFTMZOvhH80" + }, + "outputs": [], + "source": [ + "# Delete the model when it's no longer needed\n", + "delete_model(context, model.id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Example using StabilityInference class\n", + "import warnings\n", + "from stability_sdk.client import StabilityInference\n", + "from PIL import Image\n", + "\n", + "si = StabilityInference(STABILITY_HOST, STABILITY_KEY, engine=engine_id)\n", + "results = si.generate(\n", + " f\"Illustration of <{model.id}:0.7> as a wizard\",\n", + " width=1024, \n", + " height=1024, \n", + " seed=42,\n", + " sampler=generation.SAMPLER_DDIM,\n", + " style_preset=\"photographic\"\n", + ")\n", + "for resp in results:\n", + " for artifact in resp.artifacts:\n", + " if artifact.finish_reason == generation.FILTER:\n", + " warnings.warn(\n", + " \"Your request activated the API's safety filters and could not be processed.\"\n", + " \"Please modify the prompt and try again.\")\n", + " if artifact.type == generation.ARTIFACT_IMAGE:\n", + " display(Image.open(io.BytesIO(artifact.binary)))" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/setup.py b/setup.py index 450d9655..272549d2 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ 'grpcio-tools==1.53.0', 'python-dotenv', 'param', + 'pydantic==1.10.9', 'protobuf==4.21.12' ], extras_require={ diff --git a/src/stability_sdk/__init__.py b/src/stability_sdk/__init__.py index f9a14d1c..c93c2750 100644 --- a/src/stability_sdk/__init__.py +++ b/src/stability_sdk/__init__.py @@ -5,6 +5,7 @@ this_path = pathlib.Path(__file__).parent.resolve() sys.path.extend([ str(this_path / "interfaces/gooseai/dashboard"), + str(this_path / "interfaces/gooseai/finetuning"), str(this_path / "interfaces/gooseai/generation"), str(this_path / "interfaces/gooseai/project"), str(this_path / "interfaces/src/tensorizer/tensors") diff --git a/src/stability_sdk/animation.py b/src/stability_sdk/animation.py index 0495774e..71b4d693 100644 --- a/src/stability_sdk/animation.py +++ b/src/stability_sdk/animation.py @@ -9,6 +9,7 @@ import os import param import random +import re import shutil from collections import OrderedDict, deque @@ -27,9 +28,11 @@ image_mix, image_to_png_bytes, interpolate_mode_from_string, + parse_models_from_prompts, resample_transform, sampler_from_string, ) +import stability_sdk.finetune as ft import stability_sdk.matrix as matrix logger = logging.getLogger(__name__) @@ -322,14 +325,15 @@ def __init__( self.cadence_on: bool = False self.prior_frames: Deque[Image.Image] = deque([], 1) # forward warped prior frames. stores one image with cadence off, two images otherwise self.prior_diffused: Deque[Image.Image] = deque([], 1) # results of diffusion. stores one image with cadence off, two images otherwise - self.prior_xforms: Deque[matrix.Matrix] = deque([], 1) # accumulated transforms since last diffusion. stores one with cadence off, two otherwise + self.prior_xforms: Deque[matrix.Matrix] = deque([], 1) # accumulated transforms since last diffusion. stores one with cadence off, two otherwise self.negative_prompt: str = negative_prompt self.negative_prompt_weight: float = negative_prompt_weight self.start_frame_idx: int = 0 self.video_prev_frame: Optional[Image.Image] = None self.video_reader: Optional[cv2.VideoCapture] = None - # configure Api to retry on classifier obfuscations + # configure Api to retry on RPC exceptions and classifier obfuscations + self.api._max_retries = 5 self.api._retry_obfuscation = True # two stage 1024 model requires longer timeout @@ -849,6 +853,22 @@ def setup_animation(self, resume): self.load_video() self.load_init_image() + # remap model names to IDs in prompts + finetunes = self._load_finetunes() + def remap_model_names(prompt: str) -> str: + if not prompt: + return prompt + prompts, models = parse_models_from_prompts(prompt) + prompt = prompts[0] + for model, _ in models: + if model in finetunes: + prompt = re.sub(f"<{model}([^>]*)>", f"<{finetunes[model]}\\1>", prompt) + else: + logging.error(f"No fine-tune model matching name or ID {model}") + return prompt + self.animation_prompts = {k: remap_model_names(v) for k, v in self.animation_prompts.items()} + self.negative_prompt = remap_model_names(self.negative_prompt) + # handle resuming animation from last frames of a previous run if resume: if not self.out_dir: @@ -964,6 +984,19 @@ def transform_video(self, frame_idx) -> Optional[Image.Image]: return mask return None + def _load_finetunes(self) -> dict: + finetunes = {} + if self.api._stub_finetune: + try: + for model in ft.list_models(self.api): + if model.status == ft.FineTuneStatus.COMPLETED: + finetunes[model.id] = model.id + finetunes[model.name] = model.id + except Exception as e: + logger.error(f"Failed loading fine-tune models with exception: {e}") + logger.info(f"Found {len(finetunes)} fine-tune models") + return finetunes + def _postprocess_inpainting_mask( self, mask: Union[Image.Image, np.ndarray], diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index b364cae4..2724336c 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -10,12 +10,17 @@ import stability_sdk.interfaces.gooseai.dashboard.dashboard_pb2 as dashboard import stability_sdk.interfaces.gooseai.dashboard.dashboard_pb2_grpc as dashboard_grpc +import stability_sdk.interfaces.gooseai.finetuning.finetuning_pb2 as finetuning +import stability_sdk.interfaces.gooseai.finetuning.finetuning_pb2_grpc as finetuning_grpc import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc +import stability_sdk.interfaces.gooseai.project.project_pb2 as project +import stability_sdk.interfaces.gooseai.project.project_pb2_grpc as project_grpc from .utils import ( image_mix, image_to_prompt, + parse_models_from_prompts, tensor_to_prompt, ) @@ -65,33 +70,37 @@ def __init__(self, stub, engine_id): class Context: def __init__( - self, - host: str="", - api_key: str=None, - stub: generation_grpc.GenerationServiceStub=None, - generate_engine_id: str="stable-diffusion-xl-1024-v0-9", - inpaint_engine_id: str="stable-inpainting-512-v2-0", - interpolate_engine_id: str="interpolation-server-v1", - transform_engine_id: str="transform-server-v1", - upscale_engine_id: str="esrgan-v1-x2plus", - ): + self, + host: str="", + api_key: str=None, + stub: generation_grpc.GenerationServiceStub=None, + generate_engine_id: str="stable-diffusion-xl-1024-v0-9", + inpaint_engine_id: str="stable-inpainting-512-v2-0", + interpolate_engine_id: str="interpolation-server-v1", + transform_engine_id: str="transform-server-v1", + upscale_engine_id: str="esrgan-v1-x2plus", + ): if not host and stub is None: - raise Exception("Must provide either GRPC host or stub to Api") + raise Exception("Must provide either GRPC host or generation stub to Api") channel = open_channel(host, api_key) if host else None if not stub: stub = generation_grpc.GenerationServiceStub(channel) - self._dashboard_stub = dashboard_grpc.DashboardServiceStub(channel) if channel else None + self._stub_dashboard = dashboard_grpc.DashboardServiceStub(channel) if channel else None + self._stub_finetune = finetuning_grpc.FineTuningServiceStub(channel) if channel else None + self._stub_generation = stub + self._stub_project = project_grpc.ProjectServiceStub(channel) if channel else None - self._generate = Endpoint(stub, generate_engine_id) - self._inpaint = Endpoint(stub, inpaint_engine_id) - self._interpolate = Endpoint(stub, interpolate_engine_id) - self._transform = Endpoint(stub, transform_engine_id) - self._upscale = Endpoint(stub, upscale_engine_id) + # endpoints allow overriding RPC connection for specific engines + self._generate = Endpoint(self._stub_generation, generate_engine_id) + self._inpaint = Endpoint(self._stub_generation, inpaint_engine_id) + self._interpolate = Endpoint(self._stub_generation, interpolate_engine_id) + self._transform = Endpoint(self._stub_generation, transform_engine_id) + self._upscale = Endpoint(self._stub_generation, upscale_engine_id) self._debug_no_chains = False - self._max_retries = 5 # retry request on RPC error + self._max_retries = 1 # retry request on RPC error self._request_timeout = 30.0 # timeout in seconds for each request self._retry_delay = 1.0 # base delay in seconds between retries, each attempt will double self._retry_obfuscation = False # retry request with different seed on classifier obfuscation @@ -153,6 +162,7 @@ def generate( if (mask is not None) and (init_image is None) and not return_request: raise ValueError("If mask_image is provided, init_image must also be provided") + prompts, finetune_models = parse_models_from_prompts(prompts) p = [generation.Prompt(text=prompt, parameters=generation.PromptParameters(weight=weight)) for prompt,weight in zip(prompts, weights)] if init_image is not None: p.append(image_to_prompt(init_image)) @@ -162,10 +172,12 @@ def generate( p.append(image_to_prompt(init_depth, type=generation.ARTIFACT_DEPTH)) start_schedule = 1.0 - init_strength - image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale, - start_schedule, init_noise_scale, masked_area_init, - guidance_preset, guidance_cuts, guidance_strength) - + image_params = self._build_image_params( + width, height, sampler, steps, seed, samples, cfg_scale, + start_schedule, init_noise_scale, masked_area_init, finetune_models, + guidance_preset, guidance_cuts, guidance_strength, + ) + extras = Struct() if preset and preset.lower() != 'none': extras.update({ '$IPC': { "preset": preset } }) @@ -181,10 +193,10 @@ def generate( def get_user_info(self) -> Tuple[float, str]: """Get the number of credits the user has remaining and their profile picture.""" if not self._user_organization_id: - user = self._dashboard_stub.GetMe(dashboard.EmptyRequest()) + user = self._stub_dashboard.GetMe(dashboard.EmptyRequest()) self._user_profile_picture = user.profile_picture self._user_organization_id = user.organizations[0].organization.id - organization = self._dashboard_stub.GetOrganization(dashboard.GetOrganizationRequest(id=self._user_organization_id)) + organization = self._stub_dashboard.GetOrganization(dashboard.GetOrganizationRequest(id=self._user_organization_id)) return organization.payment_info.balance * 100, self._user_profile_picture def inpaint( @@ -227,15 +239,18 @@ def inpaint( :param preset: Style preset to use :return: dict mapping artifact type to data """ + prompts, finetune_models = parse_models_from_prompts(prompts) p = [generation.Prompt(text=prompt, parameters=generation.PromptParameters(weight=weight)) for prompt,weight in zip(prompts, weights)] p.append(image_to_prompt(image)) p.append(image_to_prompt(mask, type=generation.ARTIFACT_MASK)) width, height = image.size start_schedule = 1.0-init_strength - image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale, - start_schedule, init_noise_scale, masked_area_init, - guidance_preset, guidance_cuts, guidance_strength) + image_params = self._build_image_params( + width, height, sampler, steps, seed, samples, cfg_scale, + start_schedule, init_noise_scale, masked_area_init, finetune_models, + guidance_preset, guidance_cuts, guidance_strength, + ) extras = Struct() if preset and preset.lower() != 'none': @@ -537,9 +552,8 @@ def _adjust_request_for_retry(self, request: generation.Request, attempt: int): schedule.start = max(0.0, min(1.0, schedule.start + self._retry_schedule_offset)) def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale, - schedule_start, init_noise_scale, masked_area_init, + schedule_start, init_noise_scale, masked_area_init, finetune_models, guidance_preset, guidance_cuts, guidance_strength): - if not seed: seed = [random.randrange(0, 4294967295)] elif isinstance(seed, int): @@ -569,8 +583,14 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_ ] ) + fine_tuning_parameters = ( + [generation.FineTuningParameters(model_id=model, weight=weight) + for model, weight in finetune_models] + if finetune_models else None + ) + return generation.ImageParameters( - transform=None if sampler is None else generation.TransformType(diffusion=sampler), + transform=generation.TransformType(diffusion=sampler) if sampler else None, height=height, width=width, seed=seed, @@ -578,6 +598,7 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_ samples=samples, masked_area_init=masked_area_init, parameters=[generation.StepParameter(**step_parameters)], + fine_tuning_parameters=fine_tuning_parameters ) def _process_response(self, response) -> Dict[int, List[Any]]: diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py index 86557547..cde811c3 100644 --- a/src/stability_sdk/client.py +++ b/src/stability_sdk/client.py @@ -28,6 +28,7 @@ artifact_type_to_string, image_to_prompt, open_images, + parse_models_from_prompts, sampler_from_string, truncate_fit, ) @@ -193,6 +194,8 @@ def generate( :param samples: Number of samples to generate. :param safety: DEPRECATED/UNUSED - Cannot be disabled. :param classifiers: DEPRECATED/UNUSED - Has no effect on image generation. + :param finetune_models: Finetune models to use + :param finetune_weights: Weight of each finetune model :param guidance_preset: Guidance preset to use. See generation.GuidancePreset for supported values. :param guidance_cuts: Number of cuts to use for guidance. :param guidance_strength: Strength of the guidance. We recommend values in range [0.0,1.0]. A good default is 0.25 @@ -230,6 +233,8 @@ def generate( raise TypeError("prompt must be a string or generation.Prompt object") prompts.append(p) + prompts, finetune_models = parse_models_from_prompts(prompts) + step_parameters = dict( scaled_step=0, sampler=generation.SamplerParameters(cfg_scale=cfg_scale), @@ -290,15 +295,22 @@ def generate( transform=None if sampler: - transform=generation.TransformType(diffusion=sampler) + transform = generation.TransformType(diffusion=sampler) - image_parameters=generation.ImageParameters( + fine_tuning_parameters = ( + [generation.FineTuningParameters(model_id=model, weight=weight) + for model, weight in finetune_models] + if finetune_models else None + ) + + image_parameters = generation.ImageParameters( transform=transform, height=height, width=width, seed=seed, steps=steps, samples=samples, + fine_tuning_parameters=fine_tuning_parameters, adapter=adapter_parameters, parameters=[generation.StepParameter(**step_parameters)], ) diff --git a/src/stability_sdk/finetune.py b/src/stability_sdk/finetune.py new file mode 100644 index 00000000..7a14e16e --- /dev/null +++ b/src/stability_sdk/finetune.py @@ -0,0 +1,261 @@ +import logging +import mimetypes + +from enum import Enum +from google.protobuf.struct_pb2 import Struct +from PIL import Image +from pydantic import BaseModel, Field +from typing import Any, Dict, List, Optional + +from .api import Context, finetuning, generation, project +from .utils import image_to_prompt + +TRAINING_IMAGE_MAX_COUNT = {"none":64, "face":64, "style":128, "object":64} +TRAINING_IMAGE_MIN_COUNT = 4 + +TRAINING_IMAGE_MAX_SIZE = 2048 +TRAINING_IMAGE_MIN_SIZE = 384 + + +#============================================================================== +# Types +#============================================================================== + +class FineTuneMode(str, Enum): + NONE = "none" + FACE = "face" + STYLE = "style" + OBJECT = "object" + +class FineTuneStatus(str, Enum): + NOT_STARTED = "not started" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + SUBMITTED = "submitted" + +class FineTuneModel(BaseModel): + id: str = Field(description="UUID") + name: str = Field(description="Name for the fine tuned model") + mode: FineTuneMode = Field(description="Mode for the fine tuning") + object_prompt: Optional[str] = Field(description="Prompt of the object for segmentation") + project_id: str = Field(description="Project ID to fine tune") + engine_id: str = Field(description="Engine ID to fine tune") + user_id: str = Field(description="ID of the user who created the model") + duration: Optional[float] = Field(description="Duration of the fine tuning") + status: Optional[FineTuneStatus] = Field(description="Status of the fine tuning") + +class FineTuneParameters(BaseModel): + name: str = Field(description="Name for the fine tuned model") + mode: FineTuneMode = Field(description="Mode for the fine tuning") + object_prompt: Optional[str] = Field(description="Prompt of the object for segmentation") + engine_id: str = Field(description="Engine ID to fine tune") + +FINETUNE_MODE_MAP = { + FineTuneMode.NONE: finetuning.FINE_TUNING_MODE_UNSPECIFIED, + FineTuneMode.FACE: finetuning.FINE_TUNING_MODE_FACE, + FineTuneMode.STYLE: finetuning.FINE_TUNING_MODE_STYLE, + FineTuneMode.OBJECT: finetuning.FINE_TUNING_MODE_OBJECT, +} + +FINETUNE_STATUS_MAP = { + FineTuneStatus.NOT_STARTED: finetuning.FINE_TUNING_STATUS_NOT_STARTED, + FineTuneStatus.RUNNING: finetuning.FINE_TUNING_STATUS_RUNNING, + FineTuneStatus.COMPLETED: finetuning.FINE_TUNING_STATUS_COMPLETED, + FineTuneStatus.FAILED: finetuning.FINE_TUNING_STATUS_FAILED, + FineTuneStatus.SUBMITTED: finetuning.FINE_TUNING_STATUS_SUBMITTED, +} + + +#============================================================================== +# Core fine-tuning functions +#============================================================================== + +def create_model( + context: Context, + params: FineTuneParameters, + image_paths: List[str], + extras: Dict[str, Any] = None +) -> FineTuneModel: + + # Validate number of images + assert params.mode in TRAINING_IMAGE_MAX_COUNT + mode_max_count = TRAINING_IMAGE_MAX_COUNT[params.mode] + if len(image_paths) > mode_max_count: + raise ValueError(f"Too many images for mode \"{params.mode}\", please use at most {mode_max_count}") + if len(image_paths) < TRAINING_IMAGE_MIN_COUNT: + raise ValueError(f"Too few images, please use at least {TRAINING_IMAGE_MIN_COUNT}") + + # Load and validate images + images = [] + for image_path in image_paths: + image = Image.open(image_path) + if min(image.width, image.height) < TRAINING_IMAGE_MIN_SIZE: + raise ValueError(f"Image {image_path} is too small, please use images with dimensions at least 384x384") + if max(image.width, image.height) > TRAINING_IMAGE_MAX_SIZE: + logging.warning(f"Image {image_path} is too large, resizing to max dimension {TRAINING_IMAGE_MAX_SIZE}") + max_size = max(image.width, image.height) + scale = TRAINING_IMAGE_MAX_SIZE / max_size + image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS) + images.append(image) + else: + images.append(None) + + # Create training project + request = project.CreateProjectRequest( + title=params.name, + type=project.PROJECT_TYPE_TRAINING, + access=project.PROJECT_ACCESS_PRIVATE, + status=project.PROJECT_STATUS_ACTIVE + ) + proj: project.Project = context._stub_project.Create(request) + logging.info(f"Created project {proj.id}") + + try: + # Upload images + for i, image in enumerate(images): + if image is None: + # Directly use the file from disk if it was already the right size + with open(image_paths[i], 'rb') as f: + bytes = f.read() + mime_type, _ = mimetypes.guess_type(image_paths[i]) + prompt = generation.Prompt(artifact=generation.Artifact( + type=generation.ARTIFACT_IMAGE, + binary=bytes, + mime=mime_type + )) + else: + # Encode the resized image + prompt = image_to_prompt(image) + + request = generation.Request( + engine_id="asset-service", + prompt=[prompt], + asset=generation.AssetParameters( + action=generation.ASSET_PUT, + project_id=proj.id, + use=generation.ASSET_USE_INPUT + ) + ) + + success = False + for response in context._stub_generation.Generate(request): + for artifact in response.artifacts: + if artifact.type == generation.ARTIFACT_TEXT: + logging.info(f"Uploaded image {i}: {artifact.text}") + success = True + break + if not success: + raise RuntimeError(f"Failed to upload image {image_paths[i]}") + + # Pass along extra training data for development and testing + extras_struct = Struct() + if extras is not None: + extras_struct.update(extras) + + # Create fine tuning model + request = finetuning.CreateModelRequest( + name=params.name, + mode=mode_to_proto(params.mode), + object_prompt=params.object_prompt if params.mode == FineTuneMode.OBJECT else None, + project_id=proj.id, + engine_id=params.engine_id, + extras=extras_struct + ) + result = context._stub_finetune.CreateModel(request) + return model_from_proto(result.model) + + except Exception as e: + logging.info(f"Encountered error, deleting training project {proj.id}") + context._stub_project.Delete(project.DeleteProjectRequest(id=proj.id)) + raise e + +def delete_model(context: Context, model_id: str) -> FineTuneModel: + request = finetuning.DeleteModelRequest(id=model_id) + result = context._stub_finetune.DeleteModel(request) + return model_from_proto(result.model) + +def get_model(context: Context, model_id: str) -> FineTuneModel: + request = finetuning.GetModelRequest(id=model_id) + result = context._stub_finetune.GetModel(request) + return model_from_proto(result.model) + +def list_models(context: Context, org_id: str=None, user_id: str=None) -> List[FineTuneModel]: + if org_id and user_id: + raise ValueError("Only one of org_id and user_id can be specified") + request = finetuning.ListModelsRequest(org_id=org_id, user_id=user_id) + result = context._stub_finetune.ListModels(request) + return [model_from_proto(model) for model in result.models] + +def resubmit_model(context: Context, model_id: str) -> FineTuneModel: + request = finetuning.ResubmitModelRequest(id=model_id) + result = context._stub_finetune.ResubmitModel(request) + return model_from_proto(result.model) + +def update_model( + context: Context, + model_id: str, + name: Optional[str] = None, + mode: Optional[FineTuneMode] = None, + object_prompt: Optional[str] = None, + engine_id: Optional[str] = None +) -> FineTuneModel: + request = finetuning.UpdateModelRequest( + id=model_id, + name=name, + mode=mode_to_proto(mode) if mode is not None else None, + object_prompt=object_prompt, + engine_id=engine_id, + ) + result = context._stub_finetune.UpdateModel(request) + return model_from_proto(result.model) + + +#============================================================================== +# Utility functions +#============================================================================== + +def mode_from_proto(mode: finetuning.FineTuningMode) -> FineTuneMode: + for key, value in FINETUNE_MODE_MAP.items(): + if value == mode: + return key + logging.warning(f"Unrecognized fine tuning mode {mode}") + return FineTuneMode.NONE + +def mode_to_proto(mode: FineTuneMode) -> finetuning.FineTuningMode: + mapping = { + FineTuneMode.NONE: finetuning.FINE_TUNING_MODE_UNSPECIFIED, + FineTuneMode.FACE: finetuning.FINE_TUNING_MODE_FACE, + FineTuneMode.STYLE: finetuning.FINE_TUNING_MODE_STYLE, + FineTuneMode.OBJECT: finetuning.FINE_TUNING_MODE_OBJECT, + } + value = mapping.get(mode) + if value is None: + raise ValueError(f"Invalid fine tuning mode {mode}") + return value + +def model_from_proto(model: finetuning.FineTuningModel) -> FineTuneModel: + return FineTuneModel( + id=model.id, + name=model.name, + mode=mode_from_proto(model.mode), + object_prompt=model.object_prompt, + project_id=model.project_id, + engine_id=model.engine_id, + user_id=model.user_id, + duration=model.duration, + status=status_from_proto(model.status), + ) + +def status_from_proto(status: finetuning.FineTuningStatus) -> FineTuneStatus: + for key, value in FINETUNE_STATUS_MAP.items(): + if value == status: + return key + logging.warning(f"Unrecognized fine tuning status {status}") + return FineTuneStatus.NOT_STARTED + +def status_to_proto(status: FineTuneStatus) -> finetuning.FineTuningStatus: + value = FINETUNE_STATUS_MAP.get(status) + if value is None: + raise ValueError(f"Invalid fine tuning status {status}") + return value diff --git a/src/stability_sdk/utils.py b/src/stability_sdk/utils.py index fd178a58..22bae2e2 100644 --- a/src/stability_sdk/utils.py +++ b/src/stability_sdk/utils.py @@ -1,10 +1,12 @@ import io import logging +import math import os +import re import subprocess from PIL import Image -from typing import Dict, Generator, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, TypeVar, Union from .api import generation from .matrix import Matrix @@ -317,7 +319,8 @@ def image_to_prompt( """ return generation.Prompt(artifact=generation.Artifact( type=type, - binary=image_to_png_bytes(image) + binary=image_to_png_bytes(image), + mime="image/png" )) def open_images( @@ -341,6 +344,45 @@ def open_images( img.show() yield (path, artifact) +def parse_models_from_prompts(prompts: Union[Any, List[Any]]) -> Tuple[List[Any], List[Tuple[str, float]]]: + """ + Parses prompt strings for model names and weights with syntax . + :param prompts: List of prompt strings or objects. + :return: Updated prompts and list of tuples with model names and weights. + """ + if not prompts: + return [], [] + prompts = prompts if isinstance(prompts, List) else [prompts] + pattern = re.compile(r"<([^:>]+)(?::([^>]+))?>") + models = {} + + def _process_prompt(prompt): + text = prompt.text if isinstance(prompt, generation.Prompt) else prompt + matches = pattern.findall(text) + for model, weight in matches: + weight_text = weight if weight else "" + # pass default TI tokens through unmodified + if model in ["s1", "s2", "s3"]: + continue + try: + weight = max(float(weight) if weight else 1.0, models.get(model, -math.inf)) + except ValueError as e: + raise ValueError(f'Invalid weight for model "{model}": "{weight}"') from e + text = text.replace(f'<{model}:{weight_text}>', f'<{model}>') + models[model] = weight + return text + + out_prompts = [] + for prompt in prompts: + text = _process_prompt(prompt) + if isinstance(prompt, generation.Prompt): + prompt.text = text + out_prompts.append(prompt) + else: + out_prompts.append(text) + + return out_prompts, list(models.items()) + def tensor_to_prompt(tensor: 'tensors_pb.Tensor') -> generation.Prompt: """ Create Prompt message type from a tensor. diff --git a/tests/test_utils.py b/tests/test_utils.py index 88c375da..05a61ad7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,6 +20,7 @@ image_to_jpg_bytes, image_to_png_bytes, image_to_prompt, + parse_models_from_prompts, resample_transform, sampler_from_string, truncate_fit, @@ -70,6 +71,27 @@ def test_color_match_from_string_invalid(): with pytest.raises(ValueError, match="invalid color match"): color_match_from_string(s='not a real color match mode') +def test_parse_models_from_prompts(): + assert parse_models_from_prompts(None) == ([], []) + assert parse_models_from_prompts([]) == ([], []) + assert parse_models_from_prompts("a ")[1] == [("one", 1.0)] + assert parse_models_from_prompts("a ")[1] == [("one", 1.0)] + assert parse_models_from_prompts("a simple prompt") == (["a simple prompt"], []) + assert parse_models_from_prompts("") == ([""], [("weight-strip", 0.25)]) + assert parse_models_from_prompts("a ")[1] == [("my-model", 1.0)] + assert parse_models_from_prompts("a and a ")[1] == [("model-one", 1.0), ("model two", 1.0)] + assert parse_models_from_prompts("a ")[1] == [("with-weight", 0.25)] + assert parse_models_from_prompts("a ")[1] == [("neg-weight", -0.75)] + assert parse_models_from_prompts("a ")[1] == [("scientific", 1e-2)] + assert parse_models_from_prompts("a ")[1] == [("zero", 0.0)] + assert parse_models_from_prompts(["none", "", ""])[1] == [("model-one", 1.0), ("model-two", 2.0)] + assert parse_models_from_prompts([" and "])[1] == [("dupe", 0.5)] + with pytest.raises(ValueError, match="Invalid weight for model \"model-id\": \"bad-weight\""): + parse_models_from_prompts("a ") + + # test support for generation.Prompt objects + assert isinstance(parse_models_from_prompts(generation.Prompt(text="a simple prompt"))[0][0], generation.Prompt) + assert isinstance(parse_models_from_prompts(["plain str", generation.Prompt(text="a simple prompt")])[0][1], generation.Prompt) #################################### # to do: pytest.mark.paramaterized #