Skip to content

generate 函数中 is_greedy_gen_mode 和 is_sample_gen_stream_mode 同时为 True #8

@AIxyz

Description

@AIxyz

pip install transformers_stream_generator==0.0.4 后调试 llama 时,发现若使用如下命令

    tokens = None
    for token in torch_model.generate(
            input_ids=input_ids,
            max_length=1024,
            num_beams=1,
            num_return_sequences=1,
            no_repeat_ngram_size=15,
            repetition_penalty=1,
            temperature=0.65,
            do_stream=True):
        if tokens is None:
            tokens = token
        else:
            tokens = torch.cat((tokens, token))  # pylint: disable=no-member
        words = tokenizer.decode(tokens)
        yield words

会使得 NewGenerationMixin.generate(……) 函数中 is_greedy_gen_mode 和 is_sample_gen_stream_mode 同时为 True,这会使得 ~/.local/lib/python3.8/site-packages/transformers_stream_generator/main.py 里直接进入 382 行的 if is_greedy_gen_mode 块中 return self.greedy_search(……),导致无法正常流式输出。

为解决该问题,将~/.local/lib/python3.8/site-packages/transformers_stream_generator/main.py 里 292 行之后各个非 stream 的 is_xxx_mode 后添加 “and generation_config.do_stream is False”,如下图所示,就可以了

image

可以在下一个版本中进行修改

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions