diff --git a/src/http/http_client.cc b/src/http/http_client.cc index 89f4ac1a..5aadd480 100644 --- a/src/http/http_client.cc +++ b/src/http/http_client.cc @@ -105,7 +105,10 @@ void HttpClient::callback(WFHttpTask *task) // http_get_header_module_data(task->get_resp(), *resp_data); for (RPCModule *module : module_list) - module->client_task_end(task, *resp_data); + { + if (!module->client_task_end(task, *resp_data)) + break; + } } if (client_task->user_callback_) diff --git a/src/http/http_task.cc b/src/http/http_task.cc index 72839b2f..2c432571 100644 --- a/src/http/http_task.cc +++ b/src/http/http_task.cc @@ -62,11 +62,8 @@ std::string HttpClientTask::get_uri_scheme() const return ""; } -CommMessageOut *HttpClientTask::message_out() +bool HttpClientTask::check_request() { - HttpRequest *req = this->get_req(); - struct HttpMessageHeader header; - bool is_alive; void *series_data = series_of(this)->get_specific(SRPC_MODULE_DATA); RPCModuleData *data = (RPCModuleData *)series_data; @@ -75,11 +72,25 @@ CommMessageOut *HttpClientTask::message_out() data = this->mutable_module_data(); for (auto *module : modules_) - module->client_task_begin(this, *data); + { + if (!module->client_task_begin(this, *data)) + { + this->state = WFT_STATE_TASK_ERROR; + this->error = RPCStatusModuleFilterFailed; + return false; + } + } - http_set_header_module_data(*data, req); + http_set_header_module_data(*data, this->get_req()); + return true; +} + +CommMessageOut *HttpClientTask::message_out() +{ + HttpRequest *req = this->get_req(); + struct HttpMessageHeader header; + bool is_alive; - // from ComplexHttpTask::message_out() if (!req->is_chunked() && !req->has_content_length_header()) { size_t body_size = req->get_output_body_size(); @@ -387,9 +398,6 @@ void HttpServerTask::handle(int state, int error) } } - this->state = WFT_STATE_TOREPLY; - this->target = this->get_target(); - // fill module data from request to series ModuleSeries *series = new ModuleSeries(this); @@ -397,10 +405,20 @@ void HttpServerTask::handle(int state, int error) for (auto *module : this->modules_) { if (module) - module->server_task_begin(this, this->module_data_); + { + if (!module->server_task_begin(this, this->module_data_)) + { + delete this; + return; + } + } } series->set_module_data(this->mutable_module_data()); + + this->state = WFT_STATE_TOREPLY; + this->target = this->get_target(); + series->start(); } else if (this->state == WFT_STATE_TOREPLY) @@ -412,9 +430,17 @@ void HttpServerTask::handle(int state, int error) // prepare module_data from series to response for (auto *module : modules_) - module->server_task_end(this, this->module_data_); + { + if (!module->server_task_end(this, this->module_data_)) + { + this->noreply(); + this->error = RPCStatusModuleFilterFailed; + break; + } + } - http_set_header_module_data(this->module_data_, this->get_resp()); + if (this->error != RPCStatusModuleFilterFailed) + http_set_header_module_data(this->module_data_, this->get_resp()); this->subtask_done(); } diff --git a/src/http/http_task.h b/src/http/http_task.h index f9daa6da..8e551f45 100644 --- a/src/http/http_task.h +++ b/src/http/http_task.h @@ -60,6 +60,7 @@ class HttpClientTask : public WFComplexClientTaskclient_begin(task, data); + if (!this->client_begin(task, data)) + return false; for (RPCFilter *filter : this->filters) - ret = ret && filter->client_begin(task, data); + { + if (!filter->client_begin(task, data)) + return false; + } - return ret; + return true; } bool RPCModule::server_task_begin(SubTask *task, RPCModuleData& data) { - bool ret = this->server_begin(task, data); + if (!this->server_begin(task, data)) + return false; for (RPCFilter *filter : this->filters) - ret = ret && filter->server_begin(task, data); + { + if (!filter->server_begin(task, data)) + return false; + } - return ret; + return true; } bool RPCModule::client_task_end(SubTask *task, RPCModuleData& data) { SubTask *filter_task; - bool ret = this->client_end(task, data); + if (!this->client_end(task, data)) + return false; for (RPCFilter *filter : this->filters) { - if (filter->client_end(task, data)) - { - filter_task = filter->create_filter_task(data); - series_of(task)->push_front(filter_task); - } - else - ret = false; + if (!filter->client_end(task, data)) + return false; + + filter_task = filter->create_filter_task(data); + series_of(task)->push_front(filter_task); } - return ret; + return true; } bool RPCModule::server_task_end(SubTask *task, RPCModuleData& data) { SubTask *filter_task; - bool ret = this->server_end(task, data); + if (!this->server_end(task, data)) + return false; for (RPCFilter *filter : this->filters) { - if (filter->server_end(task, data)) - { - filter_task = filter->create_filter_task(data); - series_of(task)->push_front(filter_task); - } - else - ret = false; + if (!filter->server_end(task, data)) + return false; + + filter_task = filter->create_filter_task(data); + series_of(task)->push_front(filter_task); } - return ret; + return true; } SnowFlake::SnowFlake(int timestamp_bits, int group_bits, int machine_bits) diff --git a/src/rpc_basic.h b/src/rpc_basic.h index 3c9f730c..58ab3986 100644 --- a/src/rpc_basic.h +++ b/src/rpc_basic.h @@ -118,6 +118,7 @@ enum RPCStatusCode RPCStatusRespDeserializeError = 20, RPCStatusIDLSerializeNotSupported = 21, RPCStatusIDLDeserializeNotSupported = 22, + RPCStatusModuleFilterFailed = 23, RPCStatusURIInvalid = 30, RPCStatusUpstreamFailed = 31, diff --git a/src/rpc_server.h b/src/rpc_server.h index 6d5142d3..715f3d9e 100644 --- a/src/rpc_server.h +++ b/src/rpc_server.h @@ -242,15 +242,25 @@ void RPCServer::server_process(NETWORKTASK *task) const status_code = RPCStatusMethodNotFound; else { - for (auto *module : this->modules) - { - if (module) - module->server_task_begin(server_task, *task_data); - } - status_code = req->decompress(); + if (status_code == RPCStatusOK) - status_code = (*rpc)(server_task->worker); + { + for (auto *module : this->modules) + { + if (module) + { + if (!module->server_task_begin(server_task, *task_data)) + { + status_code = RPCStatusModuleFilterFailed; + break; + } + } + } + + if (status_code == RPCStatusOK) + status_code = (*rpc)(server_task->worker); + } } } diff --git a/src/rpc_task.inl b/src/rpc_task.inl index 5082b9f5..c4812550 100644 --- a/src/rpc_task.inl +++ b/src/rpc_task.inl @@ -305,12 +305,19 @@ CommMessageOut *RPCServerTask::message_out() RPCModuleData *data = this->mutable_module_data(); for (auto *module : modules_) - module->server_task_end(this, *data); - - this->resp.set_meta_module_data(*data); + { + if (!module->server_task_end(this, *data)) + { + status_code = RPCStatusModuleFilterFailed; + break; + } + } if (status_code == RPCStatusOK) + { + this->resp.set_meta_module_data(*data); return this->WFServerTask::message_out(); + } errno = EBADMSG; return NULL; @@ -438,15 +445,9 @@ template bool RPCClientTask::check_request() { int status_code = this->resp.get_status_code(); - return status_code == RPCStatusOK || status_code == RPCStatusUndefined; -} - -template -CommMessageOut *RPCClientTask::message_out() -{ - this->req.set_seqid(this->get_task_seq()); - int status_code = this->req.compress(); + if (status_code != RPCStatusOK && status_code != RPCStatusUndefined) + return false; void *series_data = series_of(this)->get_specific(SRPC_MODULE_DATA); RPCModuleData *data = (RPCModuleData *)series_data; @@ -456,9 +457,24 @@ CommMessageOut *RPCClientTask::message_out() data = this->mutable_module_data(); for (auto *module : modules_) - module->client_task_begin(this, *data); + { + if (!module->client_task_begin(this, *data)) + { + this->resp.set_status_code(RPCStatusModuleFilterFailed); + return false; + } + } this->req.set_meta_module_data(*data); + return true; +} + +template +CommMessageOut *RPCClientTask::message_out() +{ + this->req.set_seqid(this->get_task_seq()); + + int status_code = this->req.compress(); if (status_code == RPCStatusOK) { @@ -575,7 +591,11 @@ void RPCClientTask::rpc_callback(WFNetworkTask // this->resp.get_meta_module_data(resp_data); for (auto *module : modules_) - module->client_task_end(this, *resp_data); + { + // do not affect status_code, which is important for user_done_ + if (!module->client_task_end(this, *resp_data)) + break; + } } if (status_code != RPCStatusOK) diff --git a/tutorial/CMakeLists.txt b/tutorial/CMakeLists.txt index e5915b6c..f2f8eff7 100644 --- a/tutorial/CMakeLists.txt +++ b/tutorial/CMakeLists.txt @@ -119,6 +119,7 @@ set(TUTORIAL_PB_LIST tutorial-10-server_async tutorial-15-srpc_pb_proxy tutorial-16-server_with_metrics + tutorial-19-custom_filter ) if (APPLE) diff --git a/tutorial/tutorial-19-custom_filter.cc b/tutorial/tutorial-19-custom_filter.cc new file mode 100644 index 00000000..ccd2c9c0 --- /dev/null +++ b/tutorial/tutorial-19-custom_filter.cc @@ -0,0 +1,131 @@ +/* + Copyright (c) 2024 sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include "echo_pb.srpc.h" +#include "workflow/WFFacilities.h" +#include "srpc/rpc_types.h" + +#include "srpc/rpc_module.h" +#include "srpc/rpc_basic.h" + +using namespace srpc; + +static WFFacilities::WaitGroup wait_group(1); + +// Filter is available for both server and client. +// Please choose the function to implement the logic at the corresponding time. +class MyFilter : public RPCFilter +{ +public: + MyFilter() : RPCFilter(RPCModuleTypeCustom) + { + } + + bool server_begin(SubTask *task, RPCModuleData& data) override + { + auto iter = data.find("my_auth_key"); + if (iter != data.end() && iter->second.compare("my_auth_value") == 0) + { + fprintf(stderr, "[FILTER] auth success : %s\n", iter->second.c_str()); + return true; + } + + fprintf(stderr, "[FILTER] auth failed : %s\n", + iter == data.end() ? "No auth" : iter->second.c_str()); + return false; + } +}; + +class ExampleServiceImpl : public Example::Service +{ +public: + void Echo(EchoRequest *req, EchoResponse *resp, RPCContext *ctx) override + { + resp->set_message("Hi back"); + fprintf(stderr, "[SERVER] Echo() get req: %s\n", req->message().c_str()); + } +}; + +static void sig_handler(int signo) +{ + wait_group.done(); +} + +void send_client_task() +{ + Example::SRPCClient client("127.0.0.1", 1412); + + auto callback = [](EchoResponse *resp, RPCContext *ctx) + { + if (ctx->success()) + fprintf(stderr, "[CLIENT] callback success\n"); + else + fprintf(stderr, "[CLIENT] callback status[%d] error[%d] errmsg : %s\n", + ctx->get_status_code(), ctx->get_error(), ctx->get_errmsg()); + }; + + EchoRequest req; + req.set_name("Tutorial 19"); + + // send one task to test the success case + auto *task = client.create_Echo_task(callback); + req.set_message("For success case"); + task->serialize_input(&req); + task->add_baggage("my_auth_key", "my_auth_value"); + task->start(); + + // send another task to test the failure case + task = client.create_Echo_task(callback); + req.set_message("For failure case"); + task->serialize_input(&req); + task->add_baggage("my_auth_key", "randomxxx"); + task->start(); +} + +int main() +{ + GOOGLE_PROTOBUF_VERIFY_VERSION; + signal(SIGINT, sig_handler); + signal(SIGTERM, sig_handler); + + // 1. prepare server + SRPCServer server; + ExampleServiceImpl impl; + server.add_service(&impl); + + // 2. add filter into server + MyFilter my; + server.add_filter(&my); + + // 3. run server + if (server.start(1412) == 0) + { + fprintf(stderr, "[SERVER] Server with filter is running on 1412\n"); + + // 4. send client task to test + send_client_task(); + + wait_group.wait(); + server.stop(); + } + else + perror("[SERVER] server start"); + + google::protobuf::ShutdownProtobufLibrary(); + return 0; +} +