Skip to content

Commit 343f8b0

Browse files
committed
fix version problem if researcher did not install from pip
1 parent d03df6a commit 343f8b0

File tree

5 files changed

+8
-7
lines changed

5 files changed

+8
-7
lines changed

dalle_pytorch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
33

44
from pkg_resources import get_distribution
5-
__version__ = get_distribution('dalle_pytorch').version
5+
from dalle_pytorch.version import __version__

dalle_pytorch/version.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = '1.6.3'

generate.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# dalle related classes and utils
1717

18+
from dalle_pytorch import __version__
1819
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
1920
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer
2021

setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from setuptools import setup, find_packages
2+
exec(open('dalle_pytorch/version.py').read())
23

34
setup(
45
name = 'dalle-pytorch',
56
packages = find_packages(),
67
include_package_data = True,
7-
version = '1.6.2',
8+
version = __version__,
89
license='MIT',
910
description = 'DALL-E - Pytorch',
1011
author = 'Phil Wang',
1112
author_email = '[email protected]',
13+
long_description_content_type = 'text/markdown',
1214
url = 'https://github.com/lucidrains/dalle-pytorch',
1315
keywords = [
1416
'artificial intelligence',

train_dalle.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.optim.lr_scheduler import ReduceLROnPlateau
1313
from torch.utils.data import DataLoader
1414

15+
from dalle_pytorch import __version__
1516
from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE
1617
from dalle_pytorch import distributed_utils
1718
from dalle_pytorch.loader import TextImageDataset
@@ -147,10 +148,6 @@ def exists(val):
147148
def get_trainable_params(model):
148149
return [params for params in model.parameters() if params.requires_grad]
149150

150-
def get_pkg_version():
151-
from pkg_resources import get_distribution
152-
return get_distribution('dalle_pytorch').version
153-
154151
def cp_path_to_dir(cp_path, tag):
155152
"""Convert a checkpoint path to a directory with `tag` inserted.
156153
If `cp_path` is already a directory, return it unchanged.
@@ -540,7 +537,7 @@ def save_model(path, epoch=0):
540537
'hparams': dalle_params,
541538
'vae_params': vae_params,
542539
'epoch': epoch,
543-
'version': get_pkg_version(),
540+
'version': __version__,
544541
'vae_class_name': vae.__class__.__name__
545542
}
546543

0 commit comments

Comments
 (0)