Skip to content

Commit

Permalink
我要开始抄一遍了
Browse files Browse the repository at this point in the history
  • Loading branch information
sybs5968 committed Jan 10, 2024
1 parent ea3f852 commit eb91f71
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 10 deletions.
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def main():
return
(G, labels), task = read_file(args, logger)
dataloaders, out_features = get_data(G, task=task, labels=labels, args=args, logger=logger)
# out_featires是去重之后的标签个数
storage = estimate_storage(dataloaders, ['train_loader', 'val_loader', 'test_loader'], logger)
model = get_model(layers=args.layers, in_features=dataloaders[0].dataset[0].x.shape[-1], out_features=out_features,
prop_depth=args.prop_depth, args=args, logger=logger)
Expand Down
2 changes: 2 additions & 0 deletions models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def get_minibatch_embeddings(self, x, batch):
device = x.device
set_indices, batch, num_graphs = batch.set_indices, batch.batch, batch.num_graphs
num_nodes = torch.eye(num_graphs)[batch].to(device).sum(dim=0)
# 计算每个图中的节点个数,num_nodes shape = [1 * num_graphs]
zero = torch.tensor([0], dtype=torch.long).to(device)
index_bases = torch.cat([zero, torch.cumsum(num_nodes, dim=0, dtype=torch.long)[:-1]])
# 把num_nodes的最后一项去掉,然后最前面加个零
index_bases = index_bases.unsqueeze(1).expand(-1, set_indices.size(-1))
assert(index_bases.size(0) == set_indices.size(0))
set_indices_batch = index_bases + set_indices
Expand Down
66 changes: 62 additions & 4 deletions test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,68 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'0'}\n",
"0\n",
"176468\n",
"0\n",
"0\n",
"0\n"
]
}
],
"source": [
"import utils\n",
"path = r\"data/link_prediction/facebook/edges.txt\"\n",
"nodes = []\n",
"with open(path) as f:\n",
" for line in f.readlines():\n",
" # print(line.strip().split()[:2])\n",
" nodes.extend(line.strip().split()[:2])\n",
"nodes = sorted(list(set(nodes)))\n",
"print(len(nodes))"
" # break\n",
"print(set(nodes[0]))\n",
"print(nodes[0])\n",
"# nodes = sorted(list(set(nodes)))\n",
"print(len(nodes))\n",
"print(nodes[0])\n",
"for new_id , old_id in enumerate(nodes):\n",
" print(new_id)\n",
" print(old_id)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 40, 85, 135, 190])\n",
"tensor([ 40, 85, 135])\n",
"tensor([ 0, 40, 85, 135])\n"
]
}
],
"source": [
"import torch\n",
"A = torch.arange(20).resize(5 , 4)\n",
"A = A.sum(dim=0)\n",
"print(torch.cumsum(A , dim=0))\n",
"\n",
"print(torch.cumsum(A , dim=0)[:-1])\n",
"print(torch.cat([torch.tensor([0]) , torch.cumsum(A , dim=0)[:-1]]))\n",
"tm = torch.cat([torch.tensor([0]) , torch.cumsum(A , dim=0)[:-1]])\n",
"tm = tm.unsqueeze(1)\n",
"print(tm)"
]
}
],
Expand All @@ -24,7 +74,15 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
Expand Down
44 changes: 38 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def read_label(dir, task):
labels labels.txt中的信息
node_id_mapping 新编号
其它
labels None
node_id_mapping
labels = None
node_id_mapping 新编号
"""
if task == 'node_classification':
f_path = dir + 'labels.txt'
Expand All @@ -104,6 +104,15 @@ def read_label(dir, task):


def read_edges(dir, node_id_mapping):
"""读取edge
Args:
dir (string): edge所在的路径
node_id_mapping (dict): 数据中的点对应的新编号
Returns:
list: 返回边的列表[(u1 , v1) , (u2 , v2)]
"""
edges = []
fin_edges = open(dir + 'edges.txt')
for line in fin_edges.readlines():
Expand Down Expand Up @@ -169,12 +178,16 @@ def get_data(G, task, args, labels, logger):
loader = DataLoader(data_list, batch_size=args.bs, shuffle=False, num_workers=0)
return loader
G, labels, set_indices, (train_mask, val_test_mask) = generate_samples_labels_graph(G, labels, task, args, logger)
# 到这里获取的是全部要使用的数据
if args.debug:
logger.info(list(G.edges))

data_list = extract_subgaphs(G, labels, set_indices, prop_depth=args.prop_depth, layers=args.layers,
feature_flags=feature_flags, task=task,
max_sprw=(args.max_sp, args.rw_depth), parallel=args.parallel, logger=logger, debug=args.debug)
# 获取以每个点为中心的数据
train_set, val_set, test_set = split_datalist(data_list, (train_mask, val_test_mask))
# 以train_mask获取train数据集,然后将val_test_mask对半分成测试和验证集
if args.debug:
print_dataset(train_set, logger)
print_dataset(val_set, logger)
Expand All @@ -195,7 +208,7 @@ def generate_samples_labels_graph(G, labels, task, args, logger):
logger (类): _description_
Returns:
_type_: 主要是返回train_mask,test_mask
networkx , list , np.array , (np.array , np.array) : 主要是返回train_mask,test_mask
"""
if labels is None:
logger.info('Labels unavailable. Generating training/test instances from dataset ...')
Expand Down Expand Up @@ -238,10 +251,29 @@ def generate_set_indices_labels(G, task, test_ratio, data_usage=1.0):


def extract_subgaphs(G, labels, set_indices, prop_depth, layers, feature_flags, task, max_sprw, parallel, logger, debug=False):
"""抓取以各个点为中心的子图信息
Args:
G (networkx): 图
labels (list): 标签
set_indices (array[n * 1]): 可用点的下标
prop_depth (int): 邻居深度
layers (int): 邻居深度
feature_flags (Tuple(sp , wr)): 要获取的特征
task (sting): 任务
max_sprw (Tuple(sp , wr)): 特征参数
parallel (boolen): 是否并行获取数据
logger (_type_): _description_
debug (bool, optional): _description_. Defaults to False.
Returns:
list: 以每个点为中心的邻居信息,每个元素都是一个Data类
"""
# deal with adj and features
logger.info('Encode positions ... (Parallel: {})'.format(parallel))
data_list = []
hop_num = get_hop_num(prop_depth, layers, max_sprw, feature_flags)
# hop_num = int(prop_depth * layers) + 1
n_samples = set_indices.shape[0]
if not parallel:
for sample_i in tqdm(range(n_samples)):
Expand Down Expand Up @@ -284,7 +316,7 @@ def get_data_sample(G, set_index, hop_num, feature_flags, max_sprw, label, debug
debug (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
Data: 返回一个
"""


Expand Down Expand Up @@ -413,7 +445,7 @@ def split_dataset(n_samples, test_ratio, stratify=None):
stratify (_type_, optional): 若不为None则训练集和测试集内各类数据比例同stratify. Defaults to None.
Returns:
_type_: train_mask , test_mask , 长度为n_sample的数组np.array,若有为1,若无为0.
np.array , np.array: train_mask , test_mask , 长度为n_sample的数组np.array,若有为1,若无为0.
"""
train_indices, test_indices = train_test_split(list(range(n_samples)), test_size=test_ratio, stratify=stratify)
train_mask = get_mask(train_indices, n_samples)
Expand All @@ -428,7 +460,7 @@ def get_mask(idx, length):
length (_type_): 长度
Returns:
_type_: np.array 01数组
np.array: mask 01数组
"""
mask = np.zeros(length)
mask[idx] = 1
Expand Down

0 comments on commit eb91f71

Please sign in to comment.