Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Unable to Use quantization_setting for Customizing MoQ in DeepSpeed Inference #6853

Open
cyx96 opened this issue Dec 11, 2024 · 3 comments
Assignees
Labels
bug Something isn't working compression

Comments

@cyx96
Copy link

cyx96 commented Dec 11, 2024

Describe the bug
Unable to customize MoQ using quantization_setting with DeepSpeed inference.

To Reproduce
Follow the example from the DeepSpeed inference tutorial on datatypes and quantized models.

Below is the full script to reproduce the issue:

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import deepspeed

# Load T5 model and tokenizer
model_name = "t5-small"  # You can change this to other T5 models like 't5-base' or 't5-large'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Define quantization settings
quantize_groups = 8  # Example setting; adjust as needed
mlp_extra_grouping = True  # Example setting; adjust as needed

# Initialize DeepSpeed inference with quantization
model = deepspeed.init_inference(
    model=model,
    mp_size=1,  # Model parallel size (1 if no model parallelism is used)
    quantization_setting=(quantize_groups, mlp_extra_grouping)
)

# Tokenize input text
input_text = "Translate English to French: Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt")

# Perform inference
outputs = model.generate(**inputs)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the result
print("Input:", input_text)
print("Output:", output_text)

Expected behavior
The script should take the input in English and produce the French translation using the T5 model. However, an error is raised:

pydantic_core._pydantic_core.ValidationError: 1 validation error for DeepSpeedInferenceConfig
quantization_setting
  Extra inputs are not permitted [type=extra_forbidden, input_value=(8, True), input_type=tuple]
    For further information visit https://errors.pydantic.dev/2.10/v/extra_forbidden

ds_report output

[2024-12-11 10:01:03,448] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  FP Quantizer is using an untested triton version (3.1.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
/opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlvsym'
/opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlopen'
/opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlclose'
/opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlerror'
/opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlsym'
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
 [WARNING]  using untested triton version (3.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.10/site-packages/torch']
torch version .................... 2.5.1+cu124
deepspeed install path ........... ['/opt/conda/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.16.1, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.4
deepspeed wheel compiled w. ...... torch 2.5, cuda 12.4
shared memory (/dev/shm) size .... 188.94 GB

Screenshots
I will provide the full terminal output running my provided script on my machine:

python3 test_moq.py
[2024-12-11 09:56:28,176] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
[2024-12-11 09:56:30,285] [INFO] [logging.py:128:log_dist] [Rank -1] DeepSpeed info: version=0.16.1, git-hash=unknown, git-branch=unknown
Traceback (most recent call last):
  File "/home/chenyuxu/platform/ml/hhemv2/experiments/precision_test/ds_quant_test/test_moq.py", line 22, in <module>
    model = deepspeed.init_inference(
  File "/opt/conda/lib/python3.10/site-packages/deepspeed/__init__.py", line 362, in init_inference
    ds_inference_config = DeepSpeedInferenceConfig(**config_dict)
  File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/config_utils.py", line 57, in __init__
    super().__init__(**data)
  File "/opt/conda/lib/python3.10/site-packages/pydantic/main.py", line 214, in __init__
    validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)
pydantic_core._pydantic_core.ValidationError: 1 validation error for DeepSpeedInferenceConfig
quantization_setting
  Extra inputs are not permitted [type=extra_forbidden, input_value=(8, True), input_type=tuple]
    For further information visit https://errors.pydantic.dev/2.10/v/extra_forbidden

System info (please complete the following information):

  • OS: Debian GNU/Linux 11 (bullseye)
  • GPU count and types: 8x L4
  • Interconnects (if applicable): Just one machine
  • Python version: 3.10.15
  • Any other relevant info about your setup: Nothing else for now
@cyx96 cyx96 added bug Something isn't working compression labels Dec 11, 2024
@jomayeri jomayeri self-assigned this Dec 12, 2024
@tjruwase
Copy link
Contributor

@sfc-gh-reyazda, any thoughts on this? Thanks

@rlanday
Copy link

rlanday commented Dec 19, 2024

This code right here (and the method it calls) is the only place quantization_setting is referenced in code, right?

quantization_setting = None
self._init_quantization_setting(
quantization_setting) # todo: update with the new quant config for weight quant

It appears that this parameter is not currently implemented (as of b5d18a6)

@sfc-gh-reyazda
Copy link
Contributor

Hi @rlanday @cyx96

Thanks for mentioning this issue. This part has been modified as part of revision of the inference system. Let me take a look and get back to you on this.
Thanks.
Reza

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working compression
Projects
None yet
Development

No branches or pull requests

5 participants