diff --git a/setfit_doc/en/pytorch/zero_shot.ipynb b/setfit_doc/en/pytorch/zero_shot.ipynb index 780f6299..68bb35c0 100644 --- a/setfit_doc/en/pytorch/zero_shot.ipynb +++ b/setfit_doc/en/pytorch/zero_shot.ipynb @@ -11,23 +11,61 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Your class names are likely already good descriptors of the text that you're looking to classify. With 🤗 SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples.\n", + "Although SetFit was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create _synthetic examples_ that resemble the classification task, and then train a SetFit model on them. \n", "\n", - "This guide will show you how to perform zero-shot text classification." + "Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster!\n", + "\n", + "In this tutorial, we'll explore how:\n", + "\n", + "* SetFit can be applied for zero-shot classification\n", + "* Adding synthetic examples can also provide a performance boost to few-shot classification." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install setfit matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Testing dataset" + "To benchmark the performance of the \"zero-shot\" method, we'll use the following dataset and pretrained model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_id = \"emotion\"\n", + "model_id = \"sentence-transformers/paraphrase-mpnet-base-v2\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We'll use the [dair-ai/emotion](https://huggingface.co/datasets/dair-ai/emotion) dataset to test the performance of our zero-shot model." + "Next, we'll download the reference dataset from the Hugging Face Hub:" ] }, { @@ -38,14 +76,57 @@ "source": [ "from datasets import load_dataset\n", "\n", - "test_dataset = load_dataset(\"dair-ai/emotion\", \"split\", split=\"test\")" + "reference_dataset = load_dataset(dataset_id)\n", + "reference_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 16000\n", + " })\n", + " validation: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 2000\n", + " })\n", + " test: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 2000\n", + " })\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we're set up, let's create some synthetic data to train on!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a synthetic dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This dataset stores the class names within the dataset `Features`, so we'll extract the classes like so:" + "The first thing we need to do is create a dataset of synthetic examples. In `setfit`, we can do this by applying the `get_templated_dataset()` function to a dummy dataset. This function expects a few main things:\n", + "\n", + "* A list of candidate labels to classify with. We'll use the labels from the reference dataset here, but this could be anything that's relevant to the task and dataset at hand.\n", + "* A template to generate examples with. By default, it is `\"This sentence is {}\"`, where the `{}` will be filled by one of the candidate labels\n", + "* A sample size $N$, which will create $N$ synthetic examples per class. We find $N=8$ usually works best.\n", + "\n", + "Armed with this information, let's first extract some candidate labels from the dataset:" ] }, { @@ -54,29 +135,52 @@ "metadata": {}, "outputs": [], "source": [ - "classes = test_dataset.features[\"label\"].names\n", - "# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']" + "# Extract ClassLabel feature from \"label\" column\n", + "label_features = reference_dataset[\"train\"].features[\"label\"]\n", + "# Label names to classify with\n", + "candidate_labels = label_features.names\n", + "candidate_labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Otherwise, we could manually set the list of classes." + "```\n", + "['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']\n", + "```\n", + "\n", + "\n", + "\n", + "Some datasets on the Hugging Face Hub don't have a `ClassLabel` feature for the label column. In these cases, you should compute the candidate labels manually by first computing the id2label mapping as follows:\n", + "\n", + "" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## Synthetic dataset" + "def get_id2label(dataset):\n", + " # The column with the label names\n", + " label_names = dataset.unique(\"label_text\")\n", + " # The column with the label IDs\n", + " label_ids = dataset.unique(\"label\")\n", + " id2label = dict(zip(label_ids, label_names))\n", + " # Sort by label ID\n", + " return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}\n", + "\n", + "id2label = get_id2label(reference_dataset[\"train\"])\n", + "candidate_labels = list(id2label.values())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Then, we can use [get_templated_dataset()](https://huggingface.co/docs/setfit/main/en/reference/utility#setfit.get_templated_dataset) to synthetically generate a dummy dataset given these class names." + "Now that we have the labels, it's a simple matter to create synthetic examples:" ] }, { @@ -85,9 +189,34 @@ "metadata": {}, "outputs": [], "source": [ + "from datasets import Dataset\n", "from setfit import get_templated_dataset\n", "\n", - "train_dataset = get_templated_dataset()" + "# A dummy dataset to fill with synthetic examples\n", + "dummy_dataset = Dataset.from_dict({})\n", + "train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)\n", + "train_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 48\n", + "})\n", + "```\n", + "\n", + "\n", + "\n", + "You might find you can get better performance by tweaking the `template` argument from the default of `\"The sentence is {}\"` to variants like `\"This sentence is {}\"` or `\"This example is {}\"`.\n", + "\n", + "\n", + "\n", + "\n", + "Since our dataset has 6 classes and we chose a sample size of 8, our synthetic dataset contains $6\\times 8=48$ examples. If we take a look at a few of the examples:" ] }, { @@ -96,27 +225,37 @@ "metadata": {}, "outputs": [], "source": [ - "print(train_dataset)\n", - "# => Dataset({\n", - "# features: ['text', 'label'],\n", - "# num_rows: 48\n", - "# })\n", - "print(train_dataset[0])\n", - "# {'text': 'This sentence is sadness', 'label': 0}" + "train_dataset.shuffle()[:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "{'text': ['This sentence is love',\n", + " 'This sentence is fear',\n", + " 'This sentence is joy'],\n", + " 'label': [2, 4, 1]}\n", + "```\n", + "\n", + "We can see that each input takes the form of the template and has a corresponding label associated with it. \n", + "\n", + "Let's not train a SetFit model on these examples!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Training" + "## Fine-tuning the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can use this dataset to train a SetFit model just like normal:" + "To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the `SetFitModel.from_pretrained()` method:" ] }, { @@ -125,41 +264,109 @@ "metadata": {}, "outputs": [], "source": [ - "from setfit import SetFitModel, Trainer, TrainingArguments\n", - "\n", - "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\")\n", + "from setfit import SetFitModel\n", "\n", - "args = TrainingArguments(\n", - " batch_size=32,\n", - " num_epochs=1,\n", - ")\n", + "model = SetFitModel.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the [Trainer](https://huggingface.co/docs/setfit/main/en/reference/trainer#setfit.Trainer) class as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from setfit import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", - " args=args,\n", " train_dataset=train_dataset,\n", - " eval_dataset=test_dataset,\n", - ")\n", - "trainer.train()" + " eval_dataset=reference_dataset[\"test\"]\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```\n", + "Now that we've created a trainer, we can train it! While we're at it, let's time how long it takes to train and evaluate the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "trainer.train()\n", + "zeroshot_metrics = trainer.evaluate()\n", + "zeroshot_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "***** Running training *****\n", - " Num examples = 60\n", + " Num examples = 1920\n", " Num epochs = 1\n", - " Total optimization steps = 60\n", - " Total train batch size = 32\n", - "{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02} \n", - "{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83} \n", - "{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0} \n", - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00, 6.35it/s]\n", + " Total optimization steps = 120\n", + " Total train batch size = 16\n", + "***** Running evaluation *****\n", + "{'accuracy': 0.5345}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s\n", + "Wall time: 11 s\n", "```\n", "\n", - "Once trained, we can evaluate the model:" + "Great, now that we have a reference score let's compare against the zero-shot pipeline from 🤗 Transformers." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing against the zero-shot pipeline from 🤗 Transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "🤗 Transformers provides a zero-shot pipeline that frames text classification as a natural language inference task. Let's load the pipeline and place it on the GPU for fast inference:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import pipeline\n", + "\n", + "pipe = pipeline(\"zero-shot-classification\", device=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have the model, let's generate some predictions. We'll use the same candidate labels as we did with SetFit and increase the batch size for to speed things up:" ] }, { @@ -168,8 +375,8 @@ "metadata": {}, "outputs": [], "source": [ - "metrics = trainer.evaluate()\n", - "print(metrics)" + "%%time\n", + "zeroshot_preds = pipe(reference_dataset[\"test\"][\"text\"], batch_size=16, candidate_labels=candidate_labels)" ] }, { @@ -177,11 +384,11 @@ "metadata": {}, "source": [ "```\n", - "***** Running evaluation *****\n", - "{'accuracy': 0.591}\n", + "CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s\n", + "Wall time: 53.1 s\n", "```\n", "\n", - "And run predictions:" + "Note that this took almost 5x longer to generate predictions than SetFit! OK, so how well does it perform? Since each prediction is a dictionary of label names ranked by score:" ] }, { @@ -190,14 +397,7 @@ "metadata": {}, "outputs": [], "source": [ - "preds = model.predict([\n", - " \"i am just feeling cranky and blue\",\n", - " \"i feel incredibly lucky just to be able to talk to her\",\n", - " \"you're pissing me off right now\",\n", - " \"i definitely have thalassophobia, don't get me near water like that\",\n", - " \"i did not see that coming at all\",\n", - "])\n", - "print([classes[idx] for idx in preds])" + "zeroshot_preds[0]" ] }, { @@ -206,28 +406,55 @@ "metadata": {}, "outputs": [], "source": [ - "['sadness', 'joy', 'anger', 'fear', 'surprise']" + "{'sequence': 'im feeling rather rotten so im not very ambitious right now',\n", + " 'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],\n", + " 'scores': [0.7367985844612122,\n", + " 0.10041674226522446,\n", + " 0.09770156443119049,\n", + " 0.05880110710859299,\n", + " 0.004266355652362108,\n", + " 0.0020156768150627613]}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "These predictions all look right!" + "We can use the `str2int()` function from the `label` column to convert them to integers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = [label_features.str2int(pred[\"labels\"][0]) for pred in zeroshot_preds]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Baseline" + "**Note:** As noted earlier, if you're using a dataset that doesn't have a `ClassLabel` feature for the label column, you'll need to compute the label mapping manually with something like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "id2label = get_id2label(reference_dataset[\"train\"])\n", + "label2id = {v:k for k,v in id2label.items()}\n", + "preds = [label2id[pred[\"labels\"][0]] for pred in zeroshot_preds]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To show that the zero-shot performance of SetFit works well, we'll compare it against a zero-shot classification model from `transformers`." + "The last step is to compute accuracy using 🤗 Evaluate:" ] }, { @@ -236,24 +463,11 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import pipeline\n", - "from datasets import load_dataset\n", "import evaluate\n", "\n", - "# Prepare the testing dataset\n", - "test_dataset = load_dataset(\"dair-ai/emotion\", \"split\", split=\"test\")\n", - "classes = test_dataset.features[\"label\"].names\n", - "\n", - "# Set up the zero-shot classification pipeline from transformers\n", - "# Uses 'facebook/bart-large-mnli' by default\n", - "pipe = pipeline(\"zero-shot-classification\", device=0)\n", - "zeroshot_preds = pipe(test_dataset[\"text\"], batch_size=16, candidate_labels=classes)\n", - "preds = [classes.index(pred[\"labels\"][0]) for pred in zeroshot_preds]\n", - "\n", - "# Compute the accuracy\n", "metric = evaluate.load(\"accuracy\")\n", - "transformers_accuracy = metric.compute(predictions=preds, references=test_dataset[\"label\"])\n", - "print(transformers_accuracy)" + "transformers_metrics = metric.compute(predictions=preds, references=reference_dataset[\"test\"][\"label\"])\n", + "transformers_metrics" ] }, { @@ -269,21 +483,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by `transformers`." + "Compared to SetFit, this approach performs significantly worse. Let's wrap up our analysis by combining synthetic examples with a few labeled ones." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Prediction latency" + "## Augmenting labeled data with synthetic examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Beyond getting higher accuracies, SetFit is much faster too. Let's compute the latency of SetFit with `BAAI/bge-small-en-v1.5` versus the latency of `transformers` with `facebook/bart-large-mnli`. Both tests were performed on a GPU." + "If you have a few labeled examples, adding synthetic data can often boost performance. To simulate this, let's first sample 8 labeled examples from our reference dataset:" ] }, { @@ -292,21 +506,47 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "from setfit import sample_dataset\n", "\n", - "start_t = time.time()\n", - "pipe(test_dataset[\"text\"], batch_size=32, candidate_labels=classes)\n", - "delta_t = time.time() - start_t\n", - "print(f\"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence\")" + "train_dataset = sample_dataset(reference_dataset[\"train\"])\n", + "train_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 48\n", + "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```\n", - "`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence\n", - "```" + "To warm up, we'll train a SetFit model on these true labels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = SetFitModel.from_pretrained(model_id)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=reference_dataset[\"test\"]\n", + ")\n", + "trainer.train()\n", + "fewshot_metrics = trainer.evaluate()\n", + "fewshot_metrics" ] }, { @@ -315,12 +555,63 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "{'accuracy': 0.4705}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that for this particular dataset, the performance with true labels is _worse_ than training on synthetic examples! In our experiments, we found that the difference depends strongly on the dataset in question. Since SetFit models are fast to train, you can always try both approaches and pick the best one.\n", "\n", - "start_t = time.time()\n", - "model.predict(test_dataset[\"text\"])\n", - "delta_t = time.time() - start_t\n", - "print(f\"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence\")" + "In any case, let's now add some synthetic examples to our training set:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)\n", + "augmented_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 96\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As before, we can train and evaluate SetFit with the augmented dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = SetFitModel.from_pretrained(model_id)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataset=augmented_dataset,\n", + " eval_dataset=reference_dataset[\"test\"]\n", + ")\n", + "trainer.train()\n", + "augmented_metrics = trainer.evaluate()\n", + "augmented_metrics" ] }, { @@ -328,12 +619,31 @@ "metadata": {}, "source": [ "```\n", - "SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence\n", + "{'accuracy': 0.613}\n", "```\n", "\n", - "So, SetFit with `BAAI/bge-small-en-v1.5` is 67x faster than `transformers` with `facebook/bart-large-mnli`, alongside being more accurate:\n", + "Great, this has given us a significant boost in performance and given us a few percentage points over the purely synthetic example. \n", "\n", - "![zero_shot_transformers_vs_setfit](https://github.com/huggingface/setfit/assets/37621491/33f574d9-c51b-4e02-8d98-6e04e18427ef)" + "Let's plot the final results for comparison:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame.from_dict({\"Method\":[\"Transformers (zero-shot)\", \"SetFit (zero-shot)\", \"SetFit (augmented)\"], \"Accuracy\": [transformers_metrics[\"accuracy\"], zeroshot_metrics[\"accuracy\"], augmented_metrics[\"accuracy\"]]})\n", + "df.plot(kind=\"barh\", x=\"Method\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![setfit_zero_shot_results](https://github.com/huggingface/setfit/assets/37621491/b02d3e62-d51c-4506-91f6-2fe9b7ef554d)" ] } ], diff --git a/setfit_doc/en/tensorflow/zero_shot.ipynb b/setfit_doc/en/tensorflow/zero_shot.ipynb index 780f6299..68bb35c0 100644 --- a/setfit_doc/en/tensorflow/zero_shot.ipynb +++ b/setfit_doc/en/tensorflow/zero_shot.ipynb @@ -11,23 +11,61 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Your class names are likely already good descriptors of the text that you're looking to classify. With 🤗 SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples.\n", + "Although SetFit was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create _synthetic examples_ that resemble the classification task, and then train a SetFit model on them. \n", "\n", - "This guide will show you how to perform zero-shot text classification." + "Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster!\n", + "\n", + "In this tutorial, we'll explore how:\n", + "\n", + "* SetFit can be applied for zero-shot classification\n", + "* Adding synthetic examples can also provide a performance boost to few-shot classification." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install setfit matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Testing dataset" + "To benchmark the performance of the \"zero-shot\" method, we'll use the following dataset and pretrained model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_id = \"emotion\"\n", + "model_id = \"sentence-transformers/paraphrase-mpnet-base-v2\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We'll use the [dair-ai/emotion](https://huggingface.co/datasets/dair-ai/emotion) dataset to test the performance of our zero-shot model." + "Next, we'll download the reference dataset from the Hugging Face Hub:" ] }, { @@ -38,14 +76,57 @@ "source": [ "from datasets import load_dataset\n", "\n", - "test_dataset = load_dataset(\"dair-ai/emotion\", \"split\", split=\"test\")" + "reference_dataset = load_dataset(dataset_id)\n", + "reference_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 16000\n", + " })\n", + " validation: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 2000\n", + " })\n", + " test: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 2000\n", + " })\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we're set up, let's create some synthetic data to train on!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a synthetic dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This dataset stores the class names within the dataset `Features`, so we'll extract the classes like so:" + "The first thing we need to do is create a dataset of synthetic examples. In `setfit`, we can do this by applying the `get_templated_dataset()` function to a dummy dataset. This function expects a few main things:\n", + "\n", + "* A list of candidate labels to classify with. We'll use the labels from the reference dataset here, but this could be anything that's relevant to the task and dataset at hand.\n", + "* A template to generate examples with. By default, it is `\"This sentence is {}\"`, where the `{}` will be filled by one of the candidate labels\n", + "* A sample size $N$, which will create $N$ synthetic examples per class. We find $N=8$ usually works best.\n", + "\n", + "Armed with this information, let's first extract some candidate labels from the dataset:" ] }, { @@ -54,29 +135,52 @@ "metadata": {}, "outputs": [], "source": [ - "classes = test_dataset.features[\"label\"].names\n", - "# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']" + "# Extract ClassLabel feature from \"label\" column\n", + "label_features = reference_dataset[\"train\"].features[\"label\"]\n", + "# Label names to classify with\n", + "candidate_labels = label_features.names\n", + "candidate_labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Otherwise, we could manually set the list of classes." + "```\n", + "['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']\n", + "```\n", + "\n", + "\n", + "\n", + "Some datasets on the Hugging Face Hub don't have a `ClassLabel` feature for the label column. In these cases, you should compute the candidate labels manually by first computing the id2label mapping as follows:\n", + "\n", + "" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## Synthetic dataset" + "def get_id2label(dataset):\n", + " # The column with the label names\n", + " label_names = dataset.unique(\"label_text\")\n", + " # The column with the label IDs\n", + " label_ids = dataset.unique(\"label\")\n", + " id2label = dict(zip(label_ids, label_names))\n", + " # Sort by label ID\n", + " return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}\n", + "\n", + "id2label = get_id2label(reference_dataset[\"train\"])\n", + "candidate_labels = list(id2label.values())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Then, we can use [get_templated_dataset()](https://huggingface.co/docs/setfit/main/en/reference/utility#setfit.get_templated_dataset) to synthetically generate a dummy dataset given these class names." + "Now that we have the labels, it's a simple matter to create synthetic examples:" ] }, { @@ -85,9 +189,34 @@ "metadata": {}, "outputs": [], "source": [ + "from datasets import Dataset\n", "from setfit import get_templated_dataset\n", "\n", - "train_dataset = get_templated_dataset()" + "# A dummy dataset to fill with synthetic examples\n", + "dummy_dataset = Dataset.from_dict({})\n", + "train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)\n", + "train_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 48\n", + "})\n", + "```\n", + "\n", + "\n", + "\n", + "You might find you can get better performance by tweaking the `template` argument from the default of `\"The sentence is {}\"` to variants like `\"This sentence is {}\"` or `\"This example is {}\"`.\n", + "\n", + "\n", + "\n", + "\n", + "Since our dataset has 6 classes and we chose a sample size of 8, our synthetic dataset contains $6\\times 8=48$ examples. If we take a look at a few of the examples:" ] }, { @@ -96,27 +225,37 @@ "metadata": {}, "outputs": [], "source": [ - "print(train_dataset)\n", - "# => Dataset({\n", - "# features: ['text', 'label'],\n", - "# num_rows: 48\n", - "# })\n", - "print(train_dataset[0])\n", - "# {'text': 'This sentence is sadness', 'label': 0}" + "train_dataset.shuffle()[:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "{'text': ['This sentence is love',\n", + " 'This sentence is fear',\n", + " 'This sentence is joy'],\n", + " 'label': [2, 4, 1]}\n", + "```\n", + "\n", + "We can see that each input takes the form of the template and has a corresponding label associated with it. \n", + "\n", + "Let's not train a SetFit model on these examples!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Training" + "## Fine-tuning the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can use this dataset to train a SetFit model just like normal:" + "To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the `SetFitModel.from_pretrained()` method:" ] }, { @@ -125,41 +264,109 @@ "metadata": {}, "outputs": [], "source": [ - "from setfit import SetFitModel, Trainer, TrainingArguments\n", - "\n", - "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\")\n", + "from setfit import SetFitModel\n", "\n", - "args = TrainingArguments(\n", - " batch_size=32,\n", - " num_epochs=1,\n", - ")\n", + "model = SetFitModel.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the [Trainer](https://huggingface.co/docs/setfit/main/en/reference/trainer#setfit.Trainer) class as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from setfit import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", - " args=args,\n", " train_dataset=train_dataset,\n", - " eval_dataset=test_dataset,\n", - ")\n", - "trainer.train()" + " eval_dataset=reference_dataset[\"test\"]\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```\n", + "Now that we've created a trainer, we can train it! While we're at it, let's time how long it takes to train and evaluate the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "trainer.train()\n", + "zeroshot_metrics = trainer.evaluate()\n", + "zeroshot_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "***** Running training *****\n", - " Num examples = 60\n", + " Num examples = 1920\n", " Num epochs = 1\n", - " Total optimization steps = 60\n", - " Total train batch size = 32\n", - "{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02} \n", - "{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83} \n", - "{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0} \n", - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00, 6.35it/s]\n", + " Total optimization steps = 120\n", + " Total train batch size = 16\n", + "***** Running evaluation *****\n", + "{'accuracy': 0.5345}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s\n", + "Wall time: 11 s\n", "```\n", "\n", - "Once trained, we can evaluate the model:" + "Great, now that we have a reference score let's compare against the zero-shot pipeline from 🤗 Transformers." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing against the zero-shot pipeline from 🤗 Transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "🤗 Transformers provides a zero-shot pipeline that frames text classification as a natural language inference task. Let's load the pipeline and place it on the GPU for fast inference:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import pipeline\n", + "\n", + "pipe = pipeline(\"zero-shot-classification\", device=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have the model, let's generate some predictions. We'll use the same candidate labels as we did with SetFit and increase the batch size for to speed things up:" ] }, { @@ -168,8 +375,8 @@ "metadata": {}, "outputs": [], "source": [ - "metrics = trainer.evaluate()\n", - "print(metrics)" + "%%time\n", + "zeroshot_preds = pipe(reference_dataset[\"test\"][\"text\"], batch_size=16, candidate_labels=candidate_labels)" ] }, { @@ -177,11 +384,11 @@ "metadata": {}, "source": [ "```\n", - "***** Running evaluation *****\n", - "{'accuracy': 0.591}\n", + "CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s\n", + "Wall time: 53.1 s\n", "```\n", "\n", - "And run predictions:" + "Note that this took almost 5x longer to generate predictions than SetFit! OK, so how well does it perform? Since each prediction is a dictionary of label names ranked by score:" ] }, { @@ -190,14 +397,7 @@ "metadata": {}, "outputs": [], "source": [ - "preds = model.predict([\n", - " \"i am just feeling cranky and blue\",\n", - " \"i feel incredibly lucky just to be able to talk to her\",\n", - " \"you're pissing me off right now\",\n", - " \"i definitely have thalassophobia, don't get me near water like that\",\n", - " \"i did not see that coming at all\",\n", - "])\n", - "print([classes[idx] for idx in preds])" + "zeroshot_preds[0]" ] }, { @@ -206,28 +406,55 @@ "metadata": {}, "outputs": [], "source": [ - "['sadness', 'joy', 'anger', 'fear', 'surprise']" + "{'sequence': 'im feeling rather rotten so im not very ambitious right now',\n", + " 'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],\n", + " 'scores': [0.7367985844612122,\n", + " 0.10041674226522446,\n", + " 0.09770156443119049,\n", + " 0.05880110710859299,\n", + " 0.004266355652362108,\n", + " 0.0020156768150627613]}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "These predictions all look right!" + "We can use the `str2int()` function from the `label` column to convert them to integers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = [label_features.str2int(pred[\"labels\"][0]) for pred in zeroshot_preds]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Baseline" + "**Note:** As noted earlier, if you're using a dataset that doesn't have a `ClassLabel` feature for the label column, you'll need to compute the label mapping manually with something like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "id2label = get_id2label(reference_dataset[\"train\"])\n", + "label2id = {v:k for k,v in id2label.items()}\n", + "preds = [label2id[pred[\"labels\"][0]] for pred in zeroshot_preds]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To show that the zero-shot performance of SetFit works well, we'll compare it against a zero-shot classification model from `transformers`." + "The last step is to compute accuracy using 🤗 Evaluate:" ] }, { @@ -236,24 +463,11 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import pipeline\n", - "from datasets import load_dataset\n", "import evaluate\n", "\n", - "# Prepare the testing dataset\n", - "test_dataset = load_dataset(\"dair-ai/emotion\", \"split\", split=\"test\")\n", - "classes = test_dataset.features[\"label\"].names\n", - "\n", - "# Set up the zero-shot classification pipeline from transformers\n", - "# Uses 'facebook/bart-large-mnli' by default\n", - "pipe = pipeline(\"zero-shot-classification\", device=0)\n", - "zeroshot_preds = pipe(test_dataset[\"text\"], batch_size=16, candidate_labels=classes)\n", - "preds = [classes.index(pred[\"labels\"][0]) for pred in zeroshot_preds]\n", - "\n", - "# Compute the accuracy\n", "metric = evaluate.load(\"accuracy\")\n", - "transformers_accuracy = metric.compute(predictions=preds, references=test_dataset[\"label\"])\n", - "print(transformers_accuracy)" + "transformers_metrics = metric.compute(predictions=preds, references=reference_dataset[\"test\"][\"label\"])\n", + "transformers_metrics" ] }, { @@ -269,21 +483,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by `transformers`." + "Compared to SetFit, this approach performs significantly worse. Let's wrap up our analysis by combining synthetic examples with a few labeled ones." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Prediction latency" + "## Augmenting labeled data with synthetic examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Beyond getting higher accuracies, SetFit is much faster too. Let's compute the latency of SetFit with `BAAI/bge-small-en-v1.5` versus the latency of `transformers` with `facebook/bart-large-mnli`. Both tests were performed on a GPU." + "If you have a few labeled examples, adding synthetic data can often boost performance. To simulate this, let's first sample 8 labeled examples from our reference dataset:" ] }, { @@ -292,21 +506,47 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "from setfit import sample_dataset\n", "\n", - "start_t = time.time()\n", - "pipe(test_dataset[\"text\"], batch_size=32, candidate_labels=classes)\n", - "delta_t = time.time() - start_t\n", - "print(f\"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence\")" + "train_dataset = sample_dataset(reference_dataset[\"train\"])\n", + "train_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 48\n", + "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```\n", - "`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence\n", - "```" + "To warm up, we'll train a SetFit model on these true labels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = SetFitModel.from_pretrained(model_id)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=reference_dataset[\"test\"]\n", + ")\n", + "trainer.train()\n", + "fewshot_metrics = trainer.evaluate()\n", + "fewshot_metrics" ] }, { @@ -315,12 +555,63 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "{'accuracy': 0.4705}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that for this particular dataset, the performance with true labels is _worse_ than training on synthetic examples! In our experiments, we found that the difference depends strongly on the dataset in question. Since SetFit models are fast to train, you can always try both approaches and pick the best one.\n", "\n", - "start_t = time.time()\n", - "model.predict(test_dataset[\"text\"])\n", - "delta_t = time.time() - start_t\n", - "print(f\"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence\")" + "In any case, let's now add some synthetic examples to our training set:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)\n", + "augmented_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 96\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As before, we can train and evaluate SetFit with the augmented dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = SetFitModel.from_pretrained(model_id)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataset=augmented_dataset,\n", + " eval_dataset=reference_dataset[\"test\"]\n", + ")\n", + "trainer.train()\n", + "augmented_metrics = trainer.evaluate()\n", + "augmented_metrics" ] }, { @@ -328,12 +619,31 @@ "metadata": {}, "source": [ "```\n", - "SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence\n", + "{'accuracy': 0.613}\n", "```\n", "\n", - "So, SetFit with `BAAI/bge-small-en-v1.5` is 67x faster than `transformers` with `facebook/bart-large-mnli`, alongside being more accurate:\n", + "Great, this has given us a significant boost in performance and given us a few percentage points over the purely synthetic example. \n", "\n", - "![zero_shot_transformers_vs_setfit](https://github.com/huggingface/setfit/assets/37621491/33f574d9-c51b-4e02-8d98-6e04e18427ef)" + "Let's plot the final results for comparison:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame.from_dict({\"Method\":[\"Transformers (zero-shot)\", \"SetFit (zero-shot)\", \"SetFit (augmented)\"], \"Accuracy\": [transformers_metrics[\"accuracy\"], zeroshot_metrics[\"accuracy\"], augmented_metrics[\"accuracy\"]]})\n", + "df.plot(kind=\"barh\", x=\"Method\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![setfit_zero_shot_results](https://github.com/huggingface/setfit/assets/37621491/b02d3e62-d51c-4506-91f6-2fe9b7ef554d)" ] } ], diff --git a/setfit_doc/en/zero_shot.ipynb b/setfit_doc/en/zero_shot.ipynb index 780f6299..68bb35c0 100644 --- a/setfit_doc/en/zero_shot.ipynb +++ b/setfit_doc/en/zero_shot.ipynb @@ -11,23 +11,61 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Your class names are likely already good descriptors of the text that you're looking to classify. With 🤗 SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples.\n", + "Although SetFit was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create _synthetic examples_ that resemble the classification task, and then train a SetFit model on them. \n", "\n", - "This guide will show you how to perform zero-shot text classification." + "Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster!\n", + "\n", + "In this tutorial, we'll explore how:\n", + "\n", + "* SetFit can be applied for zero-shot classification\n", + "* Adding synthetic examples can also provide a performance boost to few-shot classification." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install setfit matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Testing dataset" + "To benchmark the performance of the \"zero-shot\" method, we'll use the following dataset and pretrained model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_id = \"emotion\"\n", + "model_id = \"sentence-transformers/paraphrase-mpnet-base-v2\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We'll use the [dair-ai/emotion](https://huggingface.co/datasets/dair-ai/emotion) dataset to test the performance of our zero-shot model." + "Next, we'll download the reference dataset from the Hugging Face Hub:" ] }, { @@ -38,14 +76,57 @@ "source": [ "from datasets import load_dataset\n", "\n", - "test_dataset = load_dataset(\"dair-ai/emotion\", \"split\", split=\"test\")" + "reference_dataset = load_dataset(dataset_id)\n", + "reference_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 16000\n", + " })\n", + " validation: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 2000\n", + " })\n", + " test: Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 2000\n", + " })\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we're set up, let's create some synthetic data to train on!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a synthetic dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This dataset stores the class names within the dataset `Features`, so we'll extract the classes like so:" + "The first thing we need to do is create a dataset of synthetic examples. In `setfit`, we can do this by applying the `get_templated_dataset()` function to a dummy dataset. This function expects a few main things:\n", + "\n", + "* A list of candidate labels to classify with. We'll use the labels from the reference dataset here, but this could be anything that's relevant to the task and dataset at hand.\n", + "* A template to generate examples with. By default, it is `\"This sentence is {}\"`, where the `{}` will be filled by one of the candidate labels\n", + "* A sample size $N$, which will create $N$ synthetic examples per class. We find $N=8$ usually works best.\n", + "\n", + "Armed with this information, let's first extract some candidate labels from the dataset:" ] }, { @@ -54,29 +135,52 @@ "metadata": {}, "outputs": [], "source": [ - "classes = test_dataset.features[\"label\"].names\n", - "# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']" + "# Extract ClassLabel feature from \"label\" column\n", + "label_features = reference_dataset[\"train\"].features[\"label\"]\n", + "# Label names to classify with\n", + "candidate_labels = label_features.names\n", + "candidate_labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Otherwise, we could manually set the list of classes." + "```\n", + "['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']\n", + "```\n", + "\n", + "\n", + "\n", + "Some datasets on the Hugging Face Hub don't have a `ClassLabel` feature for the label column. In these cases, you should compute the candidate labels manually by first computing the id2label mapping as follows:\n", + "\n", + "" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## Synthetic dataset" + "def get_id2label(dataset):\n", + " # The column with the label names\n", + " label_names = dataset.unique(\"label_text\")\n", + " # The column with the label IDs\n", + " label_ids = dataset.unique(\"label\")\n", + " id2label = dict(zip(label_ids, label_names))\n", + " # Sort by label ID\n", + " return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}\n", + "\n", + "id2label = get_id2label(reference_dataset[\"train\"])\n", + "candidate_labels = list(id2label.values())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Then, we can use [get_templated_dataset()](https://huggingface.co/docs/setfit/main/en/reference/utility#setfit.get_templated_dataset) to synthetically generate a dummy dataset given these class names." + "Now that we have the labels, it's a simple matter to create synthetic examples:" ] }, { @@ -85,9 +189,34 @@ "metadata": {}, "outputs": [], "source": [ + "from datasets import Dataset\n", "from setfit import get_templated_dataset\n", "\n", - "train_dataset = get_templated_dataset()" + "# A dummy dataset to fill with synthetic examples\n", + "dummy_dataset = Dataset.from_dict({})\n", + "train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)\n", + "train_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 48\n", + "})\n", + "```\n", + "\n", + "\n", + "\n", + "You might find you can get better performance by tweaking the `template` argument from the default of `\"The sentence is {}\"` to variants like `\"This sentence is {}\"` or `\"This example is {}\"`.\n", + "\n", + "\n", + "\n", + "\n", + "Since our dataset has 6 classes and we chose a sample size of 8, our synthetic dataset contains $6\\times 8=48$ examples. If we take a look at a few of the examples:" ] }, { @@ -96,27 +225,37 @@ "metadata": {}, "outputs": [], "source": [ - "print(train_dataset)\n", - "# => Dataset({\n", - "# features: ['text', 'label'],\n", - "# num_rows: 48\n", - "# })\n", - "print(train_dataset[0])\n", - "# {'text': 'This sentence is sadness', 'label': 0}" + "train_dataset.shuffle()[:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "{'text': ['This sentence is love',\n", + " 'This sentence is fear',\n", + " 'This sentence is joy'],\n", + " 'label': [2, 4, 1]}\n", + "```\n", + "\n", + "We can see that each input takes the form of the template and has a corresponding label associated with it. \n", + "\n", + "Let's not train a SetFit model on these examples!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Training" + "## Fine-tuning the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can use this dataset to train a SetFit model just like normal:" + "To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the `SetFitModel.from_pretrained()` method:" ] }, { @@ -125,41 +264,109 @@ "metadata": {}, "outputs": [], "source": [ - "from setfit import SetFitModel, Trainer, TrainingArguments\n", - "\n", - "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\")\n", + "from setfit import SetFitModel\n", "\n", - "args = TrainingArguments(\n", - " batch_size=32,\n", - " num_epochs=1,\n", - ")\n", + "model = SetFitModel.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the [Trainer](https://huggingface.co/docs/setfit/main/en/reference/trainer#setfit.Trainer) class as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from setfit import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", - " args=args,\n", " train_dataset=train_dataset,\n", - " eval_dataset=test_dataset,\n", - ")\n", - "trainer.train()" + " eval_dataset=reference_dataset[\"test\"]\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```\n", + "Now that we've created a trainer, we can train it! While we're at it, let's time how long it takes to train and evaluate the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "trainer.train()\n", + "zeroshot_metrics = trainer.evaluate()\n", + "zeroshot_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "***** Running training *****\n", - " Num examples = 60\n", + " Num examples = 1920\n", " Num epochs = 1\n", - " Total optimization steps = 60\n", - " Total train batch size = 32\n", - "{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02} \n", - "{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83} \n", - "{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0} \n", - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00, 6.35it/s]\n", + " Total optimization steps = 120\n", + " Total train batch size = 16\n", + "***** Running evaluation *****\n", + "{'accuracy': 0.5345}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s\n", + "Wall time: 11 s\n", "```\n", "\n", - "Once trained, we can evaluate the model:" + "Great, now that we have a reference score let's compare against the zero-shot pipeline from 🤗 Transformers." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing against the zero-shot pipeline from 🤗 Transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "🤗 Transformers provides a zero-shot pipeline that frames text classification as a natural language inference task. Let's load the pipeline and place it on the GPU for fast inference:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import pipeline\n", + "\n", + "pipe = pipeline(\"zero-shot-classification\", device=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have the model, let's generate some predictions. We'll use the same candidate labels as we did with SetFit and increase the batch size for to speed things up:" ] }, { @@ -168,8 +375,8 @@ "metadata": {}, "outputs": [], "source": [ - "metrics = trainer.evaluate()\n", - "print(metrics)" + "%%time\n", + "zeroshot_preds = pipe(reference_dataset[\"test\"][\"text\"], batch_size=16, candidate_labels=candidate_labels)" ] }, { @@ -177,11 +384,11 @@ "metadata": {}, "source": [ "```\n", - "***** Running evaluation *****\n", - "{'accuracy': 0.591}\n", + "CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s\n", + "Wall time: 53.1 s\n", "```\n", "\n", - "And run predictions:" + "Note that this took almost 5x longer to generate predictions than SetFit! OK, so how well does it perform? Since each prediction is a dictionary of label names ranked by score:" ] }, { @@ -190,14 +397,7 @@ "metadata": {}, "outputs": [], "source": [ - "preds = model.predict([\n", - " \"i am just feeling cranky and blue\",\n", - " \"i feel incredibly lucky just to be able to talk to her\",\n", - " \"you're pissing me off right now\",\n", - " \"i definitely have thalassophobia, don't get me near water like that\",\n", - " \"i did not see that coming at all\",\n", - "])\n", - "print([classes[idx] for idx in preds])" + "zeroshot_preds[0]" ] }, { @@ -206,28 +406,55 @@ "metadata": {}, "outputs": [], "source": [ - "['sadness', 'joy', 'anger', 'fear', 'surprise']" + "{'sequence': 'im feeling rather rotten so im not very ambitious right now',\n", + " 'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],\n", + " 'scores': [0.7367985844612122,\n", + " 0.10041674226522446,\n", + " 0.09770156443119049,\n", + " 0.05880110710859299,\n", + " 0.004266355652362108,\n", + " 0.0020156768150627613]}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "These predictions all look right!" + "We can use the `str2int()` function from the `label` column to convert them to integers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = [label_features.str2int(pred[\"labels\"][0]) for pred in zeroshot_preds]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Baseline" + "**Note:** As noted earlier, if you're using a dataset that doesn't have a `ClassLabel` feature for the label column, you'll need to compute the label mapping manually with something like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "id2label = get_id2label(reference_dataset[\"train\"])\n", + "label2id = {v:k for k,v in id2label.items()}\n", + "preds = [label2id[pred[\"labels\"][0]] for pred in zeroshot_preds]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To show that the zero-shot performance of SetFit works well, we'll compare it against a zero-shot classification model from `transformers`." + "The last step is to compute accuracy using 🤗 Evaluate:" ] }, { @@ -236,24 +463,11 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import pipeline\n", - "from datasets import load_dataset\n", "import evaluate\n", "\n", - "# Prepare the testing dataset\n", - "test_dataset = load_dataset(\"dair-ai/emotion\", \"split\", split=\"test\")\n", - "classes = test_dataset.features[\"label\"].names\n", - "\n", - "# Set up the zero-shot classification pipeline from transformers\n", - "# Uses 'facebook/bart-large-mnli' by default\n", - "pipe = pipeline(\"zero-shot-classification\", device=0)\n", - "zeroshot_preds = pipe(test_dataset[\"text\"], batch_size=16, candidate_labels=classes)\n", - "preds = [classes.index(pred[\"labels\"][0]) for pred in zeroshot_preds]\n", - "\n", - "# Compute the accuracy\n", "metric = evaluate.load(\"accuracy\")\n", - "transformers_accuracy = metric.compute(predictions=preds, references=test_dataset[\"label\"])\n", - "print(transformers_accuracy)" + "transformers_metrics = metric.compute(predictions=preds, references=reference_dataset[\"test\"][\"label\"])\n", + "transformers_metrics" ] }, { @@ -269,21 +483,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by `transformers`." + "Compared to SetFit, this approach performs significantly worse. Let's wrap up our analysis by combining synthetic examples with a few labeled ones." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Prediction latency" + "## Augmenting labeled data with synthetic examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Beyond getting higher accuracies, SetFit is much faster too. Let's compute the latency of SetFit with `BAAI/bge-small-en-v1.5` versus the latency of `transformers` with `facebook/bart-large-mnli`. Both tests were performed on a GPU." + "If you have a few labeled examples, adding synthetic data can often boost performance. To simulate this, let's first sample 8 labeled examples from our reference dataset:" ] }, { @@ -292,21 +506,47 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "from setfit import sample_dataset\n", "\n", - "start_t = time.time()\n", - "pipe(test_dataset[\"text\"], batch_size=32, candidate_labels=classes)\n", - "delta_t = time.time() - start_t\n", - "print(f\"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence\")" + "train_dataset = sample_dataset(reference_dataset[\"train\"])\n", + "train_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 48\n", + "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```\n", - "`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence\n", - "```" + "To warm up, we'll train a SetFit model on these true labels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = SetFitModel.from_pretrained(model_id)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=reference_dataset[\"test\"]\n", + ")\n", + "trainer.train()\n", + "fewshot_metrics = trainer.evaluate()\n", + "fewshot_metrics" ] }, { @@ -315,12 +555,63 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "{'accuracy': 0.4705}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that for this particular dataset, the performance with true labels is _worse_ than training on synthetic examples! In our experiments, we found that the difference depends strongly on the dataset in question. Since SetFit models are fast to train, you can always try both approaches and pick the best one.\n", "\n", - "start_t = time.time()\n", - "model.predict(test_dataset[\"text\"])\n", - "delta_t = time.time() - start_t\n", - "print(f\"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence\")" + "In any case, let's now add some synthetic examples to our training set:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)\n", + "augmented_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Dataset({\n", + " features: ['text', 'label'],\n", + " num_rows: 96\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As before, we can train and evaluate SetFit with the augmented dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = SetFitModel.from_pretrained(model_id)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataset=augmented_dataset,\n", + " eval_dataset=reference_dataset[\"test\"]\n", + ")\n", + "trainer.train()\n", + "augmented_metrics = trainer.evaluate()\n", + "augmented_metrics" ] }, { @@ -328,12 +619,31 @@ "metadata": {}, "source": [ "```\n", - "SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence\n", + "{'accuracy': 0.613}\n", "```\n", "\n", - "So, SetFit with `BAAI/bge-small-en-v1.5` is 67x faster than `transformers` with `facebook/bart-large-mnli`, alongside being more accurate:\n", + "Great, this has given us a significant boost in performance and given us a few percentage points over the purely synthetic example. \n", "\n", - "![zero_shot_transformers_vs_setfit](https://github.com/huggingface/setfit/assets/37621491/33f574d9-c51b-4e02-8d98-6e04e18427ef)" + "Let's plot the final results for comparison:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame.from_dict({\"Method\":[\"Transformers (zero-shot)\", \"SetFit (zero-shot)\", \"SetFit (augmented)\"], \"Accuracy\": [transformers_metrics[\"accuracy\"], zeroshot_metrics[\"accuracy\"], augmented_metrics[\"accuracy\"]]})\n", + "df.plot(kind=\"barh\", x=\"Method\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![setfit_zero_shot_results](https://github.com/huggingface/setfit/assets/37621491/b02d3e62-d51c-4506-91f6-2fe9b7ef554d)" ] } ],