|
| 1 | +--- |
| 2 | +title: Chain-of-thought text classification |
| 3 | +nextjs: |
| 4 | + metadata: |
| 5 | + title: Chain-of-thought text classification |
| 6 | + description: Learn about chain-of-thought text classification. |
| 7 | +--- |
| 8 | + |
| 9 | +## Overview |
| 10 | + |
| 11 | +Chain-of-thought text classification is similar to zero-shot classification since it does not require any labeled data beforehand. The only difference is that, in addition to the label itself, the model generates some additional reasoning behind its choice. In some cases, such an approach might lead to much better performance, but at the cost of higher token consumption. |
| 12 | + |
| 13 | +Example using GPT-4o: |
| 14 | + |
| 15 | +```python |
| 16 | +from skllm.models.gpt.classification.zero_shot import CoTGPTClassifier |
| 17 | +from skllm.datasets import get_classification_dataset |
| 18 | + |
| 19 | +# demo sentiment analysis dataset |
| 20 | +# labels: positive, negative, neutral |
| 21 | +X, y = get_classification_dataset() |
| 22 | + |
| 23 | +clf = CoTGPTClassifier(model="gpt-4o") |
| 24 | +clf.fit(X,y) |
| 25 | +predictions = clf.predict(X) |
| 26 | +labels, reasoning = predictions[:, 0], predictions[:, 1] |
| 27 | +``` |
| 28 | + |
| 29 | +--- |
| 30 | + |
| 31 | +## API Reference |
| 32 | + |
| 33 | +The following API reference only lists the parameters needed for the initialization of the estimator. The remaining methods follow the syntax of a scikit-learn classifier. |
| 34 | + |
| 35 | +### CoTGPTClassifier |
| 36 | +```python |
| 37 | +from skllm.models.gpt.classification.zero_shot import CoTGPTClassifier |
| 38 | +``` |
| 39 | + |
| 40 | +| **Parameter** | **Type** | **Description** | |
| 41 | +| ------------- | -------- | ------------------------ | |
| 42 | +| `model` | `str` | Model to use, by default "gpt-3.5-turbo". | |
| 43 | +| `default_label` | `str` | Default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random". | |
| 44 | +| `prompt_template` | `Optional[str]` | Custom prompt template to use, by default None. | |
| 45 | +| `key` | `Optional[str]` | Estimator-specific API key; if None, retrieved from the global config, by default None. | |
| 46 | +| `org` | `Optional[str]` | Estimator-specific ORG key; if None, retrieved from the global config, by default None. | |
0 commit comments