forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVulkanLSTM.hpp
More file actions
65 lines (54 loc) · 2.08 KB
/
Copy pathVulkanLSTM.hpp
File metadata and controls
65 lines (54 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
//
// VulkanLSTM.hpp
// MNN
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef VulkanLSTM_hpp
#define VulkanLSTM_hpp
#include "VulkanBasicExecution.hpp"
namespace MNN {
class LSTMChannel {
public:
LSTMChannel(const VulkanPipeline* vulkanLSTMPipeline, VulkanBackend* vkbackend, const int channel);
~LSTMChannel();
ErrorCode encodeImpl(std::shared_ptr<VulkanBuffer>& gates, std::shared_ptr<VulkanBuffer>& cells,
std::shared_ptr<VulkanBuffer>& weightH, std::shared_ptr<VulkanBuffer>& bias,
std::shared_ptr<VulkanBuffer>& out, const VulkanCommandPool::Buffer* cmdBuffer, const int ow);
private:
int mChannel;
const VulkanPipeline* mVulkanLSTMPipeline;
std::shared_ptr<VulkanPipeline::DescriptorSet> mDescriptorSet;
std::shared_ptr<VulkanBuffer> mParamBuffer;
};
class VulkanLSTM : public VulkanBasicExecution {
public:
VulkanLSTM(const LSTM* lstm, Backend* bn);
virtual ~VulkanLSTM();
ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const VulkanCommandPool::Buffer* cmdBuffer) override;
private:
ErrorCode _resize(const Tensor* input, const Tensor* output);
const LSTM* mLSTM;
VulkanBackend* mVKbackend;
// gates
const VulkanPipeline* mVulkanLSTMGatePipeline;
std::shared_ptr<VulkanPipeline::DescriptorSet> mGateDescriptorSet;
std::shared_ptr<VulkanBuffer> mGateParamBuffer;
// nchw -> nc4hw4
const VulkanPipeline* mVulkanLSTMSavePipeline;
std::shared_ptr<VulkanPipeline::DescriptorSet> mSaveDescriptorSet;
std::shared_ptr<VulkanBuffer> mSaveParamBuffer;
// channel loop
std::vector<std::shared_ptr<LSTMChannel>> mLSTMChannels;
const VulkanPipeline* mVulkanLSTMPipeline;
std::shared_ptr<VulkanBuffer> mWeightI;
std::shared_ptr<VulkanBuffer> mWeightH;
std::shared_ptr<VulkanBuffer> mBias;
std::shared_ptr<VulkanBuffer> mGate;
std::shared_ptr<VulkanBuffer> mCell;
std::shared_ptr<VulkanBuffer> mOutputTemp;
};
} // namespace MNN
#endif