File tree Expand file tree Collapse file tree 7 files changed +107
-0
lines changed Expand file tree Collapse file tree 7 files changed +107
-0
lines changed Original file line number Diff line number Diff line change @@ -72,4 +72,5 @@ add_subdirectory(src/05computation)
7272add_subdirectory (src/06frontend)
7373add_subdirectory (src/07onnx)
7474add_subdirectory (src/08communication)
75+ add_subdirectory (src/08-01llm)
7576add_subdirectory (src/09python_ffi)
Original file line number Diff line number Diff line change 1+ cmake_minimum_required (VERSION 3.12 FATAL_ERROR)
2+ project (llm VERSION 0.0.0 LANGUAGES CXX)
3+ message (STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION} )
4+
5+ file (GLOB_RECURSE LLM_SRC src/*.cc src/*.cpp)
6+ add_library (llm STATIC ${LLM_SRC} )
7+ target_link_libraries (llm PUBLIC frontend)
8+ target_include_directories (llm PUBLIC include )
9+
10+ file (GLOB_RECURSE LLM_TEST test /*.cpp)
11+ if (LLM_TEST)
12+ add_executable (llm_test ${LLM_TEST} )
13+ add_test (llm_test llm_test)
14+ target_link_libraries (llm_test llm GTest::gtest_main ${BACKWARD_ENABLE} )
15+ add_backward(llm_test)
16+ endif ()
Original file line number Diff line number Diff line change 1+ #ifndef LLM_OPERATORS_H
2+ #define LLM_OPERATORS_H
3+
4+ namespace refactor ::llm {
5+
6+ void register_ ();
7+
8+ }// namespace refactor::llm
9+
10+ #endif // LLM_OPERATORS_H
Original file line number Diff line number Diff line change 1+ #include " llm/operators.h"
2+ #include " operators/rms_normalization.hh"
3+
4+ namespace refactor ::llm {
5+ using namespace frontend ;
6+
7+ void register_ () {
8+ // clang-format off
9+ #define REGISTER (NAME, CLASS ) Operator::register_<CLASS>(" llm::" #NAME)
10+ #undef REGISTER
11+ // clang-format on
12+ }
13+
14+ }// namespace refactor::llm
Original file line number Diff line number Diff line change 1+ #ifndef LLM_COMMON_H
2+ #define LLM_COMMON_H
3+
4+ #include "common.h"
5+
6+ #define EXPECT_SIZE (N ) \
7+ if (inputs.size() != (N)) { \
8+ return Err(InferError(ERROR_MSG("Input size error"))); \
9+ }
10+
11+ #define EXPECT_VAL (DIM , VAL ) \
12+ int64_t VAL; \
13+ if ((DIM).hasValue()) { \
14+ VAL = (DIM).value(); \
15+ } else { \
16+ return Err(InferError(UnknownVariable{(DIM.variable()->name)})); \
17+ }
18+
19+ #endif // LLM_COMMON_H
Original file line number Diff line number Diff line change 1+ #include " rms_normalization.hh"
2+ #include " common.h"
3+
4+ namespace refactor ::llm {
5+ using Op = RmsNormalization;
6+
7+ auto Op::build (ModelContext const &, std::string_view, Attributes) -> OpBox {
8+ return OpBox (std::make_unique<Op>());
9+ }
10+ auto Op::typeId () -> size_t {
11+ static uint8_t ID = 1 ;
12+ return reinterpret_cast <size_t >(&ID);
13+ }
14+
15+ auto Op::opTypeId () const -> size_t { return typeId (); }
16+ auto Op::opTypeName () const -> std::string_view { return " llm::RmsNormalization" ; }
17+
18+ auto Op::infer (TensorRefs inputs, InferOptions const &) const -> InferResult {
19+ EXPECT_SIZE (2 )
20+
21+ TODO (" " );
22+ }
23+
24+ }// namespace refactor::llm
Original file line number Diff line number Diff line change 1+ #ifndef LLM_RMS_NORMALIZATION_HH
2+ #define LLM_RMS_NORMALIZATION_HH
3+
4+ #include " frontend/operator.h"
5+
6+ namespace refactor ::llm {
7+ using namespace frontend ;
8+
9+ struct RmsNormalization final : public Operator {
10+
11+ constexpr RmsNormalization () noexcept = default;
12+
13+ static OpBox build (ModelContext const &, std::string_view, Attributes);
14+ static size_t typeId ();
15+
16+ size_t opTypeId () const final ;
17+ std::string_view opTypeName () const final ;
18+ InferResult infer (TensorRefs, InferOptions const &) const final ;
19+ };
20+
21+ }// namespace refactor::llm
22+
23+ #endif // LLM_RMS_NORMALIZATION_HH
You can’t perform that action at this time.
0 commit comments