Skip to content

Commit

Permalink
Merge pull request #687 from XJDKC/master
Browse files Browse the repository at this point in the history
fix bugs of recycling
  • Loading branch information
nudles authored May 1, 2020
2 parents 12efad4 + 61d1ee5 commit e4082c6
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 176 deletions.
2 changes: 1 addition & 1 deletion examples/rnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def train(data,
cuda = device.create_cuda_gpu()
model = CharRNN(data.vocab_size, hidden_size)
model.on_device(cuda)
model.graph(True, True)
model.graph(True, False)

inputs, labels = None, None

Expand Down
21 changes: 10 additions & 11 deletions include/singa/core/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,16 @@ class Edge {
class BlkInfo {
public:
BlkInfo(int id, Block *blk, BlockType type = BlockType::kUnknow)
: id_(id),
blk_(blk),
type_(type),
graph_ref_(0),
write_node_(nullptr),
last_node_(nullptr) {}
: id_(id), blk_(blk), type_(type), graph_ref_(0), write_edge_(nullptr) {}

// getters of BlkInfo
int id() const { return id_; }
Block *block() const { return blk_; }
BlockType type() const { return type_; }
int graph_ref() const { return graph_ref_; }
Node *write_node() const { return write_node_; }
Node *last_node() const { return last_node_; }
Edge *write_edge() const { return write_edge_; }
const NodeVec &used_nodes() const { return used_nodes_; }
Node *used_node(const size_t idx) const;

private:
friend Graph;
Expand All @@ -119,8 +115,8 @@ class BlkInfo {
Block *blk_;
BlockType type_;
int graph_ref_;
Node *write_node_; // last node that writes the block
Node *last_node_; // last node that uses the block
Edge *write_edge_; // the edge of last node that writes data into blk
NodeVec used_nodes_; // the nodes that use this block(in order of execution)
};

class Graph {
Expand Down Expand Up @@ -165,8 +161,10 @@ class Graph {
const BlockVec &free_blocks(const size_t idx) const;

private:
void Analysis();
void Analyze();
void FreeLoop();
void AnalyzeNodes();
void AnalyzeEdges();
void AddSyncOp(function<void(Context *)> &&op);

// static void CUDART_CB Callback(cudaStream_t stream, cudaError_t status,
Expand All @@ -185,6 +183,7 @@ class Graph {

// Calculation graph analysis
bool dirty_ = false;
bool in_serial_ = false;
NodeVec begin_nodes_;
std::vector<NodeVec> next_nodes_;
std::vector<BlockVec> free_blocks_;
Expand Down
6 changes: 5 additions & 1 deletion python/singa/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def wrapper(self, *args, **kwargs):
# deconstruct Operations before running the entire graph
if name == 'optim':
for fname in self._results:
self._results[fname].creator = None
if isinstance(self._results[fname], list):
for _matrix in self._results[fname]:
_matrix.creator = None
else:
self._results[fname].creator = None
# make sure all Operations are deallocated
gc.collect()
# add result tensor
Expand Down
Loading

0 comments on commit e4082c6

Please sign in to comment.