Skip to content

Commit 71a3637

Browse files
authored
Gate debug messages in dump_gguf.py behind -v flag (#1105)
I need this because I've been using `dump_gguf.py` for some integration tests (see https://github.com/nod-ai/shark-ai/blob/41a09a836c4558745ec1ca01e80a819cd761e608/app_tests/integration_tests/llm/model_management.py#L223-L239)
1 parent 4ce3326 commit 71a3637

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

sharktank/sharktank/tools/dump_gguf.py

+44-24
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pathlib import Path
88
import re
9+
import logging
910

1011
import numpy as np
1112
import torch
@@ -17,6 +18,9 @@
1718
def main():
1819
from ..utils import cli
1920

21+
# Set up logging
22+
logger = logging.getLogger(__name__)
23+
2024
parser = cli.create_parser()
2125
cli.add_input_dataset_options(parser)
2226
parser.add_argument(
@@ -28,63 +32,79 @@ def main():
2832
parser.add_argument(
2933
"--save", type=Path, help="Save the GGUF dataset to an IRPA file"
3034
)
35+
parser.add_argument(
36+
"--verbose", "-v", action="store_true", help="Enable verbose output"
37+
)
3138
args = cli.parse(parser)
39+
40+
# Configure logging based on verbosity
41+
if args.verbose:
42+
logging.basicConfig(level=logging.DEBUG)
43+
else:
44+
logging.basicConfig(level=logging.INFO)
45+
3246
config = cli.get_input_dataset(args)
3347

3448
if args.save is not None:
3549

3650
def report(s):
37-
print(f"Save: {s}")
51+
logger.info(f"Save: {s}")
3852

39-
print(f"Saving to: {args.save}")
53+
logger.info(f"Saving to: {args.save}")
4054
config.save(args.save, io_report_callback=report)
4155
return
4256

43-
print(f"Properties:")
57+
logger.debug("Properties:")
4458
for key, value in config.properties.items():
45-
print(f" {key} = {value} (of {type(value)})")
46-
print("Tensors:")
59+
logger.debug(f" {key} = {value} (of {type(value)})")
60+
61+
logger.debug("Tensors:")
4762
for tensor in config.root_theta.flatten().values():
4863
if args.tensor_regex is not None:
4964
if not re.search(args.tensor_regex, tensor.name):
5065
continue
51-
print(f" {tensor}")
66+
67+
logger.debug(f" {tensor}")
5268
if isinstance(tensor, PrimitiveTensor):
5369
torch_tensor = tensor.as_torch()
54-
print(
70+
logger.debug(
5571
f" : torch.Tensor({list(torch_tensor.shape)}, "
5672
f"dtype={torch_tensor.dtype}) = {tensor.as_torch()}"
5773
)
5874
elif isinstance(tensor, QuantizedTensor):
59-
print(f" : QuantizedTensor({tensor.layout_type.__name__})")
75+
logger.debug(f" : QuantizedTensor({tensor.layout_type.__name__})")
6076
try:
6177
unpacked = tensor.unpack()
62-
print(f" {unpacked}")
78+
logger.debug(f" {unpacked}")
6379
except NotImplementedError:
64-
print(f" NOT IMPLEMENTED")
80+
logger.warning(f" Unpacking NOT IMPLEMENTED for {tensor.name}")
6581
elif isinstance(tensor, ShardedTensor):
6682
for i, pt in enumerate(tensor.shards):
67-
print(f" {i}: {pt}")
83+
logger.debug(f" {i}: {pt}")
6884

69-
_maybe_dump_tensor(args, tensor)
85+
_maybe_dump_tensor(args, tensor, logger)
7086

7187

72-
def _maybe_dump_tensor(args, t: InferenceTensor):
88+
def _maybe_dump_tensor(args, t: InferenceTensor, logger: logging.Logger):
7389
if not args.dump_tensor_dir:
7490
return
7591
dir: Path = args.dump_tensor_dir
7692
dir.mkdir(parents=True, exist_ok=True)
77-
print(f" (Dumping to {dir})")
78-
79-
if isinstance(t, PrimitiveTensor):
80-
torch_tensor = t.as_torch()
81-
np.save(dir / f"{t.name}.npy", torch_tensor.detach().numpy())
82-
elif isinstance(t, QuantizedTensor):
83-
layout: QuantizedLayout = t.unpack()
84-
dq = layout.dequant()
85-
np.save(dir / f"{t.name}.dequant.npy", dq.detach().numpy())
86-
else:
87-
raise AssertionError(f"Unexpected tensor type: {type(t)}")
93+
logger.info(f"Dumping tensor {t.name} to {dir}")
94+
95+
try:
96+
if isinstance(t, PrimitiveTensor):
97+
torch_tensor = t.as_torch()
98+
np.save(dir / f"{t.name}.npy", torch_tensor.detach().numpy())
99+
elif isinstance(t, QuantizedTensor):
100+
layout: QuantizedLayout = t.unpack()
101+
dq = layout.dequant()
102+
np.save(dir / f"{t.name}.dequant.npy", dq.detach().numpy())
103+
else:
104+
logger.error(f"Unexpected tensor type: {type(t)}")
105+
raise AssertionError(f"Unexpected tensor type: {type(t)}")
106+
except Exception as e:
107+
logger.error(f"Failed to dump tensor {t.name}: {str(e)}")
88108

89109

90110
if __name__ == "__main__":

0 commit comments

Comments
 (0)