55# 
66# ---------------------------------------------------------------------------- 
77
8- import  hashlib 
98import  inspect 
109import  logging 
1110import  shutil 
2221from  QEfficient .base .pytorch_transforms  import  PytorchTransform 
2322from  QEfficient .compile .qnn_compiler  import  compile  as  qnn_compile 
2423from  QEfficient .generation .cloud_infer  import  QAICInferenceSession 
25- from  QEfficient .utils  import  constants , create_json , dump_qconfig , generate_mdp_partition_config , load_json 
26- from  QEfficient .utils .cache  import  QEFF_HOME , to_hashable 
24+ from  QEfficient .utils  import  (
25+     constants ,
26+     create_json ,
27+     create_model_params ,
28+     dump_qconfig ,
29+     export_wrapper ,
30+     generate_mdp_partition_config ,
31+     hash_dict_params ,
32+     load_json ,
33+ )
2734
2835logger  =  logging .getLogger (__name__ )
2936
@@ -45,12 +52,16 @@ class QEFFBaseModel(ABC):
4552    def  _transform_names (cls ) ->  List [str ]:
4653        return  [x .__name__  for  x  in  cls ._pytorch_transforms  +  cls ._onnx_transforms ]
4754
48-     def  __init__ (self , model : torch .nn .Module ) ->  None :
55+     def  __init__ (self , model : torch .nn .Module ,  ** kwargs ) ->  None :
4956        super ().__init__ ()
5057        self .model  =  model 
58+         self .hash_params  =  create_model_params (self , ** kwargs )
5159        self .onnx_path : Optional [str ] =  None 
5260        self .qpc_path : Optional [str ] =  None 
5361        self .qpc_session : Optional [QAICInferenceSession ] =  None 
62+         self .model_architecture  =  (
63+             (arch  :=  getattr (self .model .config , "architectures" , None )) and  len (arch ) >  0  and  arch [0 ]
64+         ) or  None 
5465
5566        # Apply the transformations 
5667        any_transformed  =  False 
@@ -67,10 +78,6 @@ def __init__(self, model: torch.nn.Module) -> None:
6778    @abstractmethod  
6879    def  model_name (self ) ->  str : ...
6980
70-     @property  
71-     @abstractmethod  
72-     def  model_hash (self ) ->  str : ...
73- 
7481    @abstractmethod  
7582    def  export (self , export_dir : Optional [str ] =  None ) ->  Path :
7683        """ 
@@ -114,6 +121,7 @@ def compile(self, *args, **kwargs) -> Path:
114121            :str: Path of the compiled ``qpc`` package. 
115122        """ 
116123
124+     @export_wrapper  
117125    def  _export (
118126        self ,
119127        example_inputs : Dict [str , torch .Tensor ],
@@ -134,8 +142,6 @@ def _export(
134142            :onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class. 
135143            :export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model. 
136144        """ 
137-         export_dir  =  Path (export_dir  or  (QEFF_HOME  /  self .model_name ))
138-         export_dir  =  export_dir .with_name (export_dir .name  +  "-"  +  self .model_hash )
139145        onnx_path  =  export_dir  /  f"{ self .model_name }  
140146        if  onnx_path .is_file ():
141147            self .onnx_path  =  onnx_path 
@@ -299,23 +305,16 @@ def _compile(
299305        else :
300306            mdp_ts_json  =  None 
301307
302-         compile_hash  =  hashlib .sha256 (to_hashable (command ))
303- 
304-         if  specializations  is  not None :
305-             compile_hash .update (to_hashable (specializations ))
306- 
307-         if  custom_io  is  not None :
308-             compile_hash .update (to_hashable (custom_io ))
309- 
310-         if  num_speculative_tokens :
311-             compile_hash .update (to_hashable ({"num_speculative_tokens" : num_speculative_tokens }))
312- 
313-         # Hash the MDP partition config and the number of devices. 
314-         compile_hash .update (to_hashable (mdp_ts_json ))
315-         compile_hash .update (to_hashable ({"mdp_ts_num_devices" : mdp_ts_num_devices }))
308+         compile_hash_params  =  {
309+             "command" : command ,
310+             "specializations" : specializations ,
311+             "custom_io" : custom_io ,
312+             "mdp_ts_num_devices" : mdp_ts_num_devices ,
313+             "mdp_ts_json" : mdp_ts_json ,
314+             "num_speculative_tokens" : num_speculative_tokens ,
315+         }
316+         compile_hash  =  hash_dict_params (compile_hash_params )
316317
317-         # Check if already compiled 
318-         compile_hash  =  compile_hash .hexdigest ()[:16 ]
319318        compile_dir  =  qpc_path .with_name (qpc_path .name  +  "-"  +  compile_hash )
320319        qpc_path  =  compile_dir  /  "qpc" 
321320        qpc_path .mkdir (parents = True , exist_ok = True )
@@ -366,6 +365,10 @@ def _compile(
366365                    ]
367366                )
368367            )
368+         # Dump JSON file with hashed parameters 
369+         hashed_compile_params_path  =  compile_dir  /  "hashed_compile_params.json" 
370+         create_json (hashed_compile_params_path , compile_hash_params )
371+         logger .info ("Hashed parameters exported successfully." )
369372
370373        self .qpc_path  =  qpc_path 
371374
0 commit comments