diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..4e1ef42 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,31 @@ +# This workflows will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Upload Python Package + +on: + release: + types: [created] + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python setup.py sdist bdist_wheel + twine upload dist/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1749ca8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,140 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +examples/ \ No newline at end of file diff --git a/huggingpics/__init__.py b/huggingpics/__init__.py new file mode 100644 index 0000000..1975744 --- /dev/null +++ b/huggingpics/__init__.py @@ -0,0 +1,4 @@ +__version__ = "0.0.1" + +from .classifier import Classifier +from .data import HuggingPicsData, make_huggingpics_imagefolder diff --git a/huggingpics/classifier.py b/huggingpics/classifier.py new file mode 100644 index 0000000..475eca3 --- /dev/null +++ b/huggingpics/classifier.py @@ -0,0 +1,27 @@ +import pytorch_lightning as pl +import torch +from torchmetrics import Accuracy + + +class Classifier(pl.LightningModule): + def __init__(self, model, lr: float = 2e-5, **kwargs): + super().__init__() + self.save_hyperparameters('lr', *list(kwargs)) + self.model = model + self.forward = self.model.forward + self.val_acc = Accuracy() + + def training_step(self, batch, batch_idx): + outputs = self(**batch) + self.log(f"train_loss", outputs.loss) + return outputs.loss + + def validation_step(self, batch, batch_idx): + outputs = self(**batch) + self.log(f"val_loss", outputs.loss) + acc = self.val_acc(outputs.logits.argmax(1), batch['labels']) + self.log(f"val_acc", acc, prog_bar=True) + return outputs.loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.lr) diff --git a/huggingpics/data.py b/huggingpics/data.py new file mode 100644 index 0000000..fe5c0d1 --- /dev/null +++ b/huggingpics/data.py @@ -0,0 +1,152 @@ +import logging +import math +import shutil +from io import BytesIO +from pathlib import Path + +import pytorch_lightning as pl +import requests +import torch +from PIL import Image, UnidentifiedImageError +from torchvision.datasets import ImageFolder +from torchvision.transforms import (CenterCrop, Compose, Normalize, + RandomHorizontalFlip, RandomResizedCrop, + Resize, ToTensor) +from transformers import ViTFeatureExtractor + +logger = logging.getLogger(__name__) + +SEARCH_URL = "https://huggingface.co/api/experimental/images/search" + + +def get_image_urls_by_term(search_term: str, count=150): + params = {"q": search_term, "license": "public", "imageType": "photo", "count": count} + response = requests.get(SEARCH_URL, params=params) + response.raise_for_status() + response_data = response.json() + image_urls = [img['thumbnailUrl'] for img in response_data['value']] + return image_urls + + +def gen_images_from_urls(urls): + num_skipped = 0 + for url in urls: + response = requests.get(url) + if not response.status_code == 200: + num_skipped += 1 + try: + img = Image.open(BytesIO(response.content)) + yield img + except UnidentifiedImageError: + num_skipped += 1 + + print(f"Retrieved {len(urls) - num_skipped} images. Skipped {num_skipped}.") + + +def urls_to_image_folder(urls, save_directory): + for i, image in enumerate(gen_images_from_urls(urls)): + image.save(save_directory / f'{i}.jpg') + + +def make_huggingpics_imagefolder(data_dir, search_terms, count=150, overwrite=False, transform=None): + + data_dir = Path(data_dir) + + if data_dir.exists(): + if overwrite: + logger.warning(f"Deleting existing HuggingPics data directory to create new one: {data_dir}") + shutil.rmtree(data_dir) + else: + logger.warning(f"Using existing HuggingPics data directory: '{data_dir}'") + return ImageFolder(str(data_dir), transform=transform) + + for search_term in search_terms: + search_term_dir = data_dir / search_term + search_term_dir.mkdir(exist_ok=True, parents=True) + urls = get_image_urls_by_term(search_term, count) + logger.info(f"Saving images of {search_term} to {str(search_term_dir)}...") + urls_to_image_folder(urls, search_term_dir) + + return ImageFolder(str(data_dir), transform=transform) + + +class HuggingPicsData(pl.LightningDataModule): + def __init__( + self, + data_dir, + search_terms, + model_name_or_path='google/vit-base-patch16-224-in21k', + count=150, + val_split_pct=0.15, + batch_size=16, + num_workers=0, + pin_memory=True, + ): + super().__init__() + self.save_hyperparameters() + ds = make_huggingpics_imagefolder(self.hparams.data_dir, self.hparams.search_terms, self.hparams.count) + + classes = ds.classes + self.num_labels = len(classes) + self.id2label = {str(i): label for i, label in enumerate(classes)} + self.label2id = {label: str(i) for i, label in enumerate(classes)} + + feature_extractor = ViTFeatureExtractor.from_pretrained(self.hparams.model_name_or_path) + normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) + self.train_transform = Compose( + [ + RandomResizedCrop(feature_extractor.size), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] + ) + self.val_transform = Compose( + [ + Resize(feature_extractor.size), + CenterCrop(feature_extractor.size), + ToTensor(), + normalize, + ] + ) + + indices = torch.randperm(len(ds)).tolist() + n_val = math.floor(len(indices) * self.hparams.val_split_pct) + self.train_ds = SubsetWithTransform(ds, indices[:-n_val], transform=self.train_transform) + self.val_ds = SubsetWithTransform(ds, indices[-n_val:], transform=self.val_transform) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.train_ds, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.val_ds, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=self.collate_fn, + ) + + def collate_fn(self, batch): + imgs = torch.stack([ex[0] for ex in batch]) + labels = torch.LongTensor([ex[1] for ex in batch]) + return {'pixel_values': imgs, 'labels': labels} + + +class SubsetWithTransform(torch.utils.data.Subset): + def __init__(self, dataset, indices, transform): + super().__init__(dataset, indices) + self.transform = transform + + def __getitem__(self, idx): + img, label = self.dataset[self.indices[idx]] + img = self.transform(img) + return img, label diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0cd0e4a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch +torchvision +pytorch-lightning +transformers \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..fdeba6e --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +from setuptools import find_packages, setup + +with open("requirements.txt", "r") as f: + requirements = f.read().splitlines() + +setup( + name='huggingpics', + packages=find_packages(exclude=['examples']), + version='0.0.1', + license='MIT', + description='🤗🖼️ HuggingPics: Fine-tune Vision Transformers for anything using images found on the web.', + author='Nathan Raw', + author_email='naterawdata@gmail.com', + url='https://github.com/nateraw/huggingpics', + install_requires=requirements, +) \ No newline at end of file