Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4638,6 +4638,7 @@ def test_boft_half_conv(self):

class TestPTuningReproducibility:
device = infer_device()
causal_lm_model_id = "facebook/opt-125m"

@require_non_cpu
@require_deterministic_for_xpu
Expand Down Expand Up @@ -4674,6 +4675,48 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):
torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft)

@require_bitsandbytes
@pytest.mark.single_gpu_tests
def test_p_tuning_causal_lm_training_8bit_bnb(self):
# test is analog to PeftBnbGPUExampleTests.test_causal_lm_training
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = get_peft_model(model, config)

data = load_dataset_english_quotes()
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

model.cpu().save_pretrained(tmp_dir)

assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None


@pytest.mark.single_gpu_tests
class TestLowCpuMemUsageDifferentDevices:
Expand Down Expand Up @@ -4947,6 +4990,7 @@ def test_alora_forward_consistency(self, model, model_bnb, peft_config):
@pytest.mark.multi_gpu_tests
class TestPrefixTuning:
device = infer_device()
causal_lm_model_id = "facebook/opt-125m"

@require_torch_multi_accelerator
def test_prefix_tuning_multiple_devices_decoder_model(self):
Expand Down Expand Up @@ -5007,6 +5051,48 @@ def test_prefix_tuning_multiple_devices_encoder_decoder_model(self):
model = get_peft_model(model, peft_config)
model.generate(**inputs) # does not raise

@require_bitsandbytes
@pytest.mark.single_gpu_tests
def test_prefix_tuning_causal_lm_training_8bit_bnb(self):
# test is analog to PeftBnbGPUExampleTests.test_causal_lm_training
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
config = PrefixTuningConfig(num_virtual_tokens=10, task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, config)

data = load_dataset_english_quotes()
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

model.cpu().save_pretrained(tmp_dir)

assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None


@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU or XPU")
@pytest.mark.single_gpu_tests
Expand Down
Loading