Skip to content

Add Remote LLM Support for Perturbation-Based Attribution via RemoteLLMAttribution and VLLMProvider #1544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions captum/attr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
LLMAttribution,
LLMAttributionResult,
LLMGradientAttribution,
RemoteLLMAttribution,
)
from captum.attr._core.lrp import LRP
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
Expand All @@ -43,6 +44,7 @@
)
from captum.attr._core.noise_tunnel import NoiseTunnel
from captum.attr._core.occlusion import Occlusion
from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider
from captum.attr._core.saliency import Saliency
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._models.base import (
Expand Down Expand Up @@ -111,6 +113,9 @@
"LLMAttribution",
"LLMAttributionResult",
"LLMGradientAttribution",
"RemoteLLMAttribution",
"RemoteLLMProvider",
"VLLMProvider",
"InternalInfluence",
"InterpretableInput",
"LayerGradCam",
Expand Down
156 changes: 156 additions & 0 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.remote_provider import RemoteLLMProvider
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import (
Attribution,
Expand Down Expand Up @@ -892,3 +893,158 @@ def forward(

# the attribution target is limited to the log probability
return token_log_probs


class _PlaceholderModel:
"""
Simple placeholder model that can be used with
RemoteLLMAttribution without needing a real model.
This can be acheived by `lambda *_:0` but BaseLLMAttribution expects
`device`, so creating this class to set the device.
"""

def __init__(self) -> None:
self.device: Union[torch.device, str] = torch.device("cpu")

def __call__(self, *args: Any, **kwargs: Any) -> int:
return 0


class RemoteLLMAttribution(LLMAttribution):
"""
Attribution class for large language models
that are hosted remotely and offer logprob APIs.
"""

placeholder_model = _PlaceholderModel()

def __init__(
self,
attr_method: PerturbationAttribution,
tokenizer: TokenizerLike,
provider: RemoteLLMProvider,
attr_target: str = "log_prob",
) -> None:
"""
Args:
attr_method: Instance of a supported perturbation attribution class
tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method
provider: Remote LLM provider that implements the RemoteLLMProvider protocol
attr_target: attribute towards log probability or probability.
Available values ["log_prob", "prob"]
Default: "log_prob"
"""
super().__init__(
attr_method=attr_method,
tokenizer=tokenizer,
attr_target=attr_target,
)

self.provider = provider
self.attr_method.forward_func = self._remote_forward_func

def _get_target_tokens(
self,
inp: InterpretableInput,
target: Union[str, torch.Tensor, None] = None,
skip_tokens: Union[List[int], List[str], None] = None,
gen_args: Optional[Dict[str, Any]] = None,
) -> Tensor:
"""
Get the target tokens for the remote LLM provider.
"""
assert isinstance(
inp, self.SUPPORTED_INPUTS
), f"RemoteLLMAttribution does not support input type {type(inp)}"

if target is None:
# generate when None with remote provider
assert hasattr(self.provider, "generate") and callable(
self.provider.generate
), (
"The provider does not have generate function"
" for generating target sequence."
"Target must be given for attribution"
)
if not gen_args:
gen_args = DEFAULT_GEN_ARGS

model_inp = self._format_remote_model_input(inp.to_model_input())
target_str = self.provider.generate(model_inp, **gen_args)
target_tokens = self.tokenizer.encode(
target_str, return_tensors="pt", add_special_tokens=False
)[0]

else:
target_tokens = super()._get_target_tokens(
inp, target, skip_tokens, gen_args
)

return target_tokens

def _format_remote_model_input(self, model_input: Union[str, Tensor]) -> str:
"""
Format the model input for the remote LLM provider.
Convert tokenized tensor to str
to make RemoteLLMAttribution work with model inputs of both
raw text and text token tensors
"""
# return str input
if isinstance(model_input, Tensor):
return self.tokenizer.decode(model_input.flatten())
return model_input

def _remote_forward_func(
self,
perturbed_tensor: Union[None, Tensor],
inp: InterpretableInput,
target_tokens: Tensor,
use_cached_outputs: bool = False,
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
) -> Tensor:
"""
Forward function for the remote LLM provider.

Raises:
ValueError: If the number of token logprobs doesn't match expected length
"""
perturbed_input = self._format_remote_model_input(
inp.to_model_input(perturbed_tensor)
)

target_str: str = self.tokenizer.decode(target_tokens)

target_token_probs = self.provider.get_logprobs(
input_prompt=perturbed_input,
target_str=target_str,
tokenizer=self.tokenizer,
)

if len(target_token_probs) != target_tokens.size()[0]:
raise ValueError(
f"Number of token logprobs from provider ({len(target_token_probs)}) "
f"does not match expected target "
f"token length ({target_tokens.size()[0]})"
)

log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs))

total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0)
# 1st element is the total prob, rest are the target tokens
# add a leading dim for batch even we only support single instance for now
if self.include_per_token_attr:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0
).unsqueeze(0)
else:
target_log_probs = total_log_prob
target_probs = torch.exp(target_log_probs)

if _inspect_forward:
prompt = perturbed_input
response = self.tokenizer.decode(target_tokens)

# callback for externals to inspect (prompt, response, seq_prob)
_inspect_forward(prompt, response, target_probs[0].tolist())

return target_probs if self.attr_target != "log_prob" else target_log_probs
Loading