-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmltk_tflite_micro_context.hpp
82 lines (64 loc) · 2.24 KB
/
mltk_tflite_micro_context.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#pragma once
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_interpreter_context.h"
namespace mltk
{
typedef TfLiteStatus (*TfliteMicroLayerCallback)(
int index,
TfLiteContext& context,
const tflite::NodeAndRegistration& node_and_registration,
TfLiteStatus invoke_status,
void* arg
);
typedef void (*TfliteMicroProcessingCallback)(void *arg);
class TfliteMicroAccelerator;
class TfliteMicroContext
{
public:
static TfliteMicroContext* create(TfLiteContext *context)
{
auto buffer = context->AllocatePersistentBuffer(context, sizeof(TfliteMicroContext));
if(buffer == nullptr)
{
return nullptr;
}
return new(buffer)TfliteMicroContext();
}
virtual bool init(
const void* flatbuffer,
TfLiteContext *context,
TfliteMicroAccelerator* accelerator,
tflite::MicroAllocator* allocator
)
{
this->flatbuffer = flatbuffer;
this->context = context;
this->accelerator = accelerator;
this->allocator = allocator;
auto micro_context = tflite::GetMicroContext(context);
assert(micro_context != nullptr);
auto interpreter_micro_context = reinterpret_cast<tflite::MicroInterpreterContext*>(micro_context);
auto state = interpreter_micro_context->GetInterpreterState();
interpreter_micro_context->SetInterpreterState(tflite::MicroInterpreterContext::InterpreterState::kPrepare);
micro_context->set_external_context(this);
interpreter_micro_context->SetInterpreterState(state);
return true;
}
virtual bool load(TfLiteContext *context) { return true; }
TfLiteContext *context = nullptr;
const void* flatbuffer = nullptr;
TfliteMicroAccelerator* accelerator = nullptr;
tflite::MicroAllocator* allocator = nullptr;
tflite::BuiltinOperator current_layer_opcode = (tflite::BuiltinOperator)(-1);
int current_layer_index = -1;
TfliteMicroLayerCallback layer_callback = nullptr;
void *layer_callback_arg = nullptr;
TfliteMicroProcessingCallback processing_callback = nullptr;
void *processing_callback_arg = nullptr;
protected:
TfliteMicroContext() = default;
};
} // namespace mltk