forked from isHuangXin/deepcs4plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrepr_code.py
More file actions
78 lines (66 loc) · 3.43 KB
/
repr_code.py
File metadata and controls
78 lines (66 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import sys
from datetime import datetime
import numpy as np
import argparse
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(message)s")
import torch
from utils import normalize
from data_loader import CodeSearchDataset, save_vecs
import models, configs
##### Compute Representation #####
def repr_code(args):
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
config=getattr(configs, 'config_'+args.model)()
##### Define model ######
logger.info('Constructing Model..')
model = getattr(models, args.model)(config)#initialize the model
if args.reload_from>0:
ckpt_path = f"./trained_model/step{args.reload_from}.h5"
# ckpt_path = f'./output/{args.model}/{args.dataset}/{args.timestamp}/models/step{args.reload_from}.h5'
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model = model.to(device)
model.eval()
data_path = args.data_path+args.dataset+'/'
use_set = eval(config['dataset_name'])(data_path, config['use_names'], config['name_len'],
config['use_apis'], config['api_len'],
config['use_tokens'], config['tokens_len'])
data_loader = torch.utils.data.DataLoader(dataset=use_set, batch_size=args.batch_size,
shuffle=False, drop_last=False, num_workers=1)
chunk_id = 0
vecs, n_processed = [], 0
for batch in tqdm(data_loader):
batch_gpu = [tensor.to(device) for tensor in batch]
with torch.no_grad():
reprs = model.code_encoding(*batch_gpu).data.cpu().numpy()
reprs = reprs.astype(np.float32) # [batch x dim]
if config['sim_measure']=='cos': # do normalization for fast cosine computation
reprs = normalize(reprs)
vecs.append(reprs)
n_processed=n_processed+ batch[0].size(0)
if n_processed>= args.chunk_size:
output_path = f"{data_path}{config['use_codevecs'][:-3]}_part{chunk_id}.h5"
save_vecs(np.vstack(vecs), output_path)
chunk_id+=1
vecs, n_processed = [], 0
# save the last chunk (probably incomplete)
output_path = f"{data_path}{config['use_codevecs'][:-3]}_part{chunk_id}.h5"
save_vecs(np.vstack(vecs), output_path)
def parse_args():
parser = argparse.ArgumentParser("Train and Test Code Search(Embedding) Model")
parser.add_argument('--data_path', type=str, default='./data/', help='location of the data corpus')
parser.add_argument('--model', type=str, default='JointEmbeder', help='model name')
parser.add_argument('-d', '--dataset', type=str, default='example', help='dataset')
parser.add_argument('-t', '--timestamp', type=str, help='time stamp')
parser.add_argument('--reload_from', type=int, default=4000000, help='step to reload from')
parser.add_argument('--batch_size', type=int, default=10000, help='how many instances for encoding and normalization at each step')
parser.add_argument('--chunk_size', type=int, default=2000000, help='split code vector into chunks and store them individually. '\
'Note: should be consistent with the same argument in the search.py')
parser.add_argument('-g', '--gpu_id', type=int, default=0, help='GPU ID')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
repr_code(args)