diff --git a/mkdocs.yml b/mkdocs.yml index 58d3d47a..d0c61c99 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -160,6 +160,9 @@ plugins: python: options: show_source: false + import: + - url: https://pytorch.org/docs/1.13/objects.inv + enable_inventory: false watch: - src/kernl @@ -174,6 +177,7 @@ nav: - Code Reference: - Model Optimization: reference/model_optimization.md - Graph Pattern Matching: reference/extended_matcher.md + - Autocast: reference/autocast.md - Contribution guide: - Contributing: contribution-guide/contributing.md - Code of conduct: contribution-guide/code-of-conduct.md diff --git a/src/kernl/autocast.py b/src/kernl/autocast.py new file mode 100644 index 00000000..aa3390e9 --- /dev/null +++ b/src/kernl/autocast.py @@ -0,0 +1,158 @@ +import torch +import functools + +from typing import Any + +from kernl.utils.cast import _cast + + +_is_autocast_enabled: bool = True +_autocast_dtype: torch.dtype = torch.float16 +_is_torch_autocast_dtype_used: bool = True + + +def enable_autocast(is_autocast_enabled: bool): + """Change autocast state for kernl. + Args: + is_autocast_enabled: True to enable autocast, False to disable + """ + global _is_autocast_enabled + _is_autocast_enabled = is_autocast_enabled + + +def is_autocast_enabled() -> bool: + """Check is autocast enabled. + Returns: + wether autocast is enabled + """ + return _is_autocast_enabled + + +def set_autocast_dtype(dtype: torch.dtype) -> None: + """Change autocast target dtype. + Args: + dtype: which dtype to autocast on + """ + global _autocast_dtype + _autocast_dtype = dtype + + +def get_autocast_dtype() -> torch.dtype: + """Get autocast target dtype. + Returns: + autocast target dype + """ + return _autocast_dtype + + +def use_torch_autocast_dtype(use_torch_autocast_dtype: bool) -> None: + """ Check i + Args: + use_torch_autocast_dtype: wether torch autocast dtype is used + """ + global _is_torch_autocast_dtype_used + _is_torch_autocast_dtype_used = use_torch_autocast_dtype + + +def is_torch_autocast_dtype_used() -> bool: + """ Check if torch autocast dtype is used in autocast + Returns: + wether torch autocast dtype is used + """ + return _is_torch_autocast_dtype_used + + +class autocast(object): + """Serve as context managers or decorators that allow fused kernels to run in mixed precision. + + In these regions, kernels run in an op-specific dtype chosen by autocast + to improve performance while maintaining accuracy. + + When entering an autocast-enabled region, Tensors may be any type. + You should not call `half()` or `bfloat16()` on your model(s) or inputs when using autocasting. + + `autocast` should wrap only the forward pass(es) of your network, including the loss + computation(s). Backward passes under autocast are not recommended. + Backward ops run in the same type that autocast used for corresponding forward ops. + + Heavily inspired by [torch.autocast][]. + Args: + enabled: Whether autocasting should be enabled in the region. + dtype: Whether to use torch.float16 or torch.bfloat16. + use_torch_autocast_dtype: Whether use torch autocast dtype. + """ + def __init__( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, use_torch_autocast_dtype: bool = False + ): + self._enabled = enabled + self._dtype = dtype + self._use_autocast_dtype = use_torch_autocast_dtype + + def __enter__(self): + self._prev = is_autocast_enabled() + self._prev_dtype = get_autocast_dtype() + self._prev_use_torch_autocast_dtype = is_torch_autocast_dtype_used() + enable_autocast(self._enabled) + set_autocast_dtype(self._dtype) + use_torch_autocast_dtype(self._use_autocast_dtype) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): + enable_autocast(self._prev) + set_autocast_dtype(self._prev_dtype) + use_torch_autocast_dtype(self._prev_use_torch_autocast_dtype) + + +def custom_fwd(fwd=None, **kwargs): + """ + Helper decorator for ``forward`` methods of custom autograd functions (subclasses of + :class:`torch.autograd.Function`). See the :ref:`example page` for more detail. + + Args: + cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, + when ``forward`` runs in an autocast-enabled region, casts incoming + floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), + then executes ``forward`` with autocast disabled. + If ``None``, ``forward``'s internal ops execute with the current autocast state. + + .. note:: + If the decorated ``forward`` is called outside an autocast-enabled region, + :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect. + """ + if fwd is None: + if len(kwargs) == 0: + cast_inputs = None + else: + assert len(kwargs) == 1 + cast_inputs = kwargs["cast_inputs"] + + return functools.partial(custom_fwd, cast_inputs=cast_inputs) + + if len(kwargs) == 0: + cast_inputs = None + else: + assert len(kwargs) == 1 + cast_inputs = kwargs["cast_inputs"] + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + inputs = cast_inputs + if inputs is None: + if is_autocast_enabled(): + if is_torch_autocast_dtype_used(): + if torch.is_autocast_enabled(): + inputs = torch.get_autocast_gpu_dtype() + else: + inputs = get_autocast_dtype() + if inputs is None: + args[0]._fwd_used_autocast = torch.is_autocast_enabled() + return fwd(*args, **kwargs) + else: + autocast_context = torch.is_autocast_enabled() or is_autocast_enabled() + args[0]._fwd_used_autocast = torch.is_autocast_enabled() + if autocast_context: + with torch.cuda.amp.autocast(enabled=False): + return fwd(*_cast(args, inputs), **_cast(kwargs, inputs)) + else: + return fwd(*args, **kwargs) + + return decorate_fwd diff --git a/src/kernl/implementations/attention.py b/src/kernl/implementations/attention.py index c5ed0fc6..45c3be86 100644 --- a/src/kernl/implementations/attention.py +++ b/src/kernl/implementations/attention.py @@ -19,7 +19,8 @@ import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd + +from kernl.autocast import custom_fwd # CREDITS: Initially inspired by the Triton tutorial @@ -448,7 +449,7 @@ def _fwd_kernel( class Attention(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float16) + @custom_fwd() def forward( ctx: FunctionCtx, q: torch.Tensor, diff --git a/src/kernl/implementations/attention_skinny.py b/src/kernl/implementations/attention_skinny.py index bacb61a6..1c35daf4 100644 --- a/src/kernl/implementations/attention_skinny.py +++ b/src/kernl/implementations/attention_skinny.py @@ -18,7 +18,8 @@ import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd + +from kernl.autocast import custom_fwd def attention_split_1_reference( diff --git a/src/kernl/implementations/attention_vec_mat.py b/src/kernl/implementations/attention_vec_mat.py index 3ea94879..a4862eae 100644 --- a/src/kernl/implementations/attention_vec_mat.py +++ b/src/kernl/implementations/attention_vec_mat.py @@ -4,7 +4,8 @@ import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd + +from kernl.autocast import custom_fwd @triton.autotune( @@ -152,7 +153,7 @@ def grid(args) -> Tuple[int, int, int]: class AttentionVecMat(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float16) + @custom_fwd() def forward( ctx: FunctionCtx, q: torch.Tensor, diff --git a/src/kernl/implementations/layer_norm.py b/src/kernl/implementations/layer_norm.py index 47c6086c..acf1822e 100644 --- a/src/kernl/implementations/layer_norm.py +++ b/src/kernl/implementations/layer_norm.py @@ -18,9 +18,10 @@ import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd from triton import JITFunction +from kernl.autocast import custom_fwd + # CREDITS: Initially inspired by the Triton tutorial and xformers implementation @@ -278,7 +279,7 @@ def _layer_norm_fwd_fused_multi_pass( class LayerNorm(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float16) + @custom_fwd() def forward( ctx: FunctionCtx, x: torch.Tensor, diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 68ad9144..24c4be54 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -22,9 +22,9 @@ import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time +from kernl.autocast import custom_fwd from kernl.implementations import activation_func @@ -205,7 +205,7 @@ def kernel_fma( class LinearLayer(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float16) + @custom_fwd() def forward( ctx: FunctionCtx, x: torch.Tensor, diff --git a/test/test_autocast.py b/test/test_autocast.py new file mode 100644 index 00000000..9ada5181 --- /dev/null +++ b/test/test_autocast.py @@ -0,0 +1,67 @@ +import pytest +import torch + +from kernl.autocast import autocast, custom_fwd, get_autocast_dtype, is_autocast_enabled, is_torch_autocast_dtype_used + + +def test_autocast(): + + with autocast(): + assert is_autocast_enabled() is True + + with autocast(enabled=False): + assert is_autocast_enabled() is False + assert is_autocast_enabled() is True + + with autocast(dtype=torch.float32): + assert get_autocast_dtype() is torch.float32 + assert get_autocast_dtype() is torch.float16 + + with autocast(use_torch_autocast_dtype=False): + assert is_torch_autocast_dtype_used() is False + assert is_torch_autocast_dtype_used() is True + + +@pytest.mark.parametrize( + """ + torch_autocast_enabled, kernl_autocast_enabled, torch_autocast_dtype, + kernl_autocast_dtype, use_torch_autocast_dtype, cast_inputs, expected_dtype""", + [ + # we first check if @custom_fwd behaves the same as torch.cuda.amp.custom_fwd + # when we're not using kernl autocast + (False, False, None, None, False, None, torch.float32), + (False, False, None, None, False, torch.bfloat16, torch.float32), + (True, False, torch.float16, None, False, None, torch.float32), + (True, False, torch.float16, None, False, torch.bfloat16, torch.bfloat16), + # we check @custom_fwd without torch autocast + (False, True, None, torch.float16, False, None, torch.float16), + (False, True, None, torch.float16, False, torch.bfloat16, torch.bfloat16), + (False, True, None, torch.float16, True, None, torch.float32), + (False, True, None, torch.float16, True, torch.bfloat16, torch.bfloat16), + # we check @custom_fwd within torch autocast context and kernl autocast context + (True, True, torch.float16, torch.bfloat16, False, None, torch.bfloat16), + (True, True, torch.float16, torch.float16, False, torch.bfloat16, torch.bfloat16), + (True, True, torch.float16, torch.bfloat16, False, None, torch.bfloat16), + (True, True, torch.float16, torch.bfloat16, True, None, torch.float16), + (True, True, torch.float16, torch.float16, True, torch.bfloat16, torch.bfloat16), + ], +) +def test_custom_fwd( + torch_autocast_enabled, + kernl_autocast_enabled, + torch_autocast_dtype, + kernl_autocast_dtype, + use_torch_autocast_dtype, + cast_inputs, + expected_dtype, +): + @custom_fwd(cast_inputs=cast_inputs) + def fwd(inputs: torch.Tensor): + assert inputs.dtype == expected_dtype + + with torch.cuda.amp.autocast(enabled=torch_autocast_enabled, dtype=torch_autocast_dtype), autocast( + enabled=kernl_autocast_enabled, dtype=kernl_autocast_dtype, use_torch_autocast_dtype=use_torch_autocast_dtype + ): + inputs = torch.empty((100), dtype=torch.float32, device="cuda") + fwd(inputs) + print diff --git a/test/test_layer_norm.py b/test/test_layer_norm.py index 95958ea6..06e8b53c 100644 --- a/test/test_layer_norm.py +++ b/test/test_layer_norm.py @@ -17,7 +17,6 @@ import torch from conftest import assert_all_close, set_seed - from kernl.implementations.layer_norm import ( _layer_norm_fwd_fused_multi_pass, _layer_norm_fwd_fused_single_pass, @@ -47,16 +46,20 @@ @set_seed() @pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096], ids=lambda x: f"shape={x}x{x}") @pytest.mark.parametrize("cuda_graphs", [True, False], ids=["cuda_graphs", "no_cuda_graphs"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=["fp16", "bfloat16", "fp32"]) @pytest.mark.parametrize("implementation", implementations_layer_norm.keys()) def test_benchmark_layer_norm(benchmark, shape: int, dtype, cuda_graphs: bool, implementation: str): + if dtype == torch.bfloat16 and implementation == "pytorch_naive": + pytest.skip("unsupported configuration") M = N = shape eps = 1e-5 - factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False} + factory_kwargs = {"device": "cuda", "dtype": torch.bfloat16 if dtype == torch.bfloat16 else torch.float32, "requires_grad": False} layer_weight = torch.rand((N,), **factory_kwargs) layer_bias = torch.randn_like(layer_weight) x = -20 + 0.5 * torch.randn((M, N), **factory_kwargs) expected = torch.nn.functional.layer_norm(x, layer_weight.shape, layer_weight, layer_bias, eps) + if dtype == torch.bfloat16: + expected = expected.float() # tensors casting layer_weight = layer_weight.to(dtype) @@ -82,7 +85,7 @@ def test_benchmark_layer_norm(benchmark, shape: int, dtype, cuda_graphs: bool, i @pytest.mark.parametrize("cuda_graphs", [True, False], ids=["cuda_graphs", "no_cuda_graphs"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=["fp16", "bf16", "fp32"]) @pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096], ids=lambda x: f"shape={x}x{x}") @pytest.mark.parametrize("implementation", implementations_rms_norm.keys()) def test_benchmark_rms_norm(benchmark, shape: int, dtype, cuda_graphs: bool, implementation: str): diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index ba777d2d..61fe54e2 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -22,6 +22,7 @@ from kernl.implementations.linear_layer import linear_layer from kernl.optimizer.cuda_graph import cuda_graphs_wrapper +from kernl.autocast import autocast def get_pytorch_activation(activation: str) -> Callable: @@ -54,7 +55,7 @@ def get_pytorch_activation(activation: str) -> Callable: [(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]], ids=lambda s: "x".join(map(str, s)), ) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) @pytest.mark.parametrize("cuda_graphs", [False, True], ids=["no_cuda_graphs", "cuda_graphs"]) @pytest.mark.parametrize("implementation", ["triton", "pytorch"]) def test_benchmark( @@ -70,13 +71,14 @@ def test_benchmark( batch, M, N, K = shape # order of dimensions is wrong so we force contiguous call - x = torch.randn((batch, K, M), device="cuda", dtype=torch.float32, requires_grad=False) + factory_kwargs = {"device": "cuda", "dtype": torch.bfloat16 if dtype == torch.bfloat16 else torch.float32, "requires_grad": False} + x = torch.randn((batch, K, M), **factory_kwargs) x = x.mT if contiguous: x = x.contiguous() else: assert not x.is_contiguous() - factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False} + layer_weight = torch.randn((N, K), **factory_kwargs) layer_bias = torch.randn((K,), **factory_kwargs) if bias else None pytorch_layer_activation = get_pytorch_activation(activation) @@ -96,4 +98,7 @@ def test_benchmark( value = benchmark(fn, x) - assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1) + if dtype == torch.bfloat16: + assert_all_close(expected, value, rtol=1e-1, atol=1e-1) + else: + assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1) diff --git a/test/test_model_optimization.py b/test/test_model_optimization.py index 4d8408d8..ed36d2a3 100644 --- a/test/test_model_optimization.py +++ b/test/test_model_optimization.py @@ -20,6 +20,7 @@ from conftest import set_seed from kernl.model_optimization import optimize_model +from kernl.autocast import autocast @set_seed() @@ -30,7 +31,7 @@ def test_optimized_model(): optimized_model = AutoModel.from_pretrained(model_name).eval().cuda() optimize_model(optimized_model) - with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): + with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True), autocast(use_torch_autocast_dtype=True): inputs = { "input_ids": torch.randint(2, 10000, shape, device="cuda", dtype=torch.long), "attention_mask": torch.ones(shape, device="cuda", dtype=torch.long), diff --git a/test/test_torchdynamo.py b/test/test_torchdynamo.py index 795af5a6..d8dd411c 100644 --- a/test/test_torchdynamo.py +++ b/test/test_torchdynamo.py @@ -30,9 +30,8 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor from conftest import assert_all_close, set_seed - from kernl.model_optimization import optimize_model - +from kernl.autocast import autocast @dataclasses.dataclass class Implementation: @@ -74,7 +73,7 @@ def test_benchmark_implementations(benchmark, reference_fp32, shape: (int, int), with torch.inference_mode(): expected = reference_fp32(**inputs) model = implementation.model(reference_fp32) - with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): + with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True), autocast(use_torch_autocast_dtype=True): value = benchmark(model, **inputs) assert_all_close(value["last_hidden_state"].float(), expected["last_hidden_state"].float(), rtol=1e-1, atol=1e-1) @@ -93,7 +92,7 @@ def test_support_shape_change(implementation): pytorch_input = get_input(model_reference_fp32, shape, is_causal=implementation.is_causal) with torch.inference_mode(): expected = model_reference_fp32(**pytorch_input) - with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): + with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True), autocast(use_torch_autocast_dtype=True): result = model_tested(**pytorch_input) assert_all_close( result["last_hidden_state"].float(), expected["last_hidden_state"].float(), atol=1e-1, rtol=1e-1 @@ -107,7 +106,7 @@ def test_t5(): task = "translate English to French: The house is wonderful." inputs = tokenizer(task, return_tensors="pt", padding=True).to("cuda") - with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"): + with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"), autocast(use_torch_autocast_dtype=True): output_sequences = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], @@ -120,7 +119,7 @@ def test_t5(): optimize_model(model.encoder) optimize_model(model.decoder) - with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"): + with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"), autocast(use_torch_autocast_dtype=True): output_sequences = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], @@ -155,7 +154,7 @@ def fix_reorder_cache(past, beam_idx): processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") inputs = torch.load("test/data/whisper_input.pt") - with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"): + with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"), autocast(use_torch_autocast_dtype=True): predicted_ids = benchmark( model.generate, inputs, min_length=25, max_length=25, num_beams=num_beam, do_sample=False ) diff --git a/tutorial/bert e2e.ipynb b/tutorial/bert e2e.ipynb index e15f86ca..1ff1c287 100644 --- a/tutorial/bert e2e.ipynb +++ b/tutorial/bert e2e.ipynb @@ -937,4 +937,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file diff --git a/tutorial/t5 e2e.ipynb b/tutorial/t5 e2e.ipynb index e55aa2b8..b07e6f43 100644 --- a/tutorial/t5 e2e.ipynb +++ b/tutorial/t5 e2e.ipynb @@ -16,36 +16,29 @@ "name": "stdout", "output_type": "stream", "text": [ - "Collecting tokenizer\r\n", - " Downloading tokenizer-3.4.2-py2.py3-none-any.whl (79 kB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m79.1/79.1 KB\u001B[0m \u001B[31m6.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\r\n", - "\u001B[?25hCollecting sentencepiece\r\n", - " Downloading sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.3/1.3 MB\u001B[0m \u001B[31m43.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\r\n", - "\u001B[?25hInstalling collected packages: tokenizer, sentencepiece\r\n", - "Successfully installed sentencepiece-0.1.97 tokenizer-3.4.2\r\n", - "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\r\n", - "\u001B[0m\u001B[33mWARNING: You are using pip version 22.0.4; however, version 22.3 is available.\r\n", - "You should consider upgrading via the '/usr/bin/python3.9 -m pip install --upgrade pip' command.\u001B[0m\u001B[33m\r\n", - "\u001B[0mSat Oct 29 19:53:58 2022 \r\n", - "+-----------------------------------------------------------------------------+\r\n", - "| NVIDIA-SMI 510.85.02 Driver Version: 510.85.02 CUDA Version: 11.6 |\r\n", - "|-------------------------------+----------------------+----------------------+\r\n", - "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n", - "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n", - "| | | MIG M. |\r\n", - "|===============================+======================+======================|\r\n", - "| 0 NVIDIA GeForce ... Off | 00000000:03:00.0 On | N/A |\r\n", - "| 30% 41C P8 24W / 350W | 300MiB / 24576MiB | 1% Default |\r\n", - "| | | N/A |\r\n", - "+-------------------------------+----------------------+----------------------+\r\n", - " \r\n", - "+-----------------------------------------------------------------------------+\r\n", - "| Processes: |\r\n", - "| GPU GI CI PID Type Process name GPU Memory |\r\n", - "| ID ID Usage |\r\n", - "|=============================================================================|\r\n", - "+-----------------------------------------------------------------------------+\r\n" + "Requirement already satisfied: tokenizer in /usr/local/lib/python3.9/dist-packages (3.4.2)\n", + "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.9/dist-packages (0.1.97)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mWed Mar 15 10:10:01 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 470.161.03 Driver Version: 470.161.03 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 |\n", + "| 0% 28C P0 42W / 300W | 0MiB / 22731MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" ] } ], @@ -87,67 +80,11 @@ } }, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9eddb20b6abf4b89aed7878e55e01b35", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading: 0%| | 0.00/1.20k [00:00