diff --git a/CMakeLists.txt b/CMakeLists.txt index 6071a4ec8..4326f2359 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,8 @@ find_library(BOOSTPO_LIBRARY NAMES boost_program_options) find_library(BOOSTSYSTEM_LIBRARY NAMES boost_system) find_path(TCMALLOC_INCLUDE_DIR gperftools/heap-profiler.h) find_library(TCMALLOC_LIBRARY NAMES tcmalloc_and_profiler) +find_path(MARIADBCLIENT_INCLUDE_DIR mysql/mysql.h) +find_library(MARIADBCLIENT_LIBRARY NAMES mysqlclient) find_path(HIREDIS_INCLUDE_DIR hiredis/hiredis.h) find_library(HIREDIS_LIBRARY NAMES hiredis) find_path(GD_INCLUDE_DIR gd.h) @@ -86,6 +88,9 @@ endif() if(ZLIB_FOUND) message(STATUS "found zlib") endif() +if(MARIADBCLIENT_INCLUDE_DIR AND MARIADBCLIENT_LIBRARY) + message(STATUS "found mariadbclient") +endif() if(HIREDIS_INCLUDE_DIR AND HIREDIS_LIBRARY) message(STATUS "found hiredis") endif() diff --git a/contrib/CMakeLists.txt b/contrib/CMakeLists.txt index 8e4a50ea1..adcbf9e86 100644 --- a/contrib/CMakeLists.txt +++ b/contrib/CMakeLists.txt @@ -4,6 +4,12 @@ else() add_subdirectory(hiredis EXCLUDE_FROM_ALL) endif() +if(MARIADBCLIENT_INCLUDE_DIR AND MARIADBCLIENT_LIBRARY) + add_subdirectory(mariadbclient) +else() + add_subdirectory(mariadbclient EXCLUDE_FROM_ALL) +endif() + if(THRIFT_COMPILER AND THRIFT_INCLUDE_DIR AND THRIFT_LIBRARY) add_subdirectory(thrift) else() diff --git a/contrib/mariadbclient/CMakeLists.txt b/contrib/mariadbclient/CMakeLists.txt new file mode 100644 index 000000000..ca3a0ade2 --- /dev/null +++ b/contrib/mariadbclient/CMakeLists.txt @@ -0,0 +1,5 @@ +add_library(muduo_mariadbclient MariaDBClient.cc) +target_link_libraries(muduo_mariadbclient muduo_net mysqlclient) + +add_executable(mmariadbclient mmariadbclient.cc) +target_link_libraries(mmariadbclient muduo_mariadbclient) diff --git a/contrib/mariadbclient/MariaDBClient.cc b/contrib/mariadbclient/MariaDBClient.cc new file mode 100644 index 000000000..f48c35a3c --- /dev/null +++ b/contrib/mariadbclient/MariaDBClient.cc @@ -0,0 +1,371 @@ +#include "MariaDBClient.h" + +#include +#include +#include +#include + +#include + +#define NEXT_IMMEDIATE(newState) do { state = newState; goto again; } while (0) + +using namespace mariadbclient; +using namespace muduo; +using namespace muduo::net; + +static void dummy(const std::shared_ptr&) +{ +} + +MariaDBClient::MariaDBClient(EventLoop* loop, + const InetAddress& serverAddr, + const string& user, + const string& password, + const string& db) + : loop_(loop), + serverAddr_(serverAddr), + user_(user), + password_(password), + db_(db), + isConnected_(false) +{ +} + +MariaDBClient::~MariaDBClient() +{ + if (isConnected_) + { + disconnect(); + } +} + +void MariaDBClient::connect() +{ + assert(!isConnected_); + + ::mysql_init(&mysql_); + ::mysql_options(&mysql_, MYSQL_OPT_NONBLOCK, 0); + + stateMachineHandler(kRealConnectStart); +} + +void MariaDBClient::disconnect() +{ + assert(isConnected_); + + stateMachineHandler(kCloseStart); +} + +void MariaDBClient::executeUpdate(StringArg sql, const UpdateCallback& cb) +{ + assert(isConnected_); + + sqlQueue_.emplace_back(new SQLData(SQLData::kUpdate, sql.c_str(), cb)); + + if (sqlQueue_.size() == 1) + { + stateMachineHandler(kRealQueryStart); + } +} + +void MariaDBClient::executeQuery(StringArg sql, const QueryCallback& cb) +{ + assert(isConnected_); + + sqlQueue_.emplace_back(new SQLData(SQLData::kQuery, sql.c_str(), cb)); + + if (sqlQueue_.size() == 1) + { + stateMachineHandler(kRealQueryStart); + } +} + +void MariaDBClient::stateMachineHandler(int state, int revents, Timestamp receiveTime) +{ + loop_->assertInLoopThread(); + + int mysqlRevents = toMySQLEvents(revents); + + static MYSQL* ret = NULL; + static int err = 0; + static MYSQL_RES* res = NULL; + + again: + switch (state) + { + case kRealConnectStart: + { + int mysqlEvents = ::mysql_real_connect_start(&ret, + &mysql_, + serverAddr_.toIp().c_str(), + user_.c_str(), + password_.c_str(), + db_.c_str(), + implicit_cast(serverAddr_.toPort()), + NULL, + 0); + if (mysqlEvents != 0) + { + channel_.reset(new Channel(loop_, fd(), false)); + + channel_->setEventsCallback(std::bind(&MariaDBClient::stateMachineHandler, this, kRealConnectCont, _1, _2)); + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kRealConnectEnd); + } + } + break; + + case kRealConnectCont: + { + int mysqlEvents = ::mysql_real_connect_cont(&ret, &mysql_, mysqlRevents); + if (mysqlEvents != 0) + { + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kRealConnectEnd); + } + } + break; + + case kRealConnectEnd: + { + if (ret == NULL) + { + LOG_ERROR << "Failed to mysql_real_connect(): " << errorStr(); + } + else + { + logConnection(true); + isConnected_ = true; + channel_->setEventsCallback(Channel::EventsCallback()); + channel_->disableAll(); + } + + if (connectCb_) + { + connectCb_(this); + } + } + break; + + case kRealQueryStart: + { + int mysqlEvents = ::mysql_real_query_start(&err, + &mysql_, + sqlQueue_.front()->sql_.c_str(), + sqlQueue_.front()->sql_.size()); + if (mysqlEvents != 0) + { + channel_->setEventsCallback(std::bind(&MariaDBClient::stateMachineHandler, this, kRealQueryCont, _1, _2)); + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kRealQueryEnd); + } + } + break; + + case kRealQueryCont: + { + int mysqlEvents = ::mysql_real_query_cont(&err, &mysql_, mysqlRevents); + if (mysqlEvents != 0) + { + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kRealQueryEnd); + } + } + break; + + case kRealQueryEnd: + { + if (sqlQueue_.front()->type_ == SQLData::kUpdate) + { + if (sqlQueue_.front()->updateCb_) + { + sqlQueue_.front()->updateCb_(this); + } + sqlQueue_.pop_front(); + + if (!sqlQueue_.empty()) + { + NEXT_IMMEDIATE(kRealQueryStart); + } + } + else + { + assert(sqlQueue_.front()->type_ == SQLData::kQuery); + if (err != 0) + { + LOG_ERROR << "mysql_real_query() returns error: " << errorStr(); + if (sqlQueue_.front()->queryCb_) + { + sqlQueue_.front()->queryCb_(this, NULL); + } + sqlQueue_.pop_front(); + + if (!sqlQueue_.empty()) + { + NEXT_IMMEDIATE(kRealQueryStart); + } + } + else + { + NEXT_IMMEDIATE(kStoreResultStart); + } + } + } + break; + + case kStoreResultStart: + { + int mysqlEvents = ::mysql_store_result_start(&res, &mysql_); + if (mysqlEvents != 0) + { + channel_->setEventsCallback(std::bind(&MariaDBClient::stateMachineHandler, this, kStoreResultCont, _1, _2)); + + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kStoreResultEnd); + } + } + break; + + case kStoreResultCont: + { + int mysqlEvents = ::mysql_store_result_cont(&res, &mysql_, mysqlRevents); + if (mysqlEvents != 0) + { + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kStoreResultEnd); + } + } + break; + + case kStoreResultEnd: + { + if (sqlQueue_.front()->queryCb_) + { + sqlQueue_.front()->queryCb_(this, res); + } + else + { + if (res != NULL) + { + ::mysql_free_result(res); + } + else + { + assert(::mysql_field_count(&mysql_) != 0); + assert(::mysql_errno(&mysql_) != 0); + LOG_ERROR << "Got error while storing result: " << errorStr(); + } + } + sqlQueue_.pop_front(); + + if (!sqlQueue_.empty()) + { + NEXT_IMMEDIATE(kRealQueryStart); + } + } + break; + + case kCloseStart: + { + logConnection(false); + int mysqlEvents = ::mysql_close_start(&mysql_); + if (mysqlEvents != 0) + { + channel_->setEventsCallback(std::bind(&MariaDBClient::stateMachineHandler, this, kCloseCont, _1, _2)); + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kCloseEnd); + } + } + break; + + case kCloseCont: + { + int mysqlEvents = ::mysql_close_cont(&mysql_, mysqlRevents); + if (mysqlEvents != 0) + { + int events = toEvents(mysqlEvents); + channel_->enableEvents(events); + } + else + { + NEXT_IMMEDIATE(kCloseEnd); + } + } + break; + + case kCloseEnd: + { + isConnected_ = false; + channel_->disableAll(); + channel_->remove(); + loop_->queueInLoop(std::bind(dummy, channel_)); + channel_.reset(); + + if (disconnectCb_) + { + disconnectCb_(this); + } + } + break; + + default: + { + abort(); + } + } +} + +void MariaDBClient::logConnection(bool up) const +{ + InetAddress localAddr(sockets::getLocalAddr(fd())); + InetAddress peerAddr(sockets::getPeerAddr(fd())); + + LOG_INFO << localAddr.toIpPort() << " -> " + << peerAddr.toIpPort() << " is " + << (up ? "UP" : "DOWN"); +} + +int MariaDBClient::toEvents(int mysqlEvents) +{ + int events = (mysqlEvents & MYSQL_WAIT_READ ? POLLIN : 0) + | (mysqlEvents & MYSQL_WAIT_WRITE ? POLLOUT : 0) + | (mysqlEvents & MYSQL_WAIT_EXCEPT ? POLLPRI : 0); + + return events; +} + +int MariaDBClient::toMySQLEvents(int events) +{ + int mysqlEvents = (events & POLLIN ? MYSQL_WAIT_READ : 0) + | (events & POLLPRI ? MYSQL_WAIT_EXCEPT : 0) + | (events & POLLOUT ? MYSQL_WAIT_WRITE : 0); + return mysqlEvents; +} diff --git a/contrib/mariadbclient/MariaDBClient.h b/contrib/mariadbclient/MariaDBClient.h new file mode 100644 index 000000000..74dabdbce --- /dev/null +++ b/contrib/mariadbclient/MariaDBClient.h @@ -0,0 +1,125 @@ +#ifndef MUDUO_CONTRIB_MARIADBCLIENT_MARIADBCLIENT_H +#define MUDUO_CONTRIB_MARIADBCLIENT_MARIADBCLIENT_H + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace muduo +{ +namespace net +{ +class Channel; +class EventLoop; +} +} + +namespace mariadbclient +{ +using muduo::string; + +class MariaDBClient : muduo::noncopyable +{ + public: + typedef std::function ConnectCallback; + typedef std::function DisconnectCallback; + + typedef std::function UpdateCallback; // similar to WriteCompleteCallback + typedef std::function QueryCallback; + + enum State + { + kRealConnectStart, + kRealConnectCont, + kRealConnectEnd, + + kRealQueryStart, + kRealQueryCont, + kRealQueryEnd, + + kStoreResultStart, + kStoreResultCont, + kStoreResultEnd, + + kCloseStart, + kCloseCont, + kCloseEnd, + }; + + MariaDBClient(muduo::net::EventLoop* loop, + const muduo::net::InetAddress& serverAddr, + const string& user, + const string& password, + const string& db); + ~MariaDBClient(); + + void setConnectCallback(const ConnectCallback& cb) { connectCb_ = cb; } + void setDisconnectCallback(const DisconnectCallback& cb) { disconnectCb_ = cb; } + + void connect(); + void disconnect(); + + // INSERT, UPDATE, or DELETE + void executeUpdate(muduo::StringArg sql, const UpdateCallback& cb = UpdateCallback()); + // SELECT, SHOW, DESCRIBE, EXPLAIN, CHECK TABLE, and so forth + void executeQuery(muduo::StringArg sql, const QueryCallback& cb = QueryCallback()); + + uint32_t errorNo() { return ::mysql_errno(&mysql_); } + const char* errorStr() { return ::mysql_error(&mysql_); } + + private: + void stateMachineHandler(int state, int revents = -1, muduo::Timestamp receiveTime = muduo::Timestamp::invalid()); + + void logConnection(bool up) const; + int fd() const { return ::mysql_get_socket(&mysql_); } + + static int toEvents(int mysqlEvents); + static int toMySQLEvents(int events); + + muduo::net::EventLoop* loop_; + const muduo::net::InetAddress serverAddr_; + const string user_; + const string password_; + const string db_; + std::shared_ptr channel_; + bool isConnected_; + ConnectCallback connectCb_; + DisconnectCallback disconnectCb_; + + MYSQL mysql_; + + struct SQLData + { + enum Type + { + kUpdate, + kQuery + }; + + SQLData(Type type, const string& sql, const UpdateCallback& cb) + : type_(type), sql_(sql), updateCb_(cb) {} + SQLData(Type type, const string& sql, const QueryCallback& cb) + : type_(type), sql_(sql), queryCb_(cb) {} + ~SQLData() {} + + Type type_; + string sql_; + union + { + UpdateCallback updateCb_; + QueryCallback queryCb_; + }; + }; + + std::deque> sqlQueue_; +}; + +} // namespace mariadbclient + +#endif // MUDUO_CONTRIB_MARIADBCLIENT_MARIADBCLIENT_H diff --git a/contrib/mariadbclient/mmariadbclient.cc b/contrib/mariadbclient/mmariadbclient.cc new file mode 100644 index 000000000..e246775d7 --- /dev/null +++ b/contrib/mariadbclient/mmariadbclient.cc @@ -0,0 +1,116 @@ +#include "MariaDBClient.h" + +#include +#include + +using namespace muduo; +using namespace muduo::net; + +static uint64_t g_seqid = 0; + +void updateCallback(mariadbclient::MariaDBClient* c, uint64_t id) +{ + LOG_INFO << "seq id: " << id << "\terrorNo: " << c->errorNo() << "\terrorStr: " << c->errorStr(); +} + +void queryCallback(mariadbclient::MariaDBClient* c, MYSQL_RES* result, uint64_t id) +{ + LOG_INFO << "seq id: " << id << "\terrorNo: " << c->errorNo() << "\terrorStr: " << c->errorStr(); + + uint32_t numFields = ::mysql_num_fields(result); + MYSQL_ROW row; + while ((row = ::mysql_fetch_row(result)) != NULL) + { + uint64_t* lengths = ::mysql_fetch_lengths(result); + for (uint32_t i = 0; i < numFields; ++i) + { + printf("[%.*s] ", static_cast(lengths[i]), row[i] ? row[i] : "NULL"); + } + printf("\n"); + } + + ::mysql_free_result(result); +} + +void connectCallback(mariadbclient::MariaDBClient* c) +{ + if (c->errorNo() == 0) + { + LOG_INFO << "Connected... " << "\terrorNo: " << c->errorNo() << "\terrorStr: " << c->errorStr(); + + string sql0("DROP TABLE muduo_user"); + string sql1("CREATE TABLE muduo_user (" + "id INT(11) NOT NULL AUTO_INCREMENT," + "nick VARCHAR(64) NOT NULL," + "PRIMARY KEY (id)" + ")"); + string sql2("SELECT id, nick " + "FROM muduo_user"); + string sql3("INSERT INTO muduo_user (id, nick)" + "VALUES (1, 'ChenShuo')"); + string sql4("INSERT INTO muduo_user (nick)" + "VALUES ('Jack')"); + string sql5("INSERT INTO muduo_user (nick)" + "VALUES ('Lucy')"); + string sql6("UPDATE muduo_user " + "SET nick = 'Tom' " + "WHERE id = 2"); + string sql7("SELECT id, nick " + "FROM muduo_user"); + string sql8("DELETE FROM muduo_user " + "WHERE id > 1"); + string sql9("SELECT id, nick " + "FROM muduo_user"); + + c->executeUpdate(sql0, std::bind(&updateCallback, _1, g_seqid++)); + c->executeUpdate(sql1, std::bind(&updateCallback, _1, g_seqid++)); + c->executeQuery(sql2, std::bind(&queryCallback, _1, _2, g_seqid++)); + c->executeUpdate(sql3, std::bind(&updateCallback, _1, g_seqid++)); + c->executeUpdate(sql4); g_seqid++; + c->executeUpdate(sql5); g_seqid++; + c->executeUpdate(sql6, std::bind(&updateCallback, _1, g_seqid++)); + c->executeQuery(sql7, std::bind(&queryCallback, _1, _2, g_seqid++)); + c->executeUpdate(sql8, std::bind(&updateCallback, _1, g_seqid++)); + c->executeQuery(sql9, std::bind(&queryCallback, _1, _2, g_seqid++)); + } + else + { + LOG_ERROR << "connectCallback Error: " << "\terrorNo: " << c->errorNo() << "\terrorStr: " << c->errorStr(); + } +} + +void disconnectCallback(mariadbclient::MariaDBClient* c, EventLoop* loop) +{ + if (c->errorNo() == 0) + { + LOG_INFO << "Disconnected... " << "\terrorNo: " << c->errorNo() << "\terrorStr: " << c->errorStr(); + } + else + { + LOG_ERROR << "disconnectCallback Error: " << "\terrorNo: " << c->errorNo() << "\terrorStr: " << c->errorStr(); + } + + loop->quit(); +} + +int main(int argc, char* argv[]) +{ + int err = mysql_library_init(0, NULL, NULL); + if (err != 0) + { + LOG_FATAL << "mysql_library_init() returns error: " << err; + } + + EventLoop loop; + mariadbclient::MariaDBClient mariadbClient(&loop, InetAddress("127.0.0.1", 3306), "root", "123456", "test"); + mariadbClient.setConnectCallback(std::bind(&connectCallback, _1)); + mariadbClient.setDisconnectCallback(std::bind(&disconnectCallback, _1, &loop)); + mariadbClient.connect(); + loop.runAfter(5, std::bind(&mariadbclient::MariaDBClient::disconnect, &mariadbClient)); + + loop.loop(); + + mysql_library_end(); + + return 0; +} diff --git a/muduo/net/Channel.cc b/muduo/net/Channel.cc index 1e9a40ae7..d559d38be 100644 --- a/muduo/net/Channel.cc +++ b/muduo/net/Channel.cc @@ -21,7 +21,7 @@ const int Channel::kNoneEvent = 0; const int Channel::kReadEvent = POLLIN | POLLPRI; const int Channel::kWriteEvent = POLLOUT; -Channel::Channel(EventLoop* loop, int fd__) +Channel::Channel(EventLoop* loop, int fd__, bool classify) : loop_(loop), fd_(fd__), events_(0), @@ -30,7 +30,8 @@ Channel::Channel(EventLoop* loop, int fd__) logHup_(true), tied_(false), eventHandling_(false), - addedToLoop_(false) + addedToLoop_(false), + classify_(classify) { } @@ -84,31 +85,38 @@ void Channel::handleEventWithGuard(Timestamp receiveTime) { eventHandling_ = true; LOG_TRACE << reventsToString(); - if ((revents_ & POLLHUP) && !(revents_ & POLLIN)) + if (classify_) { - if (logHup_) + if ((revents_ & POLLHUP) && !(revents_ & POLLIN)) { - LOG_WARN << "fd = " << fd_ << " Channel::handle_event() POLLHUP"; + if (logHup_) + { + LOG_WARN << "fd = " << fd_ << " Channel::handle_event() POLLHUP"; + } + if (closeCallback_) closeCallback_(); } - if (closeCallback_) closeCallback_(); - } - if (revents_ & POLLNVAL) - { - LOG_WARN << "fd = " << fd_ << " Channel::handle_event() POLLNVAL"; - } + if (revents_ & POLLNVAL) + { + LOG_WARN << "fd = " << fd_ << " Channel::handle_event() POLLNVAL"; + } - if (revents_ & (POLLERR | POLLNVAL)) - { - if (errorCallback_) errorCallback_(); - } - if (revents_ & (POLLIN | POLLPRI | POLLRDHUP)) - { - if (readCallback_) readCallback_(receiveTime); + if (revents_ & (POLLERR | POLLNVAL)) + { + if (errorCallback_) errorCallback_(); + } + if (revents_ & (POLLIN | POLLPRI | POLLRDHUP)) + { + if (readCallback_) readCallback_(receiveTime); + } + if (revents_ & POLLOUT) + { + if (writeCallback_) writeCallback_(); + } } - if (revents_ & POLLOUT) + else { - if (writeCallback_) writeCallback_(); + if (eventsCallback_) eventsCallback_(revents_, receiveTime); } eventHandling_ = false; } diff --git a/muduo/net/Channel.h b/muduo/net/Channel.h index bafa224fa..91b9f5154 100644 --- a/muduo/net/Channel.h +++ b/muduo/net/Channel.h @@ -36,7 +36,9 @@ class Channel : noncopyable typedef std::function EventCallback; typedef std::function ReadEventCallback; - Channel(EventLoop* loop, int fd); + typedef std::function EventsCallback; + + Channel(EventLoop* loop, int fd, bool classify = true); ~Channel(); void handleEvent(Timestamp receiveTime); @@ -48,6 +50,8 @@ class Channel : noncopyable { closeCallback_ = std::move(cb); } void setErrorCallback(EventCallback cb) { errorCallback_ = std::move(cb); } + void setEventsCallback(EventsCallback cb) + { eventsCallback_ = std::move(cb); } /// Tie this channel to the owner object managed by shared_ptr, /// prevent the owner object being destroyed in handleEvent. @@ -66,6 +70,7 @@ class Channel : noncopyable void disableAll() { events_ = kNoneEvent; update(); } bool isWriting() const { return events_ & kWriteEvent; } bool isReading() const { return events_ & kReadEvent; } + void enableEvents(int events__) { events_ = events__; update(); } // for Poller int index() { return index_; } @@ -105,6 +110,9 @@ class Channel : noncopyable EventCallback writeCallback_; EventCallback closeCallback_; EventCallback errorCallback_; + + EventsCallback eventsCallback_; + bool classify_; }; } // namespace net diff --git a/muduo/net/poller/EPollPoller.cc b/muduo/net/poller/EPollPoller.cc index b2f913a4c..c5c01087a 100644 --- a/muduo/net/poller/EPollPoller.cc +++ b/muduo/net/poller/EPollPoller.cc @@ -182,11 +182,11 @@ void EPollPoller::update(int operation, Channel* channel) { if (operation == EPOLL_CTL_DEL) { - LOG_SYSERR << "epoll_ctl op =" << operationToString(operation) << " fd =" << fd; + LOG_SYSERR << "epoll_ctl op = " << operationToString(operation) << " fd =" << fd; } else { - LOG_SYSFATAL << "epoll_ctl op =" << operationToString(operation) << " fd =" << fd; + LOG_SYSFATAL << "epoll_ctl op = " << operationToString(operation) << " fd =" << fd; } } }