forked from FFrankyy/FINDER
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmvc_env.pyx
118 lines (93 loc) · 3.95 KB
/
mvc_env.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from cython.operator import dereference as deref
from libcpp.memory cimport shared_ptr
import numpy as np
import graph
from graph cimport Graph
import gc
from libc.stdlib cimport free
cdef class py_MvcEnv:
cdef shared_ptr[MvcEnv] inner_MvcEnv
cdef shared_ptr[Graph] inner_Graph
def __cinit__(self,double _norm):
self.inner_MvcEnv = shared_ptr[MvcEnv](new MvcEnv(_norm))
self.inner_Graph =shared_ptr[Graph](new Graph())
def s0(self,_g):
self.inner_Graph =shared_ptr[Graph](new Graph())
deref(self.inner_Graph).num_nodes = _g.num_nodes
deref(self.inner_Graph).num_edges = _g.num_edges
deref(self.inner_Graph).edge_list = _g.edge_list
deref(self.inner_Graph).adj_list = _g.adj_list
deref(self.inner_Graph).nodes_weight = _g.nodes_weight
deref(self.inner_Graph).total_nodes_weight =_g.total_nodes_weight
deref(self.inner_MvcEnv).s0(self.inner_Graph)
def step(self,int a):
return deref(self.inner_MvcEnv).step(a)
def stepWithoutReward(self,int a):
deref(self.inner_MvcEnv).stepWithoutReward(a)
def randomAction(self):
return deref(self.inner_MvcEnv).randomAction()
def betweenAction(self):
return deref(self.inner_MvcEnv).betweenAction()
def isTerminal(self):
return deref(self.inner_MvcEnv).isTerminal()
def getReward(self, a):
return deref(self.inner_MvcEnv).getReward(a)
def getMaxConnectedNodesNum(self):
return deref(self.inner_MvcEnv).getMaxConnectedNodesNum()
@property
def norm(self):
return deref(self.inner_MvcEnv).norm
@property
def graph(self):
# temp_innerGraph=deref(self.inner_Graph) #得到了Graph 对象
return self.G2P(deref(self.inner_Graph))
@property
def state_seq(self):
return deref(self.inner_MvcEnv).state_seq
@property
def act_seq(self):
return deref(self.inner_MvcEnv).act_seq
@property
def action_list(self):
return deref(self.inner_MvcEnv).action_list
@property
def reward_seq(self):
return deref(self.inner_MvcEnv).reward_seq
@property
def sum_rewards(self):
return deref(self.inner_MvcEnv).sum_rewards
@property
def numCoveredEdges(self):
return deref(self.inner_MvcEnv).numCoveredEdges
@property
def covered_set(self):
return deref(self.inner_MvcEnv).covered_set
@property
def avail_list(self):
return deref(self.inner_MvcEnv).avail_list
cdef G2P(self,Graph graph1):
num_nodes = graph1.num_nodes #得到Graph对象的节点个数
num_edges = graph1.num_edges #得到Graph对象的连边个数
edge_list = graph1.edge_list
nodes_weight = graph1.nodes_weight
cint_edges_from = np.zeros([num_edges],dtype=np.int)
cint_edges_to = np.zeros([num_edges],dtype=np.int)
cdouble_nodes_weight=np.zeros([num_nodes],dtype=np.double)
cdef int i
for i in range(num_nodes):
cdouble_nodes_weight[i]=nodes_weight[i]
for i in range(num_edges):
cint_edges_from[i]=edge_list[i].first
cint_edges_to[i] =edge_list[i].second
return graph.py_Graph(num_nodes,num_edges,cint_edges_from,cint_edges_to,cdouble_nodes_weight)
# cdef reshape_Graph(self, int _num_nodes, int _num_edges, int[:] edges_from, int[:] edges_to):
# cdef int *cint_edges_from = <int*>malloc(_num_edges*sizeof(int))
# cdef int *cint_edges_to = <int*>malloc(_num_edges*sizeof(int))
# cdef int i
# for i in range(_num_edges):
# cint_edges_from[i] = edges_from[i]
# for i in range(_num_edges):
# cint_edges_to[i] = edges_to[i]
# free(cint_edges_from)
# free(cint_edges_to)
# return new Graph(_num_nodes,_num_edges,&cint_edges_from[0],&cint_edges_to[0])