Skip to content

Implement .generate (greedy decoding only) #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
4 tasks done
tscholak opened this issue Apr 1, 2025 · 6 comments · May be fixed by #222
Open
4 tasks done

Implement .generate (greedy decoding only) #217

tscholak opened this issue Apr 1, 2025 · 6 comments · May be fixed by #222
Assignees
Labels
enhancement New feature or request need update

Comments

@tscholak
Copy link
Collaborator

tscholak commented Apr 1, 2025

🎯 Goal (What & Why)

Implement generate() for HuggingfaceGPTModelForCausalLM using GenerationMixin, supporting greedy decoding only.
This makes the Fast-LLM model behave like a HuggingFace model for generation. The goal is to enable validation-time text generation directly from the sharded Fast-LLM model in memory, without converting to HF format (which would require extra memory and lead to model eviction).
We use batched greedy decoding and support FlashAttention by padding and attention masking.
No beam search, sampling, or KV caching is needed.

🚀 Execution Plan

Develop a minimal, batched, greedy generation loop using Fast-LLM's .forward() and GenerationMixin integration.

Step 1: What is the smallest working version?

  • Implement the GenerationMixin interface in HuggingfaceGPTModelForCausalLM, i.e. this interface:
    class GenerationMixin:
        def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            **kwargs,
        ):
            ...
    
        @torch.no_grad()
        def generate(
            self,
            inputs: None | torch.Tensor = None,
            generation_config: None | GenerationConfig = None,
            **kwargs,
        ) -> Union[GenerateOutput, torch.LongTensor]:
            ...
  • Implement generate() to:
    • Accept inputs and max_new_tokens, eos_token_id, as well as pad_token_id via GenerationConfig or kwargs.
    • Only add support for batched greedy decoding (no sampling, no beam search).
    • Left-pad input sequences for FlashAttention.
    • Use an attention mask to exclude padding tokens.
    • Track completion using eos_token_id and max_new_tokens.
    • Reuse .forward() on the Fast-LLM model in each step.
  • Implement prepare_inputs_for_generation() minimally to satisfy HF's expectations.
  • Use the tokenizer to:
    • Tokenize prompts to input_ids.
    • Detokenize generated tokens to strings.
  • Add a slow integration test that:
    • Loads HuggingFaceTB/SmolLM2-135M-Instruct in both HF and Fast-LLM formats.
    • Runs .generate() from both and compare outputs.
    • Use fixed prompt + seed + no sampling to make it deterministic.

Step 2: What additional optimizations are possible (but optional)?

  • Add past key/value caching for faster autoregressive generation.
  • Fuse decode-time batching for speed-up (e.g., with CUDA graphs).

📌 Acceptance Criteria (Must-Haves for Completion)

  • Must implement the required methods from HF's GenerationMixin.
  • Must support batched greedy decoding with:
    • FlashAttention-compatible padding and attention masking.
    • EOS handling and padding with pad_token_id.
  • Must include:
    • Integration test comparing Fast-LLM to HuggingFace output using a shared checkpoint.
    • Benchmark test comparing generation speed (Fast-LLM vs HF w/o cache) on one GPU/shard.
  • Implementation must be documented:
    • Explain how padding and masks are used.
    • List all unsupported features and error behavior.
  • Must not refactor unrelated code.

📎 Relevant Links

🛠️ Project Management

  • Assign the project to the Fast-LLM project.
  • Set the Estimate field (in days) in the GitHub project.
  • Use the Size field to categorize the PR size (Small/Medium/Large).
  • Assign an owner when opening the issue.
@tscholak tscholak added the enhancement New feature or request label Apr 1, 2025
@tscholak tscholak changed the title Implement .generate Implement .generate (greedy decoding only) Apr 1, 2025
@jlamypoirier
Copy link
Collaborator

  • There is a very good chance that we don't need to touch generate at all, since the bulk of the work is in prepare_inputs_for_generation.
  • Using HuggingFaceTB/SmolLM2-135M-Instruct would be too slow and resource-hungry for our current test suite. A more manageable (and equally useful) alternative would be to use the saved and converted models from test_checkpoint.

@tscholak
Copy link
Collaborator Author

tscholak commented Apr 3, 2025

Just a quick update here: @bigximik already confirmed that the stock generate() works with our model when attention masks are handled properly. We're unblocked and on track. The remaining work is mostly around batch support and FlashAttention-compatible padding, as scoped.

The goal was never to reimplement generate() wholesale, just to support greedy decoding using the minimal hooks required, and that's exactly what's happening. The issue text already reflects that intent, even if the title could be read literally.

We're using HuggingFaceTB/SmolLM2-135M-Instruct because it gives consistent and meaningful outputs for prompt-response comparisons in our test-driven development. We're aware of its memory footprint and will adjust if needed.

Thanks for checking in.

@jlamypoirier
Copy link
Collaborator

To clarify, I'm not against this kind large-scale integration tests, but adding them to our current testing workflow would severely hurt our ability to run it. I run the full tests on a daily basis, including the ones marked as "slow", it's already 10-minutes long and anything more will prevent me from debugging efficiently.

So if we do want long and resource-hungry integration tests, we need to exclude them from the normal testing scheme (ex. with a separate "integration" marker) and find an alternate workflow for running them (automated CI?). That would obviously leave gaps in the "normal" testing suite, which could be filled by a small and fast version of the test (we can -- and should -- have both)

@bigximik
Copy link
Contributor

This is now being implemented together with #199 under this PR #222

@bigximik
Copy link
Contributor

How should we implement distributed generate?

  • Hugging Face (HF) generate and forward assume there is a data shard per process. If a model is further split across multiple GPUs (e.g., tensor or pipeline parallelism), there is still a single data-parallel process accessing it.

    • For example, if we have 8 GPUs and the model is split across 4 GPUs, there will be 2 data-parallel processes, each causing forward to run on one model replica across 4 GPUs.
    • For inference, those model replicas are completely independent and do not depend on the global batch (e.g., they do not process sequences of microbatches together to accumulate).
  • However, Fast-LLM is implemented differently:

    • We have one process per GPU.
    • We expect processes to be aware of their batch_data_parallel rank.
    • forward will either return a shard of output tensors via kwargs['logits'], or nothing, depending on the pipeline_parallel rank.

So how do we want to implement forward and eventually generate in this setting?

Variant (a), we discussed it last week with @tscholak

  • Treat the model as one block regardless of distribution.
  • Run something like lm_eval (and generate) only on rank 0.
  • The model on rank 0 receives the batch, splits it into microbatches, feeds them into all model replicas, and collects results back on rank 0.

Variant (b)

  • Do it per model replica.
  • lm_eval and generate will operate once per model replica.
  • Each forward distributes data among the processes corresponding to one replica and collects logits back to the local rank 0.

Another dimension to consider:

  • Whether logits are collected on rank 0 of each model replica (if varian b).
  • Or whether logits are collected across all model replicas (if variant a).
  • Or whether each process returns a full result independently (depending on approach (a) or (b) globally or per replica).

@tscholak @jlamypoirier What do you think?

Thanks!

@tscholak
Copy link
Collaborator Author

Thanks for the detailed write-up, @bigximik.
Variant (a) isn’t what we had in mind, and (b) would require patching lm_eval, which we should avoid because:

  • Version lock-in. Any local patch ties us to a specific commit; every upstream update means re-patching.
  • Maintenance drag. Patched deps break reproducible builds and add CI noise.
  • Repro friction. Contributors can't run stock lm_eval; bug reports get messy fast.

Here's a plan I would endorse:

Variant (c)

  1. Run lm_eval only on global rank 0. It owns the evaluation loop; no upstream code changes.
  2. Rank 0 broadcasts each prompt batch to all data-parallel ranks. Wrapper lives in the lm_eval model template and is called once per decoding step inside generate_until (not per token).
  3. Every rank runs Fast-LLM’s regular forward; tensor/pipeline parallelism is already baked in.
  4. Gather logits back to rank 0 (single gather, not all_gather) and perform the greedy-decoding step there.
  5. lm_eval gets the texts, updates metrics, repeats.

This keeps HF generate pristine, leaves lm_eval unmodified, and uses every GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request need update
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants