diff --git a/circuit_tracer/attribution.py b/circuit_tracer/attribution.py index d6de3df..2eacfad 100644 --- a/circuit_tracer/attribution.py +++ b/circuit_tracer/attribution.py @@ -36,6 +36,7 @@ from circuit_tracer.graph import Graph from circuit_tracer.replacement_model import ReplacementModel from circuit_tracer.utils.disk_offload import offload_modules +from circuit_tracer.utils.device import get_default_device class AttributionContext: @@ -320,7 +321,7 @@ def select_encoder_rows( def compute_partial_influences(edge_matrix, logit_p, row_to_node_index, max_iter=128, device=None): - device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = device or get_default_device() normalized_matrix = torch.empty_like(edge_matrix, device=device).copy_(edge_matrix) normalized_matrix = normalized_matrix.abs_() diff --git a/circuit_tracer/replacement_model.py b/circuit_tracer/replacement_model.py index 1dfe31e..cb4799e 100644 --- a/circuit_tracer/replacement_model.py +++ b/circuit_tracer/replacement_model.py @@ -9,6 +9,7 @@ from transformer_lens.hook_points import HookPoint from circuit_tracer.transcoder import SingleLayerTranscoder, load_transcoder_set +from circuit_tracer.utils.device import get_default_device class ReplacementMLP(nn.Module): @@ -126,7 +127,7 @@ def from_pretrained( cls, model_name: str, transcoder_set: str, - device: Optional[torch.device] = torch.device("cuda"), + device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = torch.float32, **kwargs, ) -> "ReplacementModel": @@ -138,11 +139,14 @@ def from_pretrained( transcoder_set (str): Either a predefined transcoder set name, or a config file defining where to load them from device (torch.device, Optional): the device onto which to load the transcoders - and HookedTransformer. + and HookedTransformer. Defaults to None (auto-detect). Returns: ReplacementModel: The loaded ReplacementModel """ + if device is None: + device = get_default_device() + transcoders, feature_input_hook, feature_output_hook, scan = load_transcoder_set( transcoder_set, device=device, dtype=dtype ) diff --git a/circuit_tracer/transcoder/single_layer_transcoder.py b/circuit_tracer/transcoder/single_layer_transcoder.py index 3b20eb0..004f4d5 100644 --- a/circuit_tracer/transcoder/single_layer_transcoder.py +++ b/circuit_tracer/transcoder/single_layer_transcoder.py @@ -14,6 +14,7 @@ import circuit_tracer from circuit_tracer.transcoder.activation_functions import JumpReLU from circuit_tracer.utils.hf_utils import download_hf_uris, parse_hf_uri +from circuit_tracer.utils import get_default_device class SingleLayerTranscoder(nn.Module): @@ -102,10 +103,13 @@ def forward(self, input_acts): def load_gemma_scope_transcoder( path: str, layer: int, - device: Optional[torch.device] = torch.device("cuda"), + device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = torch.float32, revision: Optional[str] = None, ) -> SingleLayerTranscoder: + if device is None: + device = get_default_device() + if os.path.isfile(path): path_to_params = path else: @@ -138,9 +142,12 @@ def load_gemma_scope_transcoder( def load_relu_transcoder( path: str, layer: int, - device: torch.device = torch.device("cuda"), + device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = torch.float32, ): + if device is None: + device = get_default_device() + param_dict = load_file(path, device=device.type) W_enc = param_dict["W_enc"] d_sae, d_model = W_enc.shape @@ -169,19 +176,21 @@ def load_relu_transcoder( def load_transcoder_set( transcoder_config_file: str, - device: Optional[torch.device] = torch.device("cuda"), + device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = torch.float32, ) -> TranscoderSettings: """Loads either a preset set of transformers, or a set specified by a file. Args: transcoder_config_file (str): _description_ - device (Optional[torch.device], optional): _description_. Defaults to torch.device('cuda'). + device (Optional[torch.device], optional): _description_. Defaults to None (auto-detect). Returns: TranscoderSettings: A namedtuple consisting of the transcoder dict, and their feature input hook, feature output hook and associated scan. """ + if device is None: + device = get_default_device() scan = None # try to match a preset, and grab its config diff --git a/circuit_tracer/utils/__init__.py b/circuit_tracer/utils/__init__.py index 333ff36..20c75c3 100644 --- a/circuit_tracer/utils/__init__.py +++ b/circuit_tracer/utils/__init__.py @@ -1,3 +1,4 @@ from circuit_tracer.utils.create_graph_files import create_graph_files +from circuit_tracer.utils.device import get_default_device -__all__ = ["create_graph_files"] +__all__ = ["create_graph_files", "get_default_device"] diff --git a/circuit_tracer/utils/create_graph_files.py b/circuit_tracer/utils/create_graph_files.py index ca9c0e5..259bddc 100644 --- a/circuit_tracer/utils/create_graph_files.py +++ b/circuit_tracer/utils/create_graph_files.py @@ -9,6 +9,7 @@ from circuit_tracer.frontend.graph_models import Metadata, Model, Node, QParams from circuit_tracer.frontend.utils import add_graph_metadata, process_token from circuit_tracer.graph import Graph, prune_graph +from circuit_tracer.utils.device import get_default_device logger = logging.getLogger(__name__) @@ -181,7 +182,7 @@ def create_graph_files( ) scan = graph.scan - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_default_device() graph.to(device) node_mask, edge_mask, cumulative_scores = ( el.cpu() for el in prune_graph(graph, node_threshold, edge_threshold) diff --git a/circuit_tracer/utils/device.py b/circuit_tracer/utils/device.py new file mode 100644 index 0000000..f862bb5 --- /dev/null +++ b/circuit_tracer/utils/device.py @@ -0,0 +1,13 @@ +# circuit_tracer/utils/device.py +import torch + +def get_default_device() -> torch.device: + """Smart device detection - CUDA > MPS > CPU""" + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + # MPS is not supported for sparse tensors + #return torch.device("mps") + return torch.device("cpu") + else: + return torch.device("cpu") \ No newline at end of file diff --git a/tests/test_attributions_gemma.py b/tests/test_attributions_gemma.py index 3573d40..811a41d 100644 --- a/tests/test_attributions_gemma.py +++ b/tests/test_attributions_gemma.py @@ -12,6 +12,7 @@ from circuit_tracer.replacement_model import ReplacementModel from circuit_tracer.transcoder import SingleLayerTranscoder from circuit_tracer.transcoder.activation_functions import JumpReLU +from circuit_tracer.utils.device import get_default_device def verify_token_and_error_edges( @@ -24,9 +25,10 @@ def verify_token_and_error_edges( logit_rtol=1e-3, ): s = graph.input_tokens - adjacency_matrix = graph.adjacency_matrix.cuda() - active_features = graph.active_features.cuda() - logit_tokens = graph.logit_tokens.cuda() + device = get_default_device() + adjacency_matrix = graph.adjacency_matrix.to(device) + active_features = graph.active_features.to(device) + logit_tokens = graph.logit_tokens.to(device) total_active_features = active_features.size(0) pos_start = 1 if delete_bos else 0 @@ -114,9 +116,10 @@ def verify_feature_edges( logit_rtol=1e-3, ): s = graph.input_tokens - adjacency_matrix = graph.adjacency_matrix.cuda() - active_features = graph.active_features.cuda() - logit_tokens = graph.logit_tokens.cuda() + device = get_default_device() + adjacency_matrix = graph.adjacency_matrix.to(device) + active_features = graph.active_features.to(device) + logit_tokens = graph.logit_tokens.to(device) total_active_features = active_features.size(0) logits, activation_cache = model.get_activations(s, apply_activation_function=False) @@ -223,7 +226,7 @@ def verify_small_gemma_model(s: torch.Tensor): "attn_types": ["global", "local"], "init_mode": "gpt2", "normalization_type": "RMSPre", - "device": device(type="cuda"), + "device": get_default_device(), "n_devices": 1, "attention_dir": "causal", "attn_only": False, @@ -317,7 +320,7 @@ def verify_large_gemma_model(s: torch.Tensor): ], "init_mode": "gpt2", "normalization_type": "RMSPre", - "device": device(type="cuda"), + "device": get_default_device(), "n_devices": 1, "attention_dir": "causal", "attn_only": False, diff --git a/tests/test_attributions_llama.py b/tests/test_attributions_llama.py index bb46bcd..7d2b26c 100644 --- a/tests/test_attributions_llama.py +++ b/tests/test_attributions_llama.py @@ -11,6 +11,7 @@ from circuit_tracer.replacement_model import ReplacementModel from circuit_tracer.transcoder import SingleLayerTranscoder from circuit_tracer.transcoder.activation_functions import TopK +from circuit_tracer.utils.device import get_default_device sys.path.append(os.path.dirname(__file__)) from test_attributions_gemma import verify_feature_edges, verify_token_and_error_edges @@ -65,7 +66,7 @@ def verify_small_llama_model(s: torch.Tensor): "attn_types": None, "init_mode": "gpt2", "normalization_type": "RMSPre", - "device": device(type="cuda"), + "device": get_default_device(), "n_devices": 1, "attention_dir": "causal", "attn_only": False, @@ -144,7 +145,7 @@ def verify_large_llama_model(s: torch.Tensor): "attn_types": None, "init_mode": "gpt2", "normalization_type": "RMSPre", - "device": device(type="cuda"), + "device": get_default_device(), "n_devices": 1, "attention_dir": "causal", "attn_only": False, diff --git a/tests/test_graph.py b/tests/test_graph.py index 58f5e13..5e5b810 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -4,6 +4,7 @@ from transformer_lens import HookedTransformerConfig from circuit_tracer.graph import Graph, compute_edge_influence, compute_node_influence +from circuit_tracer.utils.device import get_default_device def test_small_graph(): @@ -67,7 +68,7 @@ def test_small_graph(): "attn_types": ["global", "local"], "init_mode": "gpt2", "normalization_type": "RMSPre", - "device": device(type="cuda"), + "device": get_default_device(), "n_devices": 1, "attention_dir": "causal", "attn_only": False,