diff --git a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecorator.java b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecorator.java index 924c5311078..f87e339bfda 100644 --- a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecorator.java +++ b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecorator.java @@ -10,6 +10,7 @@ import datadog.trace.api.DDTags; import datadog.trace.api.InstrumenterConfig; import datadog.trace.api.ProductActivation; +import datadog.trace.api.appsec.HttpClientRequest; import datadog.trace.api.gateway.BlockResponseFunction; import datadog.trace.api.gateway.Flow; import datadog.trace.api.gateway.RequestContext; @@ -99,7 +100,7 @@ public AgentSpan onRequest(final AgentSpan span, final REQUEST request) { HTTP_RESOURCE_DECORATOR.withClientPath(span, method, url.getPath()); } // SSRF exploit prevention check - onNetworkConnection(url.toString()); + onHttpClientRequest(span, url.toString()); } else if (shouldSetResourceName()) { span.setResourceName(DEFAULT_RESOURCE_NAME); } @@ -178,24 +179,20 @@ public long getResponseContentLength(final RESPONSE response) { return 0; } - private void onNetworkConnection(final String networkConnection) { + protected void onHttpClientRequest(final AgentSpan span, final String url) { if (!APPSEC_RASP_ENABLED) { return; } - if (networkConnection == null) { + if (url == null) { return; } - final BiFunction> networkConnectionCallback = + final long requestId = span.getSpanId(); + final BiFunction> requestCb = AgentTracer.get() .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.networkConnection()); + .getCallback(EVENTS.httpClientRequest()); - if (networkConnectionCallback == null) { - return; - } - - final AgentSpan span = AgentTracer.get().activeSpan(); - if (span == null) { + if (requestCb == null) { return; } @@ -204,7 +201,7 @@ private void onNetworkConnection(final String networkConnection) { return; } - Flow flow = networkConnectionCallback.apply(ctx, networkConnection); + Flow flow = requestCb.apply(ctx, new HttpClientRequest(requestId, url)); Flow.Action action = flow.getAction(); if (action instanceof Flow.Action.RequestBlockingAction) { BlockResponseFunction brf = ctx.getBlockResponseFunction(); diff --git a/dd-java-agent/agent-bootstrap/src/test/groovy/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecoratorTest.groovy b/dd-java-agent/agent-bootstrap/src/test/groovy/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecoratorTest.groovy index 61c5e651598..bd98bc53226 100644 --- a/dd-java-agent/agent-bootstrap/src/test/groovy/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecoratorTest.groovy +++ b/dd-java-agent/agent-bootstrap/src/test/groovy/datadog/trace/bootstrap/instrumentation/decorator/HttpClientDecoratorTest.groovy @@ -1,6 +1,7 @@ package datadog.trace.bootstrap.instrumentation.decorator import datadog.trace.api.DDTags +import datadog.trace.api.appsec.HttpClientRequest import datadog.trace.api.config.AppSecConfig import datadog.trace.api.gateway.CallbackProvider import static datadog.trace.api.gateway.Events.EVENTS @@ -249,8 +250,8 @@ class HttpClientDecoratorTest extends ClientDecoratorTest { decorator.onRequest(span2, req) then: - 1 * callbackProvider.getCallback(EVENTS.networkConnection()) >> listener - 1 * listener.apply(reqCtx, _ as String) + 1 * callbackProvider.getCallback(EVENTS.httpClientRequest()) >> listener + 1 * listener.apply(reqCtx, _ as HttpClientRequest) } @Override diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java index 992e7dddace..39055aef475 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java @@ -1,5 +1,6 @@ package com.datadog.appsec; +import com.datadog.appsec.api.security.ApiSecurityDownstreamSampler; import com.datadog.appsec.api.security.ApiSecuritySampler; import com.datadog.appsec.api.security.ApiSecuritySamplerImpl; import com.datadog.appsec.api.security.AppSecSpanPostProcessor; @@ -81,11 +82,14 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s } sco.createRemaining(config); + final double maxDownstreamRequestsRate = + config.getApiSecurityDownstreamRequestAnalysisSampleRate(); GatewayBridge gatewayBridge = new GatewayBridge( gw, REPLACEABLE_EVENT_PRODUCER, () -> API_SECURITY_SAMPLER, + ApiSecurityDownstreamSampler.build(maxDownstreamRequestsRate), APP_SEC_CONFIG_SERVICE.getTraceSegmentPostProcessors()); loadModules( diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityDownstreamSampler.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityDownstreamSampler.java new file mode 100644 index 00000000000..e860548d08a --- /dev/null +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityDownstreamSampler.java @@ -0,0 +1,42 @@ +package com.datadog.appsec.api.security; + +import com.datadog.appsec.gateway.AppSecRequestContext; + +public interface ApiSecurityDownstreamSampler { + + boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId); + + boolean isSampled(AppSecRequestContext ctx, long requestId); + + ApiSecurityDownstreamSampler INCLUDE_ALL = + new ApiSecurityDownstreamSampler() { + @Override + public boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId) { + return true; + } + + @Override + public boolean isSampled(AppSecRequestContext ctx, long requestId) { + return true; + } + }; + + ApiSecurityDownstreamSampler INCLUDE_NONE = + new ApiSecurityDownstreamSampler() { + @Override + public boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId) { + return false; + } + + @Override + public boolean isSampled(AppSecRequestContext ctx, long requestId) { + return false; + } + }; + + static ApiSecurityDownstreamSampler build(double rate) { + return rate <= 0D + ? INCLUDE_NONE + : (rate >= 1D ? INCLUDE_ALL : new ApiSecurityDownstreamSamplerImpl(rate)); + } +} diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityDownstreamSamplerImpl.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityDownstreamSamplerImpl.java new file mode 100644 index 00000000000..b2e9c6bffaf --- /dev/null +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityDownstreamSamplerImpl.java @@ -0,0 +1,49 @@ +package com.datadog.appsec.api.security; + +import com.datadog.appsec.gateway.AppSecRequestContext; +import java.util.concurrent.atomic.AtomicLong; + +public class ApiSecurityDownstreamSamplerImpl implements ApiSecurityDownstreamSampler { + + private static final long KNUTH_FACTOR = 1111111111111111111L; + private static final double SAMPLING_MAX = Math.pow(2, 64) - 1; + + private final AtomicLong globalRequestCount = new AtomicLong(0); + private final double threshold; + + public ApiSecurityDownstreamSamplerImpl(double rate) { + threshold = samplingCutoff(rate); + } + + private static double samplingCutoff(double rate) { + if (rate < 0.5) { + return (long) (rate * SAMPLING_MAX) + Long.MIN_VALUE; + } + if (rate < 1.0) { + return (long) ((rate * SAMPLING_MAX) + Long.MIN_VALUE); + } + return Long.MAX_VALUE; + } + + /** + * First sample the request to ensure we randomize the request and then check if the current + * server request has budget to analyze the downstream request. + */ + @Override + public boolean sampleHttpClientRequest(final AppSecRequestContext ctx, final long requestId) { + final long counter = updateRequestCount(); + if (counter * KNUTH_FACTOR + Long.MIN_VALUE > threshold) { + return false; + } + return ctx.sampleHttpClientRequest(requestId); + } + + @Override + public boolean isSampled(final AppSecRequestContext ctx, final long requestId) { + return ctx.isHttpClientRequestSampled(requestId); + } + + private long updateRequestCount() { + return globalRequestCount.updateAndGet(cur -> (cur == Long.MAX_VALUE) ? 0L : cur + 1L); + } +} diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/event/data/KnownAddresses.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/event/data/KnownAddresses.java index d88c2fb0311..27f38db7ef9 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/event/data/KnownAddresses.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/event/data/KnownAddresses.java @@ -118,6 +118,26 @@ public interface KnownAddresses { /** The URL of a network resource being requested (outgoing request) */ Address IO_NET_URL = new Address<>("server.io.net.url"); + /** The headers of a network resource being requested (outgoing request) */ + Address>> IO_NET_REQUEST_HEADERS = + new Address<>("server.io.net.request.headers"); + + /** The method of a network resource being requested (outgoing request) */ + Address IO_NET_REQUEST_METHOD = new Address<>("server.io.net.request.method"); + + /** The body of a network resource being requested (outgoing request) */ + Address IO_NET_REQUEST_BODY = new Address<>("server.io.net.request.body"); + + /** The status of a network resource being requested (outgoing request) */ + Address IO_NET_RESPONSE_STATUS = new Address<>("server.io.net.response.status"); + + /** The response headers of a network resource being requested (outgoing request) */ + Address>> IO_NET_RESPONSE_HEADERS = + new Address<>("server.io.net.response.headers"); + + /** The response body of a network resource being requested (outgoing request) */ + Address IO_NET_RESPONSE_BODY = new Address<>("server.io.net.response.body"); + /** The representation of opened file on the filesystem */ Address IO_FS_FILE = new Address<>("server.io.fs.file"); @@ -206,6 +226,18 @@ static Address forName(String name) { return SESSION_ID; case "server.io.net.url": return IO_NET_URL; + case "server.io.net.request.headers": + return IO_NET_REQUEST_HEADERS; + case "server.io.net.request.method": + return IO_NET_REQUEST_METHOD; + case "server.io.net.request.body": + return IO_NET_REQUEST_BODY; + case "server.io.net.response.status": + return IO_NET_RESPONSE_STATUS; + case "server.io.net.response.headers": + return IO_NET_RESPONSE_HEADERS; + case "server.io.net.response.body": + return IO_NET_RESPONSE_BODY; case "server.io.fs.file": return IO_FS_FILE; case "server.db.system": diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java index 904d60318ff..b6335429423 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java @@ -149,6 +149,9 @@ public class AppSecRequestContext implements DataBundle, Closeable { private volatile Long apiSecurityEndpointHash; private volatile byte keepType = PrioritySampling.SAMPLER_KEEP; + private static final AtomicInteger httpClientRequestCount = new AtomicInteger(0); + private static final Set sampledHttpClientRequests = new HashSet<>(); + private static final AtomicIntegerFieldUpdater WAF_TIMEOUTS_UPDATER = AtomicIntegerFieldUpdater.newUpdater(AppSecRequestContext.class, "wafTimeouts"); private static final AtomicIntegerFieldUpdater RASP_TIMEOUTS_UPDATER = @@ -235,6 +238,29 @@ public void increaseRaspTimeouts() { RASP_TIMEOUTS_UPDATER.incrementAndGet(this); } + public void increaseHttpClientRequestCount() { + httpClientRequestCount.incrementAndGet(); + } + + public boolean sampleHttpClientRequest(final long id) { + synchronized (sampledHttpClientRequests) { + if (sampledHttpClientRequests.size() + < Config.get().getApiSecurityMaxDownstreamRequestBodyAnalysis()) { + sampledHttpClientRequests.add(id); + return true; + } + } + return false; + } + + public boolean isHttpClientRequestSampled(final long id) { + return sampledHttpClientRequests.contains(id); + } + + public int getHttpClientRequestCount() { + return httpClientRequestCount.get(); + } + public int getWafTimeouts() { return wafTimeouts; } diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java index 76565c0e4cb..1b2e7276fe0 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java @@ -1,13 +1,16 @@ package com.datadog.appsec.gateway; import static com.datadog.appsec.event.data.MapDataBundle.Builder.CAPACITY_0_2; +import static com.datadog.appsec.event.data.MapDataBundle.Builder.CAPACITY_3_4; import static com.datadog.appsec.event.data.MapDataBundle.Builder.CAPACITY_6_10; import static com.datadog.appsec.gateway.AppSecRequestContext.DEFAULT_REQUEST_HEADERS_ALLOW_LIST; import static com.datadog.appsec.gateway.AppSecRequestContext.REQUEST_HEADERS_ALLOW_LIST; import static com.datadog.appsec.gateway.AppSecRequestContext.RESPONSE_HEADERS_ALLOW_LIST; +import static datadog.trace.api.telemetry.LogCollector.SEND_TELEMETRY; import static datadog.trace.bootstrap.instrumentation.api.Tags.SAMPLING_PRIORITY; import com.datadog.appsec.AppSecSystem; +import com.datadog.appsec.api.security.ApiSecurityDownstreamSampler; import com.datadog.appsec.api.security.ApiSecuritySampler; import com.datadog.appsec.config.TraceSegmentPostProcessor; import com.datadog.appsec.event.EventProducerService; @@ -21,8 +24,12 @@ import com.datadog.appsec.event.data.SingletonDataBundle; import com.datadog.appsec.report.AppSecEvent; import com.datadog.appsec.report.AppSecEventWrapper; +import com.datadog.appsec.util.BodyParser; import datadog.trace.api.Config; import datadog.trace.api.ProductTraceSource; +import datadog.trace.api.appsec.HttpClientPayload; +import datadog.trace.api.appsec.HttpClientRequest; +import datadog.trace.api.appsec.HttpClientResponse; import datadog.trace.api.gateway.Events; import datadog.trace.api.gateway.Flow; import datadog.trace.api.gateway.IGSpanInfo; @@ -93,6 +100,7 @@ public class GatewayBridge { private final SubscriptionService subscriptionService; private final EventProducerService producerService; private final Supplier requestSamplerSupplier; + private final ApiSecurityDownstreamSampler downstreamSampler; private final List traceSegmentPostProcessors; // subscriber cache @@ -107,7 +115,8 @@ public class GatewayBridge { private volatile DataSubscriberInfo graphqlServerRequestMsgSubInfo; private volatile DataSubscriberInfo requestEndSubInfo; private volatile DataSubscriberInfo dbSqlQuerySubInfo; - private volatile DataSubscriberInfo ioNetUrlSubInfo; + private volatile DataSubscriberInfo httpClientRequestSubInfo; + private volatile DataSubscriberInfo httpClientResponseSubInfo; private volatile DataSubscriberInfo ioFileSubInfo; private volatile DataSubscriberInfo sessionIdSubInfo; private volatile DataSubscriberInfo userIdSubInfo; @@ -120,10 +129,12 @@ public GatewayBridge( SubscriptionService subscriptionService, EventProducerService producerService, @Nonnull Supplier requestSamplerSupplier, + ApiSecurityDownstreamSampler downstreamSampler, List traceSegmentPostProcessors) { this.subscriptionService = subscriptionService; this.producerService = producerService; this.requestSamplerSupplier = requestSamplerSupplier; + this.downstreamSampler = downstreamSampler; this.traceSegmentPostProcessors = traceSegmentPostProcessors; } @@ -154,7 +165,9 @@ public void init() { EVENTS.graphqlServerRequestMessage(), this::onGraphqlServerRequestMessage); subscriptionService.registerCallback(EVENTS.databaseConnection(), this::onDatabaseConnection); subscriptionService.registerCallback(EVENTS.databaseSqlQuery(), this::onDatabaseSqlQuery); - subscriptionService.registerCallback(EVENTS.networkConnection(), this::onNetworkConnection); + subscriptionService.registerCallback(EVENTS.httpClientSampling(), this::onHttpClientSampling); + subscriptionService.registerCallback(EVENTS.httpClientRequest(), this::onHttpClientRequest); + subscriptionService.registerCallback(EVENTS.httpClientResponse(), this::onHttpClientResponse); subscriptionService.registerCallback(EVENTS.fileLoaded(), this::onFileLoaded); subscriptionService.registerCallback(EVENTS.requestSession(), this::onRequestSession); subscriptionService.registerCallback(EVENTS.execCmd(), this::onExecCmd); @@ -188,7 +201,8 @@ public void reset() { graphqlServerRequestMsgSubInfo = null; requestEndSubInfo = null; dbSqlQuerySubInfo = null; - ioNetUrlSubInfo = null; + httpClientRequestSubInfo = null; + httpClientResponseSubInfo = null; ioFileSubInfo = null; sessionIdSubInfo = null; userIdSubInfo = null; @@ -312,31 +326,116 @@ private Flow onRequestSession(final RequestContext ctx_, final String sess } } - private Flow onNetworkConnection(RequestContext ctx_, String url) { + private Flow onHttpClientSampling(RequestContext ctx_, final long requestId) { + AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC); + if (ctx == null) { + return new Flow.ResultFlow<>(null); + } + ctx.increaseHttpClientRequestCount(); + return new Flow.ResultFlow<>(downstreamSampler.sampleHttpClientRequest(ctx, requestId)); + } + + private Flow onHttpClientRequest(RequestContext ctx_, HttpClientRequest request) { AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC); if (ctx == null) { return NoopFlow.INSTANCE; } + + final MapDataBundle.Builder bundleBuilder = + new MapDataBundle.Builder(CAPACITY_3_4) + .add(KnownAddresses.IO_NET_URL, request.getUrl()) + .add(KnownAddresses.IO_NET_REQUEST_METHOD, request.getMethod()) + .add(KnownAddresses.IO_NET_REQUEST_HEADERS, request.getHeaders()); + ; + if (downstreamSampler.isSampled(ctx, request.getRequestId())) { + final Object body = parseHttpClientBody(ctx, request); + if (body != null) { + bundleBuilder.add(KnownAddresses.IO_NET_REQUEST_BODY, body); + } + } + final DataBundle bundle = bundleBuilder.build(); + while (true) { - DataSubscriberInfo subInfo = ioNetUrlSubInfo; + DataSubscriberInfo subInfo = httpClientRequestSubInfo; if (subInfo == null) { - subInfo = producerService.getDataSubscribers(KnownAddresses.IO_NET_URL); - ioNetUrlSubInfo = subInfo; + subInfo = + producerService.getDataSubscribers( + KnownAddresses.IO_NET_URL, + KnownAddresses.IO_NET_REQUEST_METHOD, + KnownAddresses.IO_NET_REQUEST_HEADERS, + KnownAddresses.IO_NET_REQUEST_BODY); + httpClientRequestSubInfo = subInfo; } - if (subInfo == null || subInfo.isEmpty()) { - return NoopFlow.INSTANCE; + try { + final boolean raspActive = Config.get().isAppSecRaspEnabled(); + GatewayContext gwCtx = new GatewayContext(true, raspActive ? RuleType.SSRF : null); + return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx); + } catch (ExpiredSubscriberInfoException e) { + httpClientRequestSubInfo = null; + } + } + } + + private Flow onHttpClientResponse(RequestContext ctx_, HttpClientResponse response) { + AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC); + if (ctx == null) { + return NoopFlow.INSTANCE; + } + + final MapDataBundle.Builder bundleBuilder = + new MapDataBundle.Builder(CAPACITY_3_4) + .add(KnownAddresses.IO_NET_RESPONSE_STATUS, response.getStatus()) + .add(KnownAddresses.IO_NET_RESPONSE_HEADERS, response.getHeaders()); + // ignore the response if not sampled + if (downstreamSampler.isSampled(ctx, response.getRequestId())) { + final Object body = parseHttpClientBody(ctx, response); + if (body != null) { + bundleBuilder.add(KnownAddresses.IO_NET_RESPONSE_BODY, body); + } + } + + final DataBundle bundle = bundleBuilder.build(); + + while (true) { + DataSubscriberInfo subInfo = httpClientResponseSubInfo; + if (subInfo == null) { + subInfo = + producerService.getDataSubscribers( + KnownAddresses.IO_NET_RESPONSE_STATUS, + KnownAddresses.IO_NET_RESPONSE_HEADERS, + KnownAddresses.IO_NET_RESPONSE_BODY); + httpClientResponseSubInfo = subInfo; } - DataBundle bundle = - new MapDataBundle.Builder(CAPACITY_0_2).add(KnownAddresses.IO_NET_URL, url).build(); try { - GatewayContext gwCtx = new GatewayContext(true, RuleType.SSRF); + GatewayContext gwCtx = new GatewayContext(true); return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx); } catch (ExpiredSubscriberInfoException e) { - ioNetUrlSubInfo = null; + httpClientResponseSubInfo = null; } } } + private Object parseHttpClientBody( + final AppSecRequestContext ctx, final HttpClientPayload payload) { + if (payload.getContentType() == null || payload.getBody() == null) { + return null; + } + + final BodyParser parser = BodyParser.forMediaType(payload.getContentType()); + if (parser == null) { + log.debug(SEND_TELEMETRY, "Received non parseable content type {}", payload.getContentType()); + return null; + } + final BodyParser.State state = new BodyParser.State(); + final Object result = parser.parse(state, payload.getBody()); + if (state.stringTooLong || state.listMapTooLarge || state.objectTooDeep) { + ctx.setWafTruncated(); + WafMetricCollector.get() + .wafInputTruncated(state.stringTooLong, state.listMapTooLarge, state.objectTooDeep); + } + return result; + } + private Flow onExecCmd(RequestContext ctx_, String[] command) { AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC); if (ctx == null) { @@ -748,6 +847,11 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) { pp.processTraceSegment(traceSeg, ctx, collectedEvents); } + final int clientRequests = ctx.getHttpClientRequestCount(); + if (clientRequests > 0) { + traceSeg.setTagTop("_dd.appsec.downstream_request", clientRequests); + } + // If detected any events - mark span at appsec.event if (!collectedEvents.isEmpty()) { // Set asm keep in case that root span was not available when events are detected diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/util/BodyParser.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/util/BodyParser.java new file mode 100644 index 00000000000..866c1bb5b3b --- /dev/null +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/util/BodyParser.java @@ -0,0 +1,147 @@ +package com.datadog.appsec.util; + +import static com.datadog.appsec.ddwaf.WAFModule.MAX_DEPTH; +import static com.datadog.appsec.ddwaf.WAFModule.MAX_ELEMENTS; +import static com.datadog.appsec.ddwaf.WAFModule.MAX_STRING_SIZE; + +import com.squareup.moshi.JsonAdapter; +import com.squareup.moshi.JsonDataException; +import com.squareup.moshi.JsonReader; +import com.squareup.moshi.JsonWriter; +import datadog.trace.api.appsec.MediaType; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import okio.Okio; + +public interface BodyParser { + + Object parse(State state, InputStream inputStream); + + static BodyParser forJson() { + return JsonParser.INSTANCE; + } + + static BodyParser forMediaType(final MediaType type) { + if (type.isJson()) { + return JsonParser.INSTANCE; + } + return null; + } + + class State { + private int elemsLeft = MAX_ELEMENTS; + public boolean objectTooDeep = false; + public boolean listMapTooLarge = false; + public boolean stringTooLong = false; + } + + class JsonParser implements BodyParser { + + private static final BodyParser INSTANCE = new JsonParser(); + + @Override + public Object parse(final State state, final InputStream inputStream) { + try { + final JsonAdapter adapter = new BoundedObjectAdapter(state); + return adapter.fromJson(Okio.buffer(Okio.source(inputStream))); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static final class BoundedObjectAdapter extends JsonAdapter { + + private final State state; + + public BoundedObjectAdapter(final State state) { + this.state = state; + } + + @Override + public void toJson(final JsonWriter writer, @Nullable final Object value) throws IOException { + throw new UnsupportedOperationException("Parsing-only adapter"); + } + + @Override + public Object fromJson(final JsonReader reader) throws IOException { + return readValue(reader, 0); + } + + private Object readValue(final JsonReader r, final int depth) throws IOException { + if (depth >= MAX_DEPTH) { + state.objectTooDeep = true; + r.skipValue(); + return null; + } + + if (state.elemsLeft == 0) { + state.listMapTooLarge = true; + r.skipValue(); + return null; + } + state.elemsLeft--; + + switch (r.peek()) { + case BEGIN_OBJECT: + return readObject(r, depth); + case BEGIN_ARRAY: + state.elemsLeft--; + return readArray(r, depth); + case STRING: + String value = r.nextString(); + if (value.length() > MAX_STRING_SIZE) { + state.stringTooLong = true; + value = value.substring(0, MAX_STRING_SIZE); + } + return value; + case NUMBER: + return r.nextDouble(); + case BOOLEAN: + return r.nextBoolean(); + case NULL: + return r.nextNull(); + default: + throw new JsonDataException("Unexpected token at value boundary: " + r.peek()); + } + } + + private Map readObject(final JsonReader r, final int depth) + throws IOException { + Map map = new LinkedHashMap<>(); + r.beginObject(); + while (r.hasNext()) { + String name = r.nextName(); + if (state.elemsLeft > 0) { + Object val = readValue(r, depth + 1); + map.put(name, val); + } else { + state.listMapTooLarge = true; + r.skipValue(); + } + } + r.endObject(); + return map; + } + + private List readArray(final JsonReader r, final int depth) throws IOException { + List list = new ArrayList<>(); + r.beginArray(); + while (r.hasNext()) { + if (state.elemsLeft > 0) { + list.add(readValue(r, depth + 1)); + } else { + state.listMapTooLarge = true; + r.skipValue(); + } + } + r.endArray(); + return list; + } + } + } +} diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityDownstreamSamplerTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityDownstreamSamplerTest.groovy new file mode 100644 index 00000000000..ab0f8b9443c --- /dev/null +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityDownstreamSamplerTest.groovy @@ -0,0 +1,50 @@ +package com.datadog.appsec.api.security + +import com.datadog.appsec.gateway.AppSecRequestContext +import datadog.trace.test.util.DDSpecification + +class ApiSecurityDownstreamSamplerTest extends DDSpecification { + + void 'test include all/none'() { + given: + final ctx = Mock(AppSecRequestContext) + final sampler = ApiSecurityDownstreamSampler.build(rate) + + when: + final initialDecisions = (1..10).collect { sampler.sampleHttpClientRequest(ctx, it)} + + then: + initialDecisions.every { it == expected } + + when: + final sampled = (1..10).collect { sampler.isSampled(ctx, it)} + + then: + sampled.every { it == expected } + + where: + rate | expected + 0D | false + 1D | true + } + + void 'test sampling algorithm'() { + given: + final epsilon = 0.05 + final ctx = Mock(AppSecRequestContext) { + sampleHttpClientRequest(_ as long) >> true + isHttpClientRequestSampled(_ as long) >> true + } + final sampler = ApiSecurityDownstreamSampler.build(expectedRate) + + when: + final samples = (1..100).collect { sampler.sampleHttpClientRequest(ctx, it)} + + then: + final rate = samples.count { it } / samples.size() + rate.subtract(expectedRate).abs() <= epsilon + + where: + expectedRate << [0.1, 0.25, 0.5, 0.75, 0.9] + } +} diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/event/data/KnownAddressesSpecificationForkedTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/event/data/KnownAddressesSpecificationForkedTest.groovy index e634a2bf978..432363f9073 100644 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/event/data/KnownAddressesSpecificationForkedTest.groovy +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/event/data/KnownAddressesSpecificationForkedTest.groovy @@ -42,6 +42,12 @@ class KnownAddressesSpecificationForkedTest extends Specification { 'server.business_logic.users.login.success', 'server.business_logic.users.signup', 'server.io.net.url', + 'server.io.net.request.headers', + 'server.io.net.request.method', + 'server.io.net.request.body', + 'server.io.net.response.status', + 'server.io.net.response.headers', + 'server.io.net.response.body', 'server.io.fs.file', 'server.sys.exec.cmd', 'server.sys.shell.cmd', @@ -51,7 +57,7 @@ class KnownAddressesSpecificationForkedTest extends Specification { void 'number of known addresses is expected number'() { expect: - Address.instanceCount() == 39 + Address.instanceCount() == 45 KnownAddresses.WAF_CONTEXT_PROCESSOR.serial == Address.instanceCount() - 1 } } diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy index 2d67fdc6276..32da82e669c 100644 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy @@ -1,6 +1,7 @@ package com.datadog.appsec.gateway import com.datadog.appsec.AppSecSystem +import com.datadog.appsec.api.security.ApiSecurityDownstreamSampler import com.datadog.appsec.api.security.ApiSecuritySamplerImpl import com.datadog.appsec.config.TraceSegmentPostProcessor import com.datadog.appsec.event.EventDispatcher @@ -11,6 +12,9 @@ import com.datadog.appsec.report.AppSecEvent import com.datadog.appsec.report.AppSecEventWrapper import datadog.trace.api.ProductTraceSource import datadog.trace.api.TagMap +import datadog.trace.api.appsec.HttpClientRequest +import datadog.trace.api.appsec.HttpClientResponse +import datadog.trace.api.appsec.MediaType import datadog.trace.api.config.GeneralConfig import datadog.trace.api.function.TriConsumer import datadog.trace.api.function.TriFunction @@ -87,7 +91,8 @@ class GatewayBridgeSpecification extends DDSpecification { TraceSegmentPostProcessor pp = Mock() ApiSecuritySamplerImpl requestSampler = Mock(ApiSecuritySamplerImpl) - GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, () -> requestSampler, [pp]) + ApiSecurityDownstreamSampler downstreamSampler = Mock(ApiSecurityDownstreamSampler) + GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, () -> requestSampler, downstreamSampler, [pp]) Supplier> requestStartedCB BiFunction> requestEndedCB @@ -109,7 +114,9 @@ class GatewayBridgeSpecification extends DDSpecification { BiFunction, Flow> graphqlServerRequestMessageCB BiConsumer databaseConnectionCB BiFunction> databaseSqlQueryCB - BiFunction> networkConnectionCB + BiFunction> httpClientRequestCB + BiFunction> httpClientResponseCB + BiFunction> httpClientSamplingCB BiFunction> fileLoadedCB BiFunction> requestSessionCB BiFunction> execCmdCB @@ -474,7 +481,9 @@ class GatewayBridgeSpecification extends DDSpecification { 1 * ig.registerCallback(EVENTS.graphqlServerRequestMessage(), _) >> { graphqlServerRequestMessageCB = it[1]; null } 1 * ig.registerCallback(EVENTS.databaseConnection(), _) >> { databaseConnectionCB = it[1]; null } 1 * ig.registerCallback(EVENTS.databaseSqlQuery(), _) >> { databaseSqlQueryCB = it[1]; null } - 1 * ig.registerCallback(EVENTS.networkConnection(), _) >> { networkConnectionCB = it[1]; null } + 1 * ig.registerCallback(EVENTS.httpClientRequest(), _) >> { httpClientRequestCB = it[1]; null } + 1 * ig.registerCallback(EVENTS.httpClientResponse(), _) >> { httpClientResponseCB = it[1]; null } + 1 * ig.registerCallback(EVENTS.httpClientSampling(), _) >> { httpClientSamplingCB = it[1]; null } 1 * ig.registerCallback(EVENTS.fileLoaded(), _) >> { fileLoadedCB = it[1]; null } 1 * ig.registerCallback(EVENTS.requestSession(), _) >> { requestSessionCB = it[1]; null } 1 * ig.registerCallback(EVENTS.execCmd(), _) >> { execCmdCB = it[1]; null } @@ -844,24 +853,89 @@ class GatewayBridgeSpecification extends DDSpecification { gatewayContext.isRasp == true } - void 'process network connection URL'() { + void 'process http client request sampling'() { + setup: + eventDispatcher.getDataSubscribers({ KnownAddresses.IO_NET_URL in it }) >> nonEmptyDsInfo + + when: + Flow flow = httpClientSamplingCB.apply(ctx, 1L) + + then: + 1 * downstreamSampler.sampleHttpClientRequest(arCtx, 1L) >> { sampled } + flow.result == sampled + + where: + sampled << [true, false] + } + + void 'process http client request'() { setup: final url = 'https://www.datadoghq.com/' + final method = 'POST' + final headers = ['X-AppSec-TEst': ['true']] + final contentType = MediaType.parse('application/json') + final body = new ByteArrayInputStream('{"hello": "World!"}'.bytes) eventDispatcher.getDataSubscribers({ KnownAddresses.IO_NET_URL in it }) >> nonEmptyDsInfo DataBundle bundle GatewayContext gatewayContext when: - Flow flow = networkConnectionCB.apply(ctx, url) + final request = new HttpClientRequest(1L, url, method, headers) + request.setBody(contentType, body) + Flow flow = httpClientRequestCB.apply(ctx, request) then: + downstreamSampler.isSampled(arCtx, _ as long) >> { sampled } 1 * eventDispatcher.publishDataEvent(nonEmptyDsInfo, ctx.data, _ as DataBundle, _ as GatewayContext) >> { a, b, db, gw -> bundle = db; gatewayContext = gw; NoopFlow.INSTANCE } + bundle.size() == (sampled ? 4 : 3) bundle.get(KnownAddresses.IO_NET_URL) == url + bundle.get(KnownAddresses.IO_NET_REQUEST_METHOD) == method + bundle.get(KnownAddresses.IO_NET_REQUEST_HEADERS) == headers + if (sampled) { + bundle.get(KnownAddresses.IO_NET_REQUEST_BODY) == ['Hello': 'World!'] + } flow.result == null flow.action == Flow.Action.Noop.INSTANCE gatewayContext.isTransient == true gatewayContext.isRasp == true + + where: + sampled << [true, false] + } + + void 'process http client response'() { + setup: + final status = 200 + final headers = ['X-AppSec-TEst': ['true']] + final contentType = MediaType.parse('application/json') + final body = new ByteArrayInputStream('{"hello": "World!"}'.bytes) + eventDispatcher.getDataSubscribers({ KnownAddresses.IO_NET_RESPONSE_STATUS in it }) >> nonEmptyDsInfo + DataBundle bundle + GatewayContext gatewayContext + + when: + final response = new HttpClientResponse(1L, status, headers) + response.setBody(contentType, body) + Flow flow = httpClientResponseCB.apply(ctx, response) + + then: + downstreamSampler.isSampled(arCtx, _ as long) >> { sampled } + 1 * eventDispatcher.publishDataEvent(nonEmptyDsInfo, ctx.data, _ as DataBundle, _ as GatewayContext) >> + { a, b, db, gw -> bundle = db; gatewayContext = gw; NoopFlow.INSTANCE } + bundle.size() == (sampled ? 3 : 2) + bundle.get(KnownAddresses.IO_NET_RESPONSE_STATUS) == status + bundle.get(KnownAddresses.IO_NET_RESPONSE_HEADERS) == headers + if (sampled) { + bundle.get(KnownAddresses.IO_NET_RESPONSE_BODY) == ['Hello': 'World!'] + } + flow.result == null + flow.action == Flow.Action.Noop.INSTANCE + gatewayContext.isTransient == true + gatewayContext.isRasp == false + + where: + sampled << [true, false] } void 'process file loaded'() { diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/util/BodyParserSpecification.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/util/BodyParserSpecification.groovy new file mode 100644 index 00000000000..24bcaf12b8c --- /dev/null +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/util/BodyParserSpecification.groovy @@ -0,0 +1,292 @@ +package com.datadog.appsec.util + +import com.datadog.appsec.ddwaf.WAFModule +import spock.lang.Specification + +import java.nio.charset.StandardCharsets + +class BodyParserSpecification extends Specification { + + void 'test parse simple JSON object'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def json = '{"name":"John","age":30}' + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.name == 'John' + result.age == 30.0 + !state.objectTooDeep + !state.listMapTooLarge + !state.stringTooLong + } + + void 'test parse simple JSON array'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def json = '[1,2,3,"test"]' + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof List + result.size() == 4 + result[0] == 1.0 + result[1] == 2.0 + result[2] == 3.0 + result[3] == 'test' + !state.objectTooDeep + !state.listMapTooLarge + !state.stringTooLong + } + + void 'test parse JSON with various data types'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def json = '{"string":"hello","number":42,"boolean":true,"null":null,"array":[1,2],"object":{"nested":"value"}}' + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.string == 'hello' + result.number == 42.0 + result.boolean == true + result.null == null + result.array instanceof List + result.array.size() == 2 + result.object instanceof Map + result.object.nested == 'value' + } + + void 'test parse nested JSON object'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def json = '{"level1":{"level2":{"level3":"deep"}}}' + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.level1 instanceof Map + result.level1.level2 instanceof Map + result.level1.level2.level3 == 'deep' + !state.objectTooDeep + } + + void 'test parse empty JSON object'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def json = '{}' + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.isEmpty() + } + + void 'test parse empty JSON array'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def json = '[]' + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof List + result.isEmpty() + } + + void 'test parse JSON with special characters'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def json = '{"unicode":"\\u0048\\u0065\\u006c\\u006c\\u006f","escaped":"line1\\nline2\\ttab"}' + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.unicode == 'Hello' + result.escaped == 'line1\nline2\ttab' + } + + void 'test IOException wrapped in RuntimeException'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def inputStream = new InputStream() { + @Override + int read() throws IOException { + throw new IOException("Simulated IO error") + } + } + + when: + parser.parse(state, inputStream) + + then: + RuntimeException ex = thrown() + ex.cause instanceof IOException + ex.cause.message == "Simulated IO error" + } + + void 'test depth limit exceeded - objectTooDeep flag set'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def deepJson = generateDeepNestedJson(WAFModule.MAX_DEPTH + 5) + def inputStream = new ByteArrayInputStream(deepJson.getBytes(StandardCharsets.UTF_8)) + + when: + parser.parse(state, inputStream) + + then: + state.objectTooDeep + } + + void 'test string length limit exceeded - stringTooLong flag set'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def longString = "a" * (WAFModule.MAX_STRING_SIZE + 10) + def json = "{\"longString\":\"${longString}\"}" + def inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.longString.length() <= WAFModule.MAX_STRING_SIZE + state.stringTooLong + } + + void 'test elements limit exceeded in object - listMapTooLarge flag set'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def largeObjectJson = generateLargeObjectJson(WAFModule.MAX_ELEMENTS + 10) + def inputStream = new ByteArrayInputStream(largeObjectJson.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.size() <= WAFModule.MAX_ELEMENTS + state.listMapTooLarge + } + + void 'test elements limit exceeded in array - listMapTooLarge flag set'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def largeArrayJson = generateLargeArrayJson(WAFModule.MAX_ELEMENTS + 10) + def inputStream = new ByteArrayInputStream(largeArrayJson.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof List + result.size() <= WAFModule.MAX_ELEMENTS + state.listMapTooLarge + } + + void 'test mixed nested structure with limits'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def complexJson = '{"array":[1,2,3],"object":{"nested":{"deep":"value"}},"string":"normal"}' + def inputStream = new ByteArrayInputStream(complexJson.getBytes(StandardCharsets.UTF_8)) + + when: + def result = parser.parse(state, inputStream) + + then: + result instanceof Map + result.array instanceof List + result.object instanceof Map + result.string == 'normal' + !state.objectTooDeep + !state.listMapTooLarge + !state.stringTooLong + } + + void 'test invalid JSON throws JsonDataException'() { + given: + def parser = BodyParser.forJson() + def state = new BodyParser.State() + def invalidJson = '{"invalid":}' + def inputStream = new ByteArrayInputStream(invalidJson.getBytes(StandardCharsets.UTF_8)) + + when: + parser.parse(state, inputStream) + + then: + RuntimeException ex = thrown() + ex.cause instanceof Exception + } + + private static String generateDeepNestedJson(int depth) { + def sb = new StringBuilder() + for (int i = 0; i < depth; i++) { + sb.append('{"level').append(i).append('":') + } + sb.append('"deep"') + for (int i = 0; i < depth; i++) { + sb.append('}') + } + return sb.toString() + } + + private static String generateLargeObjectJson(int size) { + def sb = new StringBuilder() + sb.append('{') + for (int i = 0; i < size; i++) { + if (i > 0) { + sb.append(',') + } + sb.append('"key').append(i).append('":"value').append(i).append('"') + } + sb.append('}') + return sb.toString() + } + + private static String generateLargeArrayJson(int size) { + def sb = new StringBuilder() + sb.append('[') + for (int i = 0; i < size; i++) { + if (i > 0) { + sb.append(',') + } + sb.append(i) + } + sb.append(']') + return sb.toString() + } +} diff --git a/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy b/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy index 57de7a2b776..952261bcd17 100644 --- a/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy +++ b/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy @@ -7,13 +7,23 @@ import datadog.trace.agent.test.naming.VersionedNamingTestBase import datadog.trace.agent.test.server.http.HttpProxy import datadog.trace.api.DDSpanTypes import datadog.trace.api.DDTags +import datadog.trace.api.appsec.HttpClientRequest +import datadog.trace.api.appsec.HttpClientResponse import datadog.trace.api.config.TracerConfig import datadog.trace.api.datastreams.DataStreamsContext +import datadog.trace.api.gateway.Events +import datadog.trace.api.gateway.Flow +import datadog.trace.api.gateway.RequestContext +import datadog.trace.api.gateway.RequestContextSlot +import datadog.trace.bootstrap.instrumentation.api.AgentTracer +import datadog.trace.bootstrap.instrumentation.api.TagContext import datadog.trace.bootstrap.instrumentation.api.Tags import datadog.trace.bootstrap.instrumentation.api.URIUtils import datadog.trace.core.DDSpan import datadog.trace.core.datastreams.StatsGroup import datadog.trace.test.util.Flaky +import groovy.json.JsonOutput +import groovy.json.JsonSlurper import spock.lang.AutoCleanup import spock.lang.IgnoreIf import spock.lang.Requires @@ -21,6 +31,7 @@ import spock.lang.Shared import java.util.concurrent.ExecutionException import java.util.concurrent.TimeUnit +import java.util.function.BiFunction import static datadog.trace.agent.test.server.http.TestHttpServer.httpServer import static datadog.trace.agent.test.utils.PortUtils.UNUSABLE_PORT @@ -28,7 +39,11 @@ import static datadog.trace.agent.test.utils.TraceUtils.basicSpan import static datadog.trace.agent.test.utils.TraceUtils.runUnderTrace import static datadog.trace.api.config.TraceInstrumentationConfig.HTTP_CLIENT_HOST_SPLIT_BY_DOMAIN import static datadog.trace.api.config.TraceInstrumentationConfig.HTTP_CLIENT_TAG_QUERY_STRING -import static datadog.trace.api.config.TracerConfig.* +import static datadog.trace.api.config.TracerConfig.HEADER_TAGS +import static datadog.trace.api.config.TracerConfig.REQUEST_HEADER_TAGS +import static datadog.trace.api.config.TracerConfig.RESPONSE_HEADER_TAGS +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activeSpan +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.get abstract class HttpClientTest extends VersionedNamingTestBase { protected static final BODY_METHODS = ["POST", "PUT"] @@ -80,13 +95,24 @@ abstract class HttpClientTest extends VersionedNamingTestBase { handleDistributedRequest() String msg = "Hello." response.status(200) - .addHeader('x-datadog-test-response-header', 'baz') - .send(msg) + .addHeader('x-datadog-test-response-header', 'baz') + .send(msg) } prefix("/timeout") { Thread.sleep(10_000) throw new IllegalStateException("Should never happen") } + prefix("/json") { + if (request.getContentType() != 'application/json') { + response.status(400).send('Bad content type') + } else { + response + .status(200) + .addHeader('Content-Type', 'application/json') + .addHeader('X-AppSec-Test', 'true') + .sendWithType('application/json', request.body) + } + } } } @@ -120,19 +146,27 @@ abstract class HttpClientTest extends VersionedNamingTestBase { def setupSpec() { List proxyList = Collections.singletonList(new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxy.port))) proxySelector = new ProxySelector() { - @Override - List select(URI uri) { - if (uri.fragment == "proxy") { - return proxyList - } - return getDefault().select(uri) + @Override + List select(URI uri) { + if (uri.fragment == "proxy") { + return proxyList } + return getDefault().select(uri) + } - @Override - void connectFailed(URI uri, SocketAddress sa, IOException ioe) { - getDefault().connectFailed(uri, sa, ioe) - } + @Override + void connectFailed(URI uri, SocketAddress sa, IOException ioe) { + getDefault().connectFailed(uri, sa, ioe) } + } + + // Register the Instrumentation Gateway callbacks + def ss = get().getSubscriptionService(RequestContextSlot.APPSEC) + def callbacks = new IGCallbacks() + Events events = Events.get() + ss.registerCallback(events.httpClientRequest(), callbacks.httpClientRequestCb) + ss.registerCallback(events.httpClientResponse(), callbacks.httpClientResponseCb) + ss.registerCallback(events.httpClientSampling(), callbacks.httpClientBodySamplingCb) } /** @@ -174,7 +208,9 @@ abstract class HttpClientTest extends VersionedNamingTestBase { } and: if (isDataStreamsEnabled()) { - StatsGroup first = TEST_DATA_STREAMS_WRITER.groups.find { it.parentHash == 0 } + StatsGroup first = TEST_DATA_STREAMS_WRITER.groups.find { + it.parentHash == 0 + } verifyAll(first) { getTags() == DSM_EDGE_TAGS } @@ -810,19 +846,58 @@ abstract class HttpClientTest extends VersionedNamingTestBase { 'GET' | 'X-Datadog-Test-Response-Header' | 'response_header_tag' | [ 'response_header_tag': 'baz' ] } + + @IgnoreIf({ !instance.testAppSecAnalysis() }) + void 'test appsec client request analysis'() { + given: + final url = server.address.resolve(endpoint) + final tags = [ + 'downstream.request.url': url.toString(), + 'downstream.request.method': method, + 'downstream.request.body': body, + 'downstream.response.status': 200, + 'downstream.response.body': body, + ] + + when: + final status = runUnderAppSecTrace { + doRequest(method, url, ['Content-Type': contentType] + headers, body) { InputStream response -> + assert response.text == body + } + } + + then: + status == 200 + TEST_WRITER.waitForTraces(1) + final span = TEST_WRITER.get(0).find { it.spanType == 'http'} + tags.each { + assert span.getTag(it.key) == it.value + } + final requestHeaders = new JsonSlurper().parseText(span.getTag("downstream.request.headers") as String) as Map> + final responseHeaders = new JsonSlurper().parseText(span.getTag("downstream.response.headers") as String) as Map> + headers.each { + assert requestHeaders[it.key] == [it.value] + assert responseHeaders[it.key] == [it.value] + } + + where: + endpoint | method | contentType | headers | body + '/json' | 'POST' | 'application/json' | ['X-AppSec-Test': 'true'] | '{"hello": "world!" }' + } + // parent span must be cast otherwise it breaks debugging classloading (junit loads it early) void clientSpan( - TraceAssert trace, - Object parentSpan, - String method = "GET", - boolean renameService = false, - boolean tagQueryString = false, - URI uri = server.address.resolve("/success"), - Integer status = 200, - boolean error = false, - Throwable exception = null, - boolean ignorePeer = false, - Map extraTags = null) { + TraceAssert trace, + Object parentSpan, + String method = "GET", + boolean renameService = false, + boolean tagQueryString = false, + URI uri = server.address.resolve("/success"), + Integer status = 200, + boolean error = false, + Throwable exception = null, + boolean ignorePeer = false, + Map extraTags = null) { def expectedQuery = tagQueryString ? uri.query : null def expectedUrl = URIUtils.buildURL(uri.scheme, uri.host, uri.port, uri.path) @@ -916,4 +991,59 @@ abstract class HttpClientTest extends VersionedNamingTestBase { // function is used. There is no way to stop a test from a derived class hence the flag true } + + boolean testAppSecAnalysis() { + false + } + + protected E runUnderAppSecTrace(Closure cl) { + final ddctx = new TagContext().withRequestContextDataAppSec(new IGCallbacks.Context()) + final span = TEST_TRACER.startSpan("test", "test-appsec-span", ddctx) + try { + return AgentTracer.activateSpan(span).withCloseable(cl) + } finally { + span.finish() + } + } + + static class IGCallbacks { + + static class Context { + boolean hasAppSecData + } + + final BiFunction> httpClientBodySamplingCb = + { RequestContext rqCtxt, final long requestId -> + return new Flow.ResultFlow<>(true) + } as BiFunction> + + final BiFunction> httpClientRequestCb = + { RequestContext rqCtxt, HttpClientRequest req -> + if (req.headers?.containsKey('X-AppSec-Test')) { + final context = rqCtxt.getData(RequestContextSlot.APPSEC) as Context + context.hasAppSecData = true + activeSpan() + .setTag('downstream.request.url', req.url) + .setTag('downstream.request.method', req.method) + .setTag('downstream.request.headers', JsonOutput.toJson(req.headers)) + .setTag('downstream.request.body', req.body?.text) + + } + Flow.ResultFlow.empty() + } as BiFunction> + + final BiFunction> httpClientResponseCb = + { RequestContext rqCtxt, HttpClientResponse res -> + final context = rqCtxt.getData(RequestContextSlot.APPSEC) as Context + if (context.hasAppSecData) { + activeSpan() + .setTag('downstream.response.status', res.status) + .setTag('downstream.response.headers', JsonOutput.toJson(res.headers)) + .setTag('downstream.response.body', res.body?.text) + } + Flow.ResultFlow.empty() + } as BiFunction> + + + } } diff --git a/dd-java-agent/instrumentation/java-net/src/main/java/datadog/trace/instrumentation/java/net/URLSinkCallSite.java b/dd-java-agent/instrumentation/java-net/src/main/java/datadog/trace/instrumentation/java/net/URLSinkCallSite.java index d275e303bbb..437c2a50e2c 100644 --- a/dd-java-agent/instrumentation/java-net/src/main/java/datadog/trace/instrumentation/java/net/URLSinkCallSite.java +++ b/dd-java-agent/instrumentation/java-net/src/main/java/datadog/trace/instrumentation/java/net/URLSinkCallSite.java @@ -5,6 +5,7 @@ import datadog.appsec.api.blocking.BlockingException; import datadog.trace.agent.tooling.csi.CallSite; import datadog.trace.api.Config; +import datadog.trace.api.appsec.HttpClientRequest; import datadog.trace.api.appsec.RaspCallSites; import datadog.trace.api.gateway.BlockResponseFunction; import datadog.trace.api.gateway.Flow; @@ -59,11 +60,11 @@ private static void raspCallback(@Nonnull final URL url) { } try { - final BiFunction> networkConnectionCallback = + final BiFunction> httpClientRequestCb = AgentTracer.get() .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.networkConnection()); - if (networkConnectionCallback == null) { + .getCallback(EVENTS.httpClientRequest()); + if (httpClientRequestCb == null) { return; } @@ -77,7 +78,8 @@ private static void raspCallback(@Nonnull final URL url) { return; } - Flow flow = networkConnectionCallback.apply(ctx, url.toString()); + Flow flow = + httpClientRequestCb.apply(ctx, new HttpClientRequest(span.getSpanId(), url.toString())); Flow.Action action = flow.getAction(); if (action instanceof Flow.Action.RequestBlockingAction) { BlockResponseFunction brf = ctx.getBlockResponseFunction(); diff --git a/dd-java-agent/instrumentation/java-net/src/test/groovy/datadog/trace/instrumentation/java/net/URLSinkCallSiteTest.groovy b/dd-java-agent/instrumentation/java-net/src/test/groovy/datadog/trace/instrumentation/java/net/URLSinkCallSiteTest.groovy index 3be4e1c68e5..64e8f2b6e97 100644 --- a/dd-java-agent/instrumentation/java-net/src/test/groovy/datadog/trace/instrumentation/java/net/URLSinkCallSiteTest.groovy +++ b/dd-java-agent/instrumentation/java-net/src/test/groovy/datadog/trace/instrumentation/java/net/URLSinkCallSiteTest.groovy @@ -1,6 +1,7 @@ package datadog.trace.instrumentation.java.net import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.appsec.HttpClientRequest import datadog.trace.api.config.AppSecConfig import datadog.trace.api.config.IastConfig import datadog.trace.api.gateway.CallbackProvider @@ -85,8 +86,8 @@ class URLSinkCallSiteTest extends InstrumentationSpecification { TestURLCallSiteSuite.&"$method".call(args as Object[]) then: - 1 * callbackProvider.getCallback(EVENTS.networkConnection()) >> listener - 1 * listener.apply(reqCtx, URL.toString()) + 1 * callbackProvider.getCallback(EVENTS.httpClientRequest()) >> listener + 1 * listener.apply(reqCtx, _ as HttpClientRequest) where: suite << tests() diff --git a/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java b/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java new file mode 100644 index 00000000000..a9f82217b85 --- /dev/null +++ b/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java @@ -0,0 +1,221 @@ +package datadog.trace.instrumentation.okhttp2; + +import static datadog.trace.api.gateway.Events.EVENTS; + +import com.squareup.okhttp.Headers; +import com.squareup.okhttp.Interceptor; +import com.squareup.okhttp.Request; +import com.squareup.okhttp.RequestBody; +import com.squareup.okhttp.Response; +import com.squareup.okhttp.ResponseBody; +import datadog.appsec.api.blocking.BlockingException; +import datadog.trace.api.Config; +import datadog.trace.api.appsec.HttpClientPayload; +import datadog.trace.api.appsec.HttpClientRequest; +import datadog.trace.api.appsec.HttpClientResponse; +import datadog.trace.api.appsec.MediaType; +import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.CallbackProvider; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.Tags; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import okio.BufferedSink; +import okio.BufferedSource; +import okio.Okio; +import okio.Sink; + +public class AppSecInterceptor implements Interceptor { + + private static final int BODY_PARSING_SIZE_LIMIT = Config.get().getAppSecBodyParsingSizeLimit(); + + @Override + public Response intercept(final Chain chain) throws IOException { + final AgentSpan span = AgentTracer.activeSpan(); + final RequestContext ctx = span.getRequestContext(); + final long requestId = span.getSpanId(); + final boolean sampled = sampleRequest(ctx, requestId); + final Request request = onRequest(span, sampled, chain.request()); + final Response response = chain.proceed(request); + return onResponse(span, sampled, response); + } + + private Request onRequest(final AgentSpan span, final boolean sampled, final Request request) { + Request result = request; + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction> requestCb = + cbp.getCallback(EVENTS.httpClientRequest()); + if (requestCb == null) { + return request; + } + + final RequestBody requestBody = request.body(); + final RequestContext ctx = span.getRequestContext(); + final long requestId = span.getSpanId(); + final String url = span.getTag(Tags.HTTP_URL).toString(); + final HttpClientRequest clientRequest = + new HttpClientRequest(requestId, url, request.method(), mapHeaders(request.headers())); + if (sampled && requestBody != null) { + // we are going to effectively read all the request body in memory to be analyzed by the WAF, + // we also + // modify the outbound request accordingly + final MediaType mediaType = contentType(requestBody); + try { + final long contentLength = requestBody.contentLength(); + if (shouldProcessBody(contentLength, mediaType)) { + final byte[] payload = readBody(requestBody, (int) contentLength); + if (payload.length <= BODY_PARSING_SIZE_LIMIT) { + clientRequest.setBody(mediaType, new ByteArrayInputStream(payload)); + } + result = + request + .newBuilder() + .method(request.method(), RequestBody.create(requestBody.contentType(), payload)) + .build(); // update request + } + } catch (IOException e) { + // ignore it and keep the original request + } + } + publish(ctx, clientRequest, requestCb); + return result; + } + + private Response onResponse( + final AgentSpan span, final boolean sampled, final Response response) { + Response result = response; + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction> responseCb = + cbp.getCallback(EVENTS.httpClientResponse()); + if (responseCb == null) { + return response; + } + final ResponseBody responseBody = response.body(); + final RequestContext ctx = span.getRequestContext(); + final long requestId = span.getSpanId(); + final HttpClientResponse clientResponse = + new HttpClientResponse(requestId, response.code(), mapHeaders(response.headers())); + if (sampled && responseBody != null) { + // we are going to effectively read all the response body in memory to be analyzed by the WAF, + // we also + // modify the inbound response accordingly + final MediaType mediaType = contentType(responseBody); + try { + final long contentLength = responseBody.contentLength(); + if (shouldProcessBody(contentLength, mediaType)) { + final byte[] payload = readBody(responseBody, (int) contentLength); + if (payload.length <= BODY_PARSING_SIZE_LIMIT) { + clientResponse.setBody(mediaType, new ByteArrayInputStream(payload)); + } + result = + response + .newBuilder() + .body(ResponseBody.create(responseBody.contentType(), payload)) + .build(); + } + } catch (IOException e) { + // ignore it and keep the original response + } + } + + publish(ctx, clientResponse, responseCb); + return result; + } + + private

void publish( + final RequestContext ctx, + final P request, + final BiFunction> callback) { + Flow flow = callback.apply(ctx, request); + Flow.Action action = flow.getAction(); + if (action instanceof Flow.Action.RequestBlockingAction) { + BlockResponseFunction brf = ctx.getBlockResponseFunction(); + if (brf != null) { + Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action; + brf.tryCommitBlockingResponse( + ctx.getTraceSegment(), + rba.getStatusCode(), + rba.getBlockingContentType(), + rba.getExtraHeaders()); + } + throw new BlockingException("Blocked request (for http downstream request)"); + } + } + + private boolean sampleRequest(final RequestContext ctx, final long requestId) { + // Check if the current http request was sampled + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction> samplingCb = + cbp.getCallback(EVENTS.httpClientSampling()); + if (samplingCb == null) { + return false; + } + final Flow sampled = samplingCb.apply(ctx, requestId); + return sampled.getResult() != null && sampled.getResult(); + } + + /** + * Ensure we are only consuming payloads we can safely deserialize with a bounded size to prevent + * from OOM + */ + private boolean shouldProcessBody(final long contentLength, final MediaType mediaType) { + if (contentLength <= 0) { + return false; // prevent from copying from unbounded source (just to be safe) + } + if (BODY_PARSING_SIZE_LIMIT <= 0) { + return false; // effectively disabled by configuration + } + if (contentLength > BODY_PARSING_SIZE_LIMIT) { + return false; + } + return mediaType.isDeserializable(); + } + + private byte[] readBody(final RequestBody body, final int contentLength) throws IOException { + final ByteArrayOutputStream buffer = new ByteArrayOutputStream(contentLength); + try (final BufferedSink sink = Okio.buffer(Okio.sink(buffer))) { + body.writeTo(sink); + } + return buffer.toByteArray(); + } + + private byte[] readBody(final ResponseBody body, final int contentLength) throws IOException { + final ByteArrayOutputStream buffer = new ByteArrayOutputStream(contentLength); + try (final BufferedSource source = body.source(); + final Sink sink = Okio.sink(buffer)) { + source.readAll(sink); + } + return buffer.toByteArray(); + } + + private Map> mapHeaders(final Headers headers) { + if (headers == null) { + return Collections.emptyMap(); + } + final Map> result = new HashMap<>(headers.size()); + for (final String name : headers.names()) { + result.put(name, headers.values(name)); + } + return result; + } + + private MediaType contentType(final RequestBody body) { + return MediaType.parse( + body == null || body.contentType() == null ? null : body.contentType().toString()); + } + + private MediaType contentType(final ResponseBody body) { + return MediaType.parse( + body == null || body.contentType() == null ? null : body.contentType().toString()); + } +} diff --git a/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttp2Instrumentation.java b/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttp2Instrumentation.java index fb357b594f4..d08f6c4c628 100644 --- a/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttp2Instrumentation.java +++ b/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttp2Instrumentation.java @@ -7,6 +7,7 @@ import com.squareup.okhttp.OkHttpClient; import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.bootstrap.ActiveSubsystems; import net.bytebuddy.asm.Advice; @AutoService(InstrumenterModule.class) @@ -27,6 +28,7 @@ public String[] helperClassNames() { packageName + ".RequestBuilderInjectAdapter", packageName + ".OkHttpClientDecorator", packageName + ".TracingInterceptor", + packageName + ".AppSecInterceptor", }; } @@ -44,8 +46,10 @@ public static void addTracingInterceptor(@Advice.This final OkHttpClient client) return; } } - client.interceptors().add(new TracingInterceptor()); + if (ActiveSubsystems.APPSEC_ACTIVE) { + client.interceptors().add(new AppSecInterceptor()); + } } } } diff --git a/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttpClientDecorator.java b/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttpClientDecorator.java index 3af5209dfa6..7f128bbbff9 100644 --- a/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttpClientDecorator.java +++ b/dd-java-agent/instrumentation/okhttp-2/src/main/java/datadog/trace/instrumentation/okhttp2/OkHttpClientDecorator.java @@ -2,6 +2,7 @@ import com.squareup.okhttp.Request; import com.squareup.okhttp.Response; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; import datadog.trace.bootstrap.instrumentation.api.UTF8BytesString; import datadog.trace.bootstrap.instrumentation.decorator.HttpClientDecorator; import java.net.URI; @@ -53,4 +54,10 @@ protected String getRequestHeader(Request request, String headerName) { protected String getResponseHeader(Response response, String headerName) { return response.header(headerName); } + + /** Overridden by {@link AppSecInterceptor} */ + @Override + protected void onHttpClientRequest(AgentSpan span, String url) { + // do nothing + } } diff --git a/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2AsyncTest.groovy b/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2AsyncTest.groovy index 246d842023c..7d0299581b6 100644 --- a/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2AsyncTest.groovy +++ b/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2AsyncTest.groovy @@ -19,7 +19,8 @@ abstract class OkHttp2AsyncTest extends OkHttp2Test { @Override int doRequest(String method, URI uri, Map headers, String body, Closure callback) { - def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse("text/plain"), body) : null + final contentType = headers.remove("Content-Type") + def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse(contentType ?: "text/plain"), body) : null def request = new Request.Builder() .url(uri.toURL()) .method(method, reqBody) @@ -33,13 +34,13 @@ abstract class OkHttp2AsyncTest extends OkHttp2Test { client.newCall(request).enqueue(new Callback() { void onResponse(Response response) { responseRef.set(response) - callback?.call() + callback?.call(response.body().byteStream()) latch.countDown() } void onFailure(Request req, IOException e) { exRef.set(e) - callback?.call() + callback?.call(e) latch.countDown() } }) diff --git a/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2Test.groovy b/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2Test.groovy index 5476112fa10..3ba2cb14d2d 100644 --- a/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2Test.groovy +++ b/dd-java-agent/instrumentation/okhttp-2/src/test/groovy/OkHttp2Test.groovy @@ -26,7 +26,8 @@ abstract class OkHttp2Test extends HttpClientTest { @Override int doRequest(String method, URI uri, Map headers, String body, Closure callback) { - def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse("text/plain"), body) : null + final contentType = headers.remove("Content-Type") + def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse(contentType ?: "text/plain"), body) : null def request = new Request.Builder() .url(uri.toURL()) @@ -34,7 +35,7 @@ abstract class OkHttp2Test extends HttpClientTest { .headers(Headers.of(HeadersUtil.headersToArray(headers))) .build() def response = client.newCall(request).execute() - callback?.call() + callback?.call(response.body().byteStream()) return response.code() } @@ -47,6 +48,11 @@ abstract class OkHttp2Test extends HttpClientTest { boolean testRedirects() { false } + + @Override + boolean testAppSecAnalysis() { + true + } } @Timeout(5) diff --git a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/SpringbootApplication.java b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/SpringbootApplication.java index 26e59cc0dfa..7c115bd857a 100644 --- a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/SpringbootApplication.java +++ b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/SpringbootApplication.java @@ -25,8 +25,9 @@ private static void activateAppSec() throws Exception { Field appSecClassLoaderField = agentClass.getDeclaredField("AGENT_CLASSLOADER"); appSecClassLoaderField.setAccessible(true); ClassLoader appSecClassLoader = (ClassLoader) appSecClassLoaderField.get(null); - Class appSecSystemClass = appSecClassLoader.loadClass("com.datadog.appsec.AppSecSystem"); - Field activeField = appSecSystemClass.getField("ACTIVE"); + Class appSecSystemClass = + appSecClassLoader.loadClass("datadog.trace.bootstrap.ActiveSubsystems"); + Field activeField = appSecSystemClass.getField("APPSEC_ACTIVE"); boolean curActiveValue = (boolean) activeField.get(null); if (curActiveValue) { System.out.println("AppSec is already active"); diff --git a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java index 791bd30c29c..7667afcc29d 100644 --- a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java +++ b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java @@ -1,10 +1,18 @@ package datadog.smoketest.appsec.springboot.controller; +import static org.springframework.web.bind.annotation.RequestMethod.GET; +import static org.springframework.web.bind.annotation.RequestMethod.POST; +import static org.springframework.web.bind.annotation.RequestMethod.PUT; + import com.fasterxml.jackson.databind.JsonNode; import com.squareup.okhttp.OkHttpClient; import com.squareup.okhttp.Request; +import com.squareup.okhttp.Response; import datadog.smoketest.appsec.springboot.service.AsyncService; +import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.IOException; +import java.io.InputStream; import java.net.URL; import java.nio.file.Paths; import java.sql.Connection; @@ -30,6 +38,7 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.support.ServletUriComponentsBuilder; @RestController public class WebController { @@ -248,6 +257,74 @@ public ResponseEntity> apiSecurityResponse( return ResponseEntity.ok(body); } + @RequestMapping( + value = "/api_security/http_client/okHttp2", + method = {POST, GET, PUT}) + public ResponseEntity apiSecurityHttpClientOkHttp2(final HttpServletRequest request) + throws IOException { + // create an internal http request to the echo endpoint to validate the http client library + final String url = + ServletUriComponentsBuilder.fromRequestUri(request) + .replacePath("/echo") + .build() + .toUriString(); + Request.Builder clientRequest = new Request.Builder().url(url); + if (request.getMethod().equalsIgnoreCase("POST")) { + final String contentType = request.getContentType(); + final byte[] data = readFully(request.getInputStream()); + clientRequest = + clientRequest.post( + com.squareup.okhttp.RequestBody.create( + com.squareup.okhttp.MediaType.parse(contentType), data)); + } else { + clientRequest.method(request.getMethod(), null); + } + final String statusCode = request.getHeader("Status"); + if (statusCode != null) { + clientRequest = clientRequest.header("Status", statusCode); + } + final String witness = request.getHeader("Witness"); + if (witness != null) { + clientRequest = clientRequest.header("Witness", witness); + } + final String echoHeaders = request.getHeader("echo-headers"); + if (echoHeaders != null) { + clientRequest = clientRequest.header("echo-headers", echoHeaders); + } + final Response clientResponse = new OkHttpClient().newCall(clientRequest.build()).execute(); + return ResponseEntity.status(200).body(clientResponse.body().string()); + } + + @RequestMapping( + value = "/echo", + method = {POST, GET, PUT}) + public ResponseEntity echo(final HttpServletRequest request) throws IOException { + final String statusHeader = request.getHeader("Status"); + final int statusCode = statusHeader == null ? 200 : Integer.parseInt(statusHeader); + ResponseEntity.BodyBuilder response = ResponseEntity.status(statusCode); + final String echoHeaders = request.getHeader("echo-headers"); + if (echoHeaders != null) { + response = response.header("echo-headers", echoHeaders); + } + if (request.getMethod().equalsIgnoreCase("POST")) { + final String contentType = request.getContentType(); + final byte[] data = readFully(request.getInputStream()); + return response.contentType(MediaType.parseMediaType(contentType)).body(new String(data)); + } else { + return response.body("OK"); + } + } + + private static byte[] readFully(final InputStream in) throws IOException { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + byte[] data = new byte[4096]; // 4KB buffer + int bytesRead; + while ((bytesRead = in.read(data, 0, data.length)) != -1) { + buffer.write(data, 0, bytesRead); + } + return buffer.toByteArray(); + } + private void withProcess(final Operation op) { Process process = null; try { diff --git a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy index 5514ec7dc51..c89031a3e8f 100644 --- a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy +++ b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy @@ -207,7 +207,196 @@ class SpringBootSmokeTest extends AbstractAppSecServerSmokeTest { ], transformers: [], on_match : ['block'] - ] + ], + [ + id : "apiA-100-001", + name: "API 10 tag rule on request headers", + tags: [ + type : "api10 request headers", + category: "attack_attempt" + ], + conditions: [ + [ + parameters: [ + inputs: [ + [ + address : "server.io.net.request.headers", + key_path: ["Witness"] + ] + ], + list: ["pwq3ojtropiw3hjtowir"] + ], + operator: "exact_match" + ] + ], + output: [ + event: true, + keep : true, + attributes: [ + "_dd.appsec.trace.req_headers": [ + value: "TAG_API10_REQ_HEADERS" + ] + ] + ], + on_match: [] + ], + [ + id : "apiA-100-002", + name: "API 10 tag rule on request body", + tags: [ + type : "api10 request body", + category: "attack_attempt" + ], + conditions: [ + [ + parameters: [ + inputs: [ + [ + address : "server.io.net.request.body", + key_path: ["payload_in"] + ] + ], + list: ["qw2jedrkjerbgol23ewpfirj2qw3or"] + ], + operator: "exact_match" + ] + ], + output: [ + event: true, + keep : true, + attributes: [ + "_dd.appsec.trace.req_body": [ + value: "TAG_API10_REQ_BODY" + ] + ] + ], + on_match: [] + ], + [ + id : "apiA-100-003", + name: "API 10 tag rule on request method", + tags: [ + type : "api10 request method", + category: "attack_attempt" + ], + conditions: [ + [ + parameters: [ + inputs: [ + [ + address: "server.io.net.request.method" + ] + ], + list: ["PUT"] + ], + operator: "exact_match" + ] + ], + output: [ + event: true, + keep : true, + attributes: [ + "_dd.appsec.trace.req_method": [ + value: "TAG_API10_REQ_METHOD" + ] + ] + ], + on_match: [] + ], + [ + id : "apiA-100-004", + name: "API 10 tag rule on response status", + tags: [ + type : "api10 response status", + category: "attack_attempt" + ], + conditions: [ + [ + parameters: [ + inputs: [ + [ + address: "server.io.net.response.status" + ] + ], + list: [201] + ], + operator: "exact_match" + ] + ], + output: [ + event: true, + keep : true, + attributes: [ + "_dd.appsec.trace.res_status": [ + value: "TAG_API10_RES_STATUS" + ] + ] + ], + on_match: [] + ], + [ + id : "apiA-100-005", + name: "API 10 tag rule on response headers", + tags: [ + type : "api10 response headers", + category: "attack_attempt" + ], + conditions: [ + [ + parameters: [ + inputs: [ + [ + address : "server.io.net.response.headers", + key_path: ["echo-headers"] + ] + ], + list: ["qwoierj12l3"] + ], + operator: "exact_match" + ] + ], + output: [ + event: true, + keep : true, + attributes: [ + "_dd.appsec.trace.res_headers": [ + value: "TAG_API10_RES_HEADERS" + ] + ] + ], + on_match: [] + ], + [ + id : "apiA-100-006", + name: "API 10 tag rule on response body", + tags: [ + type : "api10 reponse body", + category: "attack_attempt" + ], + conditions: [ + [ + parameters: [ + inputs: [ + [ + address: "server.io.net.response.body" + ] + ], + list: ["kqehf09123r4lnksef"] + ], + operator: "exact_match" + ] + ], + output: [ + event: true, + keep : true, + attributes: [ + "_dd.appsec.trace.res_body": [ + value: "TAG_API10_RES_BODY" + ] + ] + ], + on_match: [] + ], ]) } @@ -221,6 +410,7 @@ class SpringBootSmokeTest extends AbstractAppSecServerSmokeTest { List command = new ArrayList<>() command.add(javaPath()) + command.add('-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5005') command.addAll(defaultJavaProperties) command.addAll(defaultAppSecProperties) command.addAll((String[]) ["-jar", springBootShadowJar, "--server.port=${httpPort}"]) @@ -717,6 +907,131 @@ class SpringBootSmokeTest extends AbstractAppSecServerSmokeTest { assert schema['letters'][1]["len"] == 3 } + void 'API Security downstream request header analysis'() { + when: + final url = "http://localhost:${httpPort}/api_security/http_client/${variant}" + final request = new Request.Builder() + .url(url) + .get() + .header('Witness', "pwq3ojtropiw3hjtowir") + .build() + final response = client.newCall(request).execute() + + then: + response.code() == 200 + final span = assertDownstreamTrace() + span.meta['_dd.appsec.trace.req_headers'] == 'TAG_API10_REQ_HEADERS' + + where: + variant << httpClientDownstreamAnalysisVariants() + } + + void 'API Security downstream request body analysis'() { + when: + final url = "http://localhost:${httpPort}/api_security/http_client/${variant}" + final request = new Request.Builder() + .url(url) + .post(RequestBody.create(MediaType.parse('application/json'), '{"payload_in": "qw2jedrkjerbgol23ewpfirj2qw3or"}')) + .build() + final response = client.newCall(request).execute() + + then: + response.code() == 200 + final span = assertDownstreamTrace() + span.meta['_dd.appsec.trace.req_body'] == 'TAG_API10_REQ_BODY' + + where: + variant << httpClientDownstreamAnalysisVariants() + } + + void 'API Security downstream request method analysis'() { + when: + final url = "http://localhost:${httpPort}/api_security/http_client/${variant}" + final request = new Request.Builder() + .url(url) + .method("PUT", RequestBody.create(MediaType.parse("text/plain"), "hello".bytes)) + .build() + final response = client.newCall(request).execute() + + then: + response.code() == 200 + final span = assertDownstreamTrace() + span.meta['_dd.appsec.trace.req_method'] == 'TAG_API10_REQ_METHOD' + + where: + variant << httpClientDownstreamAnalysisVariants() + } + + void 'API Security downstream response status analysis'() { + when: + final url = "http://localhost:${httpPort}/api_security/http_client/${variant}" + final request = new Request.Builder() + .url(url) + .get() + .header('Status', "201") + .build() + final response = client.newCall(request).execute() + + then: + response.code() == 200 + final span = assertDownstreamTrace() + span.meta['_dd.appsec.trace.res_status'] == 'TAG_API10_RES_STATUS' + + where: + variant << httpClientDownstreamAnalysisVariants() + } + + void 'API Security downstream response header analysis'() { + when: + final url = "http://localhost:${httpPort}/api_security/http_client/${variant}" + final request = new Request.Builder() + .url(url) + .get() + .header('echo-headers', "qwoierj12l3") + .build() + final response = client.newCall(request).execute() + + then: + response.code() == 200 + final span = assertDownstreamTrace() + span.meta['_dd.appsec.trace.res_headers'] == 'TAG_API10_RES_HEADERS' + + where: + variant << httpClientDownstreamAnalysisVariants() + } + + void 'API Security downstream response body analysis'() { + when: + final url = "http://localhost:${httpPort}/api_security/http_client/${variant}" + final request = new Request.Builder() + .url(url) + .post(RequestBody.create(MediaType.parse('application/json'), '{"payload_out": "kqehf09123r4lnksef"}')) + .build() + final response = client.newCall(request).execute() + + then: + response.code() == 200 + final span = assertDownstreamTrace() + span.meta['_dd.appsec.trace.res_body'] == 'TAG_API10_RES_BODY' + + where: + variant << httpClientDownstreamAnalysisVariants() + } + + private RootSpan assertDownstreamTrace() { + waitForTraceCount(2) // original + echo + + final rootSpans = this.rootSpans.toList() + final span = rootSpans.find { it.getSpan().resource.contains('/api_security/http_client') } + span.metrics['_dd.appsec.downstream_request'] == 1 + + return span + } + + private static List httpClientDownstreamAnalysisVariants() { + return ['okHttp2'] + } + private static byte[] unzip(final String text) { final inflaterStream = new GZIPInputStream(new ByteArrayInputStream(text.decodeBase64())) return inflaterStream.getBytes() diff --git a/dd-smoke-tests/appsec/src/main/groovy/datadog/smoketest/appsec/AbstractAppSecServerSmokeTest.groovy b/dd-smoke-tests/appsec/src/main/groovy/datadog/smoketest/appsec/AbstractAppSecServerSmokeTest.groovy index 7c47625ee8f..01bca5a83f6 100644 --- a/dd-smoke-tests/appsec/src/main/groovy/datadog/smoketest/appsec/AbstractAppSecServerSmokeTest.groovy +++ b/dd-smoke-tests/appsec/src/main/groovy/datadog/smoketest/appsec/AbstractAppSecServerSmokeTest.groovy @@ -53,7 +53,9 @@ abstract class AbstractAppSecServerSmokeTest extends AbstractServerSmokeTest { "-Ddd.appsec.waf.timeout=300000", "-DPOWERWAF_EXIT_ON_LEAK=true", // disable AppSec rate limit - "-Ddd.appsec.trace.rate.limit=-1" + "-Ddd.appsec.trace.rate.limit=-1", + // disable http client sampling + "-Ddd.api-security.downstream.request.analysis.sample_rate=1" ] + (System.getProperty('smoke_test.appsec.enabled') == 'inactive' ? // enable remote config so that appsec is partially enabled (rc is now enabled by default) [ diff --git a/dd-trace-api/src/main/java/datadog/trace/api/ConfigDefaults.java b/dd-trace-api/src/main/java/datadog/trace/api/ConfigDefaults.java index ad94c6bed52..0affaa3c784 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/ConfigDefaults.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/ConfigDefaults.java @@ -120,11 +120,14 @@ public final class ConfigDefaults { // TODO: change to true once the RFC is approved static final boolean DEFAULT_API_SECURITY_ENDPOINT_COLLECTION_ENABLED = false; static final int DEFAULT_API_SECURITY_ENDPOINT_COLLECTION_MESSAGE_LIMIT = 300; + static final double DEFAULT_API_SECURITY_DOWNSTREAM_REQUEST_ANALYSIS_SAMPLE_RATE = 0.5D; + static final int DEFAULT_API_SECURITY_MAX_DOWNSTREAM_REQUEST_BODY_ANALYSIS = 1; static final boolean DEFAULT_APPSEC_RASP_ENABLED = true; static final boolean DEFAULT_APPSEC_STACK_TRACE_ENABLED = true; static final int DEFAULT_APPSEC_MAX_STACK_TRACES = 2; static final int DEFAULT_APPSEC_MAX_STACK_TRACE_DEPTH = 32; static final int DEFAULT_APPSEC_MAX_COLLECTED_HEADERS = 50; + static final int DEFAULT_APPSEC_BODY_PARSING_SIZE_LIMIT = 10_000_000; static final String DEFAULT_IAST_ENABLED = "false"; static final boolean DEFAULT_IAST_DEBUG_ENABLED = false; public static final int DEFAULT_IAST_MAX_CONCURRENT_REQUESTS = 4; diff --git a/dd-trace-api/src/main/java/datadog/trace/api/config/AppSecConfig.java b/dd-trace-api/src/main/java/datadog/trace/api/config/AppSecConfig.java index e65fbbfbf07..ed519f62817 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/config/AppSecConfig.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/config/AppSecConfig.java @@ -23,6 +23,7 @@ public final class AppSecConfig { "appsec.automated-user-events-tracking"; public static final String APPSEC_AUTO_USER_INSTRUMENTATION_MODE = "appsec.auto-user-instrumentation-mode"; + public static final String APPSEC_BODY_PARSING_SIZE_LIMIT = "appsec.body-parsing-size-limit"; public static final String API_SECURITY_ENABLED = "api-security.enabled"; public static final String API_SECURITY_ENABLED_EXPERIMENTAL = "experimental.api-security.enabled"; @@ -31,6 +32,10 @@ public final class AppSecConfig { "api-security.endpoint.collection.enabled"; public static final String API_SECURITY_ENDPOINT_COLLECTION_MESSAGE_LIMIT = "api-security.endpoint.collection.message.limit"; + public static final String API_SECURITY_DOWNSTREAM_REQUEST_ANALYSIS_SAMPLE_RATE = + "api-security.downstream.request.analysis.sample_rate"; + public static final String API_SECURITY_MAX_DOWNSTREAM_REQUEST_BODY_ANALYSIS = + "api-security.max.downstream.request.body.analysis"; public static final String APPSEC_SCA_ENABLED = "appsec.sca.enabled"; public static final String APPSEC_RASP_ENABLED = "appsec.rasp.enabled"; diff --git a/internal-api/build.gradle.kts b/internal-api/build.gradle.kts index c4f0dfd019f..0691f356bb3 100644 --- a/internal-api/build.gradle.kts +++ b/internal-api/build.gradle.kts @@ -65,6 +65,10 @@ val excludedClassesCoverage by extra( // These are almost fully abstract classes so nothing to test "datadog.trace.api.profiling.RecordingData", "datadog.trace.api.appsec.AppSecEventTracker", + // POJOs + "datadog.trace.api.appsec.HttpClientPayload", + "datadog.trace.api.appsec.HttpClientRequest", + "datadog.trace.api.appsec.HttpClientResponse", // A plain enum "datadog.trace.api.profiling.RecordingType", // Data Streams Monitoring diff --git a/internal-api/src/main/java/datadog/trace/api/Config.java b/internal-api/src/main/java/datadog/trace/api/Config.java index 9903cdd92fd..2043e2c82e3 100644 --- a/internal-api/src/main/java/datadog/trace/api/Config.java +++ b/internal-api/src/main/java/datadog/trace/api/Config.java @@ -7,10 +7,13 @@ import static datadog.trace.api.ConfigDefaults.DEFAULT_AGENT_TIMEOUT; import static datadog.trace.api.ConfigDefaults.DEFAULT_AGENT_WRITER_TYPE; import static datadog.trace.api.ConfigDefaults.DEFAULT_ANALYTICS_SAMPLE_RATE; +import static datadog.trace.api.ConfigDefaults.DEFAULT_API_SECURITY_DOWNSTREAM_REQUEST_ANALYSIS_SAMPLE_RATE; import static datadog.trace.api.ConfigDefaults.DEFAULT_API_SECURITY_ENABLED; import static datadog.trace.api.ConfigDefaults.DEFAULT_API_SECURITY_ENDPOINT_COLLECTION_ENABLED; import static datadog.trace.api.ConfigDefaults.DEFAULT_API_SECURITY_ENDPOINT_COLLECTION_MESSAGE_LIMIT; +import static datadog.trace.api.ConfigDefaults.DEFAULT_API_SECURITY_MAX_DOWNSTREAM_REQUEST_BODY_ANALYSIS; import static datadog.trace.api.ConfigDefaults.DEFAULT_API_SECURITY_SAMPLE_DELAY; +import static datadog.trace.api.ConfigDefaults.DEFAULT_APPSEC_BODY_PARSING_SIZE_LIMIT; import static datadog.trace.api.ConfigDefaults.DEFAULT_APPSEC_MAX_COLLECTED_HEADERS; import static datadog.trace.api.ConfigDefaults.DEFAULT_APPSEC_MAX_STACK_TRACES; import static datadog.trace.api.ConfigDefaults.DEFAULT_APPSEC_MAX_STACK_TRACE_DEPTH; @@ -181,13 +184,16 @@ import static datadog.trace.api.DDTags.SCHEMA_VERSION_TAG_KEY; import static datadog.trace.api.DDTags.SERVICE; import static datadog.trace.api.DDTags.SERVICE_TAG; +import static datadog.trace.api.config.AppSecConfig.API_SECURITY_DOWNSTREAM_REQUEST_ANALYSIS_SAMPLE_RATE; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_ENABLED; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_ENABLED_EXPERIMENTAL; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_ENDPOINT_COLLECTION_ENABLED; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_ENDPOINT_COLLECTION_MESSAGE_LIMIT; +import static datadog.trace.api.config.AppSecConfig.API_SECURITY_MAX_DOWNSTREAM_REQUEST_BODY_ANALYSIS; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_SAMPLE_DELAY; import static datadog.trace.api.config.AppSecConfig.APPSEC_AUTOMATED_USER_EVENTS_TRACKING; import static datadog.trace.api.config.AppSecConfig.APPSEC_AUTO_USER_INSTRUMENTATION_MODE; +import static datadog.trace.api.config.AppSecConfig.APPSEC_BODY_PARSING_SIZE_LIMIT; import static datadog.trace.api.config.AppSecConfig.APPSEC_COLLECT_ALL_HEADERS; import static datadog.trace.api.config.AppSecConfig.APPSEC_HEADER_COLLECTION_REDACTION_ENABLED; import static datadog.trace.api.config.AppSecConfig.APPSEC_HTTP_BLOCKED_TEMPLATE_HTML; @@ -943,10 +949,13 @@ public static String getHostName() { private final boolean appSecHeaderCollectionRedactionEnabled; private final int appSecMaxCollectedHeaders; private final boolean appSecRaspCollectRequestBody; + private final int appSecBodyParsingSizeLimit; private final boolean apiSecurityEnabled; private final float apiSecuritySampleDelay; private final boolean apiSecurityEndpointCollectionEnabled; private final int apiSecurityEndpointCollectionMessageLimit; + private final int apiSecurityMaxDownstreamRequestBodyAnalysis; + private final double apiSecurityDownstreamRequestAnalysisSampleRate; private final IastDetectionMode iastDetectionMode; private final int iastMaxConcurrentRequests; @@ -2091,6 +2100,9 @@ PROFILING_DATADOG_PROFILER_ENABLED, isDatadogProfilerSafeInCurrentEnvironment()) APPSEC_MAX_COLLECTED_HEADERS, DEFAULT_APPSEC_MAX_COLLECTED_HEADERS); appSecRaspCollectRequestBody = configProvider.getBoolean(APPSEC_RASP_COLLECT_REQUEST_BODY, false); + appSecBodyParsingSizeLimit = + configProvider.getInteger( + APPSEC_BODY_PARSING_SIZE_LIMIT, DEFAULT_APPSEC_BODY_PARSING_SIZE_LIMIT); apiSecurityEnabled = configProvider.getBoolean( API_SECURITY_ENABLED, DEFAULT_API_SECURITY_ENABLED, API_SECURITY_ENABLED_EXPERIMENTAL); @@ -2104,6 +2116,14 @@ PROFILING_DATADOG_PROFILER_ENABLED, isDatadogProfilerSafeInCurrentEnvironment()) configProvider.getInteger( API_SECURITY_ENDPOINT_COLLECTION_MESSAGE_LIMIT, DEFAULT_API_SECURITY_ENDPOINT_COLLECTION_MESSAGE_LIMIT); + apiSecurityMaxDownstreamRequestBodyAnalysis = + configProvider.getInteger( + API_SECURITY_MAX_DOWNSTREAM_REQUEST_BODY_ANALYSIS, + DEFAULT_API_SECURITY_MAX_DOWNSTREAM_REQUEST_BODY_ANALYSIS); + apiSecurityDownstreamRequestAnalysisSampleRate = + configProvider.getDouble( + API_SECURITY_DOWNSTREAM_REQUEST_ANALYSIS_SAMPLE_RATE, + DEFAULT_API_SECURITY_DOWNSTREAM_REQUEST_ANALYSIS_SAMPLE_RATE); iastDebugEnabled = configProvider.getBoolean(IAST_DEBUG_ENABLED, DEFAULT_IAST_DEBUG_ENABLED); @@ -3588,6 +3608,14 @@ public int getApiSecurityEndpointCollectionMessageLimit() { return apiSecurityEndpointCollectionMessageLimit; } + public int getApiSecurityMaxDownstreamRequestBodyAnalysis() { + return apiSecurityMaxDownstreamRequestBodyAnalysis; + } + + public double getApiSecurityDownstreamRequestAnalysisSampleRate() { + return apiSecurityDownstreamRequestAnalysisSampleRate; + } + public boolean isApiSecurityEndpointCollectionEnabled() { return apiSecurityEndpointCollectionEnabled; } @@ -5087,6 +5115,10 @@ public boolean isAppSecRaspCollectRequestBody() { return appSecRaspCollectRequestBody; } + public int getAppSecBodyParsingSizeLimit() { + return appSecBodyParsingSizeLimit; + } + public boolean isCloudPayloadTaggingEnabledFor(String serviceName) { return cloudPayloadTaggingServices.contains(serviceName); } diff --git a/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientPayload.java b/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientPayload.java new file mode 100644 index 00000000000..9be785a1037 --- /dev/null +++ b/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientPayload.java @@ -0,0 +1,39 @@ +package datadog.trace.api.appsec; + +import java.io.InputStream; +import java.util.List; +import java.util.Map; + +public abstract class HttpClientPayload { + + private final long requestId; + private final Map> headers; + private MediaType contentType; + private InputStream body; + + protected HttpClientPayload(final long requestId, final Map> headers) { + this.requestId = requestId; + this.headers = headers; + } + + public long getRequestId() { + return requestId; + } + + public MediaType getContentType() { + return contentType; + } + + public Map> getHeaders() { + return headers; + } + + public InputStream getBody() { + return body; + } + + public void setBody(MediaType contentType, InputStream body) { + this.contentType = contentType; + this.body = body; + } +} diff --git a/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientRequest.java b/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientRequest.java new file mode 100644 index 00000000000..3ba862be58f --- /dev/null +++ b/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientRequest.java @@ -0,0 +1,32 @@ +package datadog.trace.api.appsec; + +import java.util.List; +import java.util.Map; + +public class HttpClientRequest extends HttpClientPayload { + + private final String url; + private final String method; + + public HttpClientRequest(final long id, final String url) { + this(id, url, null, null); + } + + public HttpClientRequest( + final long id, + final String url, + final String method, + final Map> headers) { + super(id, headers); + this.url = url; + this.method = method; + } + + public String getUrl() { + return url; + } + + public String getMethod() { + return method; + } +} diff --git a/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientResponse.java b/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientResponse.java new file mode 100644 index 00000000000..0c26b96402a --- /dev/null +++ b/internal-api/src/main/java/datadog/trace/api/appsec/HttpClientResponse.java @@ -0,0 +1,19 @@ +package datadog.trace.api.appsec; + +import java.util.List; +import java.util.Map; + +public class HttpClientResponse extends HttpClientPayload { + + private final int status; + + public HttpClientResponse( + final long requestId, final int status, final Map> headers) { + super(requestId, headers); + this.status = status; + } + + public int getStatus() { + return status; + } +} diff --git a/internal-api/src/main/java/datadog/trace/api/appsec/MediaType.java b/internal-api/src/main/java/datadog/trace/api/appsec/MediaType.java new file mode 100644 index 00000000000..979562160c8 --- /dev/null +++ b/internal-api/src/main/java/datadog/trace/api/appsec/MediaType.java @@ -0,0 +1,87 @@ +package datadog.trace.api.appsec; + +import java.util.Locale; + +public class MediaType { + + public static final MediaType UNKNOWN = new MediaType(null, null, null); + + private final String type; + private final String subtype; + private final String charset; + + private MediaType(final String type, final String subtype, final String charset) { + this.type = type; + this.subtype = subtype; + this.charset = charset; + } + + public String getType() { + return type; + } + + public String getSubtype() { + return subtype; + } + + public String getCharset() { + return charset; + } + + public boolean isJson() { + return subtype != null && ("json".equals(subtype) || subtype.endsWith("+json")); + } + + public boolean isDeserializable() { + // TODO add other supported types + return isJson(); + } + + @Override + public String toString() { + String contentType = type + "/" + subtype; + if (charset != null) { + contentType += "; charset=" + charset; + } + return contentType; + } + + public static MediaType parse(final String header) { + if (header == null) { + return UNKNOWN; + } + final String mediaType = header.trim().toLowerCase(Locale.ROOT); + final int semicolonIndex = mediaType.indexOf(';'); + String contentType, charset = null; + if (semicolonIndex != -1) { + contentType = mediaType.substring(0, semicolonIndex); + final String parameter = mediaType.substring(semicolonIndex + 1); + final int charsetIndex = parameter.indexOf("charset="); + if (charsetIndex != -1) { + charset = parameter.substring(charsetIndex + 8); + } + } else { + contentType = mediaType; + } + final int slashIndex = contentType.indexOf('/'); + if (slashIndex != -1) { + final String type = contentType.substring(0, slashIndex); + final String subtype = contentType.substring(slashIndex + 1); + return create(type, subtype, charset); + } else { + return create(contentType, null, charset); + } + } + + public static MediaType create(final String type, final String subtype, final String charset) { + return new MediaType(trimValue(type), trimValue(subtype), trimValue(charset)); + } + + public static String trimValue(String value) { + if (value == null) { + return null; + } + value = value.trim(); + return value.isEmpty() ? null : value; + } +} diff --git a/internal-api/src/main/java/datadog/trace/api/gateway/Events.java b/internal-api/src/main/java/datadog/trace/api/gateway/Events.java index 41d32658d4b..394459bc7c8 100644 --- a/internal-api/src/main/java/datadog/trace/api/gateway/Events.java +++ b/internal-api/src/main/java/datadog/trace/api/gateway/Events.java @@ -1,5 +1,7 @@ package datadog.trace.api.gateway; +import datadog.trace.api.appsec.HttpClientRequest; +import datadog.trace.api.appsec.HttpClientResponse; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.function.TriFunction; import datadog.trace.api.http.StoredBodySupplier; @@ -237,16 +239,17 @@ public EventType>> grpcServerMetho return (EventType>>) GRPC_SERVER_METHOD; } - static final int NETWORK_CONNECTION_ID = 19; + static final int HTTP_CLIENT_REQUEST_ID = 19; @SuppressWarnings("rawtypes") - private static final EventType NETWORK_CONNECTION = - new ET<>("network.connection", NETWORK_CONNECTION_ID); + private static final EventType HTTP_CLIENT_REQUEST = + new ET<>("http.client.request", HTTP_CLIENT_REQUEST_ID); - /** A I/O network URL */ + /** An http downstream request */ @SuppressWarnings("unchecked") - public EventType>> networkConnection() { - return (EventType>>) NETWORK_CONNECTION; + public EventType>> httpClientRequest() { + return (EventType>>) + HTTP_CLIENT_REQUEST; } static final int FILE_LOADED_ID = 20; @@ -334,6 +337,32 @@ public EventType>> responseBody() return (EventType>>) RESPONSE_BODY; } + static final int HTTP_CLIENT_RESPONSE_ID = 28; + + @SuppressWarnings("rawtypes") + private static final EventType HTTP_CLIENT_RESPONSE = + new ET<>("http.client.response", HTTP_CLIENT_RESPONSE_ID); + + /** An http downstream response */ + @SuppressWarnings("unchecked") + public EventType>> + httpClientResponse() { + return (EventType>>) + HTTP_CLIENT_RESPONSE; + } + + static final int HTTP_CLIENT_SAMPLING_ID = 29; + + @SuppressWarnings("rawtypes") + private static final EventType HTTP_CLIENT_SAMPLING = + new ET<>("http.client.sampling", HTTP_CLIENT_SAMPLING_ID); + + /** Check sampling status for a downstream request */ + @SuppressWarnings("unchecked") + public EventType>> httpClientSampling() { + return (EventType>>) HTTP_CLIENT_SAMPLING; + } + static final int MAX_EVENTS = nextId.get(); private static final class ET extends EventType { diff --git a/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java b/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java index d8ba93910a3..6a1cc258de1 100644 --- a/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java +++ b/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java @@ -7,10 +7,12 @@ import static datadog.trace.api.gateway.Events.GRAPHQL_SERVER_REQUEST_MESSAGE_ID; import static datadog.trace.api.gateway.Events.GRPC_SERVER_METHOD_ID; import static datadog.trace.api.gateway.Events.GRPC_SERVER_REQUEST_MESSAGE_ID; +import static datadog.trace.api.gateway.Events.HTTP_CLIENT_REQUEST_ID; +import static datadog.trace.api.gateway.Events.HTTP_CLIENT_RESPONSE_ID; +import static datadog.trace.api.gateway.Events.HTTP_CLIENT_SAMPLING_ID; import static datadog.trace.api.gateway.Events.HTTP_ROUTE_ID; import static datadog.trace.api.gateway.Events.LOGIN_EVENT_ID; import static datadog.trace.api.gateway.Events.MAX_EVENTS; -import static datadog.trace.api.gateway.Events.NETWORK_CONNECTION_ID; import static datadog.trace.api.gateway.Events.REQUEST_BODY_CONVERTED_ID; import static datadog.trace.api.gateway.Events.REQUEST_BODY_DONE_ID; import static datadog.trace.api.gateway.Events.REQUEST_BODY_START_ID; @@ -30,6 +32,7 @@ import static datadog.trace.api.gateway.Events.SHELL_CMD_ID; import static datadog.trace.api.gateway.Events.USER_ID; +import datadog.trace.api.appsec.HttpClientPayload; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.function.TriFunction; import datadog.trace.api.http.StoredBodySupplier; @@ -433,7 +436,6 @@ public Flow apply(RequestContext ctx, String arg) { } }; case DATABASE_SQL_QUERY_ID: - case NETWORK_CONNECTION_ID: case FILE_LOADED_ID: case SHELL_CMD_ID: return (C) @@ -449,6 +451,35 @@ public Flow apply(RequestContext ctx, String arg) { } } }; + case HTTP_CLIENT_REQUEST_ID: + case HTTP_CLIENT_RESPONSE_ID: + return (C) + new BiFunction>() { + @Override + public Flow apply(RequestContext ctx, HttpClientPayload arg) { + try { + return ((BiFunction>) callback) + .apply(ctx, arg); + } catch (Throwable t) { + log.warn("Callback for {} threw.", eventType, t); + return Flow.ResultFlow.empty(); + } + } + }; + case HTTP_CLIENT_SAMPLING_ID: + return (C) + new BiFunction>() { + @Override + public Flow apply(RequestContext ctx, Long requestId) { + try { + return ((BiFunction>) callback) + .apply(ctx, requestId); + } catch (Throwable t) { + log.warn("Callback for {} threw.", eventType, t); + return Flow.ResultFlow.empty(); + } + } + }; case EXEC_CMD_ID: return (C) new BiFunction>() { diff --git a/internal-api/src/test/groovy/datadog/trace/api/appsec/MediaTypeSpecification.groovy b/internal-api/src/test/groovy/datadog/trace/api/appsec/MediaTypeSpecification.groovy new file mode 100644 index 00000000000..b0d7b9462ef --- /dev/null +++ b/internal-api/src/test/groovy/datadog/trace/api/appsec/MediaTypeSpecification.groovy @@ -0,0 +1,64 @@ +package datadog.trace.api.appsec + + +import spock.lang.Specification + +class MediaTypeSpecification extends Specification { + + void 'test media type parsing'() { + when: + final media = MediaType.parse(header) + + then: + media.type == type + media.subtype == subtype + media.charset == charset + media.json == json + media.deserializable == deserializable + + where: + header | type | subtype | charset | json | deserializable + // Standard JSON types + 'application/json' | 'application' | 'json' | null | true | true + 'text/json' | 'text' | 'json' | null | true | true + + // JSON subtypes + 'application/vnd.api+json' | 'application' | 'vnd.api+json' | null | true | true + 'application/ld+json' | 'application' | 'ld+json' | null | true | true + 'application/hal+json' | 'application' | 'hal+json' | null | true | true + 'application/problem+json' | 'application' | 'problem+json' | null | true | true + 'application/merge-patch+json' | 'application' | 'merge-patch+json' | null | true | true + 'application/json-patch+json' | 'application' | 'json-patch+json' | null | true | true + + // With parameters + 'application/json; charset=utf-8' | 'application' | 'json' | 'utf-8' | true | true + 'application/json;charset=UTF-8' | 'application' | 'json' | 'utf-8' | true | true + 'text/json; charset=iso-8859-1' | 'text' | 'json' | 'iso-8859-1' | true | true + 'application/vnd.api+json; charset=utf-8' | 'application' | 'vnd.api+json' | 'utf-8' | true | true + + // Case variations + 'APPLICATION/JSON' | 'application' | 'json' | null | true | true + 'Text/Json' | 'text' | 'json' | null | true | true + 'application/VND.API+JSON' | 'application' | 'vnd.api+json' | null | true | true + + // With whitespace + ' application/json ' | 'application' | 'json' | null | true | true + 'application/json ; charset=utf-8' | 'application' | 'json' | 'utf-8' | true | true + ' text/json; charset=utf-8 ' | 'text' | 'json' | 'utf-8' | true | true + + // Non-JSON types + 'application/xml' | 'application' | 'xml' | null | false | false + 'text/html' | 'text' | 'html' | null | false | false + 'text/plain' | 'text' | 'plain' | null | false | false + 'application/pdf' | 'application' | 'pdf' | null | false | false + 'image/png' | 'image' | 'png' | null | false | false + 'application/octet-stream' | 'application' | 'octet-stream' | null | false | false + + // Edge cases + null | null | null | null | false | false + '' | null | null | null | false | false + ' ' | null | null | null | false | false + 'json' | 'json' | null | null | false | false + 'application/' | 'application' | null | null | false | false + } +} diff --git a/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java b/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java index 1d8d1cbf01b..a0fb9794c52 100644 --- a/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java +++ b/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java @@ -216,8 +216,12 @@ public void testNormalCalls() { cbp.getCallback(events.databaseConnection()).accept(null, null); ss.registerCallback(events.databaseSqlQuery(), callback); cbp.getCallback(events.databaseSqlQuery()).apply(null, null); - ss.registerCallback(events.networkConnection(), callback); - cbp.getCallback(events.networkConnection()).apply(null, null); + ss.registerCallback(events.httpClientRequest(), callback); + cbp.getCallback(events.httpClientRequest()).apply(null, null); + ss.registerCallback(events.httpClientResponse(), callback); + cbp.getCallback(events.httpClientResponse()).apply(null, null); + ss.registerCallback(events.httpClientSampling(), callback); + cbp.getCallback(events.httpClientSampling()).apply(null, null); ss.registerCallback(events.fileLoaded(), callback); cbp.getCallback(events.fileLoaded()).apply(null, null); ss.registerCallback(events.user(), callback); @@ -298,8 +302,12 @@ public void testThrowableBlocking() { cbp.getCallback(events.databaseConnection()).accept(null, null); ss.registerCallback(events.databaseSqlQuery(), throwback); cbp.getCallback(events.databaseSqlQuery()).apply(null, null); - ss.registerCallback(events.networkConnection(), throwback); - cbp.getCallback(events.networkConnection()).apply(null, null); + ss.registerCallback(events.httpClientRequest(), throwback); + cbp.getCallback(events.httpClientRequest()).apply(null, null); + ss.registerCallback(events.httpClientResponse(), throwback); + cbp.getCallback(events.httpClientResponse()).apply(null, null); + ss.registerCallback(events.httpClientSampling(), throwback); + cbp.getCallback(events.httpClientSampling()).apply(null, null); ss.registerCallback(events.fileLoaded(), throwback); cbp.getCallback(events.fileLoaded()).apply(null, null); ss.registerCallback(events.user(), throwback);