Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/write race condition #886

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.packet.DefaultPacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.network.packet.PacketRegistry;
import org.geysermc.mcprotocollib.protocol.MinecraftProtocol;
import org.geysermc.mcprotocollib.protocol.codec.MinecraftTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.crypto.SecretKey;
import java.security.GeneralSecurityException;

public class TestProtocol extends PacketProtocol {
public class TestProtocol extends MinecraftProtocol {
private static final Logger log = LoggerFactory.getLogger(TestProtocol.class);
private final PacketHeader header = new DefaultPacketHeader();
private final PacketRegistry registry = new PacketRegistry();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package org.geysermc.mcprotocollib.network;

import org.geysermc.mcprotocollib.network.event.server.ServerListener;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.network.server.AbstractServer;
import org.geysermc.mcprotocollib.protocol.MinecraftProtocol;

import java.net.SocketAddress;
import java.util.List;
Expand All @@ -25,7 +25,7 @@ public interface Server {
*
* @return The server's packet protocol.
*/
Supplier<? extends PacketProtocol> getPacketProtocol();
Supplier<? extends MinecraftProtocol> getPacketProtocol();

/**
* Returns true if the listener is listening.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.geysermc.mcprotocollib.network.event.session.SessionListener;
import org.geysermc.mcprotocollib.network.netty.FlushHandler;
import org.geysermc.mcprotocollib.network.packet.Packet;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.protocol.MinecraftProtocol;

import java.net.SocketAddress;
import java.util.List;
Expand Down Expand Up @@ -42,7 +42,7 @@ public interface Session {
*
* @return The session's packet protocol.
*/
PacketProtocol getPacketProtocol();
MinecraftProtocol getPacketProtocol();

/**
* Gets this session's set flags. If this session belongs to a server, the server's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,11 @@ public Session getSession() {
/**
* Gets the packet involved in this event as the required type.
*
* @param <T> Type of the packet.
* @return The event's packet as the required type.
* @throws IllegalStateException If the packet's value isn't of the required type.
*/
@SuppressWarnings("unchecked")
public <T extends Packet> T getPacket() {
try {
return (T) this.packet;
} catch (ClassCastException e) {
throw new IllegalStateException("Tried to get packet as the wrong type. Actual type: " + this.packet.getClass().getName());
}
public Packet getPacket() {
return this.packet;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import lombok.experimental.Accessors;
import org.geysermc.mcprotocollib.network.ProxyInfo;
import org.geysermc.mcprotocollib.network.netty.DefaultPacketHandlerExecutor;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.network.session.ClientNetworkSession;
import org.geysermc.mcprotocollib.protocol.MinecraftProtocol;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
Expand All @@ -18,7 +18,7 @@
@NoArgsConstructor(access = lombok.AccessLevel.PRIVATE)
public final class ClientNetworkSessionFactory {
private SocketAddress remoteSocketAddress;
private PacketProtocol protocol;
private MinecraftProtocol protocol;
private Executor packetHandlerExecutor;
private SocketAddress bindSocketAddress;
private ProxyInfo proxy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.geysermc.mcprotocollib.network.BuiltinFlags;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.protocol.MinecraftProtocol;

import java.util.function.Function;

Expand All @@ -31,7 +31,7 @@ protected S createSession(Channel ch) {
}

protected void addHandlers(S session, Channel ch) {
PacketProtocol protocol = session.getPacketProtocol();
MinecraftProtocol protocol = session.getPacketProtocol();
ChannelPipeline pipeline = ch.pipeline();

pipeline.addLast(NetworkConstants.READ_TIMEOUT_NAME, new ReadTimeoutHandler(session.getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import io.netty.handler.codec.ByteToMessageCodec;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.codec.PacketDefinition;
import org.geysermc.mcprotocollib.network.event.session.PacketErrorEvent;
import org.geysermc.mcprotocollib.network.packet.Packet;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.network.packet.PacketRegistry;
import org.geysermc.mcprotocollib.protocol.MinecraftProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.slf4j.MarkerFactory;

import java.util.List;

public class PacketCodec extends MessageToMessageCodec<ByteBuf, Packet> {
public class PacketCodec extends ByteToMessageCodec<Packet> {
private static final Marker marker = MarkerFactory.getMarker("packet_logging");
private static final Logger log = LoggerFactory.getLogger(PacketCodec.class);

Expand All @@ -32,33 +30,34 @@ public PacketCodec(Session session, boolean client) {

@SuppressWarnings({"rawtypes", "unchecked"})
@Override
public void encode(ChannelHandlerContext ctx, Packet packet, List<Object> out) {
public void encode(ChannelHandlerContext ctx, Packet packet, ByteBuf out) {
if (log.isTraceEnabled()) {
log.trace(marker, "Encoding packet: {}", packet.getClass().getSimpleName());
}

PacketProtocol packetProtocol = this.session.getPacketProtocol();
int initial = out.writerIndex();
MinecraftProtocol packetProtocol = this.session.getPacketProtocol();
PacketRegistry packetRegistry = packetProtocol.getOutboundPacketRegistry();
try {
int packetId = this.client ? packetRegistry.getServerboundId(packet) : packetRegistry.getClientboundId(packet);
PacketDefinition definition = this.client ? packetRegistry.getServerboundDefinition(packetId) : packetRegistry.getClientboundDefinition(packetId);

ByteBuf buf = ctx.alloc().buffer();
packetProtocol.getPacketHeader().writePacketId(buf, packetId);
definition.getSerializer().serialize(buf, packet);

out.add(buf);
packetProtocol.getPacketHeader().writePacketId(out, packetId);
definition.getSerializer().serialize(out, packet);

if (log.isDebugEnabled()) {
log.debug(marker, "Encoded packet {} ({})", packet.getClass().getSimpleName(), packetId);
}
} catch (Throwable t) {
log.debug(marker, "Error encoding packet", t);

// Reset writer index to make sure incomplete data is not written out.
out.writerIndex(initial);

PacketErrorEvent e = new PacketErrorEvent(this.session, t, packet);
this.session.callEvent(e);
if (!e.shouldSuppress()) {
throw new EncoderException(t);
throw t;
}
}
}
Expand All @@ -72,7 +71,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out)

int initial = buf.readerIndex();

PacketProtocol packetProtocol = this.session.getPacketProtocol();
MinecraftProtocol packetProtocol = this.session.getPacketProtocol();
PacketRegistry packetRegistry = packetProtocol.getInboundPacketRegistry();
Packet packet = null;
try {
Expand Down Expand Up @@ -104,7 +103,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out)
PacketErrorEvent e = new PacketErrorEvent(this.session, t, packet);
this.session.callEvent(e);
if (!e.shouldSuppress()) {
throw new DecoderException(t);
throw t;
}
} finally {
if (packet != null && packet.isTerminal()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import lombok.RequiredArgsConstructor;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
Expand All @@ -12,7 +12,7 @@
import java.util.List;

@RequiredArgsConstructor
public class PacketCompressionCodec extends MessageToMessageCodec<ByteBuf, ByteBuf> {
public class PacketCompressionCodec extends ByteToMessageCodec<ByteBuf> {
private static final int MAX_UNCOMPRESSED_SIZE = 8 * 1024 * 1024; // 8MiB

@Override
Expand All @@ -26,10 +26,10 @@ public void handlerRemoved(ChannelHandlerContext ctx) {
}

@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
public void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) {
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
out.writeBytes(msg);
return;
}

Expand All @@ -39,29 +39,33 @@ public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
}

ByteBuf outBuf = ctx.alloc().directBuffer(uncompressed);
if (uncompressed < config.threshold()) {
// Under the threshold, there is nothing to do.
MinecraftTypes.writeVarInt(outBuf, 0);
outBuf.writeBytes(msg);
} else {
MinecraftTypes.writeVarInt(outBuf, uncompressed);
config.compression().deflate(msg, outBuf);
}
try {
if (uncompressed < config.threshold()) {
// Under the threshold, there is nothing to do.
MinecraftTypes.writeVarInt(outBuf, 0);
outBuf.writeBytes(msg);
} else {
MinecraftTypes.writeVarInt(outBuf, uncompressed);
config.compression().deflate(msg, outBuf);
}

out.add(outBuf);
out.writeBytes(outBuf);
} finally {
outBuf.release();
}
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
out.add(in.readBytes(in.readableBytes()));
return;
}

int claimedUncompressedSize = MinecraftTypes.readVarInt(in);
if (claimedUncompressedSize == 0) {
out.add(in.retain());
out.add(in.readBytes(in.readableBytes()));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;

import java.util.List;

public class PacketEncryptorCodec extends MessageToMessageCodec<ByteBuf, ByteBuf> {
public class PacketEncryptorCodec extends ByteToMessageCodec<ByteBuf> {
@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
public void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
out.writeBytes(msg);
return;
}

Expand All @@ -27,18 +27,19 @@ public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {

try {
config.encryption().encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
out.writeBytes(heapBuf);
} catch (Exception e) {
heapBuf.release();
throw new EncoderException("Error encrypting packet", e);
} finally {
heapBuf.release();
}
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
out.add(in.readBytes(in.readableBytes()));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,30 @@ public void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) {
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) {
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
int size = header.getLengthSize();
if (size == 0) {
out.add(buf.retain());
out.add(in.readBytes(in.readableBytes()));
return;
}

buf.markReaderIndex();
in.markReaderIndex();
byte[] lengthBytes = new byte[size];
for (int index = 0; index < lengthBytes.length; index++) {
if (!buf.isReadable()) {
buf.resetReaderIndex();
if (!in.isReadable()) {
in.resetReaderIndex();
return;
}

lengthBytes[index] = buf.readByte();
lengthBytes[index] = in.readByte();
if ((header.isLengthVariable() && lengthBytes[index] >= 0) || index == size - 1) {
int length = header.readLength(Unpooled.wrappedBuffer(lengthBytes), buf.readableBytes());
if (buf.readableBytes() < length) {
buf.resetReaderIndex();
int length = header.readLength(Unpooled.wrappedBuffer(lengthBytes), in.readableBytes());
if (in.readableBytes() < length) {
in.resetReaderIndex();
return;
}

out.add(buf.readBytes(length));
out.add(in.readBytes(length));
return;
}
}
Expand Down
Loading