Skip to content

Commit

Permalink
Merge pull request #3 from nateraw/huggingpics-package
Browse files Browse the repository at this point in the history
Add package
  • Loading branch information
nateraw authored Nov 17, 2021
2 parents b8fe068 + feb0474 commit 246a821
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 0 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Upload Python Package

on:
release:
types: [created]

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
140 changes: 140 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

examples/
4 changes: 4 additions & 0 deletions huggingpics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__version__ = "0.0.1"

from .classifier import Classifier
from .data import HuggingPicsData, make_huggingpics_imagefolder
27 changes: 27 additions & 0 deletions huggingpics/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytorch_lightning as pl
import torch
from torchmetrics import Accuracy


class Classifier(pl.LightningModule):
def __init__(self, model, lr: float = 2e-5, **kwargs):
super().__init__()
self.save_hyperparameters('lr', *list(kwargs))
self.model = model
self.forward = self.model.forward
self.val_acc = Accuracy()

def training_step(self, batch, batch_idx):
outputs = self(**batch)
self.log(f"train_loss", outputs.loss)
return outputs.loss

def validation_step(self, batch, batch_idx):
outputs = self(**batch)
self.log(f"val_loss", outputs.loss)
acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
self.log(f"val_acc", acc, prog_bar=True)
return outputs.loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
152 changes: 152 additions & 0 deletions huggingpics/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import logging
import math
import shutil
from io import BytesIO
from pathlib import Path

import pytorch_lightning as pl
import requests
import torch
from PIL import Image, UnidentifiedImageError
from torchvision.datasets import ImageFolder
from torchvision.transforms import (CenterCrop, Compose, Normalize,
RandomHorizontalFlip, RandomResizedCrop,
Resize, ToTensor)
from transformers import ViTFeatureExtractor

logger = logging.getLogger(__name__)

SEARCH_URL = "https://huggingface.co/api/experimental/images/search"


def get_image_urls_by_term(search_term: str, count=150):
params = {"q": search_term, "license": "public", "imageType": "photo", "count": count}
response = requests.get(SEARCH_URL, params=params)
response.raise_for_status()
response_data = response.json()
image_urls = [img['thumbnailUrl'] for img in response_data['value']]
return image_urls


def gen_images_from_urls(urls):
num_skipped = 0
for url in urls:
response = requests.get(url)
if not response.status_code == 200:
num_skipped += 1
try:
img = Image.open(BytesIO(response.content))
yield img
except UnidentifiedImageError:
num_skipped += 1

print(f"Retrieved {len(urls) - num_skipped} images. Skipped {num_skipped}.")


def urls_to_image_folder(urls, save_directory):
for i, image in enumerate(gen_images_from_urls(urls)):
image.save(save_directory / f'{i}.jpg')


def make_huggingpics_imagefolder(data_dir, search_terms, count=150, overwrite=False, transform=None):

data_dir = Path(data_dir)

if data_dir.exists():
if overwrite:
logger.warning(f"Deleting existing HuggingPics data directory to create new one: {data_dir}")
shutil.rmtree(data_dir)
else:
logger.warning(f"Using existing HuggingPics data directory: '{data_dir}'")
return ImageFolder(str(data_dir), transform=transform)

for search_term in search_terms:
search_term_dir = data_dir / search_term
search_term_dir.mkdir(exist_ok=True, parents=True)
urls = get_image_urls_by_term(search_term, count)
logger.info(f"Saving images of {search_term} to {str(search_term_dir)}...")
urls_to_image_folder(urls, search_term_dir)

return ImageFolder(str(data_dir), transform=transform)


class HuggingPicsData(pl.LightningDataModule):
def __init__(
self,
data_dir,
search_terms,
model_name_or_path='google/vit-base-patch16-224-in21k',
count=150,
val_split_pct=0.15,
batch_size=16,
num_workers=0,
pin_memory=True,
):
super().__init__()
self.save_hyperparameters()
ds = make_huggingpics_imagefolder(self.hparams.data_dir, self.hparams.search_terms, self.hparams.count)

classes = ds.classes
self.num_labels = len(classes)
self.id2label = {str(i): label for i, label in enumerate(classes)}
self.label2id = {label: str(i) for i, label in enumerate(classes)}

feature_extractor = ViTFeatureExtractor.from_pretrained(self.hparams.model_name_or_path)
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
self.train_transform = Compose(
[
RandomResizedCrop(feature_extractor.size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
self.val_transform = Compose(
[
Resize(feature_extractor.size),
CenterCrop(feature_extractor.size),
ToTensor(),
normalize,
]
)

indices = torch.randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * self.hparams.val_split_pct)
self.train_ds = SubsetWithTransform(ds, indices[:-n_val], transform=self.train_transform)
self.val_ds = SubsetWithTransform(ds, indices[-n_val:], transform=self.val_transform)

def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
collate_fn=self.collate_fn,
)

def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_ds,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
collate_fn=self.collate_fn,
)

def collate_fn(self, batch):
imgs = torch.stack([ex[0] for ex in batch])
labels = torch.LongTensor([ex[1] for ex in batch])
return {'pixel_values': imgs, 'labels': labels}


class SubsetWithTransform(torch.utils.data.Subset):
def __init__(self, dataset, indices, transform):
super().__init__(dataset, indices)
self.transform = transform

def __getitem__(self, idx):
img, label = self.dataset[self.indices[idx]]
img = self.transform(img)
return img, label
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch
torchvision
pytorch-lightning
transformers
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from setuptools import find_packages, setup

with open("requirements.txt", "r") as f:
requirements = f.read().splitlines()

setup(
name='huggingpics',
packages=find_packages(exclude=['examples']),
version='0.0.1',
license='MIT',
description='🤗🖼️ HuggingPics: Fine-tune Vision Transformers for anything using images found on the web.',
author='Nathan Raw',
author_email='[email protected]',
url='https://github.com/nateraw/huggingpics',
install_requires=requirements,
)

0 comments on commit 246a821

Please sign in to comment.