55#
66# ----------------------------------------------------------------------------
77
8- import hashlib
98import inspect
109import logging
1110import shutil
2221from QEfficient .base .pytorch_transforms import PytorchTransform
2322from QEfficient .compile .qnn_compiler import compile as qnn_compile
2423from QEfficient .generation .cloud_infer import QAICInferenceSession
25- from QEfficient .utils import constants , create_json , dump_qconfig , generate_mdp_partition_config , load_json
26- from QEfficient .utils .cache import QEFF_HOME , to_hashable
24+ from QEfficient .utils import (
25+ constants ,
26+ create_json ,
27+ create_model_params ,
28+ dump_qconfig ,
29+ export_wrapper ,
30+ generate_mdp_partition_config ,
31+ hash_dict_params ,
32+ load_json ,
33+ )
2734
2835logger = logging .getLogger (__name__ )
2936
@@ -45,12 +52,16 @@ class QEFFBaseModel(ABC):
4552 def _transform_names (cls ) -> List [str ]:
4653 return [x .__name__ for x in cls ._pytorch_transforms + cls ._onnx_transforms ]
4754
48- def __init__ (self , model : torch .nn .Module ) -> None :
55+ def __init__ (self , model : torch .nn .Module , ** kwargs ) -> None :
4956 super ().__init__ ()
5057 self .model = model
58+ self .hash_params = create_model_params (self , ** kwargs )
5159 self .onnx_path : Optional [str ] = None
5260 self .qpc_path : Optional [str ] = None
5361 self .qpc_session : Optional [QAICInferenceSession ] = None
62+ self .model_architecture = (
63+ (arch := getattr (self .model .config , "architectures" , None )) and len (arch ) > 0 and arch [0 ]
64+ ) or None
5465
5566 # Apply the transformations
5667 any_transformed = False
@@ -67,10 +78,6 @@ def __init__(self, model: torch.nn.Module) -> None:
6778 @abstractmethod
6879 def model_name (self ) -> str : ...
6980
70- @property
71- @abstractmethod
72- def model_hash (self ) -> str : ...
73-
7481 @abstractmethod
7582 def export (self , export_dir : Optional [str ] = None ) -> Path :
7683 """
@@ -114,6 +121,7 @@ def compile(self, *args, **kwargs) -> Path:
114121 :str: Path of the compiled ``qpc`` package.
115122 """
116123
124+ @export_wrapper
117125 def _export (
118126 self ,
119127 example_inputs : Dict [str , torch .Tensor ],
@@ -134,8 +142,6 @@ def _export(
134142 :onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
135143 :export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
136144 """
137- export_dir = Path (export_dir or (QEFF_HOME / self .model_name ))
138- export_dir = export_dir .with_name (export_dir .name + "-" + self .model_hash )
139145 onnx_path = export_dir / f"{ self .model_name } .onnx"
140146 if onnx_path .is_file ():
141147 self .onnx_path = onnx_path
@@ -299,23 +305,16 @@ def _compile(
299305 else :
300306 mdp_ts_json = None
301307
302- compile_hash = hashlib .sha256 (to_hashable (command ))
303-
304- if specializations is not None :
305- compile_hash .update (to_hashable (specializations ))
306-
307- if custom_io is not None :
308- compile_hash .update (to_hashable (custom_io ))
309-
310- if num_speculative_tokens :
311- compile_hash .update (to_hashable ({"num_speculative_tokens" : num_speculative_tokens }))
312-
313- # Hash the MDP partition config and the number of devices.
314- compile_hash .update (to_hashable (mdp_ts_json ))
315- compile_hash .update (to_hashable ({"mdp_ts_num_devices" : mdp_ts_num_devices }))
308+ compile_hash_params = {
309+ "command" : command ,
310+ "specializations" : specializations ,
311+ "custom_io" : custom_io ,
312+ "mdp_ts_num_devices" : mdp_ts_num_devices ,
313+ "mdp_ts_json" : mdp_ts_json ,
314+ "num_speculative_tokens" : num_speculative_tokens ,
315+ }
316+ compile_hash = hash_dict_params (compile_hash_params )
316317
317- # Check if already compiled
318- compile_hash = compile_hash .hexdigest ()[:16 ]
319318 compile_dir = qpc_path .with_name (qpc_path .name + "-" + compile_hash )
320319 qpc_path = compile_dir / "qpc"
321320 qpc_path .mkdir (parents = True , exist_ok = True )
@@ -366,6 +365,10 @@ def _compile(
366365 ]
367366 )
368367 )
368+ # Dump JSON file with hashed parameters
369+ hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
370+ create_json (hashed_compile_params_path , compile_hash_params )
371+ logger .info ("Hashed parameters exported successfully." )
369372
370373 self .qpc_path = qpc_path
371374
0 commit comments