Skip to content

Commit 88e0712

Browse files
committed
build: 添加 llm 子项目用于前端大模型自定义算子
Signed-off-by: YdrMaster <[email protected]>
1 parent 1375be3 commit 88e0712

File tree

7 files changed

+107
-0
lines changed

7 files changed

+107
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,5 @@ add_subdirectory(src/05computation)
7272
add_subdirectory(src/06frontend)
7373
add_subdirectory(src/07onnx)
7474
add_subdirectory(src/08communication)
75+
add_subdirectory(src/08-01llm)
7576
add_subdirectory(src/09python_ffi)

src/08-01llm/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
2+
project(llm VERSION 0.0.0 LANGUAGES CXX)
3+
message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION})
4+
5+
file(GLOB_RECURSE LLM_SRC src/*.cc src/*.cpp)
6+
add_library(llm STATIC ${LLM_SRC})
7+
target_link_libraries(llm PUBLIC frontend)
8+
target_include_directories(llm PUBLIC include)
9+
10+
file(GLOB_RECURSE LLM_TEST test/*.cpp)
11+
if(LLM_TEST)
12+
add_executable(llm_test ${LLM_TEST})
13+
add_test(llm_test llm_test)
14+
target_link_libraries(llm_test llm GTest::gtest_main ${BACKWARD_ENABLE})
15+
add_backward(llm_test)
16+
endif()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef LLM_OPERATORS_H
2+
#define LLM_OPERATORS_H
3+
4+
namespace refactor::llm {
5+
6+
void register_();
7+
8+
}// namespace refactor::llm
9+
10+
#endif// LLM_OPERATORS_H

src/08-01llm/src/operators.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include "llm/operators.h"
2+
#include "operators/rms_normalization.hh"
3+
4+
namespace refactor::llm {
5+
using namespace frontend;
6+
7+
void register_() {
8+
// clang-format off
9+
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("llm::" #NAME)
10+
#undef REGISTER
11+
// clang-format on
12+
}
13+
14+
}// namespace refactor::llm
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef LLM_COMMON_H
2+
#define LLM_COMMON_H
3+
4+
#include "common.h"
5+
6+
#define EXPECT_SIZE(N) \
7+
if (inputs.size() != (N)) { \
8+
return Err(InferError(ERROR_MSG("Input size error"))); \
9+
}
10+
11+
#define EXPECT_VAL(DIM, VAL) \
12+
int64_t VAL; \
13+
if ((DIM).hasValue()) { \
14+
VAL = (DIM).value(); \
15+
} else { \
16+
return Err(InferError(UnknownVariable{(DIM.variable()->name)})); \
17+
}
18+
19+
#endif// LLM_COMMON_H
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "rms_normalization.hh"
2+
#include "common.h"
3+
4+
namespace refactor::llm {
5+
using Op = RmsNormalization;
6+
7+
auto Op::build(ModelContext const &, std::string_view, Attributes) -> OpBox {
8+
return OpBox(std::make_unique<Op>());
9+
}
10+
auto Op::typeId() -> size_t {
11+
static uint8_t ID = 1;
12+
return reinterpret_cast<size_t>(&ID);
13+
}
14+
15+
auto Op::opTypeId() const -> size_t { return typeId(); }
16+
auto Op::opTypeName() const -> std::string_view { return "llm::RmsNormalization"; }
17+
18+
auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult {
19+
EXPECT_SIZE(2)
20+
21+
TODO("");
22+
}
23+
24+
}// namespace refactor::llm
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef LLM_RMS_NORMALIZATION_HH
2+
#define LLM_RMS_NORMALIZATION_HH
3+
4+
#include "frontend/operator.h"
5+
6+
namespace refactor::llm {
7+
using namespace frontend;
8+
9+
struct RmsNormalization final : public Operator {
10+
11+
constexpr RmsNormalization() noexcept = default;
12+
13+
static OpBox build(ModelContext const &, std::string_view, Attributes);
14+
static size_t typeId();
15+
16+
size_t opTypeId() const final;
17+
std::string_view opTypeName() const final;
18+
InferResult infer(TensorRefs, InferOptions const &) const final;
19+
};
20+
21+
}// namespace refactor::llm
22+
23+
#endif// LLM_RMS_NORMALIZATION_HH

0 commit comments

Comments
 (0)