diff --git a/servers/mqtt_pit.c b/servers/mqtt_pit.c index 13bff74..e5d55a0 100644 --- a/servers/mqtt_pit.c +++ b/servers/mqtt_pit.c @@ -622,10 +622,12 @@ bool sendPubrel(struct mqttClient* client, uint16_t packetId) { ssize_t w = write(client->fd, arr, size); if (w == -1) { fprintf(stderr, "sendPubrel: write failed"); + free(arr); return false; } // syslog(LOG_INFO, "Sent PUBREL to client fd=%d", client->fd); + free(arr); return true; } @@ -869,6 +871,11 @@ int main(int argc, char* argv[]) { // } } else { struct mqttClient* client = lookupClient(currentFd); + if (client == NULL) { + epoll_ctl(epollfd, EPOLL_CTL_DEL, currentFd, NULL); + close(currentFd); + continue; + } ssize_t bytesRead = read(currentFd, client->buffer + client->bytesWrittenToBuffer, // Avoid overwriting existing data sizeof(client->buffer) - client->bytesWrittenToBuffer); @@ -898,11 +905,12 @@ int main(int argc, char* argv[]) { packetLengths, packetStarts, &packetCount); uint32_t processedPackets = 0; + bool clientDisconnected = false; for (uint32_t i = 0; i < packetCount; i++) { uint32_t packetLength = packetLengths[i]; uint32_t packetStart = packetStarts[i]; uint32_t packetEnd = packetStart + packetLength; - + if (packetLength == 0 || processedPackets + packetLength > client->bytesWrittenToBuffer) { // syslog(LOG_INFO, "Incomplete packet"); break; // Incomplete packet @@ -924,12 +932,14 @@ int main(int argc, char* argv[]) { if(!ackSuccess) { fprintf(stderr, "Disconnecting client due to CONNACK failure"); disconnectClient(client, epollfd, now); + clientDisconnected = true; break; } pubSuccess = sendPublish(client, "$SYS/credentials", "username=admin password=admin"); if(!pubSuccess) { fprintf(stderr, "Disconnecting client due to publish failure"); disconnectClient(client, epollfd, now); + clientDisconnected = true; } break; case SUBSCRIBE: @@ -947,6 +957,7 @@ int main(int argc, char* argv[]) { if(!pubSuccess) { fprintf(stderr, "Disconnecting client due to publish failure"); disconnectClient(client, epollfd, now); + clientDisconnected = true; } break; case UNSUBSCRIBE: @@ -957,23 +968,28 @@ int main(int argc, char* argv[]) { if(!pingSuccess){ fprintf(stderr, "Disconnecting client due to ping failure"); disconnectClient(client, epollfd, now); + clientDisconnected = true; break; } break; case DISCONNECT: fprintf(stderr, "Disconnecting client due to receiving DISCONNECT"); disconnectClient(client, epollfd, now); + clientDisconnected = true; break; default: break; } + if (clientDisconnected) break; processedPackets += packetLength; } - uint32_t leftover = client->bytesWrittenToBuffer - processedPackets; - if (leftover > 0) { - memmove(client->buffer, client->buffer + processedPackets, leftover); + if (!clientDisconnected) { + uint32_t leftover = client->bytesWrittenToBuffer - processedPackets; + if (leftover > 0) { + memmove(client->buffer, client->buffer + processedPackets, leftover); + } + client->bytesWrittenToBuffer = leftover; } - client->bytesWrittenToBuffer = leftover; } }