diff --git a/.gitignore b/.gitignore index fd8c907..5eff2a9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ __pycache__/ # Ignore some data folders /data/acled/ +# course slides +slides/ + # C extensions *.so diff --git a/course-overview.pptx b/course-overview.pptx new file mode 100644 index 0000000..95e2e16 Binary files /dev/null and b/course-overview.pptx differ diff --git a/data/docs/mw-parliament/report-commitee-social-affairs.pdf b/data/docs/mw-parliament/report-commitee-social-affairs.pdf new file mode 100644 index 0000000..410cbed Binary files /dev/null and b/data/docs/mw-parliament/report-commitee-social-affairs.pdf differ diff --git a/docs/_toc.yml b/docs/_toc.yml index 84a3eed..269fe2d 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -56,7 +56,7 @@ parts: title: PhosoAI- A Food Prices Chatbot - file: notebooks/text-classification/intro-text-classification sections: - - file: notebooks/text-classification/protest-classification-gpt.ipynb + - file: notebooks/text-classification/protest-classification.ipynb - caption: Additional Resources chapters: - file: docs/additional-resources/intro-open-source-llms diff --git a/notebooks/text-classification/protest-classification-gpt.ipynb b/notebooks/text-classification/protest-classification-gpt.ipynb deleted file mode 100644 index 04698f9..0000000 --- a/notebooks/text-classification/protest-classification-gpt.ipynb +++ /dev/null @@ -1,1125 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Protest Classification Using LLM\n", - "In this notebook, the objective is to use a Language Model (LLM) to classify protests into predefined categories. The input data is sourced from ACLED. Each row represents a protest with multiple columns, with the most relevant for classification being the `notes` column, which provides a description of the protest.\n", - "\n", - "## Overview of Approach\n", - "1. **LLM Family**: For this, we utilize the OpenAI family of `GPT` models.\n", - "\n", - "2. **Design Classification Prompt and Assess Performance**: Using a manually curated training dataset with labeled protests, we experiment with various prompting strategies and evaluate performance. We also examine the impact of the number of few-shot examples used on results.\n", - "\n", - "3. **Apply Classification to the Dataset**: Once the optimal prompting strategy and number of few-shot examples are determined, we apply the classification approach to the entire dataset. This involves using the refined prompt to categorize each protest event based on its description, ensuring consistency and accuracy across all entries.\n", - "\n", - "## Limitations \n", - "1. **Cost**. It is expensive to run OpenAI models when you have many tokens. This classification task costed around $100\n", - "2. **Not utilizing all examples** In order to reduce cost and processing time, we are not utilizing all available examples to perfom the classification\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "remove-cell" - ] - }, - "outputs": [], - "source": [ - "import os\n", - "from pathlib import Path\n", - "\n", - "import pandas as pd\n", - "import geopandas as gpd\n", - "\n", - "from datetime import datetime\n", - "\n", - "import bokeh\n", - "from bokeh.models import Tabs, TabPanel\n", - "from bokeh.core.validation.warnings import EMPTY_LAYOUT, MISSING_RENDERERS\n", - "from bokeh.plotting import show, output_notebook\n", - "\n", - "from langchain_openai import OpenAI, ChatOpenAI\n", - "from langchain.prompts import PromptTemplate\n", - "from langchain.chains import LLMChain\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"langchain\")\n", - "\n", - "\n", - "from dotenv import load_dotenv\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Global variables \n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [ - "remove-cell" - ] - }, - "outputs": [], - "source": [ - "# ==================\n", - "# SETUP INPUT\n", - "# ==================\n", - "DIR_DATA = Path.cwd().parents[1].joinpath(\"data\", \"conflict\")\n", - "FILE_PROTESTS = DIR_DATA.joinpath(\"protests_iran_20160101_20241009.csv\")\n", - "FILE_PROTESTS_CLASSIFIED = DIR_DATA.joinpath(\"protests_sample_acled_iran.csv\")\n", - "FILE_PROTESTS_CLASSES = DIR_DATA.joinpath(\"protest_classification.csv\")\n", - "FILE_PROTESTS_TRAINING = DIR_DATA.joinpath(\"protests-labeled-sample-training.csv\")\n", - "\n", - "# ==================\n", - "# CLASSIFICATION \n", - "# ==================\n", - "PROP_TRAIN = 0.4\n", - "NUM_EXAMPLES = 10\n", - "SAMPLE_PROP = 0.5\n", - "OPENAI_MODEL = \"gpt-3.5-turbo\"\n", - "\n", - "# For testing, classify only a portion of the documents\n", - "SAMPLE_SIZE = 0.3" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocess Data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "df = pd.read_csv(FILE_PROTESTS_TRAINING)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "category\n", - "Livelihood (Prices, jobs and salaries) 64\n", - "Political/Security 56\n", - "Business and legal 42\n", - "Social 26\n", - "Public service delivery 25\n", - "Climate and environment 11\n", - "Name: count, dtype: int64" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.category.value_counts()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "ename": "FileNotFoundError", - "evalue": "[Errno 2] No such file or directory: '/Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protest_classification.csv'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m df_prot \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(FILE_PROTESTS, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m)\n\u001b[1;32m 2\u001b[0m df_prot_labels \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(FILE_PROTESTS_CLASSIFIED)\n\u001b[0;32m----> 3\u001b[0m df_tmp2 \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mFILE_PROTESTS_CLASSES\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/.venv/lib/python3.12/site-packages/pandas/io/parsers/readers.py:1026\u001b[0m, in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[0m\n\u001b[1;32m 1013\u001b[0m kwds_defaults \u001b[38;5;241m=\u001b[39m _refine_defaults_read(\n\u001b[1;32m 1014\u001b[0m dialect,\n\u001b[1;32m 1015\u001b[0m delimiter,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1022\u001b[0m dtype_backend\u001b[38;5;241m=\u001b[39mdtype_backend,\n\u001b[1;32m 1023\u001b[0m )\n\u001b[1;32m 1024\u001b[0m kwds\u001b[38;5;241m.\u001b[39mupdate(kwds_defaults)\n\u001b[0;32m-> 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/.venv/lib/python3.12/site-packages/pandas/io/parsers/readers.py:620\u001b[0m, in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 617\u001b[0m _validate_names(kwds\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m 619\u001b[0m \u001b[38;5;66;03m# Create the parser.\u001b[39;00m\n\u001b[0;32m--> 620\u001b[0m parser \u001b[38;5;241m=\u001b[39m \u001b[43mTextFileReader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m chunksize \u001b[38;5;129;01mor\u001b[39;00m iterator:\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parser\n", - "File \u001b[0;32m~/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/.venv/lib/python3.12/site-packages/pandas/io/parsers/readers.py:1620\u001b[0m, in \u001b[0;36mTextFileReader.__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 1617\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptions[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwds[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 1619\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles: IOHandles \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1620\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_engine \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_engine\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mengine\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/.venv/lib/python3.12/site-packages/pandas/io/parsers/readers.py:1880\u001b[0m, in \u001b[0;36mTextFileReader._make_engine\u001b[0;34m(self, f, engine)\u001b[0m\n\u001b[1;32m 1878\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[1;32m 1879\u001b[0m mode \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m-> 1880\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;241m=\u001b[39m \u001b[43mget_handle\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mcompression\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcompression\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43mmemory_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmemory_map\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_text\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding_errors\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstrict\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstorage_options\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1889\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1890\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1891\u001b[0m f \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles\u001b[38;5;241m.\u001b[39mhandle\n", - "File \u001b[0;32m~/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/.venv/lib/python3.12/site-packages/pandas/io/common.py:873\u001b[0m, in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 868\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(handle, \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 869\u001b[0m \u001b[38;5;66;03m# Check whether the filename is to be opened in binary mode.\u001b[39;00m\n\u001b[1;32m 870\u001b[0m \u001b[38;5;66;03m# Binary mode does not support 'encoding' and 'newline'.\u001b[39;00m\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mencoding \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mmode:\n\u001b[1;32m 872\u001b[0m \u001b[38;5;66;03m# Encoding\u001b[39;00m\n\u001b[0;32m--> 873\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 874\u001b[0m \u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 875\u001b[0m \u001b[43m \u001b[49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 876\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 877\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 878\u001b[0m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 879\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 880\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 881\u001b[0m \u001b[38;5;66;03m# Binary mode\u001b[39;00m\n\u001b[1;32m 882\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mopen\u001b[39m(handle, ioargs\u001b[38;5;241m.\u001b[39mmode)\n", - "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protest_classification.csv'" - ] - } - ], - "source": [ - "df_prot = pd.read_csv(FILE_PROTESTS, dtype=str)\n", - "df_prot_labels = pd.read_csv(FILE_PROTESTS_CLASSIFIED)\n", - "df_tmp2 = pd.read_csv(FILE_PROTESTS_CLASSES)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "description_code = dict(df_tmp2[[\"code\", \"description\"]].values)\n", - "category_code = dict(df_tmp2[[\"code\", \"major_category\"]].values)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "df_prot_labels.rename(columns={'Classification code': \"code\"}, inplace=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ - "df_prot_labels['description'] = df_prot_labels.code.map(description_code)\n", - "df_prot_labels['category'] = df_prot_labels.code.map(category_code)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "def pretty_print_value_counts(\n", - " df, column, title=None, line_length=None, top_n=None, table_number=None\n", - "):\n", - " \"\"\"\n", - " Pretty prints the value counts of a specified column in a Pandas DataFrame,\n", - " with counts formatted with thousand separators, percentages, and cumulative percentages.\n", - "\n", - " Parameters:\n", - " -----------\n", - " df : pandas.DataFrame\n", - " The DataFrame containing the data.\n", - " column : str\n", - " The name of the column for which to calculate value counts.\n", - " title : str, optional\n", - " A title to print above the formatted output. If None, no title is printed.\n", - " line_length : int, optional\n", - " The length of the separator line. If None, it will be determined based on\n", - " the length of the title or default to 50 if no title is provided.\n", - " top_n : int, optional\n", - " The number of top categories to display. If None, all categories are displayed.\n", - " table_number : int, optional\n", - " The numeric value for the table number. If provided, the table number will be displayed as 'Table-X'.\n", - "\n", - " Returns:\n", - " --------\n", - " None\n", - " Displays a styled DataFrame with counts, percentages, and cumulative percentages.\n", - " \"\"\"\n", - " # Calculate the value counts and convert to DataFrame\n", - " count_df = pd.DataFrame(df[column].value_counts(normalize=False).reset_index())\n", - " count_df.columns = [\"Category\", \"Count\"]\n", - "\n", - " # Add a percentage column\n", - " count_df[\"Percent\"] = (count_df[\"Count\"] / count_df[\"Count\"].sum()) * 100\n", - "\n", - " # Add a cumulative percentage column\n", - " count_df[\"Cum. Percent\"] = count_df[\"Percent\"].cumsum()\n", - "\n", - " # Limit the output to top_n categories if specified\n", - " if top_n:\n", - " count_df = count_df.head(top_n)\n", - "\n", - " # Print the table number if provided\n", - " if table_number is not None:\n", - " print(f\"Table-{table_number}\")\n", - "\n", - " # Determine the length of the line if line_length is not provided\n", - " if title:\n", - " if line_length is None:\n", - " line_length = max(\n", - " 50, len(title) + 4\n", - " ) # Ensure at least 50 characters, or more based on the title\n", - "\n", - " # Calculate padding to center the title\n", - " total_padding = line_length - len(title)\n", - " left_padding = total_padding // 2\n", - " right_padding = total_padding - left_padding\n", - "\n", - " # Print the centered title with the \"=\" line\n", - " print(\"=\" * line_length)\n", - " print(\" \" * left_padding + title + \" \" * right_padding)\n", - " print(\"=\" * line_length)\n", - "\n", - " # Display the styled DataFrame without index, formatting Count, Percent, and Cumulative Percent columns\n", - " display(\n", - " count_df.style.hide(axis=\"index\").format(\n", - " {\n", - " \"Count\": \"{:,.0f}\", # Thousand separator for Count\n", - " \"Percent\": \"{:.2f}%\", # Format Percent to 2 decimal places with a % symbol\n", - " \"Cum. Percent\": \"{:.2f}%\", # Format Cumulative Percent to 2 decimal places with a % symbol\n", - " }\n", - " )\n", - " )\n", - "\n", - " # Print footer or separator\n", - " print(\"-\" * line_length)" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "============================================================\n", - " Distribution of Protest Classes \n", - "============================================================\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)6428.57%28.57%
Political/Security5625.00%53.57%
Business and legal4218.75%72.32%
Social2611.61%83.93%
Public service delivery2511.16%95.09%
Climate and environment114.91%100.00%
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "------------------------------------------------------------\n" - ] - } - ], - "source": [ - "pretty_print_value_counts(df_prot_labels, \"category\", \n", - "title=\"Distribution of Protest Classes\", line_length=60)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "df_prot_labels.to_csv(DIR_DATA.joinpath(\"protest-with-labels.csv\"), \n", - "index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "## Check Zero Shot Classification Accuracy" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Check Classification Accuracy " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def gpt3_5_classification(dataset, categories):\n", - " predictions = []\n", - " for note in dataset['notes']:\n", - " # Format the prompt with the note and available categories\n", - " prompt = prompt_template.format(categories=\", \".join(categories), note=note)\n", - " response = llm(prompt)\n", - " predictions.append(response.strip())\n", - " \n", - " dataset['gpt3_5_classification'] = predictions\n", - " \n", - " # Calculate accuracy\n", - " accuracy = accuracy_score(dataset['category'], dataset['gpt3_5_classification'])\n", - " return accuracy" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Split the examples into training and test" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/k5/p4nvl2pj4gq2ks0j7qvrhwbh0000gp/T/ipykernel_96019/4000967764.py:4: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n", - " train_df = df_prot_labels.groupby('category', group_keys=False).apply(lambda x: x.sample(frac=PROP_TRAIN,\n" - ] - }, - { - "data": { - "text/plain": [ - "((89, 6), (135, 6))" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Let's perform a stratified split, ensuring that 80% of the examples from each category are used for training\n", - "\n", - "# Split the data while maintaining the distribution of categories\n", - "train_df = df_prot_labels.groupby('category', group_keys=False).apply(lambda x: x.sample(frac=PROP_TRAIN, \n", - "random_state=42))\n", - "test_df = df_prot_labels.drop(train_df.index)\n", - "\n", - "(train_df.shape, test_df.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "============================================================\n", - " Train-Distribution of Protest Classes \n", - "============================================================\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)2629.21%29.21%
Political/Security2224.72%53.93%
Business and legal1719.10%73.03%
Public service delivery1011.24%84.27%
Social1011.24%95.51%
Climate and environment44.49%100.00%
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "------------------------------------------------------------\n" - ] - } - ], - "source": [ - "pretty_print_value_counts(train_df, \"category\", \n", - "title=\"Train-Distribution of Protest Classes\", line_length=60)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "============================================================\n", - " Test-Distribution of Protest Classes \n", - "============================================================\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)3828.15%28.15%
Political/Security3425.19%53.33%
Business and legal2518.52%71.85%
Social1611.85%83.70%
Public service delivery1511.11%94.81%
Climate and environment75.19%100.00%
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "------------------------------------------------------------\n" - ] - } - ], - "source": [ - "pretty_print_value_counts(test_df, \"category\", \n", - "title=\"Test-Distribution of Protest Classes\", line_length=60)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "# Sample up to 5 examples from each category in the training set\n", - "examples = []\n", - "for category in train_df['category'].unique():\n", - " # Get all samples if fewer than 5 exist, otherwise take 5\n", - " category_samples = train_df[train_df['category'] == category].sample(\n", - " min(NUM_EXAMPLES, len(train_df[train_df['category'] == category])), random_state=42\n", - " )\n", - " examples.extend(category_samples[['notes', 'description', 'category']].values.tolist())" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 3: Define a prompt template for classification, adding detailed examples\n", - "classification_prompt_template = \"\"\"\n", - "You are a highly intelligent assistant. Your task is to classify each document into one of the following categories:\n", - "- Political/Security\n", - "- Livelihood (Prices, jobs and salaries)\n", - "- Public service delivery\n", - "- Business and legal\n", - "- Climate and environment\n", - "- Social\n", - "\n", - "Each category has a description that helps explain its purpose.\n", - "\n", - "Here are some examples:\n", - "\n", - "{examples}\n", - "\n", - "Given the following document:\n", - "{document}\n", - "\n", - "Based on the descriptions and the examples, which category does this document fall into? Please respond with one of the categories listed above.\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 4: Format the examples for the prompt, including notes, descriptions, and categories\n", - "formatted_examples = \"\\n\".join(\n", - " [\n", - " f\"Example {i+1}:\\nNotes: {example[0]}\\nDescription: {example[1]}\\nCategory: {example[2]}\"\n", - " for i, example in enumerate(examples)\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Step 2: Initialize the LLM (OpenAI in this case)\n", - "llm = ChatOpenAI(model=OPENAI_MODEL, temperature=0.7)\n", - "\n", - "# Step 5: Create a prompt template using Langchain\n", - "classification_prompt = PromptTemplate(\n", - " input_variables=[\"document\", \"examples\"],\n", - " template=classification_prompt_template\n", - ")\n", - "\n", - "# Step 6: Create a Langchain with the LLM and the classification prompt template\n", - "classification_chain = LLMChain(\n", - " llm=llm,\n", - " prompt=classification_prompt\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "19.23076923076923" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "50000/2600" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "=======================================================\n", - " Accuracy using model: gpt-3.5-turbo with 10 Examples\n", - "=======================================================\n", - "Accuracy: 87.41%\n", - "-------------------------------------------------------\n" - ] - } - ], - "source": [ - "# Test with all the 'notes' column from test_df\n", - "documents = test_df['notes'].tolist()\n", - "\n", - "# Classify each document and store results\n", - "classifications = []\n", - "for doc in documents:\n", - " result = classification_chain.run(document=doc, examples=formatted_examples)\n", - " # Clean up the result by removing the \"Category: \" prefix if it exists\n", - " cleaned_result = result.strip().replace(\"Category: \", \"\")\n", - " classifications.append(cleaned_result) # Store the cleaned classification\n", - "\n", - "# Add the cleaned classification results to the DataFrame as a new column\n", - "test_df['classification'] = classifications\n", - "\n", - "\n", - "# Calculate accuracy\n", - "correct_predictions = (test_df['classification'] == test_df['category']).sum()\n", - "total_predictions = len(test_df)\n", - "accuracy = correct_predictions / total_predictions\n", - "\n", - "# Print the accuracy\n", - "print(\"=\"*55)\n", - "print(f\" Accuracy using model: {OPENAI_MODEL} with {NUM_EXAMPLES} Examples\")\n", - "print(\"=\"*55)\n", - "print(f\"Accuracy: {accuracy * 100:.2f}%\")\n", - "print(\"-\"*55)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Classify All Documents " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# df_prot_sample = df_prot.sample(frac=SAMPLE_SIZE)\n", - "# for index, row in df_prot_sample.iterrows():\n", - "# try:\n", - "# # Run classification for the \"notes\" column of the current row\n", - "# result = classification_chain.run(document=row['notes'], examples=formatted_examples)\n", - "# # Clean up the result by removing the \"Category: \" prefix if it exists\n", - "# cleaned_result = result.strip().replace(\"Category: \", \"\")\n", - "# # Store the cleaned classification in the \"classification\" column of the current row\n", - "# df_prot_sample.at[index, 'classification'] = cleaned_result\n", - "# except Exception as e:\n", - "# print(f\"Error processing row {index}: {e}\")\n", - "# df_prot_sample.at[index, 'classification'] = None # Optionally, mark as None if there was an error\n" - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "metadata": {}, - "outputs": [], - "source": [ - "def classify_with_retry(document, examples, max_retries=1):\n", - " \"\"\"\n", - " Classifies the document, retrying if the initial classification is not in the specified categories.\n", - " \n", - " Parameters:\n", - " - document (str): The document to classify.\n", - " - examples (list): Examples to use for classification.\n", - " - max_retries (int): Number of retries allowed if classification is not in specified categories.\n", - "\n", - " Returns:\n", - " - str: The final classification or 'Failed2Classify' if classification is unsuccessful.\n", - " \"\"\"\n", - " for _ in range(max_retries + 1):\n", - " result = classification_chain.run(document=document, examples=examples)\n", - " cleaned_result = result.strip().replace(\"Category: \", \"\")\n", - " \n", - " if cleaned_result in categories:\n", - " return cleaned_result # Return if classification is in the categories\n", - " \n", - " # If classification failed after retries, label as \"Failed2Classify\"\n", - " return \"Failed2Classify\"" - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "metadata": {}, - "outputs": [], - "source": [ - "# Classify and ignore rows already classified in df_pro_sample\n", - "# Also, classify when value in df_prot_sample == \"Failed2Classify\"\n", - "for index, row in df_prot.iterrows():\n", - " # Check if this 'note' has already been classified in df_prot_sample\n", - " sample_classification = df_prot_sample.loc[\n", - " (df_prot_sample['event_date'] == row['event_date']) &\n", - " (df_prot_sample['source'] == row['source']) &\n", - " (df_prot_sample['admin1'] == row['admin1']) &\n", - " (df_prot_sample['admin2'] == row['admin2']) &\n", - " (df_prot_sample['admin3'] == row['admin3']) &\n", - " (df_prot_sample['notes'] == row['notes'])\n", - " ]['classification']\n", - "\n", - " if sample_classification.notnull().any():\n", - " # Only reclassify if the current classification is \"Failed2Classify\"\n", - " if \"Failed2Classify\" in sample_classification.values:\n", - " # Run classification with retry mechanism\n", - " try:\n", - " classification = classify_with_retry(document=row['notes'], examples=formatted_examples)\n", - " df_prot.at[index, 'classification'] = classification\n", - " except Exception as e:\n", - " print(f\"Error processing row {index}: {e}\")\n", - " df_prot.at[index, 'classification'] = None # Optionally, mark as None if there was an error\n", - " else:\n", - " continue # Skip if already classified\n", - " else:\n", - " # Run classification if not in df_prot_sample\n", - " try:\n", - " classification = classify_with_retry(document=row['notes'], examples=formatted_examples)\n", - " df_prot.at[index, 'classification'] = classification\n", - " except Exception as e:\n", - " print(f\"Error processing row {index}: {e}\")\n", - " df_prot.at[index, 'classification'] = None\n" - ] - }, - { - "cell_type": "code", - "execution_count": 77, - "metadata": {}, - "outputs": [], - "source": [ - "# Merge the dataframes on all columns except \"classification\"\n", - "merged_df = df_prot.merge(\n", - " df_prot_sample,\n", - " on=['event_date', 'source', 'admin1', 'admin2', 'admin3', 'event_type', 'sub_event_type', \n", - " 'interaction', 'fatalities', 'latitude', 'longitude', 'actor1', 'actor2', 'notes'],\n", - " how='left',\n", - " suffixes=('', '_sample')\n", - ")\n", - "\n", - "# Update the \"classification\" column to prioritize non-\"Failed2Classify\" values from df_prot\n", - "merged_df['classification'] = merged_df.apply(\n", - " lambda row: row['classification']\n", - " if pd.notnull(row['classification'])\n", - " else (row['classification_sample'] if row['classification_sample'] != \"Failed2Classify\" else \"Failed2Classify\"),\n", - " axis=1\n", - ")\n", - "\n", - "# Drop the extra \"classification_sample\" column\n", - "merged_df.drop(columns=['classification_sample'], inplace=True)\n", - "\n", - "# Drop duplicates based on all columns except \"classification\"\n", - "merged_df.drop_duplicates(\n", - " subset=['event_date', 'source', 'admin1', 'admin2', 'admin3', 'event_type', 'sub_event_type', \n", - " 'interaction', 'fatalities', 'latitude', 'longitude', 'actor1', 'actor2', 'notes'],\n", - " inplace=True\n", - ")\n", - "\n", - "# Reset index for a clean merged dataframe\n", - "merged_df.reset_index(drop=True, inplace=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 79, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "merged_df.shape[0] == df_prot.shape[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "classification\n", - "Livelihood (Prices, jobs and salaries) 13459\n", - "Business and legal 3369\n", - "Social 3119\n", - "Political/Security 3109\n", - "Public service delivery 984\n", - "Climate and environment 915\n", - "Name: count, dtype: int64" - ] - }, - "execution_count": 78, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "merged_df.classification.value_counts(dropna=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Preprocess Results\n" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [], - "source": [ - "# ===============================================\n", - "# REMOVE CATEGORIES NOT IN THE AVAILABLE CLASSES\n", - "# ===============================================\n", - "categories = list(category_code.values())\n", - "df_prot_sample['classification'] = df_prot_sample['classification'].apply(lambda x: x if x in categories else \"Failed2Classify\")" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "metadata": {}, - "outputs": [], - "source": [ - "merged_df.drop(columns=[\"Unnamed: 0\"], inplace=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "metadata": {}, - "outputs": [], - "source": [ - "merged_df.to_csv(DIR_DATA.joinpath(\"protests-labeled-gpt.csv\"), index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "============================================================\n", - " Distribution of Labeled Protests \n", - "============================================================\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)13,45953.93%53.93%
Business and legal3,36913.50%67.43%
Social3,11912.50%79.93%
Political/Security3,10912.46%92.39%
Public service delivery9843.94%96.33%
Climate and environment9153.67%100.00%
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "------------------------------------------------------------\n" - ] - } - ], - "source": [ - "pretty_print_value_counts(merged_df, \"classification\", \n", - "\"Distribution of Labeled Protests\",line_length=60 )" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "df = pd.read_csv(DIR_DATA.joinpath(\"protests-labeled-all-gpt.csv\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/text-classification/protest-classification.ipynb b/notebooks/text-classification/protest-classification.ipynb new file mode 100644 index 0000000..b18a454 --- /dev/null +++ b/notebooks/text-classification/protest-classification.ipynb @@ -0,0 +1,5130 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Protest Classification Using LLM\n", + "\n", + "## Objectives\n", + "In this notebook, the objective is to use a Language Model (LLM) to classify each protest event in an ACLED dataset into predefined categories. Given this dataset which contains a ```note``` attribute providing detailed description about each protest event, leverage an LLM and example human labeled protest event to classify all protest evemts in the data so that we end up with dataset which will now have a ```predicted_category``` label. \n", + "## Dataset\n", + "In this task, we are using the same ACLED data being used for the rest of the conflict analysis. Howver, in addition, we are using an annotated dataset which provides classification of each protest event in the ACLED dataset. In this anotated dataset, whch we will refer to as ```training``` data, a team of experts working on Iran painstakingly labelled 224 instances of protest with two attributes as follows:\n", + "**category**. Each protest is classified into 6 categories. \n", + "\n", + "> **Categories**. 'Livelihood (Prices, jobs and salaries)',\n", + " 'Political/Security',\n", + " 'Business and legal',\n", + " 'Social',\n", + " 'Public service delivery',\n", + " 'Climate and environment'\n", + " \n", + " > **Description** This attribute provides a template defining core characteristics of each protest.\n", + "\n", + "## Overview of Approach\n", + "![GPT-4-turbo Perfomance with Few-shot Examples](../../docs/images/protest-classification/methodology.png)\n", + "The Figure above provide an overview of the main steps we followed in this task.\n", + "1. **Data Annotation**. In order to use an LLM to perform this kind of classification accurately, the first step was to generate a training data by annotating a few hundred protest examples. \n", + "\n", + "2. **Zero-shot model evaluation and selection**: To get a sense whether out of the box LLMs can accurately perfom this task as well as to pick which model to use, we tested a variety of open source and commercial LLMs to perform the classification task without any examples at all and found that OpenAI family of models gave superior performance. We therefore ended up using OpenAI models. \n", + "\n", + "3. **Assess few-shot classification and optimize LLM**. At this stage, we profiled the perfomance of the selected OpenAI model to see how well it performs in a few-shot setting and also determine optimal number of parameter values such as number of examples to use and proper prompting strategy.\n", + "\n", + "4. **Classify the entire dataset** Apply classification to the entire dataset using settings determined above.\n", + "\n", + "4. **Conduct quality assurance** The final stage is to perform various sanity checks to make sure the classifications generated by the LLM are sensible, consistent with what happened on the ground.\n", + "\n", + "More details about each of thes steps are provided in the respective section." + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "# ======================\n", + "# BASE MODULES\n", + "# ======================\n", + "import os\n", + "import time\n", + "from pathlib import Path\n", + "from datetime import datetime\n", + "\n", + "# =======================\n", + "# ENVIRONMENT HANDLING\n", + "# =======================\n", + "from dotenv import load_dotenv\n", + "load_dotenv()\n", + "\n", + "# ======================\n", + "# DATA HANDLING MODULES\n", + "# ======================\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "# ======================\n", + "# LLMs \n", + "# ======================\n", + "from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments\n", + "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "from langchain_openai import OpenAI, ChatOpenAI\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.chains import LLMChain\n", + "from langchain.embeddings import OpenAIEmbeddings\n", + "from langchain.vectorstores import Chroma\n", + "from langchain.schema import Document\n", + "from langchain.evaluation.embedding_distance import EmbeddingDistance\n", + "from langchain.evaluation import load_evaluator\n", + "from langchain.evaluation import load_evaluator\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"langchain\")\n", + "\n", + "# ===================================\n", + "# CLASSIFICATION METRICS FROM SKLEARN\n", + "# ===================================\n", + "from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix\n", + "from sklearn.metrics.pairwise import cosine_similarity\n", + "from scipy.spatial.distance import cdist\n", + "\n", + "# ==================\n", + "# PLOT\n", + "# =================\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "# =================\n", + "# OTHER UTILS\n", + "# =================\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Global variables \n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "# ==================\n", + "# SETUP INPUT\n", + "# ==================\n", + "DIR_DATA = Path.cwd().parents[1].joinpath(\"data\", \"conflict\")\n", + "FILE_PROTESTS = DIR_DATA.joinpath(\"protests_iran_20160101_20241009.csv\")\n", + "FILE_PROTESTS_CLASSIFIED = DIR_DATA.joinpath(\"protests_sample_acled_iran.csv\")\n", + "FILE_PROTESTS_CLASSES = DIR_DATA.joinpath(\"protest_classification.csv\")\n", + "FILE_PROTESTS_LABELS = DIR_DATA.joinpath(\"protests-labeled-sample-training.csv\")\n", + "\n", + "# ==================\n", + "# CLASSIFICATION \n", + "# ==================\n", + "PROP_TRAIN = 0.5\n", + "NUM_EXAMPLES = 10\n", + "SAMPLE_PROP = 0.5\n", + "OPENAI_MODEL = \"gpt-4-turbo\"\n", + "\n", + "# This will be updated with all categories\n", + "CATEGORIES = None\n", + "\n", + "# For testing, classify only a portion of the documents\n", + "SAMPLE_SIZE = 0.3\n", + "\n", + "# whether to combine description with notes column\n", + "COMBINE_NOTES_DESCRIPTION = False\n", + "\n", + "# ==================\n", + "# API KEYS\n", + "# ==================\n", + "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "source": [ + "## Define Utility Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "def pretty_print_value_counts(\n", + " df, column, title=None, line_length=None, top_n=None, table_number=None\n", + "):\n", + " \"\"\"\n", + " Pretty prints the value counts of a specified column in a Pandas DataFrame,\n", + " with counts formatted with thousand separators, percentages, and cumulative percentages.\n", + "\n", + " Parameters:\n", + " -----------\n", + " df : pandas.DataFrame\n", + " The DataFrame containing the data.\n", + " column : str\n", + " The name of the column for which to calculate value counts.\n", + " title : str, optional\n", + " A title to print above the formatted output. If None, no title is printed.\n", + " line_length : int, optional\n", + " The length of the separator line. If None, it will be determined based on\n", + " the length of the title or default to 50 if no title is provided.\n", + " top_n : int, optional\n", + " The number of top categories to display. If None, all categories are displayed.\n", + " table_number : int, optional\n", + " The numeric value for the table number. If provided, the table number will be displayed as 'Table-X'.\n", + "\n", + " Returns:\n", + " --------\n", + " None\n", + " Displays a styled DataFrame with counts, percentages, and cumulative percentages.\n", + " \"\"\"\n", + " # Calculate the value counts and convert to DataFrame\n", + " count_df = pd.DataFrame(df[column].value_counts(normalize=False).reset_index())\n", + " count_df.columns = [\"Category\", \"Count\"]\n", + "\n", + " # Add a percentage column\n", + " count_df[\"Percent\"] = (count_df[\"Count\"] / count_df[\"Count\"].sum()) * 100\n", + "\n", + " # Add a cumulative percentage column\n", + " count_df[\"Cum. Percent\"] = count_df[\"Percent\"].cumsum()\n", + "\n", + " # Limit the output to top_n categories if specified\n", + " if top_n:\n", + " count_df = count_df.head(top_n)\n", + "\n", + " # Print the table number if provided\n", + " if table_number is not None:\n", + " print(f\"Table-{table_number}\")\n", + "\n", + " # Determine the length of the line if line_length is not provided\n", + " if title:\n", + " if line_length is None:\n", + " line_length = max(\n", + " 50, len(title) + 4\n", + " ) # Ensure at least 50 characters, or more based on the title\n", + "\n", + " # Calculate padding to center the title\n", + " total_padding = line_length - len(title)\n", + " left_padding = total_padding // 2\n", + " right_padding = total_padding - left_padding\n", + "\n", + " # Print the centered title with the \"=\" line\n", + " print(\"=\" * line_length)\n", + " print(\" \" * left_padding + title + \" \" * right_padding)\n", + " print(\"=\" * line_length)\n", + "\n", + " # Display the styled DataFrame without index, formatting Count, Percent, and Cumulative Percent columns\n", + " display(\n", + " count_df.style.hide(axis=\"index\").format(\n", + " {\n", + " \"Count\": \"{:,.0f}\", # Thousand separator for Count\n", + " \"Percent\": \"{:.2f}%\", # Format Percent to 2 decimal places with a % symbol\n", + " \"Cum. Percent\": \"{:.2f}%\", # Format Cumulative Percent to 2 decimal places with a % symbol\n", + " }\n", + " )\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Preprocess Data and Explore the Labeled Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "if FILE_PROTESTS_LABELS.exists():\n", + " df_prot = pd.read_csv(FILE_PROTESTS_LABELS)\n", + " df_prot.dropna(subset=['notes', 'description', 'category'], inplace=True)\n", + " \n", + "else:\n", + " df_prot = pd.read_csv(FILE_PROTESTS_CLASSIFIED)\n", + "\n", + " description_code = dict(df_prot[[\"code\", \"description\"]].values)\n", + " category_code = dict(df_prot[[\"code\", \"major_category\"]].values)\n", + "\n", + " df_prot.rename(columns={'Classification code': \"code\"}, inplace=True)\n", + " df_prot['description'] = df_prot.code.map(description_code)\n", + " df_prot['category'] = df_prot.code.map(category_code)\n", + " df_prot.dropna(subset=['notes', 'description', 'category'], inplace=True)\n", + "\n", + " df_prot.to_csv(DIR_DATA.joinpath(\"protest-with-labels.csv\"), \n", + "index=False)\n", + "\n", + "# Update CATEGORIES variable\n", + "CATEGORIES = list(df_prot.category.unique())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.1 Explore the Labeled Protest Data\n", + "In machine learning (ML) and working with LLMs, the labeled data containing examples of the six protest categories is referred to as `training data`. This data is used to train or guide the ML/LLM model in distinguishing between the different categories.\n", + "\n", + "Our dataset includes a total of 224 examples. Understanding the distribution of this training data is crucial, as it directly impacts the performance of the ML model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Distribution of Protest Classes \n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)6428.57%28.57%
Political/Security5625.00%53.57%
Business and legal4218.75%72.32%
Social2611.61%83.93%
Public service delivery2511.16%95.09%
Climate and environment114.91%100.00%
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pretty_print_value_counts(df_prot, \"category\", \n", + "title=\"Distribution of Protest Classes\", line_length=60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Zero-shot Classification Performance \n", + "\n", + "### What is zero-shot classification\n", + "To select the best LLM or BERT-based model for our task, we first assess each model’s ability to classify text “out of the box,” without fine-tuning or examples. This zero-shot classification evaluation helps us gauge whether a model has the inherent capability to accurately categorize text in our dataset. Strong zero-shot performance indicates a model’s suitability and potential for further fine-tuning, allowing us to confidently narrow down our options before investing in training.\n", + "\n", + "### Measuring model performance \n", + "We will use macro-precision, macro-recall and accuracy as our metrics for measuring model perfomance. \n", + "\n", + "1. **Accuracy**: \n", + " - **Definition**: Accuracy is the percentage of correct predictions made by the model out of all predictions.\n", + " - **Formula**: \n", + " $$\n", + " \\text{Accuracy} = \\frac{\\text{Number of Correct Predictions}}{\\text{Total Predictions}} \\times 100\n", + " $$\n", + " - **Explanation**: If a model’s accuracy is 90%, it means the model correctly classified 90 out of every 100 items in the dataset. Accuracy is a helpful indicator of overall performance but might not give the full picture when dealing with multiple categories or unbalanced data (where some categories are much larger than others).\n", + "\n", + "2. **Precision**: \n", + " - **Definition**: Precision is the measure of how often the model’s predictions for a certain class are correct out of all the times it predicted that class.\n", + " - **Formula**: \n", + " $$\n", + " \\text{Precision} = \\frac{\\text{True Positives}}{\\text{True Positives} + \\text{False Positives}}\n", + " $$\n", + " - **Explanation**: Think of precision as answering the question, \"When the model says something is in a particular category, how often is it right?\" High precision means the model doesn’t make many false claims for a class. For example, if precision for the \"positive\" category is 80%, then out of all items the model labeled \"positive,\" 80% were truly positive. Precision is especially useful when we want to avoid falsely predicting a category.\n", + "\n", + "3. **Recall**:\n", + " - **Definition**: Recall is the measure of how well the model captures all items of a certain class out of the actual occurrences of that class in the dataset.\n", + " - **Formula**: \n", + " $$\n", + " \\text{Recall} = \\frac{\\text{True Positives}}{\\text{True Positives} + \\text{False Negatives}}\n", + " $$\n", + " - **Explanation**: Recall answers the question, \"How many of the actual items in a class did the model successfully identify?\" High recall means the model is good at finding most instances of a particular category. For example, if recall for \"positive\" is 85%, the model correctly labeled 85% of all truly positive items as \"positive.\" Recall is particularly important when we want to make sure we capture as many true items in a category as possible, even if we occasionally include incorrect ones. \n", + "\n", + "When these metrics are averaged across multiple categories (like the six classes you have), they are referred to as **macro precision** and **macro recall**, ensuring all categories are given equal importance in the evaluation.\n", + "\n", + "### Comparing models\n", + "We begin by comparing BERT, OpenAI's GPT, and Llama3 models for text classification to assess each model’s strengths and find the best option for our needs. BERT represents a unique approach, focusing specifically on text classification, so we want to see if it achieves higher accuracy compared to the broader language models. Testing Llama3, an open-source model, lets us explore whether it can match or exceed the performance of proprietary models like GPT, providing us with a potentially powerful, cost-effective option.\n", + "\n", + "For OpenAI's GPT, we try to versions, ```GPT-4o``` and ```GPT-4o-mini```. We are trying the older ```GPT-4-Turbo``` because it may offer similar perfomance while being cheaper and faster to run." + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "def zero_shot_evaluate_model(model_name, input_dataset, label2id, id2label):\n", + " \"\"\"\n", + " Evaluates an open-source model (e.g., BERT, LLAMA3) on a zero-shot classification task.\n", + "\n", + " This function loads a specified model and tokenizer, tokenizes the input text from a dataset,\n", + " and performs zero-shot classification without additional training. It then calculates the model's\n", + " classification accuracy, macro-averaged precision, and macro-averaged recall by comparing the predicted \n", + " labels to the true labels.\n", + "\n", + " Parameters:\n", + " model_name (str): The name of the pre-trained model to load (e.g., \"bert-base-uncased\" or \"llama3\").\n", + " input_dataset (pd.DataFrame): A DataFrame containing the data for classification, with columns 'notes' \n", + " (text data) and 'category' (true labels).\n", + " label2id (dict): A dictionary mapping each label to a unique integer ID.\n", + " id2label (dict): A dictionary mapping each integer ID back to its corresponding label.\n", + "\n", + " Returns:\n", + " dict: A dictionary containing accuracy, macro-averaged precision, and macro-averaged recall for the model's \n", + " predictions compared to the true labels in the dataset.\n", + " \"\"\"\n", + " dataset = input_dataset.copy()\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + " model = AutoModelForSequenceClassification.from_pretrained(\n", + " model_name, num_labels=len(label2id), ignore_mismatched_sizes=True\n", + " )\n", + " encodings = tokenizer(list(dataset['notes']), truncation=True, padding=True, max_length=128, return_tensors=\"pt\")\n", + " \n", + " with torch.no_grad():\n", + " outputs = model(**encodings)\n", + " predictions = torch.argmax(outputs.logits, dim=1).numpy()\n", + " \n", + " dataset[f'{model_name}_classification'] = [id2label[pred] for pred in predictions]\n", + " \n", + " # Calculate metrics\n", + " accuracy = accuracy_score(dataset['category'], dataset[f'{model_name}_classification'])\n", + " macro_precision = precision_score(dataset['category'], dataset[f'{model_name}_classification'], average='macro')\n", + " macro_recall = recall_score(dataset['category'], dataset[f'{model_name}_classification'], average='macro')\n", + " \n", + " return accuracy, macro_precision, macro_recall" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "def gpt_classification(input_dataset, categories, model=\"gpt-3.5-turbo\"):\n", + " \"\"\"\n", + " Classify protest notes into specified categories using a GPT model.\n", + "\n", + " Parameters:\n", + " - input_dataset (pd.DataFrame): A dataset containing notes and true categories.\n", + " - categories (list): List of possible classification categories.\n", + " - model (str): GPT model to use for classification (e.g., \"gpt-3.5-turbo\").\n", + "\n", + " Returns:\n", + " - overall_accuracy (float): Overall accuracy of the classifications.\n", + " - macro_precision (float): Macro-average precision across categories.\n", + " - macro_recall (float): Macro-average recall across categories.\n", + " \"\"\"\n", + " dataset = input_dataset.copy()\n", + " llm = ChatOpenAI(model=model, api_key=OPENAI_API_KEY)\n", + " predictions = []\n", + "\n", + " for note in dataset['notes']:\n", + " # Format the prompt with the note and available categories\n", + " prompt = prompt_template.format(categories=\", \".join(categories), note=note)\n", + " response = llm(prompt)\n", + " predictions.append(response.content.strip())\n", + " \n", + " dataset['predicted_category'] = predictions\n", + "\n", + " # Calculate overall accuracy\n", + " overall_accuracy = accuracy_score(dataset['category'], dataset['predicted_category'])\n", + " \n", + " # Calculate macro precision and macro recall\n", + " macro_precision = precision_score(dataset['category'], dataset['predicted_category'], average='macro')\n", + " macro_recall = recall_score(dataset['category'], dataset['predicted_category'], average='macro')\n", + "\n", + " return overall_accuracy, macro_precision, macro_recall, dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "# ===========================\n", + "# LOAD THE DATA \n", + "# ===========================\n", + "\n", + "# Combine \"notes\" and \"description\" columns to create a single input text for each example\n", + "if COMBINE_NOTES_DESCRIPTION:\n", + " df_prot['text'] = df['notes'] + \" \" + df_prot['description']\n", + "else:\n", + " df_prot['text'] = df_prot['notes']\n", + "\n", + "\n", + "# Categories\n", + "categories = list(df_prot.category.unique())" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "tags": [ + "remove-cell", + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "/Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d36ed2209e6e47f884d514792942aa8a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 These misclassifications can be further analyzed to determine whether there are inherent, indistinguishable similarities in how these categories are described in the dataset. A spreadsheet, `social-misclassifications.csv`, has been saved to facilitate this inspection by individuals who are highly familiar with the dataset.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Generate the confusion matrix\n", + "conf_matrix = confusion_matrix(df_gpt4_turbo['category'], df_gpt4_turbo['predicted_category'], labels=CATEGORIES)\n", + "\n", + "# Plot the confusion matrix with a red color palette\n", + "plt.figure(figsize=(10, 8))\n", + "sns.heatmap(conf_matrix, annot=True, fmt=\"d\", cmap=\"Reds\", xticklabels=categories, yticklabels=categories, cbar=False)\n", + "plt.xlabel(\"Predicted Label\")\n", + "plt.ylabel(\"True Label\")\n", + "plt.title(\"Confusion Matrix for GPT 4-Turbo Model\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "# =======================================================\n", + "# EXAMINE MISCLASSIFIED INSTANCES FOR CATEGORY SOCIAL\n", + "# =======================================================\n", + "df_misclassified_social = df_gpt4_turbo[\n", + " (df_gpt4_turbo['category'] == 'Social') & \n", + " (df_gpt4_turbo['predicted_category'] == 'Political/Security')\n", + "]\n", + "\n", + "# Save misclassified examples \n", + "df_misclassified.to_csv(DIR_DATA.joinpath(\"social-misclassified.csv\"), index=False)\n", + "\n", + "# df_misclassified = df_misclassified [['notes', 'description',\n", + "# 'category', 'predicted_category']]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Classify the Entire Dataset\n", + "\n", + "In the previous section, we evaluated the performance of several state-of-the-art LLMs and smaller language models (e.g., BERT) to understand how well they could classify the `note` into our predetermined six categories. Among these models, OpenAI's `GPT-4-turbo` performed exceptionally well. Moving forward, we will use the `GPT-4-turbo` model to classify our entire target dataset, which includes the `note` and `description` columns providing detailed information about each protest. We will follow this process:\n", + "\n", + "### Use Few-Shot Learning to Improve GPT-4-turbo Performance\n", + "\n", + "In the earlier section, we assessed the performance of `GPT-4-turbo` and other models without providing them with any examples or additional training (zero-shot learning). In this scenario, `GPT-4-turbo` achieved an accuracy of 82%. However, the confusion matrix revealed significant variation in performance across categories. For instance, the precision for categories such as **social** was notably poor.\n", + "\n", + "Now that we are classifying the entire dataset, we will leverage a technique called `few-shot learning`. This involves providing examples to the LLM and then asking it to classify data based on the patterns demonstrated in the examples. This approach almost always enhances the model’s performance.\n", + "\n", + "\n", + "### Check Performance of GPT-4-turbo with Few-Shot Learning\n", + "\n", + "While we gained an initial understanding of how well `GPT-4-turbo` performed in the zero-shot classification scenario, it is important to quantify its performance when examples are provided. To achieve this, we will split our dataset of 212 training examples into a train set (e.g., 80% of the data) and a test set. We will then randomly select examples from the train set (e.g., 5, 10, 15 examples) to pass to the LLM and evaluate its performance. We will also revisit the confusion matrix to determine whether few-shot learning improves the model’s performance for categories such as **social**, which previously suffered from high misclassification rates.\n", + "\n", + "### Classify the entire dataset\n", + "Once we determine the optimal number of examples to use, the most effective prompt strategy, and how well the few-shot example setting performs, we will apply the same prompt template, formatting, and settings to classify each protest. This will result in a new column, ```classification```, being added to the dataset. Given that there are approximately 25,000 protest instances, this process typically takes around 2–3 hours, depending on the number of examples provided.\n", + "\n", + "### Perform Sanity Checks\n", + "\n", + "The steps above will give us confidence in how well the LLM performs. For instance, if the LLM achieves 90% accuracy, we can reasonably expect that approximately 90% of the classifications for the entire dataset will be correct. Beyond this, there are additional sanity checks we can perform to further validate the model’s classifications:\n", + "\n", + "- **Compare the distribution of model classifications with the training data (212 examples):** \n", + " While the distributions do not need to be identical, we expect them to be comparable. Significant discrepancies might indicate issues with the model's classifications.\n", + "\n", + "- **Analyze the similarity of notes within a category versus across categories:** \n", + " Intuitively, protest descriptions within a category (e.g., **social**) should exhibit greater semantic similarity (measured using techniques like cosine similarity) compared to descriptions from different categories. We will perform this analysis to confirm whether this pattern holds true.\n", + "\n", + "- **Verify classifications against real-world knowledge:** \n", + " As previously demonstrated, we can cross-check classifications against actual historical data. For instance, if the LLM classifies a large number of protests as **social** in January 2024 but there is no record of such a spike, it may indicate misclassification by the LLM." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1 Assessing Classification Accuracy with Few-Shot Examples\n", + "To evaluate `GPT-4-turbo` under few-shot learning, we will split our dataset of 212 training examples into a train set (e.g., 80%) and a test set. By providing varying numbers of examples (e.g., 5, 10, 15) from the train set, we aim to balance performance, running time, and cost while identifying the most effective prompting strategy. We will also revisit the confusion matrix to assess improvements in category-specific performance, such as **social**, which previously showed high misclassification rates." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "def gpt4_turbo_classification(dataset, categories):\n", + " predictions = []\n", + " for note in dataset['notes']:\n", + " # Format the prompt with the note and available categories\n", + " prompt = prompt_template.format(categories=\", \".join(categories), note=note)\n", + " response = llm(prompt)\n", + " predictions.append(response.strip())\n", + " \n", + " dataset['predicted_category'] = predictions\n", + " \n", + " # Calculate accuracy\n", + " accuracy = accuracy_score(dataset['category'], dataset['predicted_category'])\n", + " # Calculate macro precision and macro recall\n", + " macro_precision = precision_score(dataset['category'], dataset['predicted_category'], average='macro')\n", + " macro_recall = recall_score(dataset['category'], dataset['predicted_category'], average='macro')\n", + " return accuracy, macro_precision, macro_recall" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1.1 Split the examples into training and test\n", + "It’s important to note that this approach differs from the typical machine learning paradigm, where around 70-80% of the dataset is used for training and the remainder for evaluation. With LLMs like GPT-4-turbo, feeding too many examples slows down the process and significantly increases costs. \n", + "\n", + "To evaluate how well the LLM performs under few-shot learning, it’s more practical to allocate a larger portion of the data for testing and a smaller portion for providing examples to the model. Therefore, we split the dataset into 40% for training and 60% for testing.\n", + "\n", + "Examples for the model will be drawn from the 40% training portion, and the model’s performance will be evaluated on the 60% test portion. This strategy ensures we can thoroughly evaluate the model while keeping costs and runtime manageable.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/k5/p4nvl2pj4gq2ks0j7qvrhwbh0000gp/T/ipykernel_49121/1674196748.py:5: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n", + " train_df = df_prot.groupby('category', group_keys=False).apply(lambda x: x.sample(frac=PROP_TRAIN,\n" + ] + }, + { + "data": { + "text/plain": [ + "((112, 6), (112, 6))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ===========================================\n", + "# SPLIT THE EXAMPLES INTO TRAINING AND TEST\n", + "# ===========================================\n", + "# Split the data while maintaining the distribution of categories\n", + "train_df = df_prot.groupby('category', group_keys=False).apply(lambda x: x.sample(frac=PROP_TRAIN, \n", + "random_state=42))\n", + "test_df = df_prot.drop(train_df.index)\n", + "\n", + "(train_df.shape, test_df.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Train-Distribution of Protest Classes \n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)3228.57%28.57%
Political/Security2825.00%53.57%
Business and legal2118.75%72.32%
Social1311.61%83.93%
Public service delivery1210.71%94.64%
Climate and environment65.36%100.00%
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pretty_print_value_counts(train_df, \"category\", \n", + "title=\"Train-Distribution of Protest Classes\", line_length=60)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Test-Distribution of Protest Classes \n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)3228.57%28.57%
Political/Security2825.00%53.57%
Business and legal2118.75%72.32%
Social1311.61%83.93%
Public service delivery1311.61%95.54%
Climate and environment54.46%100.00%
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pretty_print_value_counts(test_df, \"category\", \n", + "title=\"Test-Distribution of Protest Classes\", line_length=60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1.2 . Prepare Examples \n", + "For each example, we use three columns to provide information to the LLM: `notes`, `description`, and `category`. These columns must be formatted appropriately before being passed to the model.\n", + "\n", + "We use a variable, `NUM_EXAMPLES`, to specify the number of examples per category to include. After experimenting with 5, 10, and 15 examples, we determined that 10 examples (which gives a total of 60 examples with 6 categories) strike the optimal balance. However, it’s important to note that some classes, such as `climate`, have only 11 examples available in the training dataset.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "def generate_formatted_examples(train_df, num_examples):\n", + " \"\"\"\n", + " Generates and formats examples from the training dataset for prompt construction.\n", + "\n", + " Parameters:\n", + " - train_df (pd.DataFrame): The training dataset containing 'notes', 'description', and 'category' columns.\n", + " - num_examples (int): Number of examples to use per category.\n", + "\n", + " Returns:\n", + " - str: Formatted examples as a single string for use in a prompt.\n", + " \"\"\"\n", + " examples = []\n", + "\n", + " # Select examples based on num_examples\n", + " for category in CATEGORIES:\n", + " # Get all samples if fewer than num_examples exist, otherwise take num_examples\n", + " category_samples = train_df[train_df['category'] == category].sample(\n", + " min(num_examples, len(train_df[train_df['category'] == category])), random_state=42\n", + " )\n", + " examples.extend(category_samples[['notes', 'description', 'category']].values.tolist())\n", + "\n", + " # Format the examples for the prompt\n", + " formatted_examples = \"\\n\".join(\n", + " [\n", + " f\"Example {i+1}:\\nNotes: {example[0]}\\nDescription: {example[1]}\\nCategory: {example[2]}\"\n", + " for i, example in enumerate(examples)\n", + " ]\n", + " )\n", + "\n", + " return formatted_examples\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1.3 Prepare Prompt Template\n", + "Experimenting with multiple templates is essential because the way information is presented to an LLM can significantly impact its performance. In this case, the template must ensure the model pays attention to all three columns: `notes`, `description`, and `category`. Additionally, the template should be designed to instruct the LLM to output only the classification, as models sometimes prepend additional text before the category or classification, which can lead to inconsistent results. By refining the template, we can maximize accuracy and ensure the outputs are structured correctly for downstreaming data processing and usage of the data." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "# =============================\n", + "# PREPARE TEMPLATE \n", + "# =============================\n", + "classification_prompt_template = \"\"\"\n", + "You are a highly intelligent assistant. Your task is to classify each document into one of the following categories:\n", + "- Political/Security\n", + "- Livelihood (Prices, jobs and salaries)\n", + "- Public service delivery\n", + "- Business and legal\n", + "- Climate and environment\n", + "- Social\n", + "\n", + "Each category has a description that helps explain its purpose.\n", + "\n", + "Here are some examples:\n", + "\n", + "{examples}\n", + "\n", + "Given the following document:\n", + "{document}\n", + "\n", + "Based on the descriptions and the examples, which category does this document fall into? Please respond with one of the categories listed above.\n", + "Please respond with only the category name without adding other information.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/k5/p4nvl2pj4gq2ks0j7qvrhwbh0000gp/T/ipykernel_49121/2134438236.py:11: LangChainDeprecationWarning: The class `LLMChain` was deprecated in LangChain 0.1.17 and will be removed in 1.0. Use :meth:`~RunnableSequence, e.g., `prompt | llm`` instead.\n", + " classification_chain = LLMChain(\n" + ] + } + ], + "source": [ + "# Step 2: Initialize the LLM (OpenAI in this case)\n", + "llm = ChatOpenAI(model=OPENAI_MODEL)\n", + "\n", + "# Step 5: Create a prompt template using Langchain\n", + "classification_prompt = PromptTemplate(\n", + " input_variables=[\"document\", \"examples\"],\n", + " template=classification_prompt_template\n", + ")\n", + "\n", + "# Step 6: Create a Langchain with the LLM and the classification prompt template\n", + "classification_chain = LLMChain(\n", + " llm=llm,\n", + " prompt=classification_prompt\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "# Test with all the 'notes' column from test_df\n", + "documents = test_df['notes'].tolist()\n", + "\n", + "# Formatted examples\n", + "formatted_examples = generate_formatted_examples(train_df, num_examples=NUM_EXAMPLES)\n", + "\n", + "# Classify each document and store results\n", + "classifications = []\n", + "for doc in documents:\n", + " result = classification_chain.run(document=doc, examples=formatted_examples)\n", + " # Clean up the result by removing the \"Category: \" prefix if it exists\n", + " cleaned_result = result.strip().replace(\"Category: \", \"\")\n", + " classifications.append(cleaned_result) # Store the cleaned classification\n", + "\n", + "# Add the cleaned classification results to the DataFrame as a new column\n", + "test_df['predicted_category'] = classifications\n", + "\n", + "# Calculate accuracy\n", + "correct_predictions = (test_df['predicted_category'] == test_df['category']).sum()\n", + "total_predictions = len(test_df)\n", + "accuracy = correct_predictions / total_predictions\n", + "\n", + "# Calculate macro-precision and macro-recall\n", + "macro_precision = precision_score(test_df['category'], test_df['predicted_category'], average='macro', zero_division=0)\n", + "macro_recall = recall_score(test_df['category'], test_df['predicted_category'], average='macro', zero_division=0)\n", + "\n", + "# Print the results\n", + "print(\"=\"*62)\n", + "print(f\" Accuracy and Other Metrics for: {OPENAI_MODEL} with {NUM_EXAMPLES} Examples\")\n", + "print(\"=\"*62)\n", + "print(f\" Accuracy: {accuracy * 100:.2f}%\")\n", + "print(f\" Macro-Precision: {macro_precision * 100:.2f}%\")\n", + "print(f\" Macro-Recall: {macro_recall * 100:.2f}%\")\n", + "print(\"-\"*62)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![GPT-4-turbo Perfomance with Few-shot Examples](../../docs/images/protest-classification/results-few-shot.png)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Distribution- Predicted Categories in Test Set \n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)3127.68%27.68%
Political/Security2623.21%50.89%
Business and legal2623.21%74.11%
Social1311.61%85.71%
Public service delivery119.82%95.54%
Climate and environment54.46%100.00%
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pretty_print_value_counts(test_df, \"predicted_category\", \n", + "title=\"Distribution- Predicted Categories in Test Set\", line_length=60)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Generate the confusion matrix\n", + "conf_matrix = confusion_matrix(test_df['category'], test_df['predicted_category'], labels=CATEGORIES)\n", + "\n", + "# Plot the confusion matrix with a red color palette\n", + "plt.figure(figsize=(10, 8))\n", + "sns.heatmap(conf_matrix, annot=True, fmt=\"d\", cmap=\"Reds\", xticklabels=categories, yticklabels=categories, cbar=False)\n", + "plt.xlabel(\"Predicted Label\")\n", + "plt.ylabel(\"True Label\")\n", + "plt.title(\"Confusion Matrix for GPT 4-Turbo Model\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 Classify All Documents \n", + "We are now ready to classify all the documents in the dataset using the previously defined prompt template and example formatting strategy. This approach has achieved approximately 90% accuracy and precision for 112 test examples, utilizing 10 examples per category for a total of 60 examples." + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "def classify_with_retry(document, examples, categories, max_retries=5):\n", + " \"\"\"\n", + " Classifies the document, retrying if the initial classification is not in the specified categories.\n", + " \n", + " Parameters:\n", + " - document (str): The document to classify.\n", + " - examples (list): Examples to use for classification.\n", + " - categories (list) : Categories to ensure classification is within required categories\n", + " - max_retries (int): Number of retries allowed if classification is not in specified categories.\n", + "\n", + " Returns:\n", + " - str: The final classification or 'Failed2Classify' if classification is unsuccessful.\n", + " \"\"\"\n", + " for _ in range(max_retries + 1):\n", + " result = classification_chain.run(document=document, examples=examples)\n", + " cleaned_result = result.strip().replace(\"Category: \", \"\")\n", + " \n", + " if cleaned_result in categories:\n", + " return cleaned_result \n", + " \n", + " # If classification failed after retries, label as \"Failed2Classify\"\n", + " return \"Failed2Classify\"" + ] + }, + { + "cell_type": "code", + "execution_count": 148, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "def classify_entire_dataset(df_train, df_target, output_file, categories, checkpoint_interval=100, sample=1):\n", + " \"\"\"\n", + " Classifies a large dataset with the ability to resume from a checkpoint in case of interruptions.\n", + "\n", + " Parameters:\n", + " - df_train (pd.DataFrame): DataFrame where we are drawing examples from.\n", + " - df_target (pd.DataFrame): Inference dataset.\n", + " - output_file (str): File path to save progress.\n", + " - categories (list): Categories defining all classes.\n", + " - checkpoint_interval (int): Number of rows to process before saving a checkpoint.\n", + " - sample (float): Whether to run on everything or a sampled fraction.\n", + "\n", + " Returns:\n", + " - pd.DataFrame: DataFrame with classifications added.\n", + " \"\"\"\n", + " # Generate formatted examples\n", + " examples = generate_formatted_examples(df_train, num_examples=NUM_EXAMPLES)\n", + " \n", + " # Load progress if output_file exists, merge with df_target\n", + " try:\n", + " df_classified = pd.read_csv(output_file)\n", + " print(f\"Loaded progress from {output_file}. Resuming classification.\")\n", + " except FileNotFoundError:\n", + " df_classified = df_target.copy()\n", + " df_classified['predicted_category'] = None\n", + "\n", + " # Whether to run on everything or sample\n", + " num_rows = int(sample*df_classified.shape[0])\n", + " print(f\"WILL STOP AT {num_rows} ROWS FOR DEBUGGING PURPOSES\")\n", + " print(\"-\"*40)\n", + "\n", + " # Ensure all rows are included by merging the progress\n", + " if 'predicted_category' in df_classified.columns:\n", + " df_classified = df_target.merge(\n", + " df_classified[['notes', 'predicted_category']], on='notes', how='left', suffixes=('', '_classified')\n", + " )\n", + " df_classified['predicted_category'] = df_classified['predicted_category'].combine_first(df_classified.get('predicted_category_classified', pd.Series(None, index=df_classified.index)))\n", + " if 'predicted_category_classified' in df_classified.columns:\n", + " df_classified.drop(columns=['predicted_category_classified'], inplace=True)\n", + " else:\n", + " print(\"No progress file found or 'predicted_category' column missing; starting fresh.\")\n", + "\n", + " # Classify unprocessed or failed rows\n", + " count = 0\n", + " for idx, row in tqdm(df_classified.iterrows(), total=len(df_classified)):\n", + " if pd.isna(row['predicted_category']) or row['predicted_category'] == \"Failed2Classify\":\n", + " try:\n", + " classification = classify_with_retry(row['notes'], examples, CATEGORIES, max_retries=5)\n", + " df_classified.at[idx, 'predicted_category'] = classification\n", + " except Exception as e:\n", + " print(f\"Error at row {idx}: {e}\")\n", + " df_classified.at[idx, 'predicted_category'] = \"Failed2Classify\"\n", + "\n", + " # Save progress at intervals\n", + " if idx % checkpoint_interval == 0:\n", + " df_classified.to_csv(output_file, index=False)\n", + " print(f\"Checkpoint saved to {output_file} after {idx} rows.\")\n", + " \n", + " # Stop once we hit num_rows as per sample size\n", + " if count == num_rows:\n", + " print(\"STOPPING EARLY FOR DEBUGGING\")\n", + " break\n", + " count += 1\n", + "\n", + " # Save final results\n", + " df_classified.to_csv(output_file, index=False)\n", + " print(f\"Classification completed. Results saved to {output_file}.\")\n", + " return df_classified\n" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded progress from /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv. Resuming classification.\n", + "WILL STOP AT 1294 ROWS FOR DEBUGGING PURPOSES\n", + "----------------------------------------\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 1/31527 [00:00<1:33:52, 5.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 0 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 101/31527 [00:00<01:32, 338.93it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 100 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 1%| | 201/31527 [00:00<01:11, 440.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 200 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 1%| | 301/31527 [00:00<01:02, 496.42it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 300 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 1%|▏ | 401/31527 [00:00<00:59, 525.71it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 400 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 501/31527 [00:01<00:58, 534.69it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 500 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 601/31527 [00:01<00:56, 552.06it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 600 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 701/31527 [01:22<5:21:45, 1.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 700 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 3%|▎ | 801/31527 [02:59<8:53:28, 1.04s/it] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 800 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 3%|▎ | 901/31527 [04:47<11:37:24, 1.37s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 900 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 3%|▎ | 1001/31527 [06:51<11:13:32, 1.32s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 1000 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 3%|▎ | 1101/31527 [08:55<10:35:15, 1.25s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 1100 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 4%|▍ | 1201/31527 [10:59<10:43:13, 1.27s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv after 1200 rows.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 4%|▍ | 1294/31527 [12:56<5:02:15, 1.67it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STOPPING EARLY FOR DEBUGGING\n", + "Classification completed. Results saved to /Users/dunstanmatekenya/Library/CloudStorage/OneDrive-WBG(2)/Data-Lab/iran-economic-monitoring/data/conflict/protests-labeled-all-gpt4-turbo.csv.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# ========================================\n", + "# LOAD THE DATA AND PERFOM CLASSIFICATION\n", + "# ========================================\n", + "df_all = pd.read_csv(FILE_PROTESTS)\n", + "df_all.drop(columns=[\"Unnamed: 0\"], inplace=True)\n", + "csv_output = DIR_DATA.joinpath(\"protests-labeled-all-gpt4-turbo.csv\")\n", + "df_prot_classified = classify_entire_dataset(df_prot, df_all, csv_output, categories, \n", + "checkpoint_interval=100, sample=0.05)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3 Conduct Quality Assurance on Classified Protests\n", + "Now that we have used the LLM to add a category, in the column names ```predicted_category``` for all protests instances in our dataset. We will perfom some sanity checks to make sure that the predicted categories actually make sense.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 155, + "metadata": {}, + "outputs": [], + "source": [ + "df_prot_pred = pd.read_csv(DIR_DATA.joinpath(\"protests-labeled-all-gpt35-turbo.csv\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3.1 Compare Distributions\n", + "In the training dataset, we check distribution (frequency distribution) for the variable ```category``` and compare it with frequency distribution for ```predicted_category```- categories generated by LLM. Although we dont expect axact match, we expect these distributions to be comparable. " + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_cnts = df_prot_pred.classification.value_counts(normalize=True)*100\n", + "actual_cnts = df_prot.category.value_counts(normalize=True)*100" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "metadata": {}, + "outputs": [], + "source": [ + "all_labels = sorted(set(df_prot_pred.classification.unique()))\n", + "predicted_cnts = predicted_cnts.reindex(all_labels, fill_value=0)\n", + "actual_cnts = actual_cnts.reindex(all_labels, fill_value=0)\n", + "\n", + "df_comparison = pd.DataFrame({'predicated_category': predicted_cnts,\n", + "'actual_cnts': actual_cnts})" + ] + }, + { + "cell_type": "code", + "execution_count": 173, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_comparison.plot(kind='bar', figsize=(10,6))\n", + "plt.title(\"Comparison of Predicted vs. Actual Categories\")\n", + "plt.ylabel(\"Percent\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Distribution of Predicted Categories \n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)13,45953.93%53.93%
Business and legal3,36913.50%67.43%
Social3,11912.50%79.93%
Political/Security3,10912.46%92.39%
Public service delivery9843.94%96.33%
Climate and environment9153.67%100.00%
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pretty_print_value_counts(df_prot_pred, \"classification\", \n", + "\"Distribution of Predicted Categories\",line_length=60 )" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Distribution of Actual Categories in Training Data \n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CategoryCountPercentCum. Percent
Livelihood (Prices, jobs and salaries)6428.57%28.57%
Political/Security5625.00%53.57%
Business and legal4218.75%72.32%
Social2611.61%83.93%
Public service delivery2511.16%95.09%
Climate and environment114.91%100.00%
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pretty_print_value_counts(df_prot, \"category\", \n", + "\"Distribution of Actual Categories in Training Data\",line_length=60 )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3.2 Similarity of notes within and across categories\n", + "In order to have confidence about the classifications generated by the model, \n", + "lets perform some sanity checks. " + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_cosine_similarity_matrix(df, categories):\n", + " \"\"\"\n", + " Calculate a cosine similarity matrix for all categories by comparing embeddings within and across categories.\n", + "\n", + " Parameters:\n", + " - df: DataFrame containing the 'notes' embeddings and 'predicted_category'.\n", + " - categories: List of unique categories to compare.\n", + "\n", + " Returns:\n", + " - similarity_matrix (pd.DataFrame): A DataFrame representing the cosine similarity matrix.\n", + " \"\"\"\n", + " # Initialize an empty matrix for similarity scores\n", + " similarity_matrix = np.zeros((len(categories), len(categories)))\n", + "\n", + " # Calculate cosine similarities within and across categories\n", + " for i, category1 in enumerate(categories):\n", + " # Filter embeddings for category1\n", + " embeddings_category1 = np.vstack(df[df['predicted_category'] == category1]['embedding'].values)\n", + "\n", + " for j, category2 in enumerate(categories):\n", + " # Filter embeddings for category2\n", + " embeddings_category2 = np.vstack(df[df['predicted_category'] == category2]['embedding'].values)\n", + "\n", + " # Calculate cosine similarities between embeddings in category1 and category2\n", + " similarity_scores = cosine_similarity(embeddings_category1, embeddings_category2).flatten()\n", + " avg_similarity = np.mean(similarity_scores) if len(similarity_scores) > 0 else 0\n", + " similarity_matrix[i, j] = avg_similarity\n", + "\n", + " # Convert the matrix to a DataFrame for readability\n", + " similarity_df = pd.DataFrame(similarity_matrix, index=categories, columns=categories)\n", + " return similarity_df" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_distance_matrix(df, categories, distance_metric=\"euclidean\"):\n", + " \"\"\"\n", + " Calculate a distance matrix for all categories by comparing embeddings within and across categories.\n", + "\n", + " Parameters:\n", + " - df: DataFrame containing the 'notes' embeddings and 'predicted_category'.\n", + " - categories: List of unique categories to compare.\n", + " - distance_metric: Type of distance to compute (e.g., 'euclidean', 'manhattan').\n", + "\n", + " Returns:\n", + " - distance_matrix (pd.DataFrame): A DataFrame representing the distance matrix.\n", + " \"\"\"\n", + " # Initialize an empty matrix for distance scores\n", + " distance_matrix = np.zeros((len(categories), len(categories)))\n", + "\n", + " # Calculate distances within and across categories\n", + " for i, category1 in enumerate(categories):\n", + " # Filter embeddings for category1\n", + " embeddings_category1 = np.vstack(df[df['predicted_category'] == category1]['embedding'].values)\n", + "\n", + " for j, category2 in enumerate(categories):\n", + " # Filter embeddings for category2\n", + " embeddings_category2 = np.vstack(df[df['predicted_category'] == category2]['embedding'].values)\n", + "\n", + " # Calculate pairwise distances between embeddings in category1 and category2\n", + " distance_scores = cdist(embeddings_category1, embeddings_category2, metric=distance_metric).flatten()\n", + " avg_distance = np.mean(distance_scores) if len(distance_scores) > 0 else 0\n", + " distance_matrix[i, j] = avg_distance\n", + "\n", + " # Convert the matrix to a DataFrame for readability\n", + " distance_df = pd.DataFrame(distance_matrix, index=categories, columns=categories)\n", + " return distance_df" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "# Load dataset\n", + "df_classified_protests = pd.read_csv(DIR_DATA.joinpath(\"protests-labeled-all-gpt.csv\"))\n", + "categories = list(df_classified_protests.classification.unique())\n", + "df_classified_protests.rename(columns={\"classification\":\"predicted_category\"}, inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "# =========================================\n", + "# LOAD AND CREATE EMBEDDINGS FOR DATAFRAME\n", + "# =========================================\n", + "# Initialize OpenAI embeddings\n", + "embedding_model = OpenAIEmbeddings(model=\"text-embedding-ada-002\", \n", + "openai_api_type=OPENAI_API_KEY)\n", + "\n", + "# Embed each note and store the embeddings in the DataFrame\n", + "df_classified_protests['embedding'] = embedding_model.embed_documents(df_classified_protests['notes'].tolist())" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Cosine Similarity Matrix (Within and Across Categories)\n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Political/SecurityLivelihood (Prices, jobs and salaries)Climate and environmentBusiness and legalSocialPublic service delivery
Political/Security0.8598970.8241010.8321800.8274530.8421720.831968
Livelihood (Prices, jobs and salaries)0.8241010.8626650.8351490.8471270.8361760.838875
Climate and environment0.8321800.8351490.8715370.8358490.8339480.853963
Business and legal0.8274530.8471270.8358490.8548580.8356030.838907
Social0.8421720.8361760.8339480.8356030.8560320.837756
Public service delivery0.8319680.8388750.8539630.8389070.8377560.855566
\n", + "
" + ], + "text/plain": [ + " Political/Security \\\n", + "Political/Security 0.859897 \n", + "Livelihood (Prices, jobs and salaries) 0.824101 \n", + "Climate and environment 0.832180 \n", + "Business and legal 0.827453 \n", + "Social 0.842172 \n", + "Public service delivery 0.831968 \n", + "\n", + " Livelihood (Prices, jobs and salaries) \\\n", + "Political/Security 0.824101 \n", + "Livelihood (Prices, jobs and salaries) 0.862665 \n", + "Climate and environment 0.835149 \n", + "Business and legal 0.847127 \n", + "Social 0.836176 \n", + "Public service delivery 0.838875 \n", + "\n", + " Climate and environment \\\n", + "Political/Security 0.832180 \n", + "Livelihood (Prices, jobs and salaries) 0.835149 \n", + "Climate and environment 0.871537 \n", + "Business and legal 0.835849 \n", + "Social 0.833948 \n", + "Public service delivery 0.853963 \n", + "\n", + " Business and legal Social \\\n", + "Political/Security 0.827453 0.842172 \n", + "Livelihood (Prices, jobs and salaries) 0.847127 0.836176 \n", + "Climate and environment 0.835849 0.833948 \n", + "Business and legal 0.854858 0.835603 \n", + "Social 0.835603 0.856032 \n", + "Public service delivery 0.838907 0.837756 \n", + "\n", + " Public service delivery \n", + "Political/Security 0.831968 \n", + "Livelihood (Prices, jobs and salaries) 0.838875 \n", + "Climate and environment 0.853963 \n", + "Business and legal 0.838907 \n", + "Social 0.837756 \n", + "Public service delivery 0.855566 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "------------------------------------------------------------\n" + ] + } + ], + "source": [ + "# =========================================\n", + "# SKLEARN COSINE SIMILARITY\n", + "# =========================================\n", + "cosine_similarity_matrix = calculate_cosine_similarity_matrix(df_classified_protests, categories)\n", + "\n", + "print(\"=\"*60)\n", + "print(\" Cosine Similarity Matrix (Within and Across Categories)\")\n", + "print(\"=\"*60)\n", + "display(cosine_similarity_matrix)\n", + "print(\"-\"*60)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + " Euclidean Distance Similarity Matrix (Within and Across Categories)\n", + "============================================================\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Political/SecurityLivelihood (Prices, jobs and salaries)Climate and environmentBusiness and legalSocialPublic service delivery
Political/Security0.5248130.5920960.5783130.5865350.5601150.578551
Livelihood (Prices, jobs and salaries)0.5920960.5206360.5731560.5508570.5708760.566408
Climate and environment0.5783130.5731560.5024130.5720530.5752350.538168
Business and legal0.5865350.5508570.5720530.5348870.5724380.566481
Social0.5601150.5708760.5752350.5724380.5321470.568317
Public service delivery0.5785510.5664080.5381680.5664810.5683170.534158
\n", + "
" + ], + "text/plain": [ + " Political/Security \\\n", + "Political/Security 0.524813 \n", + "Livelihood (Prices, jobs and salaries) 0.592096 \n", + "Climate and environment 0.578313 \n", + "Business and legal 0.586535 \n", + "Social 0.560115 \n", + "Public service delivery 0.578551 \n", + "\n", + " Livelihood (Prices, jobs and salaries) \\\n", + "Political/Security 0.592096 \n", + "Livelihood (Prices, jobs and salaries) 0.520636 \n", + "Climate and environment 0.573156 \n", + "Business and legal 0.550857 \n", + "Social 0.570876 \n", + "Public service delivery 0.566408 \n", + "\n", + " Climate and environment \\\n", + "Political/Security 0.578313 \n", + "Livelihood (Prices, jobs and salaries) 0.573156 \n", + "Climate and environment 0.502413 \n", + "Business and legal 0.572053 \n", + "Social 0.575235 \n", + "Public service delivery 0.538168 \n", + "\n", + " Business and legal Social \\\n", + "Political/Security 0.586535 0.560115 \n", + "Livelihood (Prices, jobs and salaries) 0.550857 0.570876 \n", + "Climate and environment 0.572053 0.575235 \n", + "Business and legal 0.534887 0.572438 \n", + "Social 0.572438 0.532147 \n", + "Public service delivery 0.566481 0.568317 \n", + "\n", + " Public service delivery \n", + "Political/Security 0.578551 \n", + "Livelihood (Prices, jobs and salaries) 0.566408 \n", + "Climate and environment 0.538168 \n", + "Business and legal 0.566481 \n", + "Social 0.568317 \n", + "Public service delivery 0.534158 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "------------------------------------------------------------\n" + ] + } + ], + "source": [ + "# =========================================\n", + "# SKLEARN COSINE SIMILARITY\n", + "# =========================================\n", + "euclidean_similarity_matrix = calculate_distance_matrix(df_classified_protests, categories)\n", + "\n", + "print(\"=\"*60)\n", + "print(\" Euclidean Distance Similarity Matrix (Within and Across Categories)\")\n", + "print(\"=\"*60)\n", + "display(euclidean_similarity_matrix)\n", + "print(\"-\"*60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Limitations and Challenges\n", + "There are two main limitations organized in two categories as below:\n", + "### Cost related challenges\n", + "It is expensive to run OpenAI models when you have many tokens. The main drivers of the cost when using commercial LLMs in general are provide below.\n", + "- **Need to use best performing model**. In order to get good performance without investing in further model fine-tuning, we need to use commercial models such as that from OpenAI which we used. Furthermore, from OpenAI, best performing models are the latest models which are expensive.\n", + "- **Large number of examples**.Again here, we need to pass more examples to the model in order to get better performance because generally speaking, this enables the model to learn from the examples. However, more examples means we are passing more tokens to OpenAI API which drives the cost up as the charging is done per token. \n", + "- **Large dataset** We need to generate the classifications on a dataset size of 24k rows. This isnt by any stretch of imagination a large dataset, however for tasks involing LLMs where you are paying by tokens, this again drives the cost up as each ```note``` contains many tokens.\n", + "- **Price for experimentation**. In order to determine optimal parameters and other things\n", + "### Compute resource and processing time\n", + "Although compute resources can be accessed with budget availability, LLMs generally take long to run due to the large number of parameters. In order to experiment fast, get results faster and be cost effective, sometimes you sacfrice a little bit of perfomance gains. \n", + "\n", + "### Classification results are not 100% accurate\n", + "In most cases, acheiving 90% classification accuracy is good enough.Ultimately, what number counts as accurate enough varies depending on use case. In this case, we still have to live with the fact that potentially 5-10% of the classifications are wrong and there is no way to know which one. Howeever, again the sanity checks we conducted do provide a recourse to increase confidence in the results of the LLM.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Potential Next Steps\n", + "### Further sanity checks\n", + "In a separate notebooks, more analysis will be done particulary focusing on time to check and demonstrate that what we got from the LLM is reasonable and matches what was happening on ground.\n", + "\n", + "### Further tuning of the model\n", + "As mentioned above, the experimentation in this work was limited by available resources. Its possible to squeeze out more perfomance by utilizing a more performant OpenAI model (e.g., ```gpt-4o```). Secondly, more experiments could be done to check to see if increasing number of examples further would improve results or whether selecting examples in a smarter way (e.g., semantic selection) which would select examples based on similary to the target ```note``` in question would yield better results. Finally, we also could try several prompting strategies to see if that also gives better perfomance. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. References\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}