@@ -8,8 +8,8 @@ index 2bd4e90..e18a296 100644
88 import timeit
99+ import time
1010+ import sys
11- + import threading  
12-   
11+ + import threading
12+ 
1313 import numpy as np
1414 import torch
1515@@ -45,19 +48,22 @@  from transformers.data.metrics.squad_metrics import (
@@ -18,29 +18,29 @@ index 2bd4e90..e18a296 100644
1818 from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
1919- 
2020+ import intel_extension_for_pytorch as ipex
21-   
21+ 
2222 try:
2323     from torch.utils.tensorboard import SummaryWriter
2424 except ImportError:
2525     from tensorboardX import SummaryWriter
26-   
26+ 
2727- 
2828 logger = logging.getLogger(__name__)
29-   
29+ 
3030 MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
3131 MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
32-   
32+ 
3333+ def trace_handler(prof):
3434+     print(prof.key_averages().table(
3535+         sort_by="self_cpu_time_total", row_limit=-1))
3636+     prof.export_chrome_trace("./log/test_trace_" + str(prof.step_num) + ".json")
37-   
37+ 
3838 def set_seed(args):
3939     random.seed(args.seed)
4040@@ -264,19 +270,95 @@  def train(args, train_dataset, model, tokenizer):
41-   
41+ 
4242     return global_step, tr_loss / global_step
43-   
43+ 
4444+ def wrap_model(model, args):
4545+     model.eval()
4646+     ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
@@ -63,7 +63,7 @@ index 2bd4e90..e18a296 100644
6363+         # enable fusion path work(need to run two interation).
6464+         with torch.no_grad():
6565+             y = model(dumpy_tensor, dumpy_tensor, dumpy_tensor)
66- +             y = model(dumpy_tensor, dumpy_tensor, dumpy_tensor)  
66+ +             y = model(dumpy_tensor, dumpy_tensor, dumpy_tensor)
6767+             #dumpy_tensor = torch.ones((128, 384), dtype=torch.long)
6868+             #y = model(dumpy_tensor, dumpy_tensor, dumpy_tensor)
6969+             #dumpy_tensor = torch.ones((81, 384), dtype=torch.long)
@@ -79,9 +79,9 @@ index 2bd4e90..e18a296 100644
7979+     elif args.use_jit:
8080+         with torch.no_grad():
8181+             model = torch.jit.trace(model, jit_inputs, strict=False)
82- +             #model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True))  
82+ +             #model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True))
8383+             model = torch.jit.freeze(model)
84- +     return model  
84+ +     return model
8585+ 
8686+ def benchmark_evaluate(args, model, eval_dataloader):
8787+     total_time = 0
@@ -107,32 +107,32 @@ index 2bd4e90..e18a296 100644
107107+                 "token_type_ids": batch[2],
108108+             }
109109+             time_start = time.time()
110- +             outputs = model(**inputs)           
110+ +             outputs = model(**inputs)
111111+             #prof.step()
112112+             time_end = time.time()
113113+             i += 1
114114+             if i > args.perf_begin_iter:
115115+                 total_time +=(time_end - time_start)
116116+             if i >= args.perf_run_iters + args.perf_begin_iter:
117117+                 throughput = args.eval_batch_size * args.perf_run_iters / total_time
118- +                 print("Throughput: {:.3f} sentence/s".format(throughput))      
118+ +                 print("Throughput: {:.3f} sentence/s".format(throughput))
119119+                 break
120-   
120+ 
121121 def evaluate(args, model, tokenizer, prefix=""):
122122     dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
123123- 
124124     if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
125125         os.makedirs(args.output_dir)
126126- 
127- +   
127+ + 
128128     args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
129-   
129+ 
130130+     model = wrap_model(model, args)
131131     # Note that DistributedSampler samples randomly
132132     eval_sampler = SequentialSampler(dataset)
133133     eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
134134- 
135- +        
135+ + 
136136     # multi-gpu evaluate
137137     if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
138138         model = torch.nn.DataParallel(model)
@@ -141,7 +141,7 @@ index 2bd4e90..e18a296 100644
141141     logger.info("  Num examples = %d", len(dataset))
142142     logger.info("  Batch size = %d", args.eval_batch_size)
143143- 
144- +      
144+ + 
145145+     if args.do_calibration:
146146+         conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine)
147147+         for step, batch in enumerate(eval_dataloader):
@@ -168,18 +168,18 @@ index 2bd4e90..e18a296 100644
168168+             for t in threads:
169169+                 t.join()
170170+         else:
171- +             benchmark_evaluate(args, model, eval_dataloader)         
171+ +             benchmark_evaluate(args, model, eval_dataloader)
172172+         exit()
173173     all_results = []
174174     start_time = timeit.default_timer()
175175- 
176- +      
176+ + 
177177     for batch in tqdm(eval_dataloader, desc="Evaluating"):
178178         model.eval()
179179         batch = tuple(t.to(args.device) for t in batch)
180180@@ -658,6 +768,36 @@  def main():
181181     parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
182-   
182+ 
183183     parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
184184+     parser.add_argument(
185185+         "--bf16",
@@ -198,7 +198,7 @@ index 2bd4e90..e18a296 100644
198198+                         help='use llga int8 in pytorch jit model')
199199+     parser.add_argument('--int8_fp32', dest='int8_fp32', action='store_true',
200200+                         help='use int8 fp32 mix precision')
201- +     parser.add_argument("--int8_config", type=str, default="config.json",  
201+ +     parser.add_argument("--int8_config", type=str, default="config.json",
202202+                         help="quantization config file for int8 mode")
203203+     parser.add_argument("--do_calibration", action='store_true',
204204+                         help="Enable calibration process")
@@ -212,20 +212,20 @@ index 2bd4e90..e18a296 100644
212212+                         help="Total cores used for this process, used for share_weight mode")
213213+ 
214214     args = parser.parse_args()
215-   
215+ 
216216     if args.doc_stride >= args.max_seq_length - args.max_query_length:
217217diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py 
218218index 23d25cf..b281147 100644
219219--- a/src/transformers/modeling_bert.py 
220220+++ b/src/transformers/modeling_bert.py 
221221@@ -139,7 +139,7 @@  def mish(x):
222222     return x * torch.tanh(nn.functional.softplus(x))
223-   
224-   
223+ 
224+ 
225225- ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
226226+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
227-   
228-   
227+ 
228+ 
229229 BertLayerNorm = torch.nn.LayerNorm
230230@@ -239,6 +239,8 @@  class BertSelfAttention(nn.Module):
231231         attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@@ -235,4 +235,4 @@ index 23d25cf..b281147 100644
235235+                 attention_mask = attention_mask.to(attention_scores.dtype)
236236             # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
237237             attention_scores = attention_scores + attention_mask
238-   
238+ 
0 commit comments