diff --git a/src/lkcp.c b/src/lkcp.c index 844f5de..c167871 100644 --- a/src/lkcp.c +++ b/src/lkcp.c @@ -48,23 +48,34 @@ static int kcp_output_callback(const char *buf, int len, ikcpcb *kcp, void *arg) struct Callback* c = (struct Callback*)arg; lua_State* L = c -> L; uint64_t handle = c -> handle; - - lua_rawgeti(L, LUA_REGISTRYINDEX, handle); + lua_pushstring(L, "kcp-cbs"); + lua_rawget(L, LUA_REGISTRYINDEX); + lua_pushinteger(L, handle); + lua_rawget(L, -2); lua_pushlstring(L, buf, len); - lua_call(L, 1, 0); + int status = lua_pcall(L, 1, 0, 0); + if (status) { + lua_pop(L, 1); + } + lua_pop(L, 1); return 0; } static int kcp_gc(lua_State* L) { - ikcpcb* kcp = check_kcp(L, 1); - if (kcp == NULL) { + ikcpcb* kcp = check_kcp(L, 1); + if (kcp == NULL) { return 0; - } + } if (kcp->user != NULL) { struct Callback* c = (struct Callback*)kcp -> user; uint64_t handle = c -> handle; - luaL_unref(L, LUA_REGISTRYINDEX, handle); + lua_pushstring(L, "kcp-cbs"); + lua_rawget(L, LUA_REGISTRYINDEX); + lua_pushinteger(L, handle); + lua_pushnil(L); + lua_rawset(L, -3); + lua_pop(L, 1); free(c); kcp->user = NULL; } @@ -74,13 +85,23 @@ static int kcp_gc(lua_State* L) { } static int lkcp_create(lua_State* L){ - uint64_t handle = luaL_ref(L, LUA_REGISTRYINDEX); int32_t conv = luaL_checkinteger(L, 1); + uint64_t idx = luaL_checkinteger(L, 2); + lua_pushstring(L, "kcp-cbs"); + lua_rawget(L, LUA_REGISTRYINDEX); + lua_pushinteger(L, idx); + lua_pushvalue(L, 3); + lua_rawset(L, -3); + lua_pop(L, 1); + + lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_MAINTHREAD); + lua_State* mL = lua_tothread(L, -1); + lua_pop(L, 1); struct Callback* c = malloc(sizeof(struct Callback)); memset(c, 0, sizeof(struct Callback)); - c -> handle = handle; - c -> L = L; + c -> handle = idx; + c -> L = mL; ikcpcb* kcp = ikcp_create(conv, (void*)c); if (kcp == NULL) { @@ -235,7 +256,9 @@ static const struct luaL_Reg l_methods[] = { int luaopen_lkcp(lua_State* L) { luaL_checkversion(L); - + lua_pushstring(L, "kcp-cbs"); + lua_newtable(L); + lua_rawset(L, LUA_REGISTRYINDEX); luaL_newmetatable(L, "kcp_meta"); lua_newtable(L); diff --git a/src/testkcp.lua b/src/testkcp.lua index f1b3e56..44cb460 100644 --- a/src/testkcp.lua +++ b/src/testkcp.lua @@ -31,7 +31,7 @@ local function test(mode) a = 'aaa', b = false, } - local kcp1 = LKcp.lkcp_create(session, function (buf) + local kcp1 = LKcp.lkcp_create(session, 1, function (buf) udp_output(buf, info) end) local info2 = { @@ -42,7 +42,7 @@ local function test(mode) print 'hahahah!!!' end, } - local kcp2 = LKcp.lkcp_create(session, function (buf) + local kcp2 = LKcp.lkcp_create(session, 2, function (buf) udp_output(buf, info2) end) @@ -108,69 +108,68 @@ local function test(mode) slap = slap + 20 index = index + 1 end - - --处理虚拟网络:检测是否有udp包从p1->p2 - while 1 do - hrlen, hr = lsm:recv(1) - if hrlen < 0 then - break - end - --如果 p2收到udp,则作为下层协议输入到kcp2 - kcp2:lkcp_input(hr) - end - - --处理虚拟网络:检测是否有udp包从p2->p1 - while 1 do - hrlen, hr = lsm:recv(0) - if hrlen < 0 then - break - end - --如果 p1收到udp,则作为下层协议输入到kcp1 - kcp1:lkcp_input(hr) - end - - --kcp2接收到任何包都返回回去 - while 1 do - hrlen, hr = kcp2:lkcp_recv() - if hrlen <= 0 then - break + --处理虚拟网络:检测是否有udp包从p1->p2 + while 1 do + hrlen, hr = lsm:recv(1) + if hrlen < 0 then + break + end + --如果 p2收到udp,则作为下层协议输入到kcp2 + kcp2:lkcp_input(hr) + end + + --处理虚拟网络:检测是否有udp包从p2->p1 + while 1 do + hrlen, hr = lsm:recv(0) + if hrlen < 0 then + break + end + --如果 p1收到udp,则作为下层协议输入到kcp1 + kcp1:lkcp_input(hr) end - kcp2:lkcp_send(hr) - --kcp2:lkcp_flush() - end - --kcp1收到kcp2的回射数据 - while 1 do - hrlen, hr = kcp1:lkcp_recv() - --没有收到包就退出 - if hrlen <= 0 then - break + --kcp2接收到任何包都返回回去 + while 1 do + hrlen, hr = kcp2:lkcp_recv() + if hrlen <= 0 then + break + end + kcp2:lkcp_send(hr) + --kcp2:lkcp_flush() end - local hr1 = string.sub(hr, 1, 4) - local hr2 = string.sub(hr, 5, 8) - local sn = LUtil.netbytes2uint32(hr1) - local ts = LUtil.netbytes2uint32(hr2) - local rtt = current - ts + --kcp1收到kcp2的回射数据 + while 1 do + hrlen, hr = kcp1:lkcp_recv() + --没有收到包就退出 + if hrlen <= 0 then + break + end + + local hr1 = string.sub(hr, 1, 4) + local hr2 = string.sub(hr, 5, 8) + local sn = LUtil.netbytes2uint32(hr1) + local ts = LUtil.netbytes2uint32(hr2) + local rtt = current - ts - if sn ~= inext then - --如果收到的包不连续 - print(string.format("ERROR sn %d<->%d\n", count, inext)) - return - end - - inext = inext + 1 - sumrtt = sumrtt + rtt - count = count + 1 - if rtt > maxrtt then - maxrtt = rtt - end - - print(string.format("[RECV] mode=%d sn=%d rtt=%d\n", mode, sn, rtt)) - end - if inext > 10 then - break - end + if sn ~= inext then + --如果收到的包不连续 + print(string.format("ERROR sn %d<->%d\n", count, inext)) + return + end + + inext = inext + 1 + sumrtt = sumrtt + rtt + count = count + 1 + if rtt > maxrtt then + maxrtt = rtt + end + + print(string.format("[RECV] mode=%d sn=%d rtt=%d\n", mode, sn, rtt)) + end + if inext > 10 then + break + end end ts1 = getms() - ts1