-
Notifications
You must be signed in to change notification settings - Fork 160
Description
Hi,
Thank you for sharing the code from your insightful paper!
I'm attempting to train the model using the end-to-end (e2e) setup, and I've encountered an issue related to embeddings. As I understand, you're utilizing the TextDataset_NoCache class for the dataset, which comprises the model's embedding.
Diffusion-LM/improved-diffusion/improved_diffusion/text_datasets.py
Lines 815 to 828 in 759889d
| class TextDataset_NoCache(Dataset): | |
| def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet', | |
| classes=None, shard=0, num_shards=1, eigen_transform=None, | |
| mapping_func=None, model_emb=None): | |
| super().__init__() | |
| self.resolution = resolution | |
| self.text_datasets = text_datasets | |
| self.length = len(self.text_datasets['train']) | |
| self.model_arch = model_arch | |
| self.data_args = data_args | |
| print(self.resolution) | |
| self.eigen_transform = eigen_transform | |
| self.mapping_func = mapping_func | |
| self.model_emb = model_emb |
In the training script, you're passing model=None to the load_data_text function.
Diffusion-LM/improved-diffusion/scripts/train.py
Lines 81 to 105 in 759889d
| if args.experiment == 'random1': | |
| args.experiment = 'random' | |
| print('loading from the vocabs here.') | |
| assert args.in_channel == 64 | |
| assert args.modality == 'roc' | |
| model22 = torch.nn.Embedding(args.vocab_size, args.in_channel) | |
| model22_weight = torch.load('predictability/diffusion_models_v7/diff_roc-aug_pad_rand64_' | |
| 'transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e/' | |
| 'ema_0.9999_200000.pt', map_location='cpu')['word_embedding.weight'] | |
| model22.weight = model22_weight | |
| model22.weight.requires_grad=False | |
| else: | |
| model22 = None | |
| data = load_data_text( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| image_size=args.image_size, | |
| class_cond=args.class_cond, | |
| data_args = args, | |
| task_mode=args.modality, | |
| padding_mode=args.padding_mode, #block, pad | |
| load_vocab=rev_tokenizer, | |
| model=model22, | |
| ) |
I assume that the embeddings are initialized at:
However, in the e2e setup, it seems logical that one would want to use the continuously updated embeddings from the model. Looking through the training loop, I couldn't find any indication that the embeddings are updated from the model after each gradient step. Could you please shed light on how it's feasible to train embeddings end-to-end when the embeddings are housed within the dataset class?
Thank you for your time and clarification!