Skip to content

Add Remote LLM Support for Perturbation-Based Attribution via RemoteLLMAttribution and VLLMProvider #1544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions captum/attr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
LLMAttribution,
LLMAttributionResult,
LLMGradientAttribution,
RemoteLLMAttribution,
)
from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider
from captum.attr._core.lrp import LRP
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap
Expand Down Expand Up @@ -111,6 +113,9 @@
"LLMAttribution",
"LLMAttributionResult",
"LLMGradientAttribution",
"RemoteLLMAttribution",
"RemoteLLMProvider",
"VLLMProvider",
"InternalInfluence",
"InterpretableInput",
"LayerGradCam",
Expand Down
120 changes: 120 additions & 0 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TextTokenInput,
)
from torch import nn, Tensor
from captum.attr._core.remote_provider import RemoteLLMProvider

DEFAULT_GEN_ARGS: Dict[str, Any] = {
"max_new_tokens": 25,
Expand Down Expand Up @@ -892,3 +893,122 @@ def forward(

# the attribution target is limited to the log probability
return token_log_probs


class RemoteLLMAttribution(LLMAttribution):
"""
Attribution class for large language models that are hosted remotely and offer logprob APIs.
"""
def __init__(
self,
attr_method: PerturbationAttribution,
tokenizer: TokenizerLike,
provider: RemoteLLMProvider,
attr_target: str = "log_prob",
) -> None:
"""
Args:
attr_method: Instance of a supported perturbation attribution class
tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method
provider: Remote LLM provider that implements the RemoteLLMProvider protocol
attr_target: attribute towards log probability or probability.
Available values ["log_prob", "prob"]
Default: "log_prob"
"""
super().__init__(
attr_method=attr_method,
tokenizer=tokenizer,
attr_target=attr_target,
)

self.provider = provider
self.attr_method.forward_func = self._remote_forward_func

def _get_target_tokens(
self,
inp: InterpretableInput,
target: Union[str, torch.Tensor, None] = None,
skip_tokens: Union[List[int], List[str], None] = None,
gen_args: Optional[Dict[str, Any]] = None
) -> Tensor:
"""
Get the target tokens for the remote LLM provider.
"""
assert isinstance(
inp, self.SUPPORTED_INPUTS
), f"RemoteLLMAttribution does not support input type {type(inp)}"

if target is None:
# generate when None with remote provider
assert hasattr(self.provider, "generate") and callable(self.provider.generate), (
"The provider does not have generate function for generating target sequence."
"Target must be given for attribution"
)
if not gen_args:
gen_args = DEFAULT_GEN_ARGS

model_inp = self._format_model_input(inp.to_model_input())
target_str = self.provider.generate(model_inp, **gen_args)
target_tokens = self.tokenizer.encode(target_str, return_tensors="pt", add_special_tokens=False)[0]

else:
target_tokens = super()._get_target_tokens(inp, target, skip_tokens, gen_args)

return target_tokens

def _format_model_input(self, model_input: Union[str, Tensor]) -> str:
"""
Format the model input for the remote LLM provider.
"""
# return str input
if isinstance(model_input, Tensor):
return self.tokenizer.decode(model_input.flatten())
return model_input

def _remote_forward_func(
self,
perturbed_tensor: Union[None, Tensor],
inp: InterpretableInput,
target_tokens: Tensor,
use_cached_outputs: bool = False,
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
) -> Tensor:
"""
Forward function for the remote LLM provider.

Raises:
ValueError: If the number of token logprobs doesn't match expected length
"""
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))

target_str:str = self.tokenizer.decode(target_tokens)

target_token_probs = self.provider.get_logprobs(input_prompt=perturbed_input, target_str=target_str, tokenizer=self.tokenizer)

if len(target_token_probs) != target_tokens.size()[0]:
raise ValueError(
f"Number of token logprobs from provider ({len(target_token_probs)}) "
f"does not match expected target token length ({target_tokens.size()[0]})"
)

log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs))

total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0)
# 1st element is the total prob, rest are the target tokens
# add a leading dim for batch even we only support single instance for now
if self.include_per_token_attr:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0
).unsqueeze(0)
else:
target_log_probs = total_log_prob
target_probs = torch.exp(target_log_probs)

if _inspect_forward:
prompt = perturbed_input
response = self.tokenizer.decode(target_tokens)

# callback for externals to inspect (prompt, response, seq_prob)
_inspect_forward(prompt, response, target_probs[0].tolist())

return target_probs if self.attr_target != "log_prob" else target_log_probs
191 changes: 191 additions & 0 deletions captum/attr/_core/remote_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from captum._utils.typing import TokenizerLike
from openai import OpenAI
import os

class RemoteLLMProvider(ABC):
"""All remote LLM providers that offer logprob via API (like vLLM) extends this class."""

api_url: str

@abstractmethod
def generate(
self,
prompt: str,
**gen_args: Any
) -> str:
"""
Args:
prompt: The input prompt to generate from
gen_args: Additional generation arguments

Returns:
The generated text.
"""
...

@abstractmethod
def get_logprobs(
self,
input_prompt: str,
target_str: str,
tokenizer: Optional[TokenizerLike] = None
) -> List[float]:
"""
Get the log probabilities for all tokens in the target string.

Args:
input_prompt: The input prompt
target_str: The target string
tokenizer: The tokenizer to use

Returns:
A list of log probabilities corresponding to each token in the target prompt.
For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`.
"""
...

class VLLMProvider(RemoteLLMProvider):
def __init__(self, api_url: str, model_name: Optional[str] = None):
"""
Initialize a vLLM provider.

Args:
api_url: The URL of the vLLM API
model_name: The name of the model to use. If None, the first model from
the API's model list will be used.

Raises:
ValueError: If api_url is empty or model_name is not in the API's model list
ConnectionError: If API connection fails
"""
if not api_url.strip():
raise ValueError("API URL is required")

self.api_url = api_url

try:
self.client = OpenAI(base_url=self.api_url,
api_key=os.getenv("OPENAI_API_KEY", "EMPTY")
)

# If model_name is not provided, get the first available model from the API
if model_name is None:
models = self.client.models.list().data
if not models:
raise ValueError("No models available from the vLLM API")
self.model_name = models[0].id
else:
self.model_name = model_name

except ConnectionError as e:
raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error while initializing vLLM provider: {str(e)}")

def generate(self, prompt: str, **gen_args: Any) -> str:
"""
Generate text using the vLLM API.

Args:
prompt: The input prompt for text generation
**gen_args: Additional generation arguments

Returns:
str: The generated text

Raises:
KeyError: If API response is missing expected data
ConnectionError: If connection to API fails
"""
# Parameter normalization
if 'max_tokens' not in gen_args:
gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25)
if 'do_sample' in gen_args:
gen_args.pop('do_sample')

try:
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
**gen_args
)
if not hasattr(response, 'choices') or not response.choices:
raise KeyError("API response missing expected 'choices' data")

return response.choices[0].text

except ConnectionError as e:
raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error during text generation: {str(e)}")

def get_logprobs(
self,
input_prompt: str,
target_str: str,
tokenizer: Optional[TokenizerLike] = None
) -> List[float]:
"""
Get the log probabilities for all tokens in the target string.

Args:
input_prompt: The input prompt
target_str: The target string
tokenizer: The tokenizer to use

Returns:
A list of log probabilities corresponding to each token in the target prompt.
For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`.

Raises:
ValueError: If tokenizer is None or target_str is empty or response format is invalid
KeyError: If API response is missing expected data
IndexError: If response format is unexpected
ConnectionError: If connection to API fails
"""
if tokenizer is None:
raise ValueError("Tokenizer is required for vLLM provider")
if not target_str:
raise ValueError("Target string cannot be empty")

num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False))

prompt = input_prompt + target_str

try:
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
temperature=0.0,
max_tokens=1,
extra_body={"prompt_logprobs": 0}
)

if not hasattr(response, 'choices') or not response.choices:
raise KeyError("API response missing expected 'choices' data")

if not hasattr(response.choices[0], 'prompt_logprobs'):
raise KeyError("API response missing 'prompt_logprobs' data")

prompt_logprobs = []
try:
for probs in response.choices[0].prompt_logprobs[1:]:
if not probs:
raise ValueError("Empty probability data in API response")
prompt_logprobs.append(list(probs.values())[0]['logprob'])
except (IndexError, KeyError) as e:
raise IndexError(f"Unexpected format in log probability data: {str(e)}")

if len(prompt_logprobs) < num_target_str_tokens:
raise ValueError(f"Not enough logprobs received: expected {num_target_str_tokens}, got {len(prompt_logprobs)}")

return prompt_logprobs[-num_target_str_tokens:]

except ConnectionError as e:
raise ConnectionError(f"Failed to connect to vLLM API when getting logprobs: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error while getting log probabilities: {str(e)}")


4 changes: 4 additions & 0 deletions setup.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ def report(*args):

TEST_REQUIRES = ["pytest", "pytest-cov", "parameterized", "flask", "flask-compress"]

REMOTE_REQUIRES = ["openai"]

DEV_REQUIRES = (
INSIGHTS_REQUIRES
+ TEST_REQUIRES
+ REMOTE_REQUIRES
+ [
"black",
"flake8",
Expand Down Expand Up @@ -169,6 +172,7 @@ def get_package_files(root, subdirs):
"insights": INSIGHTS_REQUIRES,
"test": TEST_REQUIRES,
"tutorials": TUTORIALS_REQUIRES,
"remote": REMOTE_REQUIRES,
},
package_data={"captum": package_files},
data_files=[
Expand Down
Loading