Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions circuit_tracer/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from circuit_tracer.graph import Graph
from circuit_tracer.replacement_model import ReplacementModel
from circuit_tracer.utils.disk_offload import offload_modules
from circuit_tracer.utils.device import get_default_device


class AttributionContext:
Expand Down Expand Up @@ -303,7 +304,16 @@ def select_scaled_decoder_vecs(
for layer, row in enumerate(activations):
_, feat_idx = row.coalesce().indices()
rows.append(transcoders[layer].W_dec[feat_idx])
return torch.cat(rows) * activations.values()[:, None]
activation_vals = activations.values()[:, None]
decoder_rows = torch.cat(rows)

# We might have moved sparse tensors to CPU (if MPS, avoiding SparseMPS
# backend issues), but activation_vals is no longer sparse so it should
# "move" to match the device of decoder_rows.
# If the device was not MPS, devices already match and this is a no-op.
# If the device was MPS, this "move" is in unified memory anyways.
activation_vals = activation_vals.to(decoder_rows.device)
return decoder_rows * activation_vals


@torch.no_grad()
Expand All @@ -320,7 +330,7 @@ def select_encoder_rows(


def compute_partial_influences(edge_matrix, logit_p, row_to_node_index, max_iter=128, device=None):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = device or get_default_device()

normalized_matrix = torch.empty_like(edge_matrix, device=device).copy_(edge_matrix)
normalized_matrix = normalized_matrix.abs_()
Expand Down
12 changes: 10 additions & 2 deletions circuit_tracer/replacement_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformer_lens.hook_points import HookPoint

from circuit_tracer.transcoder import SingleLayerTranscoder, load_transcoder_set
from circuit_tracer.utils.device import get_default_device


class ReplacementMLP(nn.Module):
Expand Down Expand Up @@ -126,7 +127,7 @@ def from_pretrained(
cls,
model_name: str,
transcoder_set: str,
device: Optional[torch.device] = torch.device("cuda"),
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = torch.float32,
**kwargs,
) -> "ReplacementModel":
Expand All @@ -138,11 +139,14 @@ def from_pretrained(
transcoder_set (str): Either a predefined transcoder set name, or a config file
defining where to load them from
device (torch.device, Optional): the device onto which to load the transcoders
and HookedTransformer.
and HookedTransformer. Defaults to None (auto-detect).

Returns:
ReplacementModel: The loaded ReplacementModel
"""
if device is None:
device = get_default_device()

transcoders, feature_input_hook, feature_output_hook, scan = load_transcoder_set(
transcoder_set, device=device, dtype=dtype
)
Expand Down Expand Up @@ -288,6 +292,10 @@ def cache_activations(acts, hook, layer, zero_bos):
if zero_bos:
transcoder_acts[0] = 0
if sparse:
# MPS backend does not currently support sparse tensors,
# so force a conversion to CPU if mps is in use
if transcoder_acts.device.type == "mps":
transcoder_acts = transcoder_acts.cpu()
activation_matrix[layer] = transcoder_acts.to_sparse()
else:
activation_matrix[layer] = transcoder_acts
Expand Down
17 changes: 13 additions & 4 deletions circuit_tracer/transcoder/single_layer_transcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import circuit_tracer
from circuit_tracer.transcoder.activation_functions import JumpReLU
from circuit_tracer.utils.hf_utils import download_hf_uris, parse_hf_uri
from circuit_tracer.utils import get_default_device


class SingleLayerTranscoder(nn.Module):
Expand Down Expand Up @@ -102,10 +103,13 @@ def forward(self, input_acts):
def load_gemma_scope_transcoder(
path: str,
layer: int,
device: Optional[torch.device] = torch.device("cuda"),
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = torch.float32,
revision: Optional[str] = None,
) -> SingleLayerTranscoder:
if device is None:
device = get_default_device()

if os.path.isfile(path):
path_to_params = path
else:
Expand Down Expand Up @@ -138,9 +142,12 @@ def load_gemma_scope_transcoder(
def load_relu_transcoder(
path: str,
layer: int,
device: torch.device = torch.device("cuda"),
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = torch.float32,
):
if device is None:
device = get_default_device()

param_dict = load_file(path, device=device.type)
W_enc = param_dict["W_enc"]
d_sae, d_model = W_enc.shape
Expand Down Expand Up @@ -169,19 +176,21 @@ def load_relu_transcoder(

def load_transcoder_set(
transcoder_config_file: str,
device: Optional[torch.device] = torch.device("cuda"),
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = torch.float32,
) -> TranscoderSettings:
"""Loads either a preset set of transformers, or a set specified by a file.

Args:
transcoder_config_file (str): _description_
device (Optional[torch.device], optional): _description_. Defaults to torch.device('cuda').
device (Optional[torch.device], optional): _description_. Defaults to None (auto-detect).

Returns:
TranscoderSettings: A namedtuple consisting of the transcoder dict,
and their feature input hook, feature output hook and associated scan.
"""
if device is None:
device = get_default_device()

scan = None
# try to match a preset, and grab its config
Expand Down
3 changes: 2 additions & 1 deletion circuit_tracer/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from circuit_tracer.utils.create_graph_files import create_graph_files
from circuit_tracer.utils.device import get_default_device

__all__ = ["create_graph_files"]
__all__ = ["create_graph_files", "get_default_device"]
3 changes: 2 additions & 1 deletion circuit_tracer/utils/create_graph_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from circuit_tracer.frontend.graph_models import Metadata, Model, Node, QParams
from circuit_tracer.frontend.utils import add_graph_metadata, process_token
from circuit_tracer.graph import Graph, prune_graph
from circuit_tracer.utils.device import get_default_device

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -181,7 +182,7 @@ def create_graph_files(
)
scan = graph.scan

device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_default_device()
graph.to(device)
node_mask, edge_mask, cumulative_scores = (
el.cpu() for el in prune_graph(graph, node_threshold, edge_threshold)
Expand Down
11 changes: 11 additions & 0 deletions circuit_tracer/utils/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# circuit_tracer/utils/device.py
import torch

def get_default_device() -> torch.device:
"""Smart device detection - CUDA > MPS > CPU"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
19 changes: 11 additions & 8 deletions tests/test_attributions_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from circuit_tracer.replacement_model import ReplacementModel
from circuit_tracer.transcoder import SingleLayerTranscoder
from circuit_tracer.transcoder.activation_functions import JumpReLU
from circuit_tracer.utils.device import get_default_device


def verify_token_and_error_edges(
Expand All @@ -24,9 +25,10 @@ def verify_token_and_error_edges(
logit_rtol=1e-3,
):
s = graph.input_tokens
adjacency_matrix = graph.adjacency_matrix.cuda()
active_features = graph.active_features.cuda()
logit_tokens = graph.logit_tokens.cuda()
device = get_default_device()
adjacency_matrix = graph.adjacency_matrix.to(device)
active_features = graph.active_features.to(device)
logit_tokens = graph.logit_tokens.to(device)
total_active_features = active_features.size(0)
pos_start = 1 if delete_bos else 0

Expand Down Expand Up @@ -114,9 +116,10 @@ def verify_feature_edges(
logit_rtol=1e-3,
):
s = graph.input_tokens
adjacency_matrix = graph.adjacency_matrix.cuda()
active_features = graph.active_features.cuda()
logit_tokens = graph.logit_tokens.cuda()
device = get_default_device()
adjacency_matrix = graph.adjacency_matrix.to(device)
active_features = graph.active_features.to(device)
logit_tokens = graph.logit_tokens.to(device)
total_active_features = active_features.size(0)

logits, activation_cache = model.get_activations(s, apply_activation_function=False)
Expand Down Expand Up @@ -223,7 +226,7 @@ def verify_small_gemma_model(s: torch.Tensor):
"attn_types": ["global", "local"],
"init_mode": "gpt2",
"normalization_type": "RMSPre",
"device": device(type="cuda"),
"device": get_default_device(),
"n_devices": 1,
"attention_dir": "causal",
"attn_only": False,
Expand Down Expand Up @@ -317,7 +320,7 @@ def verify_large_gemma_model(s: torch.Tensor):
],
"init_mode": "gpt2",
"normalization_type": "RMSPre",
"device": device(type="cuda"),
"device": get_default_device(),
"n_devices": 1,
"attention_dir": "causal",
"attn_only": False,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_attributions_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from circuit_tracer.replacement_model import ReplacementModel
from circuit_tracer.transcoder import SingleLayerTranscoder
from circuit_tracer.transcoder.activation_functions import TopK
from circuit_tracer.utils.device import get_default_device

sys.path.append(os.path.dirname(__file__))
from test_attributions_gemma import verify_feature_edges, verify_token_and_error_edges
Expand Down Expand Up @@ -65,7 +66,7 @@ def verify_small_llama_model(s: torch.Tensor):
"attn_types": None,
"init_mode": "gpt2",
"normalization_type": "RMSPre",
"device": device(type="cuda"),
"device": get_default_device(),
"n_devices": 1,
"attention_dir": "causal",
"attn_only": False,
Expand Down Expand Up @@ -144,7 +145,7 @@ def verify_large_llama_model(s: torch.Tensor):
"attn_types": None,
"init_mode": "gpt2",
"normalization_type": "RMSPre",
"device": device(type="cuda"),
"device": get_default_device(),
"n_devices": 1,
"attention_dir": "causal",
"attn_only": False,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformer_lens import HookedTransformerConfig

from circuit_tracer.graph import Graph, compute_edge_influence, compute_node_influence
from circuit_tracer.utils.device import get_default_device


def test_small_graph():
Expand Down Expand Up @@ -67,7 +68,7 @@ def test_small_graph():
"attn_types": ["global", "local"],
"init_mode": "gpt2",
"normalization_type": "RMSPre",
"device": device(type="cuda"),
"device": get_default_device(),
"n_devices": 1,
"attention_dir": "causal",
"attn_only": False,
Expand Down