diff --git a/src/observer/event/session_event.cpp b/src/observer/event/session_event.cpp index a503f1212..41ca80d4f 100644 --- a/src/observer/event/session_event.cpp +++ b/src/observer/event/session_event.cpp @@ -15,21 +15,10 @@ See the Mulan PSL v2 for more details. */ #include "session_event.h" #include "net/communicator.h" -SessionEvent::SessionEvent(Communicator *comm) - : communicator_(comm), - sql_result_(communicator_->session()) -{} +SessionEvent::SessionEvent(Communicator *comm) : communicator_(comm), sql_result_(communicator_->session()) {} -SessionEvent::~SessionEvent() -{ -} +SessionEvent::~SessionEvent() {} -Communicator *SessionEvent::get_communicator() const -{ - return communicator_; -} +Communicator *SessionEvent::get_communicator() const { return communicator_; } -Session *SessionEvent::session() const -{ - return communicator_->session(); -} +Session *SessionEvent::session() const { return communicator_->session(); } diff --git a/src/observer/event/session_event.h b/src/observer/event/session_event.h index ec2ba5206..97e933d39 100644 --- a/src/observer/event/session_event.h +++ b/src/observer/event/session_event.h @@ -14,36 +14,33 @@ See the Mulan PSL v2 for more details. */ #pragma once -#include #include #include "common/seda/stage_event.h" -#include "sql/executor/sql_result.h" #include "event/sql_debug.h" +#include "sql/executor/sql_result.h" class Session; class Communicator; /** * @brief 表示一个SQL请求 - * + * */ -class SessionEvent : public common::StageEvent +class SessionEvent : public common::StageEvent { public: SessionEvent(Communicator *client); virtual ~SessionEvent(); Communicator *get_communicator() const; - Session *session() const; + Session *session() const; void set_query(const std::string &query) { query_ = query; } const std::string &query() const { return query_; } - - SqlResult *sql_result() { return &sql_result_; } - - SqlDebug &sql_debug() { return sql_debug_; } + SqlResult *sql_result() { return &sql_result_; } + SqlDebug &sql_debug() { return sql_debug_; } private: Communicator *communicator_ = nullptr; ///< 与客户端通讯的对象 diff --git a/src/observer/event/sql_debug.cpp b/src/observer/event/sql_debug.cpp index 4dd90d44c..4f34951cd 100644 --- a/src/observer/event/sql_debug.cpp +++ b/src/observer/event/sql_debug.cpp @@ -14,43 +14,34 @@ See the Mulan PSL v2 for more details. */ #include +#include "event/session_event.h" #include "event/sql_debug.h" #include "session/session.h" -#include "event/session_event.h" using namespace std; -void SqlDebug::add_debug_info(const std::string &debug_info) -{ - debug_infos_.push_back(debug_info); -} +void SqlDebug::add_debug_info(const string &debug_info) { debug_infos_.push_back(debug_info); } -void SqlDebug::clear_debug_info() -{ - debug_infos_.clear(); -} +void SqlDebug::clear_debug_info() { debug_infos_.clear(); } -const list &SqlDebug::get_debug_infos() const -{ - return debug_infos_; -} +const list &SqlDebug::get_debug_infos() const { return debug_infos_; } void sql_debug(const char *fmt, ...) { - Session *session = Session::current_session(); + Session *session = Session::current_session(); if (nullptr == session) { return; } SessionEvent *request = session->current_request(); if (nullptr == request) { - return ; + return; } SqlDebug &sql_debug = request->sql_debug(); const int buffer_size = 4096; - char *str = new char[buffer_size]; + char *str = new char[buffer_size]; va_list ap; va_start(ap, fmt); diff --git a/src/observer/event/sql_debug.h b/src/observer/event/sql_debug.h index 493afecd8..3f3f5c132 100644 --- a/src/observer/event/sql_debug.h +++ b/src/observer/event/sql_debug.h @@ -19,7 +19,7 @@ See the Mulan PSL v2 for more details. */ /** * @brief SQL调试信息 - * @details + * @details * 希望在运行SQL时,可以直接输出一些调试信息到客户端。 * 当前把调试信息都放在了session上,可以随着SQL语句输出。 * 但是现在还不支持与输出调试信息与行数据同步输出。 @@ -27,7 +27,7 @@ See the Mulan PSL v2 for more details. */ class SqlDebug { public: - SqlDebug() = default; + SqlDebug() = default; virtual ~SqlDebug() = default; void add_debug_info(const std::string &debug_info); diff --git a/src/observer/event/sql_event.cpp b/src/observer/event/sql_event.cpp index 63a23e402..a6e98e53c 100644 --- a/src/observer/event/sql_event.cpp +++ b/src/observer/event/sql_event.cpp @@ -14,14 +14,10 @@ See the Mulan PSL v2 for more details. */ #include "event/sql_event.h" -#include - #include "event/session_event.h" -#include "sql/parser/parse_defs.h" #include "sql/stmt/stmt.h" -SQLStageEvent::SQLStageEvent(SessionEvent *event, const std::string &sql) : session_event_(event), sql_(sql) -{} +SQLStageEvent::SQLStageEvent(SessionEvent *event, const std::string &sql) : session_event_(event), sql_(sql) {} SQLStageEvent::~SQLStageEvent() noexcept { diff --git a/src/observer/event/sql_event.h b/src/observer/event/sql_event.h index bda7d1c9a..6d0e0a00e 100644 --- a/src/observer/event/sql_event.h +++ b/src/observer/event/sql_event.h @@ -14,10 +14,10 @@ See the Mulan PSL v2 for more details. */ #pragma once -#include -#include #include "common/seda/stage_event.h" #include "sql/operator/physical_operator.h" +#include +#include class SessionEvent; class Stmt; @@ -32,53 +32,23 @@ class SQLStageEvent : public common::StageEvent SQLStageEvent(SessionEvent *event, const std::string &sql); virtual ~SQLStageEvent() noexcept; - SessionEvent *session_event() const - { - return session_event_; - } + SessionEvent *session_event() const { return session_event_; } - const std::string &sql() const - { - return sql_; - } - const std::unique_ptr &sql_node() const - { - return sql_node_; - } - Stmt *stmt() const - { - return stmt_; - } - std::unique_ptr &physical_operator() - { - return operator_; - } - const std::unique_ptr &physical_operator() const - { - return operator_; - } + const std::string &sql() const { return sql_; } + const std::unique_ptr &sql_node() const { return sql_node_; } + Stmt *stmt() const { return stmt_; } + std::unique_ptr &physical_operator() { return operator_; } + const std::unique_ptr &physical_operator() const { return operator_; } - void set_sql(const char *sql) - { - sql_ = sql; - } - void set_sql_node(std::unique_ptr sql_node) - { - sql_node_ = std::move(sql_node); - } - void set_stmt(Stmt *stmt) - { - stmt_ = stmt; - } - void set_operator(std::unique_ptr oper) - { - operator_ = std::move(oper); - } + void set_sql(const char *sql) { sql_ = sql; } + void set_sql_node(std::unique_ptr sql_node) { sql_node_ = std::move(sql_node); } + void set_stmt(Stmt *stmt) { stmt_ = stmt; } + void set_operator(std::unique_ptr oper) { operator_ = std::move(oper); } private: - SessionEvent *session_event_ = nullptr; - std::string sql_; ///< 处理的SQL语句 - std::unique_ptr sql_node_; ///< 语法解析后的SQL命令 - Stmt *stmt_ = nullptr; ///< Resolver之后生成的数据结构 - std::unique_ptr operator_; ///< 生成的执行计划,也可能没有 + SessionEvent *session_event_ = nullptr; + std::string sql_; ///< 处理的SQL语句 + std::unique_ptr sql_node_; ///< 语法解析后的SQL命令 + Stmt *stmt_ = nullptr; ///< Resolver之后生成的数据结构 + std::unique_ptr operator_; ///< 生成的执行计划,也可能没有 }; diff --git a/src/observer/event/storage_event.h b/src/observer/event/storage_event.h index 274439e6c..567aca7e1 100644 --- a/src/observer/event/storage_event.h +++ b/src/observer/event/storage_event.h @@ -18,17 +18,14 @@ See the Mulan PSL v2 for more details. */ class SQLStageEvent; -class StorageEvent : public common::StageEvent { +class StorageEvent : public common::StageEvent +{ public: - StorageEvent(SQLStageEvent *sql_event) : sql_event_(sql_event) - {} + StorageEvent(SQLStageEvent *sql_event) : sql_event_(sql_event) {} virtual ~StorageEvent(); - SQLStageEvent *sql_event() const - { - return sql_event_; - } + SQLStageEvent *sql_event() const { return sql_event_; } private: SQLStageEvent *sql_event_; diff --git a/src/observer/net/buffered_writer.cpp b/src/observer/net/buffered_writer.cpp index 6f844ade7..be9ee3379 100644 --- a/src/observer/net/buffered_writer.cpp +++ b/src/observer/net/buffered_writer.cpp @@ -12,26 +12,19 @@ See the Mulan PSL v2 for more details. */ // Created by Wangyunlai on 2023/06/16. // +#include #include #include -#include #include "net/buffered_writer.h" using namespace std; -BufferedWriter::BufferedWriter(int fd) - : fd_(fd), buffer_() -{} +BufferedWriter::BufferedWriter(int fd) : fd_(fd), buffer_() {} -BufferedWriter::BufferedWriter(int fd, int32_t size) - : fd_(fd), buffer_(size) -{} +BufferedWriter::BufferedWriter(int fd, int32_t size) : fd_(fd), buffer_(size) {} -BufferedWriter::~BufferedWriter() -{ - close(); -} +BufferedWriter::~BufferedWriter() { close(); } RC BufferedWriter::close() { @@ -74,6 +67,7 @@ RC BufferedWriter::writen(const char *data, int32_t size) int32_t write_size = 0; while (write_size < size) { int32_t tmp_write_size = 0; + RC rc = write(data + write_size, size - write_size, tmp_write_size); if (OB_FAIL(rc)) { return rc; @@ -105,11 +99,12 @@ RC BufferedWriter::flush_internal(int32_t size) } RC rc = RC::SUCCESS; + int32_t write_size = 0; while (OB_SUCC(rc) && buffer_.size() > 0 && size > write_size) { - const char *buf = nullptr; - int32_t read_size = 0; - rc = buffer_.buffer(buf, read_size); + const char *buf = nullptr; + int32_t read_size = 0; + rc = buffer_.buffer(buf, read_size); if (OB_FAIL(rc)) { return rc; } diff --git a/src/observer/net/buffered_writer.h b/src/observer/net/buffered_writer.h index 7d0e9ad31..3f7d8fc43 100644 --- a/src/observer/net/buffered_writer.h +++ b/src/observer/net/buffered_writer.h @@ -38,7 +38,7 @@ class BufferedWriter * @brief 写数据到文件/socket * @details 缓存满会自动刷新缓存 * @param data 要写入的数据 - * @param size 要写入的数据大小 + * @param size 要写入的数据大小 * @param write_size 实际写入的数据大小 */ RC write(const char *data, int32_t size, int32_t &write_size); @@ -67,6 +67,6 @@ class BufferedWriter RC flush_internal(int32_t size); private: - int fd_ = -1; + int fd_ = -1; RingBuffer buffer_; }; \ No newline at end of file diff --git a/src/observer/net/cli_communicator.cpp b/src/observer/net/cli_communicator.cpp index 7e34e9f9a..921d17ada 100644 --- a/src/observer/net/cli_communicator.cpp +++ b/src/observer/net/cli_communicator.cpp @@ -13,26 +13,27 @@ See the Mulan PSL v2 for more details. */ // #include "net/cli_communicator.h" -#include "net/buffered_writer.h" -#include "common/log/log.h" #include "common/lang/string.h" +#include "common/log/log.h" #include "event/session_event.h" +#include "net/buffered_writer.h" #ifdef USE_READLINE -#include "readline/readline.h" #include "readline/history.h" +#include "readline/readline.h" #endif #define MAX_MEM_BUFFER_SIZE 8192 #define PORT_DEFAULT 6789 +using namespace std; using namespace common; #ifdef USE_READLINE -const std::string HISTORY_FILE = std::string(getenv("HOME")) + "/.miniob.history"; -time_t last_history_write_time = 0; +const string HISTORY_FILE = string(getenv("HOME")) + "/.miniob.history"; +time_t last_history_write_time = 0; -char *my_readline(const char *prompt) +char *my_readline(const char *prompt) { int size = history_length; if (size == 0) { @@ -55,7 +56,7 @@ char *my_readline(const char *prompt) } return line; } -#else // USE_READLINE +#else // USE_READLINE char *my_readline(const char *prompt) { char *buffer = (char *)malloc(MAX_MEM_BUFFER_SIZE); @@ -74,32 +75,29 @@ char *my_readline(const char *prompt) } return buffer; } -#endif // USE_READLINE +#endif // USE_READLINE /* this function config a exit-cmd list, strncasecmp func truncate the command from terminal according to the number, - 'strncasecmp("exit", cmd, 4)' means that obclient read command string from terminal, truncate it to 4 chars from + 'strncasecmp("exit", cmd, 4)' means that obclient read command string from terminal, truncate it to 4 chars from the beginning, then compare the result with 'exit', if they match, exit the obclient. */ -bool is_exit_command(const char *cmd) { - return 0 == strncasecmp("exit", cmd, 4) || - 0 == strncasecmp("bye", cmd, 3) || - 0 == strncasecmp("\\q", cmd, 2) ; +bool is_exit_command(const char *cmd) +{ + return 0 == strncasecmp("exit", cmd, 4) || 0 == strncasecmp("bye", cmd, 3) || 0 == strncasecmp("\\q", cmd, 2); } char *read_command() { - const char *prompt_str = "miniob > "; - char *input_command = nullptr; - for (input_command = my_readline(prompt_str); - common::is_blank(input_command); - input_command = my_readline(prompt_str)) { + const char *prompt_str = "miniob > "; + char *input_command = nullptr; + for (input_command = my_readline(prompt_str); is_blank(input_command); input_command = my_readline(prompt_str)) { free(input_command); input_command = nullptr; } return input_command; } -RC CliCommunicator::init(int fd, Session *session, const std::string &addr) +RC CliCommunicator::init(int fd, Session *session, const string &addr) { RC rc = PlainCommunicator::init(fd, session, addr); if (OB_FAIL(rc)) { @@ -115,7 +113,7 @@ RC CliCommunicator::init(int fd, Session *session, const std::string &addr) const char delimiter = '\n'; send_message_delimiter_.assign(1, delimiter); - fd_ = -1; // 防止被父类析构函数关闭 + fd_ = -1; // 防止被父类析构函数关闭 } else { rc = RC::INVALID_ARGUMENT; LOG_WARN("only stdin supported"); @@ -125,7 +123,7 @@ RC CliCommunicator::init(int fd, Session *session, const std::string &addr) RC CliCommunicator::read_event(SessionEvent *&event) { - event = nullptr; + event = nullptr; char *command = read_command(); if (is_exit_command(command)) { @@ -135,7 +133,7 @@ RC CliCommunicator::read_event(SessionEvent *&event) } event = new SessionEvent(this); - event->set_query(std::string(command)); + event->set_query(string(command)); free(command); return RC::SUCCESS; } @@ -143,6 +141,7 @@ RC CliCommunicator::read_event(SessionEvent *&event) RC CliCommunicator::write_result(SessionEvent *event, bool &need_disconnect) { RC rc = PlainCommunicator::write_result(event, need_disconnect); + need_disconnect = false; return rc; } diff --git a/src/observer/net/cli_communicator.h b/src/observer/net/cli_communicator.h index fcaeb79d5..6ca6cc034 100644 --- a/src/observer/net/cli_communicator.h +++ b/src/observer/net/cli_communicator.h @@ -25,7 +25,7 @@ See the Mulan PSL v2 for more details. */ class CliCommunicator : public PlainCommunicator { public: - CliCommunicator() = default; + CliCommunicator() = default; virtual ~CliCommunicator() = default; RC init(int fd, Session *session, const std::string &addr) override; @@ -33,5 +33,5 @@ class CliCommunicator : public PlainCommunicator RC write_result(SessionEvent *event, bool &need_disconnect) override; private: - int write_fd_ = -1; ///< 与使用远程通讯模式不同,如果读数据使用标准输入,那么输出应该是标准输出 + int write_fd_ = -1; ///< 与使用远程通讯模式不同,如果读数据使用标准输入,那么输出应该是标准输出 }; diff --git a/src/observer/net/communicator.cpp b/src/observer/net/communicator.cpp index e02442475..5cdad197a 100644 --- a/src/observer/net/communicator.cpp +++ b/src/observer/net/communicator.cpp @@ -13,20 +13,20 @@ See the Mulan PSL v2 for more details. */ // #include "net/communicator.h" +#include "net/buffered_writer.h" +#include "net/cli_communicator.h" #include "net/mysql_communicator.h" #include "net/plain_communicator.h" -#include "net/cli_communicator.h" -#include "net/buffered_writer.h" #include "session/session.h" #include "common/lang/mutex.h" RC Communicator::init(int fd, Session *session, const std::string &addr) { - fd_ = fd; + fd_ = fd; session_ = session; - addr_ = addr; - writer_ = new BufferedWriter(fd_); + addr_ = addr; + writer_ = new BufferedWriter(fd_); return RC::SUCCESS; } diff --git a/src/observer/net/communicator.h b/src/observer/net/communicator.h index 8fcae3273..c2d82fd33 100644 --- a/src/observer/net/communicator.h +++ b/src/observer/net/communicator.h @@ -14,9 +14,9 @@ See the Mulan PSL v2 for more details. */ #pragma once -#include -#include #include "common/rc.h" +#include +#include struct ConnectionContext; class SessionEvent; @@ -38,7 +38,7 @@ class BufferedWriter; * 在server中监听到某个连接有新的消息,就通过Communicator::read_event接收消息。 */ -class Communicator +class Communicator { public: virtual ~Communicator(); @@ -66,41 +66,32 @@ class Communicator /** * @brief 关联的会话信息 */ - Session *session() const - { - return session_; - } + Session *session() const { return session_; } /** * @brief libevent使用的数据,参考server.cpp */ - struct event &read_event() - { - return read_event_; - } + struct event &read_event() { return read_event_; } /** * @brief 对端地址 * 如果是unix socket,可能没有意义 */ - const char *addr() const - { - return addr_.c_str(); - } + const char *addr() const { return addr_.c_str(); } protected: - Session *session_ = nullptr; - struct event read_event_; - std::string addr_; + Session *session_ = nullptr; + struct event read_event_; + std::string addr_; BufferedWriter *writer_ = nullptr; - int fd_ = -1; + int fd_ = -1; }; /** * @brief 当前支持的通讯协议 * @ingroup Communicator */ -enum class CommunicateProtocol +enum class CommunicateProtocol { PLAIN, ///< 以'\0'结尾的协议 CLI, ///< 与客户端进行交互的协议 @@ -111,7 +102,7 @@ enum class CommunicateProtocol * @brief 通讯协议工厂 * @ingroup Communicator */ -class CommunicatorFactory +class CommunicatorFactory { public: Communicator *create(CommunicateProtocol protocol); diff --git a/src/observer/net/mysql_communicator.cpp b/src/observer/net/mysql_communicator.cpp index 37fa1abb8..3e56cc624 100644 --- a/src/observer/net/mysql_communicator.cpp +++ b/src/observer/net/mysql_communicator.cpp @@ -15,11 +15,11 @@ See the Mulan PSL v2 for more details. */ #include #include -#include "common/log/log.h" #include "common/io/io.h" -#include "net/mysql_communicator.h" -#include "net/buffered_writer.h" +#include "common/log/log.h" #include "event/session_event.h" +#include "net/buffered_writer.h" +#include "net/mysql_communicator.h" #include "sql/operator/string_list_physical_operator.h" /** @@ -31,7 +31,7 @@ See the Mulan PSL v2 for more details. */ // the flags below are negotiate by handshake packet const uint32_t CLIENT_PROTOCOL_41 = 512; // const uint32_t CLIENT_INTERACTIVE = 1024; // This is an interactive client -const uint32_t CLIENT_TRANSACTIONS = 8192; // Client knows about transactions. +const uint32_t CLIENT_TRANSACTIONS = 8192; // Client knows about transactions. const uint32_t CLIENT_SESSION_TRACK = (1UL << 23); // Capable of handling server state change information const uint32_t CLIENT_DEPRECATE_EOF = (1UL << 24); // Client no longer needs EOF_Packet and will use OK_Packet instead const uint32_t CLIENT_OPTIONAL_RESULTSET_METADATA = @@ -53,7 +53,7 @@ const uint32_t CLIENT_OPTIONAL_RESULTSET_METADATA = * @details 这些枚举值都是从MySQL的协议中抄过来的 * @ingroup MySQLProtocol */ -enum ResultSetMetaData +enum ResultSetMetaData { RESULTSET_METADATA_NONE = 0, RESULTSET_METADATA_FULL = 1, @@ -64,7 +64,7 @@ enum ResultSetMetaData * @details 枚举值类型是从MySQL的协议中抄过来的 * @ingroup MySQLProtocol */ -enum enum_field_types +enum enum_field_types { MYSQL_TYPE_DECIMAL, MYSQL_TYPE_TINY, @@ -87,19 +87,19 @@ enum enum_field_types MYSQL_TYPE_DATETIME2, /**< Internal to MySQL. Not used in protocol */ MYSQL_TYPE_TIME2, /**< Internal to MySQL. Not used in protocol */ MYSQL_TYPE_TYPED_ARRAY, /**< Used for replication only */ - MYSQL_TYPE_INVALID = 243, - MYSQL_TYPE_BOOL = 244, /**< Currently just a placeholder */ - MYSQL_TYPE_JSON = 245, - MYSQL_TYPE_NEWDECIMAL = 246, - MYSQL_TYPE_ENUM = 247, - MYSQL_TYPE_SET = 248, - MYSQL_TYPE_TINY_BLOB = 249, + MYSQL_TYPE_INVALID = 243, + MYSQL_TYPE_BOOL = 244, /**< Currently just a placeholder */ + MYSQL_TYPE_JSON = 245, + MYSQL_TYPE_NEWDECIMAL = 246, + MYSQL_TYPE_ENUM = 247, + MYSQL_TYPE_SET = 248, + MYSQL_TYPE_TINY_BLOB = 249, MYSQL_TYPE_MEDIUM_BLOB = 250, - MYSQL_TYPE_LONG_BLOB = 251, - MYSQL_TYPE_BLOB = 252, - MYSQL_TYPE_VAR_STRING = 253, - MYSQL_TYPE_STRING = 254, - MYSQL_TYPE_GEOMETRY = 255 + MYSQL_TYPE_LONG_BLOB = 251, + MYSQL_TYPE_BLOB = 252, + MYSQL_TYPE_VAR_STRING = 253, + MYSQL_TYPE_STRING = 254, + MYSQL_TYPE_GEOMETRY = 255 }; /** @@ -110,7 +110,7 @@ enum enum_field_types /** * @brief 将数据写入到缓存中 - * + * * @param buf 数据缓存 * @param value 要写入的值 * @return int 写入的字节数 @@ -124,7 +124,7 @@ int store_int1(char *buf, int8_t value) /** * @brief 将数据写入到缓存中 - * + * * @param buf 数据缓存 * @param value 要写入的值 * @return int 写入的字节数 @@ -138,7 +138,7 @@ int store_int2(char *buf, int16_t value) /** * @brief 将数据写入到缓存中 - * + * * @param buf 数据缓存 * @param value 要写入的值 * @return int 写入的字节数 @@ -152,7 +152,7 @@ int store_int3(char *buf, int32_t value) /** * @brief 将数据写入到缓存中 - * + * * @param buf 数据缓存 * @param value 要写入的值 * @return int 写入的字节数 @@ -166,7 +166,7 @@ int store_int4(char *buf, int32_t value) /** * @brief 将数据写入到缓存中 - * + * * @param buf 数据缓存 * @param value 要写入的值 * @return int 写入的字节数 @@ -180,7 +180,7 @@ int store_int6(char *buf, int64_t value) /** * @brief 将数据写入到缓存中 - * + * * @param buf 数据缓存 * @param value 要写入的值 * @return int 写入的字节数 @@ -226,7 +226,7 @@ int store_lenenc_int(char *buf, uint64_t value) /** * @brief 将以'\0'结尾的字符串写入到缓存中 - * + * * @param buf 数据缓存 * @param s 要写入的字符串 * @return int 写入的字节数 @@ -245,7 +245,7 @@ int store_null_terminated_string(char *buf, const char *s) /** * @brief 将指定长度的字符串写入到缓存中 - * + * * @param buf 数据缓存 * @param s 要写入的字符串 * @param len 字符串的长度 @@ -264,7 +264,7 @@ int store_fix_length_string(char *buf, const char *s, int len) /** * @brief 按照带有长度标识的字符串写入到缓存,长度标识以变长整数编码 - * + * * @param buf 数据缓存 * @param s 要写入的字符串 * @return int 写入的字节数 @@ -272,7 +272,7 @@ int store_fix_length_string(char *buf, const char *s, int len) */ int store_lenenc_string(char *buf, const char *s) { - int len = strlen(s); + int len = static_cast(strlen(s)); int pos = store_lenenc_int(buf, len); store_fix_length_string(buf + pos, s, len); return pos + len; @@ -284,10 +284,10 @@ int store_lenenc_string(char *buf, const char *s) * [MariaDB Packet](https://mariadb.com/kb/en/0-packet/) * @ingroup MySQLProtocol */ -struct PacketHeader +struct PacketHeader { int32_t payload_length : 24; //! 当前packet的除掉头的长度 - int8_t sequence_id = 0; //! 当前packet在当前处理过程中是第几个包 + int8_t sequence_id = 0; //! 当前packet在当前处理过程中是第几个包 }; /** @@ -295,21 +295,18 @@ struct PacketHeader * @details 所有的包都有一个包头,所以BasePacket中包含了一个 @ref PacketHeader * @ingroup MySQLProtocol */ -class BasePacket +class BasePacket { public: PacketHeader packet_header; - BasePacket(int8_t sequence = 0) - { - packet_header.sequence_id = sequence; - } + BasePacket(int8_t sequence = 0) { packet_header.sequence_id = sequence; } virtual ~BasePacket() = default; /** * @brief 将当前包编码成网络包 - * + * * @param[in] capabilities MySQL协议中的capability标志 * @param[out] net_packet 编码后的网络包 */ @@ -321,25 +318,25 @@ class BasePacket * @ingroup MySQLProtocol * @details 先由服务端发送到客户端。 * 这个包会交互capability与用户名密码。 - * [MySQL Handshake]https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html + * [MySQL + * Handshake]https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html */ -struct HandshakeV10 : public BasePacket +struct HandshakeV10 : public BasePacket { - int8_t protocol = 10; - char server_version[7] = "5.7.25"; - int32_t thread_id = 21501807; // conn id - char auth_plugin_data_part_1[9] = - "12345678"; // first 8 bytes of the plugin provided data (scramble) // and the filler - int16_t capability_flags_1 = 0xF7DF; // The lower 2 bytes of the Capabilities Flags - int8_t character_set = 83; - int16_t status_flags = 0; - int16_t capability_flags_2 = 0x0000; - int8_t auth_plugin_data_len = 0; - char reserved[10] = {0}; - char auth_plugin_data_part_2[13] = "bbbbbbbbbbbb"; - - HandshakeV10(int8_t sequence = 0) : BasePacket(sequence) - {} + int8_t protocol = 10; + char server_version[7] = "5.7.25"; + int32_t thread_id = 21501807; // conn id + char auth_plugin_data_part_1[9] = + "12345678"; // first 8 bytes of the plugin provided data (scramble) // and the filler + int16_t capability_flags_1 = 0xF7DF; // The lower 2 bytes of the Capabilities Flags + int8_t character_set = 83; + int16_t status_flags = 0; + int16_t capability_flags_2 = 0x0000; + int8_t auth_plugin_data_len = 0; + char reserved[10] = {0}; + char auth_plugin_data_part_2[13] = "bbbbbbbbbbbb"; + + HandshakeV10(int8_t sequence = 0) : BasePacket(sequence) {} virtual ~HandshakeV10() = default; /** @@ -350,8 +347,8 @@ struct HandshakeV10 : public BasePacket net_packet.resize(100); char *buf = net_packet.data(); - int pos = 0; - pos += 3; + int pos = 0; + pos += 3; // skip packet length pos += store_int1(buf + pos, packet_header.sequence_id); pos += store_int1(buf + pos, protocol); @@ -380,29 +377,28 @@ struct HandshakeV10 : public BasePacket * @brief 响应包,在很多场景中都会使用 * @ingroup MySQLProtocol */ -struct OkPacket : public BasePacket +struct OkPacket : public BasePacket { - int8_t header = 0; // 0x00 for ok and 0xFE for EOF - int32_t affected_rows = 0; - int32_t last_insert_id = 0; - int16_t status_flags = 0x22; - int16_t warnings = 0; + int8_t header = 0; // 0x00 for ok and 0xFE for EOF + int32_t affected_rows = 0; + int32_t last_insert_id = 0; + int16_t status_flags = 0x22; + int16_t warnings = 0; std::string info; // human readable status information - OkPacket(int8_t sequence = 0) : BasePacket(sequence) - {} + OkPacket(int8_t sequence = 0) : BasePacket(sequence) {} virtual ~OkPacket() = default; /** * https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html */ - virtual RC encode(uint32_t capabilities, std::vector &net_packet) const override + RC encode(uint32_t capabilities, std::vector &net_packet) const override { net_packet.resize(100); char *buf = net_packet.data(); - int pos = 0; + int pos = 0; - pos += 3; + pos += 3; // skip packet length pos += store_int1(buf + pos, packet_header.sequence_id); pos += store_int1(buf + pos, header); pos += store_lenenc_int(buf + pos, affected_rows); @@ -434,21 +430,20 @@ struct OkPacket : public BasePacket * @ingroup MySQLProtocol * @details [basic_err_packet](https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html) */ -struct EofPacket : public BasePacket +struct EofPacket : public BasePacket { - int8_t header = 0xFE; - int16_t warnings = 0; + int8_t header = 0xFE; + int16_t warnings = 0; int16_t status_flags = 0x22; - EofPacket(int8_t sequence = 0) : BasePacket(sequence) - {} + EofPacket(int8_t sequence = 0) : BasePacket(sequence) {} virtual ~EofPacket() = default; - virtual RC encode(uint32_t capabilities, std::vector &net_packet) const override + RC encode(uint32_t capabilities, std::vector &net_packet) const override { net_packet.resize(10); char *buf = net_packet.data(); - int pos = 0; + int pos = 0; pos += 3; store_int1(buf + pos, packet_header.sequence_id); @@ -476,23 +471,22 @@ struct EofPacket : public BasePacket * @details [eof_packet](https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html) * @ingroup MySQLProtocol */ -struct ErrPacket : public BasePacket +struct ErrPacket : public BasePacket { - int8_t header = 0xFF; - int16_t error_code = 0; - char sql_state_marker[1] = {'#'}; + int8_t header = 0xFF; + int16_t error_code = 0; + char sql_state_marker[1] = {'#'}; std::string sql_state{"HY000"}; std::string error_message; - ErrPacket(int8_t sequence = 0) : BasePacket(sequence) - {} + ErrPacket(int8_t sequence = 0) : BasePacket(sequence) {} virtual ~ErrPacket() = default; virtual RC encode(uint32_t capabilities, std::vector &net_packet) const override { net_packet.resize(1000); char *buf = net_packet.data(); - int pos = 0; + int pos = 0; pos += 3; @@ -519,14 +513,15 @@ struct ErrPacket : public BasePacket /** * @brief MySQL客户端发过来的请求包 * @ingroup MySQLProtocol - * @details [MySQL Protocol Command Phase](https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase.html) - * [MariaDB Text Protocol](https://mariadb.com/kb/en/2-text-protocol/) + * @details [MySQL Protocol Command + * Phase](https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase.html) [MariaDB Text + * Protocol](https://mariadb.com/kb/en/2-text-protocol/) */ -struct QueryPacket +struct QueryPacket { PacketHeader packet_header; - int8_t command; // 0x03: COM_QUERY - std::string query; // the text of the SQL query to execute + int8_t command; // 0x03: COM_QUERY + std::string query; // the text of the SQL query to execute }; /** @@ -549,7 +544,7 @@ RC decode_query_packet(std::vector &net_packet, QueryPacket &query_packet) */ RC create_version_comment_sql_result(SqlResult *sql_result) { - TupleSchema tuple_schema; + TupleSchema tuple_schema; TupleCellSpec cell_spec("", "", "@@version_comment"); tuple_schema.append_cell(cell_spec); @@ -595,12 +590,13 @@ RC MysqlCommunicator::init(int fd, Session *session, const std::string &addr) /** * @brief MySQL客户端连接时会发起一个"select @@version_comment"的查询,这里对这个查询进行特殊处理 - * + * * @param[out] need_disconnect 连接上如果出现异常,通过这个标识来判断是否需要断开连接 */ RC MysqlCommunicator::handle_version_comment(bool &need_disconnect) { SessionEvent session_event(this); + RC rc = create_version_comment_sql_result(session_event.sql_result()); if (rc != RC::SUCCESS) { LOG_WARN("failed to handle version comment. rc=%s", strrc(rc)); @@ -613,7 +609,7 @@ RC MysqlCommunicator::handle_version_comment(bool &need_disconnect) /** * @brief 读取客户端发过来的请求 - * + * * @param[out] event 如果有新的请求,就会生成一个SessionEvent */ RC MysqlCommunicator::read_event(SessionEvent *&event) @@ -622,6 +618,7 @@ RC MysqlCommunicator::read_event(SessionEvent *&event) /// 读取一个完整的数据包 PacketHeader packet_header; + int ret = common::readn(fd_, &packet_header, sizeof(packet_header)); if (ret != 0) { LOG_WARN("failed to read packet header. length=%d, addr=%s. error=%s", @@ -651,6 +648,7 @@ RC MysqlCommunicator::read_event(SessionEvent *&event) // send ok packet and return OkPacket ok_packet; ok_packet.packet_header.sequence_id = sequence_id_; + rc = send_packet(ok_packet); if (rc != RC::SUCCESS) { LOG_WARN("failed to send ok packet while auth"); @@ -698,8 +696,8 @@ RC MysqlCommunicator::write_state(SessionEvent *event, bool &need_disconnect) { SqlResult *sql_result = event->sql_result(); - const int buf_size = 2048; - char *buf = new char[buf_size]; + const int buf_size = 2048; + char *buf = new char[buf_size]; const std::string &state_string = sql_result->state_string(); if (state_string.empty()) { const char *result = RC::SUCCESS == sql_result->return_code() ? "SUCCESS" : "FAILURE"; @@ -718,9 +716,9 @@ RC MysqlCommunicator::write_state(SessionEvent *event, bool &need_disconnect) } else { ErrPacket err_packet; err_packet.packet_header.sequence_id = sequence_id_++; - err_packet.error_code = static_cast(sql_result->return_code()); - err_packet.error_message = buf; - rc = send_packet(err_packet); + err_packet.error_code = static_cast(sql_result->return_code()); + err_packet.error_message = buf; + rc = send_packet(err_packet); } if (rc != RC::SUCCESS) { LOG_WARN("failed to send ok packet to client. addr=%s, error=%s", addr(), strrc(rc)); @@ -738,13 +736,13 @@ RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect) { RC rc = RC::SUCCESS; - need_disconnect = true; + need_disconnect = true; SqlResult *sql_result = event->sql_result(); if (nullptr == sql_result) { const char *response = "Unexpected error: no result"; - const int len = strlen(response); - OkPacket ok_packet; // TODO if error occurs, we should send an error packet to client + const int len = strlen(response); + OkPacket ok_packet; // TODO if error occurs, we should send an error packet to client ok_packet.info.assign(response, len); rc = send_packet(ok_packet); if (rc != RC::SUCCESS) { @@ -767,7 +765,7 @@ RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect) } const TupleSchema &tuple_schema = sql_result->tuple_schema(); - const int cell_num = tuple_schema.cell_num(); + const int cell_num = tuple_schema.cell_num(); if (cell_num == 0) { // maybe a dml that send nothing to client } else { @@ -794,6 +792,7 @@ RC MysqlCommunicator::write_result(SessionEvent *event, bool &need_disconnect) RC MysqlCommunicator::send_packet(const BasePacket &packet) { std::vector net_packet; + RC rc = packet.encode(client_capabilities_flag_, net_packet); if (rc != RC::SUCCESS) { LOG_WARN("failed to encode ok packet. rc=%s", strrc(rc)); @@ -821,8 +820,9 @@ RC MysqlCommunicator::send_packet(const BasePacket &packet) RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_disconnect) { RC rc = RC::SUCCESS; + const TupleSchema &tuple_schema = sql_result->tuple_schema(); - const int cell_num = tuple_schema.cell_num(); + const int cell_num = tuple_schema.cell_num(); if (cell_num == 0) { return rc; @@ -831,7 +831,7 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d std::vector net_packet; net_packet.resize(1024); char *buf = net_packet.data(); - int pos = 0; + int pos = 0; pos += 3; store_int1(buf + pos, sequence_id_++); @@ -867,20 +867,20 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d store_int1(buf + pos, sequence_id_++); pos += 1; - const TupleCellSpec &spec = tuple_schema.cell_at(i); - const char *catalog = "def"; // The catalog used. Currently always "def" - const char *schema = "sys"; // schema name - const char *table = spec.table_name(); - const char *org_table = spec.table_name(); - const char *name = spec.alias(); + const TupleCellSpec &spec = tuple_schema.cell_at(i); + const char *catalog = "def"; // The catalog used. Currently always "def" + const char *schema = "sys"; // schema name + const char *table = spec.table_name(); + const char *org_table = spec.table_name(); + const char *name = spec.alias(); // const char *org_name = spec.field_name(); - const char *org_name = spec.alias(); - int fixed_len_fields = 0x0c; - int character_set = 33; - int column_length = 16384; - int type = MYSQL_TYPE_VAR_STRING; - int16_t flags = 0; - int8_t decimals = 0x1f; + const char *org_name = spec.alias(); + int fixed_len_fields = 0x0c; + int character_set = 33; + int column_length = 16384; + int type = MYSQL_TYPE_VAR_STRING; + int16_t flags = 0; + int8_t decimals = 0x1f; pos += store_lenenc_string(buf + pos, catalog); pos += store_lenenc_string(buf + pos, schema); @@ -917,8 +917,8 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d if (!(client_capabilities_flag_ & CLIENT_DEPRECATE_EOF)) { EofPacket eof_packet; eof_packet.packet_header.sequence_id = sequence_id_++; - eof_packet.status_flags = 0x02; - rc = send_packet(eof_packet); + eof_packet.status_flags = 0x02; + rc = send_packet(eof_packet); if (rc != RC::SUCCESS) { need_disconnect = true; LOG_WARN("failed to send eof packet to client. addr=%s, error=%s", addr(), strerror(errno)); @@ -941,11 +941,12 @@ RC MysqlCommunicator::send_column_definition(SqlResult *sql_result, bool &need_d RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def, bool &need_disconnect) { RC rc = RC::SUCCESS; + std::vector packet; packet.resize(4 * 1024 * 1024); // TODO warning: length cannot be fix - int affected_rows = 0; - Tuple *tuple = nullptr; + int affected_rows = 0; + Tuple *tuple = nullptr; while (RC::SUCCESS == (rc = sql_result->next_tuple(tuple))) { assert(tuple != nullptr); @@ -960,7 +961,7 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_row.html // note: if some field is null, send a 0xFB char *buf = packet.data(); - int pos = 0; + int pos = 0; pos += 3; pos += store_int1(buf + pos, sequence_id_++); @@ -991,13 +992,13 @@ RC MysqlCommunicator::send_result_rows(SqlResult *sql_result, bool no_column_def LOG_TRACE("client has CLIENT_DEPRECATE_EOF or has empty column, send ok packet"); OkPacket ok_packet; ok_packet.packet_header.sequence_id = sequence_id_++; - ok_packet.affected_rows = affected_rows; - rc = send_packet(ok_packet); + ok_packet.affected_rows = affected_rows; + rc = send_packet(ok_packet); } else { LOG_TRACE("send eof packet to client"); EofPacket eof_packet; eof_packet.packet_header.sequence_id = sequence_id_++; - rc = send_packet(eof_packet); + rc = send_packet(eof_packet); } LOG_TRACE("send rows to client done"); diff --git a/src/observer/net/mysql_communicator.h b/src/observer/net/mysql_communicator.h index 3bdf026c2..1d6677008 100644 --- a/src/observer/net/mysql_communicator.h +++ b/src/observer/net/mysql_communicator.h @@ -27,7 +27,7 @@ class BasePacket; * 可以参考 [MySQL Page Protocol](https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html) * 或 [MariaDB Protocol](https://mariadb.com/kb/en/clientserver-protocol/) */ -class MysqlCommunicator : public Communicator +class MysqlCommunicator : public Communicator { public: /** @@ -53,14 +53,14 @@ class MysqlCommunicator : public Communicator private: /** * @brief 发送数据包到客户端 - * + * * @param[in] packet 要发送的数据包 */ RC send_packet(const BasePacket &packet); /** * @brief 有些情况下不需要给客户端返回一行行的记录结果,而是返回执行是否成功即可,比如create table等 - * + * * @param[in] event 处理的结果 * @param[out] need_disconnect 是否需要断开连接 */ @@ -75,11 +75,11 @@ class MysqlCommunicator : public Communicator /** * @brief 返回客户端行数据 - * + * * @param[in] sql_result 返回的结果 * @param no_column_def 是否没有列描述信息 * @param[out] need_disconnect 是否需要断开连接 - * @return RC + * @return RC */ RC send_result_rows(SqlResult *sql_result, bool no_column_def, bool &need_disconnect); diff --git a/src/observer/net/plain_communicator.cpp b/src/observer/net/plain_communicator.cpp index e70e6d619..6c5e1604f 100644 --- a/src/observer/net/plain_communicator.cpp +++ b/src/observer/net/plain_communicator.cpp @@ -13,12 +13,14 @@ See the Mulan PSL v2 for more details. */ // #include "net/plain_communicator.h" -#include "net/buffered_writer.h" -#include "sql/expr/tuple.h" -#include "event/session_event.h" -#include "session/session.h" #include "common/io/io.h" #include "common/log/log.h" +#include "event/session_event.h" +#include "net/buffered_writer.h" +#include "session/session.h" +#include "sql/expr/tuple.h" + +using namespace std; PlainCommunicator::PlainCommunicator() { @@ -37,8 +39,8 @@ RC PlainCommunicator::read_event(SessionEvent *&event) int data_len = 0; int read_len = 0; - const int max_packet_size = 8192; - std::vector buf(max_packet_size); + const int max_packet_size = 8192; + vector buf(max_packet_size); // 持续接收消息,直到遇到'\0'。将'\0'遇到的后续数据直接丢弃没有处理,因为目前仅支持一收一发的模式 while (true) { @@ -88,16 +90,16 @@ RC PlainCommunicator::read_event(SessionEvent *&event) LOG_INFO("receive command(size=%d): %s", data_len, buf.data()); event = new SessionEvent(this); - event->set_query(std::string(buf.data())); + event->set_query(string(buf.data())); return rc; } RC PlainCommunicator::write_state(SessionEvent *event, bool &need_disconnect) { - SqlResult *sql_result = event->sql_result(); - const int buf_size = 2048; - char *buf = new char[buf_size]; - const std::string &state_string = sql_result->state_string(); + SqlResult *sql_result = event->sql_result(); + const int buf_size = 2048; + char *buf = new char[buf_size]; + const string &state_string = sql_result->state_string(); if (state_string.empty()) { const char *result = RC::SUCCESS == sql_result->return_code() ? "SUCCESS" : "FAILURE"; snprintf(buf, buf_size, "%s\n", result); @@ -126,7 +128,8 @@ RC PlainCommunicator::write_debug(SessionEvent *request, bool &need_disconnect) } SqlDebug &sql_debug = request->sql_debug(); - const std::list &debug_infos = sql_debug.get_debug_infos(); + + const list &debug_infos = sql_debug.get_debug_infos(); for (auto &debug_info : debug_infos) { RC rc = writer_->writen(debug_message_prefix_.data(), debug_message_prefix_.size()); if (OB_FAIL(rc)) { @@ -143,6 +146,7 @@ RC PlainCommunicator::write_debug(SessionEvent *request, bool &need_disconnect) } char newline = '\n'; + rc = writer_->writen(&newline, 1); if (OB_FAIL(rc)) { LOG_WARN("failed to send new line to client. err=%s", strerror(errno)); @@ -172,13 +176,14 @@ RC PlainCommunicator::write_result(SessionEvent *event, bool &need_disconnect) return rc; } } - writer_->flush(); // TODO handle error + writer_->flush(); // TODO handle error return rc; } RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disconnect) { RC rc = RC::SUCCESS; + need_disconnect = true; SqlResult *sql_result = event->sql_result(); @@ -194,15 +199,16 @@ RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disc return write_state(event, need_disconnect); } - const TupleSchema &schema = sql_result->tuple_schema(); - const int cell_num = schema.cell_num(); + const TupleSchema &schema = sql_result->tuple_schema(); + const int cell_num = schema.cell_num(); for (int i = 0; i < cell_num; i++) { - const TupleCellSpec &spec = schema.cell_at(i); - const char *alias = spec.alias(); + const TupleCellSpec &spec = schema.cell_at(i); + const char *alias = spec.alias(); if (nullptr != alias || alias[0] != 0) { if (0 != i) { const char *delim = " | "; + rc = writer_->writen(delim, strlen(delim)); if (OB_FAIL(rc)) { LOG_WARN("failed to send data to client. err=%s", strerror(errno)); @@ -211,6 +217,7 @@ RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disc } int len = strlen(alias); + rc = writer_->writen(alias, len); if (OB_FAIL(rc)) { LOG_WARN("failed to send data to client. err=%s", strerror(errno)); @@ -222,6 +229,7 @@ RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disc if (cell_num > 0) { char newline = '\n'; + rc = writer_->writen(&newline, 1); if (OB_FAIL(rc)) { LOG_WARN("failed to send data to client. err=%s", strerror(errno)); @@ -231,6 +239,7 @@ RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disc } rc = RC::SUCCESS; + Tuple *tuple = nullptr; while (RC::SUCCESS == (rc = sql_result->next_tuple(tuple))) { assert(tuple != nullptr); @@ -239,6 +248,7 @@ RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disc for (int i = 0; i < cell_num; i++) { if (i != 0) { const char *delim = " | "; + rc = writer_->writen(delim, strlen(delim)); if (OB_FAIL(rc)) { LOG_WARN("failed to send data to client. err=%s", strerror(errno)); @@ -254,7 +264,8 @@ RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disc return rc; } - std::string cell_str = value.to_string(); + string cell_str = value.to_string(); + rc = writer_->writen(cell_str.data(), cell_str.size()); if (OB_FAIL(rc)) { LOG_WARN("failed to send data to client. err=%s", strerror(errno)); @@ -264,6 +275,7 @@ RC PlainCommunicator::write_result_internal(SessionEvent *event, bool &need_disc } char newline = '\n'; + rc = writer_->writen(&newline, 1); if (OB_FAIL(rc)) { LOG_WARN("failed to send data to client. err=%s", strerror(errno)); diff --git a/src/observer/net/plain_communicator.h b/src/observer/net/plain_communicator.h index 9453cccf6..01db59b09 100644 --- a/src/observer/net/plain_communicator.h +++ b/src/observer/net/plain_communicator.h @@ -23,7 +23,7 @@ See the Mulan PSL v2 for more details. */ * @ingroup Communicator * @details 使用简单的文本通讯协议,每个消息使用'\0'结尾 */ -class PlainCommunicator : public Communicator +class PlainCommunicator : public Communicator { public: PlainCommunicator(); @@ -38,6 +38,6 @@ class PlainCommunicator : public Communicator RC write_result_internal(SessionEvent *event, bool &need_disconnect); protected: - std::vector send_message_delimiter_; ///< 发送消息分隔符 - std::vector debug_message_prefix_; ///< 调试信息前缀 + std::vector send_message_delimiter_; ///< 发送消息分隔符 + std::vector debug_message_prefix_; ///< 调试信息前缀 }; diff --git a/src/observer/net/ring_buffer.cpp b/src/observer/net/ring_buffer.cpp index 28c16cb58..7468cbf8f 100644 --- a/src/observer/net/ring_buffer.cpp +++ b/src/observer/net/ring_buffer.cpp @@ -14,23 +14,18 @@ See the Mulan PSL v2 for more details. */ #include -#include "net/ring_buffer.h" #include "common/log/log.h" +#include "net/ring_buffer.h" using namespace std; const int32_t DEFAULT_BUFFER_SIZE = 16 * 1024; -RingBuffer::RingBuffer() - : RingBuffer(DEFAULT_BUFFER_SIZE) -{} +RingBuffer::RingBuffer() : RingBuffer(DEFAULT_BUFFER_SIZE) {} -RingBuffer::RingBuffer(int32_t size) - : buffer_(size) -{} +RingBuffer::RingBuffer(int32_t size) : buffer_(size) {} -RingBuffer::~RingBuffer() -{} +RingBuffer::~RingBuffer() {} RC RingBuffer::read(char *buf, int32_t size, int32_t &read_size) { @@ -38,12 +33,12 @@ RC RingBuffer::read(char *buf, int32_t size, int32_t &read_size) return RC::INVALID_ARGUMENT; } - RC rc = RC::SUCCESS; + RC rc = RC::SUCCESS; read_size = 0; - while (OB_SUCC(rc) && read_size < size && this->size() > 0) { - const char *tmp_buf = nullptr; - int32_t tmp_size = 0; - rc = buffer(tmp_buf, tmp_size); + while (OB_SUCC(rc) && read_sizesize()> 0) { + const char *tmp_buf = nullptr; + int32_t tmp_size = 0; + rc = buffer(tmp_buf, tmp_size); if (OB_SUCC(rc)) { int32_t copy_size = min(size - read_size, tmp_size); memcpy(buf + read_size, tmp_buf, copy_size); @@ -60,7 +55,7 @@ RC RingBuffer::buffer(const char *&buf, int32_t &read_size) { const int32_t size = this->size(); if (size == 0) { - buf = buffer_.data(); + buf = buffer_.data(); read_size = 0; return RC::SUCCESS; } @@ -96,11 +91,11 @@ RC RingBuffer::write(const char *data, int32_t size, int32_t &write_size) return RC::INVALID_ARGUMENT; } - RC rc = RC::SUCCESS; + RC rc = RC::SUCCESS; write_size = 0; - while (OB_SUCC(rc) && write_size < size && this->remain() > 0) { + while (OB_SUCC(rc) && write_sizeremain()> 0) { - const int32_t read_pos = this->read_pos(); + const int32_t read_pos = this->read_pos(); const int32_t tmp_buf_size = (read_pos <= write_pos_) ? (capacity() - write_pos_) : (read_pos - write_pos_); const int32_t copy_size = min(size - write_size, tmp_buf_size); diff --git a/src/observer/net/ring_buffer.h b/src/observer/net/ring_buffer.h index d25b3fe82..5e4bf84dc 100644 --- a/src/observer/net/ring_buffer.h +++ b/src/observer/net/ring_buffer.h @@ -32,7 +32,7 @@ class RingBuffer /** * @brief 指定初始化大小的构造函数 - * + * */ explicit RingBuffer(int32_t size); @@ -86,9 +86,9 @@ class RingBuffer private: int32_t read_pos() const { return (write_pos_ - this->size() + capacity()) % capacity(); } - + private: - std::vector buffer_; ///< 缓存使用的内存,使用vector方便管理 - int32_t data_size_ = 0; ///< 已经写入的数据量 - int32_t write_pos_ = 0; ///< 当前写指针的位置,范围不会超出[0, capacity) + std::vector buffer_; ///< 缓存使用的内存,使用vector方便管理 + int32_t data_size_ = 0; ///< 已经写入的数据量 + int32_t write_pos_ = 0; ///< 当前写指针的位置,范围不会超出[0, capacity) }; \ No newline at end of file diff --git a/src/observer/net/server.cpp b/src/observer/net/server.cpp index e6692aa6e..bab3917f5 100644 --- a/src/observer/net/server.cpp +++ b/src/observer/net/server.cpp @@ -16,6 +16,7 @@ See the Mulan PSL v2 for more details. */ #include #include +#include #include #include #include @@ -26,16 +27,15 @@ See the Mulan PSL v2 for more details. */ #include #include #include -#include +#include "common/ini_setting.h" +#include "common/io/io.h" #include "common/lang/mutex.h" #include "common/log/log.h" -#include "common/io/io.h" #include "common/seda/seda_config.h" #include "event/session_event.h" -#include "session/session.h" -#include "common/ini_setting.h" #include "net/communicator.h" +#include "session/session.h" using namespace common; @@ -43,14 +43,12 @@ Stage *Server::session_stage_ = nullptr; ServerParam::ServerParam() { - listen_addr = INADDR_ANY; + listen_addr = INADDR_ANY; max_connection_num = MAX_CONNECTION_NUM_DEFAULT; - port = PORT_DEFAULT; + port = PORT_DEFAULT; } -Server::Server(ServerParam input_server_param) : server_param_(input_server_param) -{ -} +Server::Server(ServerParam input_server_param) : server_param_(input_server_param) {} Server::~Server() { @@ -59,10 +57,7 @@ Server::~Server() } } -void Server::init() -{ - session_stage_ = get_seda_config()->get_stage(SESSION_STAGE_NAME); -} +void Server::init() { session_stage_ = get_seda_config()->get_stage(SESSION_STAGE_NAME); } int Server::set_non_block(int fd) { @@ -92,6 +87,7 @@ void Server::recv(int fd, short ev, void *arg) Communicator *comm = (Communicator *)arg; SessionEvent *event = nullptr; + RC rc = comm->read_event(event); if (rc != RC::SUCCESS) { close_connection(comm); @@ -107,9 +103,9 @@ void Server::recv(int fd, short ev, void *arg) void Server::accept(int fd, short ev, void *arg) { - Server *instance = (Server *)arg; + Server *instance = (Server *)arg; struct sockaddr_in addr; - socklen_t addrlen = sizeof(addr); + socklen_t addrlen = sizeof(addr); int ret = 0; @@ -139,7 +135,7 @@ void Server::accept(int fd, short ev, void *arg) if (!instance->server_param_.use_unix_socket) { // unix socket不支持设置NODELAY int yes = 1; - ret = setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes)); + ret = setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes)); if (ret < 0) { LOG_ERROR("Failed to set socket of %s option as : TCP_NODELAY %s\n", addr_str.c_str(), strerror(errno)); ::close(client_fd); @@ -148,6 +144,7 @@ void Server::accept(int fd, short ev, void *arg) } Communicator *communicator = instance->communicator_factory_.create(instance->server_param_.protocol); + RC rc = communicator->init(client_fd, new Session(Session::default_session()), addr_str); if (rc != RC::SUCCESS) { LOG_WARN("failed to init communicator. rc=%s", strrc(rc)); @@ -188,7 +185,7 @@ int Server::start() int Server::start_tcp_server() { - int ret = 0; + int ret = 0; struct sockaddr_in sa; server_socket_ = socket(AF_INET, SOCK_STREAM, 0); @@ -198,7 +195,7 @@ int Server::start_tcp_server() } int yes = 1; - ret = setsockopt(server_socket_, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)); + ret = setsockopt(server_socket_, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)); if (ret < 0) { LOG_ERROR("Failed to set socket option of reuse address: %s.", strerror(errno)); ::close(server_socket_); @@ -213,8 +210,8 @@ int Server::start_tcp_server() } memset(&sa, 0, sizeof(sa)); - sa.sin_family = AF_INET; - sa.sin_port = htons(server_param_.port); + sa.sin_family = AF_INET; + sa.sin_port = htons(server_param_.port); sa.sin_addr.s_addr = htonl(server_param_.listen_addr); ret = ::bind(server_socket_, (struct sockaddr *)&sa, sizeof(sa)); @@ -253,7 +250,7 @@ int Server::start_tcp_server() int Server::start_unix_socket_server() { - int ret = 0; + int ret = 0; server_socket_ = socket(PF_UNIX, SOCK_STREAM, 0); if (server_socket_ < 0) { LOG_ERROR("socket(): can not create unix socket: %s.", strerror(errno)); @@ -311,6 +308,7 @@ int Server::start_unix_socket_server() int Server::start_stdin_server() { Communicator *communicator = communicator_factory_.create(server_param_.protocol); + RC rc = communicator->init(STDIN_FILENO, new Session(Session::default_session()), "stdin"); if (OB_FAIL(rc)) { LOG_WARN("failed to init cli communicator. rc=%s", strrc(rc)); @@ -321,6 +319,7 @@ int Server::start_stdin_server() while (started_) { SessionEvent *event = nullptr; + rc = communicator->read_event(event); if (OB_FAIL(rc)) { LOG_WARN("failed to read event. rc=%s", strrc(rc)); diff --git a/src/observer/net/server.h b/src/observer/net/server.h index 08172e0f1..6325d3170 100644 --- a/src/observer/net/server.h +++ b/src/observer/net/server.h @@ -14,7 +14,6 @@ See the Mulan PSL v2 for more details. */ #pragma once -#include "common/defs.h" #include "common/seda/stage.h" #include "net/server_param.h" @@ -26,7 +25,7 @@ class Communicator; * @details 当前支持网络连接,有TCP和Unix Socket两种方式。通过命令行参数来指定使用哪种方式。 * 启动后监听端口或unix socket,使用libevent来监听事件,当有新的连接到达时,创建一个Communicator对象进行处理。 */ -class Server +class Server { public: Server(ServerParam input_server_param); @@ -37,7 +36,7 @@ class Server static void close_connection(Communicator *comm); public: - int serve(); + int serve(); void shutdown(); private: @@ -61,7 +60,7 @@ class Server private: /** * @brief 将socket描述符设置为非阻塞模式 - * + * * @param fd 指定的描述符 */ int set_non_block(int fd); @@ -83,13 +82,13 @@ class Server private: volatile bool started_ = false; - int server_socket_ = -1; ///< 监听套接字,是一个描述符 - struct event_base *event_base_ = nullptr; ///< libevent对象 - struct event *listen_ev_ = nullptr; ///< libevent监听套接字事件 + int server_socket_ = -1; ///< 监听套接字,是一个描述符 + struct event_base *event_base_ = nullptr; ///< libevent对象 + struct event *listen_ev_ = nullptr; ///< libevent监听套接字事件 ServerParam server_param_; ///< 服务启动参数 - CommunicatorFactory communicator_factory_; ///< 通过这个对象创建新的Communicator对象 + CommunicatorFactory communicator_factory_; ///< 通过这个对象创建新的Communicator对象 static common::Stage *session_stage_; ///< 通过这个对象创建新的请求任务 }; diff --git a/src/observer/net/server_param.h b/src/observer/net/server_param.h index d2031990d..1a196784e 100644 --- a/src/observer/net/server_param.h +++ b/src/observer/net/server_param.h @@ -14,20 +14,20 @@ See the Mulan PSL v2 for more details. */ #pragma once -#include #include "net/communicator.h" +#include /** * @brief 服务端启动参数 * @ingroup Communicator */ -class ServerParam +class ServerParam { public: ServerParam(); ServerParam(const ServerParam &other) = default; - ~ServerParam() = default; + ~ServerParam() = default; public: // accpet client's address, default is INADDR_ANY, means accept every address @@ -35,9 +35,9 @@ class ServerParam int max_connection_num; ///< 最大连接数 - int port; ///< 监听的端口号 + int port; ///< 监听的端口号 - std::string unix_socket_path; ///< unix socket的路径 + std::string unix_socket_path; ///< unix socket的路径 bool use_std_io = false; ///< 是否使用标准输入输出作为通信条件 @@ -45,5 +45,5 @@ class ServerParam ///< 后面如果改成支持多种通讯方式,就不需要这个参数了 bool use_unix_socket = false; - CommunicateProtocol protocol; ///< 通讯协议,目前支持文本协议和mysql协议 + CommunicateProtocol protocol; ///< 通讯协议,目前支持文本协议和mysql协议 };