Skip to content

Commit

Permalink
Make the return values of module and filter affect.
Browse files Browse the repository at this point in the history
  • Loading branch information
holmes1412 committed Jul 4, 2024
1 parent 0386144 commit 9284ae6
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 60 deletions.
5 changes: 4 additions & 1 deletion src/http/http_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ void HttpClient::callback(WFHttpTask *task)
// http_get_header_module_data(task->get_resp(), *resp_data);

for (RPCModule *module : module_list)
module->client_task_end(task, *resp_data);
{
if (!module->client_task_end(task, *resp_data))
break;
}
}

if (client_task->user_callback_)
Expand Down
52 changes: 39 additions & 13 deletions src/http/http_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ std::string HttpClientTask::get_uri_scheme() const
return "";
}

CommMessageOut *HttpClientTask::message_out()
bool HttpClientTask::check_request()
{
HttpRequest *req = this->get_req();
struct HttpMessageHeader header;
bool is_alive;
void *series_data = series_of(this)->get_specific(SRPC_MODULE_DATA);
RPCModuleData *data = (RPCModuleData *)series_data;

Expand All @@ -75,11 +72,25 @@ CommMessageOut *HttpClientTask::message_out()
data = this->mutable_module_data();

for (auto *module : modules_)
module->client_task_begin(this, *data);
{
if (!module->client_task_begin(this, *data))
{
this->state = WFT_STATE_TASK_ERROR;
this->error = RPCStatusModuleFilterFailed;
return false;
}
}

http_set_header_module_data(*data, req);
http_set_header_module_data(*data, this->get_req());
return true;
}

CommMessageOut *HttpClientTask::message_out()
{
HttpRequest *req = this->get_req();
struct HttpMessageHeader header;
bool is_alive;

// from ComplexHttpTask::message_out()
if (!req->is_chunked() && !req->has_content_length_header())
{
size_t body_size = req->get_output_body_size();
Expand Down Expand Up @@ -387,20 +398,27 @@ void HttpServerTask::handle(int state, int error)
}
}

this->state = WFT_STATE_TOREPLY;
this->target = this->get_target();

// fill module data from request to series
ModuleSeries *series = new ModuleSeries(this);

http_get_header_module_data(req, this->module_data_);
for (auto *module : this->modules_)
{
if (module)
module->server_task_begin(this, this->module_data_);
{
if (!module->server_task_begin(this, this->module_data_))
{
delete this;
return;
}
}
}

series->set_module_data(this->mutable_module_data());

this->state = WFT_STATE_TOREPLY;
this->target = this->get_target();

series->start();
}
else if (this->state == WFT_STATE_TOREPLY)
Expand All @@ -412,9 +430,17 @@ void HttpServerTask::handle(int state, int error)

// prepare module_data from series to response
for (auto *module : modules_)
module->server_task_end(this, this->module_data_);
{
if (!module->server_task_end(this, this->module_data_))
{
this->noreply();
this->error = RPCStatusModuleFilterFailed;
break;
}
}

http_set_header_module_data(this->module_data_, this->get_resp());
if (this->error != RPCStatusModuleFilterFailed)
http_set_header_module_data(this->module_data_, this->get_resp());

this->subtask_done();
}
Expand Down
1 change: 1 addition & 0 deletions src/http/http_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class HttpClientTask : public WFComplexClientTask<protocol::HttpRequest,
*/

protected:
virtual bool check_request();
virtual CommMessageOut *message_out();
virtual CommMessageIn *message_in();
virtual int keep_alive_timeout();
Expand Down
2 changes: 2 additions & 0 deletions src/message/rpc_message_brpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ const char *BRPCResponse::get_errmsg() const
return "IDL Serialize Not Supported";
case RPCStatusIDLDeserializeNotSupported:
return "IDL Deserialize Not Supported";
case RPCStatusModuleFilterFailed:
return "Module or filter check failed";
case RPCStatusURIInvalid:
return "URI Invalid";
case RPCStatusUpstreamFailed:
Expand Down
2 changes: 2 additions & 0 deletions src/message/rpc_message_srpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ const char *SRPCResponse::get_errmsg() const
return "IDL Serialize Not Supported";
case RPCStatusIDLDeserializeNotSupported:
return "IDL Deserialize Not Supported";
case RPCStatusModuleFilterFailed:
return "Module or filter check failed";
case RPCStatusURIInvalid:
return "URI Invalid";
case RPCStatusUpstreamFailed:
Expand Down
10 changes: 8 additions & 2 deletions src/module/rpc_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,14 @@ class RPCFilter
}

private:
virtual SubTask *create(RPCModuleData& data) = 0;
virtual bool filter(RPCModuleData& data) = 0;
virtual SubTask *create(RPCModuleData& data)
{
return WFTaskFactory::create_empty_task();
}
virtual bool filter(RPCModuleData& data)
{
return false;
}

public:
RPCFilter(enum RPCModuleType module_type)
Expand Down
54 changes: 30 additions & 24 deletions src/module/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,60 +27,66 @@ namespace srpc

bool RPCModule::client_task_begin(SubTask *task, RPCModuleData& data)
{
bool ret = this->client_begin(task, data);
if (!this->client_begin(task, data))
return false;

for (RPCFilter *filter : this->filters)
ret = ret && filter->client_begin(task, data);
{
if (!filter->client_begin(task, data))
return false;
}

return ret;
return true;
}

bool RPCModule::server_task_begin(SubTask *task, RPCModuleData& data)
{
bool ret = this->server_begin(task, data);
if (!this->server_begin(task, data))
return false;

for (RPCFilter *filter : this->filters)
ret = ret && filter->server_begin(task, data);
{
if (!filter->server_begin(task, data))
return false;
}

return ret;
return true;
}

bool RPCModule::client_task_end(SubTask *task, RPCModuleData& data)
{
SubTask *filter_task;
bool ret = this->client_end(task, data);
if (!this->client_end(task, data))
return false;

for (RPCFilter *filter : this->filters)
{
if (filter->client_end(task, data))
{
filter_task = filter->create_filter_task(data);
series_of(task)->push_front(filter_task);
}
else
ret = false;
if (!filter->client_end(task, data))
return false;

filter_task = filter->create_filter_task(data);
series_of(task)->push_front(filter_task);
}

return ret;
return true;
}

bool RPCModule::server_task_end(SubTask *task, RPCModuleData& data)
{
SubTask *filter_task;
bool ret = this->server_end(task, data);
if (!this->server_end(task, data))
return false;

for (RPCFilter *filter : this->filters)
{
if (filter->server_end(task, data))
{
filter_task = filter->create_filter_task(data);
series_of(task)->push_front(filter_task);
}
else
ret = false;
if (!filter->server_end(task, data))
return false;

filter_task = filter->create_filter_task(data);
series_of(task)->push_front(filter_task);
}

return ret;
return true;
}

SnowFlake::SnowFlake(int timestamp_bits, int group_bits, int machine_bits)
Expand Down
1 change: 1 addition & 0 deletions src/rpc_basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ enum RPCStatusCode
RPCStatusRespDeserializeError = 20,
RPCStatusIDLSerializeNotSupported = 21,
RPCStatusIDLDeserializeNotSupported = 22,
RPCStatusModuleFilterFailed = 23,

RPCStatusURIInvalid = 30,
RPCStatusUpstreamFailed = 31,
Expand Down
24 changes: 17 additions & 7 deletions src/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,25 @@ void RPCServer<RPCTYPE>::server_process(NETWORKTASK *task) const
status_code = RPCStatusMethodNotFound;
else
{
for (auto *module : this->modules)
{
if (module)
module->server_task_begin(server_task, *task_data);
}

status_code = req->decompress();

if (status_code == RPCStatusOK)
status_code = (*rpc)(server_task->worker);
{
for (auto *module : this->modules)
{
if (module)
{
if (!module->server_task_begin(server_task, *task_data))
{
status_code = RPCStatusModuleFilterFailed;
break;
}
}
}

if (status_code == RPCStatusOK)
status_code = (*rpc)(server_task->worker);
}
}
}

Expand Down
46 changes: 33 additions & 13 deletions src/rpc_task.inl
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,19 @@ CommMessageOut *RPCServerTask<RPCREQ, RPCRESP>::message_out()
RPCModuleData *data = this->mutable_module_data();

for (auto *module : modules_)
module->server_task_end(this, *data);

this->resp.set_meta_module_data(*data);
{
if (!module->server_task_end(this, *data))
{
status_code = RPCStatusModuleFilterFailed;
break;
}
}

if (status_code == RPCStatusOK)
{
this->resp.set_meta_module_data(*data);
return this->WFServerTask<RPCREQ, RPCRESP>::message_out();
}

errno = EBADMSG;
return NULL;
Expand Down Expand Up @@ -438,15 +445,9 @@ template<class RPCREQ, class RPCRESP>
bool RPCClientTask<RPCREQ, RPCRESP>::check_request()
{
int status_code = this->resp.get_status_code();
return status_code == RPCStatusOK || status_code == RPCStatusUndefined;
}

template<class RPCREQ, class RPCRESP>
CommMessageOut *RPCClientTask<RPCREQ, RPCRESP>::message_out()
{
this->req.set_seqid(this->get_task_seq());

int status_code = this->req.compress();
if (status_code != RPCStatusOK && status_code != RPCStatusUndefined)
return false;

void *series_data = series_of(this)->get_specific(SRPC_MODULE_DATA);
RPCModuleData *data = (RPCModuleData *)series_data;
Expand All @@ -456,9 +457,24 @@ CommMessageOut *RPCClientTask<RPCREQ, RPCRESP>::message_out()
data = this->mutable_module_data();

for (auto *module : modules_)
module->client_task_begin(this, *data);
{
if (!module->client_task_begin(this, *data))
{
this->resp.set_status_code(RPCStatusModuleFilterFailed);
return false;
}
}

this->req.set_meta_module_data(*data);
return true;
}

template<class RPCREQ, class RPCRESP>
CommMessageOut *RPCClientTask<RPCREQ, RPCRESP>::message_out()
{
this->req.set_seqid(this->get_task_seq());

int status_code = this->req.compress();

if (status_code == RPCStatusOK)
{
Expand Down Expand Up @@ -575,7 +591,11 @@ void RPCClientTask<RPCREQ, RPCRESP>::rpc_callback(WFNetworkTask<RPCREQ, RPCRESP>
// this->resp.get_meta_module_data(resp_data);

for (auto *module : modules_)
module->client_task_end(this, *resp_data);
{
// do not affect status_code, which is important for user_done_
if (!module->client_task_end(this, *resp_data))
break;
}
}

if (status_code != RPCStatusOK)
Expand Down
1 change: 1 addition & 0 deletions tutorial/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ set(TUTORIAL_PB_LIST
tutorial-10-server_async
tutorial-15-srpc_pb_proxy
tutorial-16-server_with_metrics
tutorial-19-custom_filter
)

if (APPLE)
Expand Down
Loading

0 comments on commit 9284ae6

Please sign in to comment.