From ddbbdc4969cc8c35a5c57a93ce8b4f2002897860 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 23 Jan 2026 17:33:58 +0100 Subject: [PATCH 1/5] send lambda data to the WAF for analysis --- .../datadog/trace/lambda/LambdaHandler.java | 739 +++++++++++++++++- 1 file changed, 729 insertions(+), 10 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java index 4ee2add2940..728910042fe 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java @@ -1,5 +1,6 @@ package datadog.trace.lambda; +import static datadog.trace.api.gateway.Events.EVENTS; import static datadog.trace.bootstrap.instrumentation.api.AgentPropagation.extractContextAndGetSpanContext; import static java.util.concurrent.TimeUnit.SECONDS; @@ -7,12 +8,31 @@ import com.squareup.moshi.Moshi; import datadog.trace.api.DDSpanId; import datadog.trace.api.DDTags; +import datadog.trace.api.function.TriConsumer; +import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.IGSpanInfo; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.api.internal.TraceSegment; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; +import datadog.trace.bootstrap.instrumentation.api.TagContext; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; import java.nio.charset.StandardCharsets; import java.util.Base64; +import java.util.Collections; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase; import okhttp3.ConnectionPool; import okhttp3.MediaType; import okhttp3.OkHttpClient; @@ -73,6 +93,21 @@ public class LambdaHandler { private static String EXTENSION_BASE_URL = "http://127.0.0.1:8124"; public static AgentSpanContext notifyStartInvocation(Object event, String lambdaRequestId) { + AgentTracer.TracerAPI tracer = AgentTracer.get(); + AgentSpanContext extractedContext = null; + + // Extract headers and call AppSec gateway events + if (event instanceof ByteArrayInputStream) { + try { + LambdaEventData eventData = extractEventData((ByteArrayInputStream) event); + extractedContext = callIGCallbackStart(tracer, eventData); + } catch (Exception e) { + log.error("Failed to extract data from event stream", e); + } + } else { + log.debug("Event is not a ByteArrayInputStream, type: {}", event != null ? event.getClass().getName() : "null"); + } + RequestBody body = RequestBody.create(jsonMediaType, writeValueAsString(event)); try (Response response = HTTP_CLIENT @@ -85,14 +120,19 @@ public static AgentSpanContext notifyStartInvocation(Object event, String lambda .build()) .execute()) { if (response.isSuccessful()) { - - return extractContextAndGetSpanContext( - response.headers(), - (carrier, classifier) -> { - for (String headerName : carrier.names()) { - classifier.accept(headerName, carrier.get(headerName)); - } - }); + AgentSpanContext extensionContext = + extractContextAndGetSpanContext( + response.headers(), + (carrier, classifier) -> { + for (String headerName : carrier.names()) { + classifier.accept(headerName, carrier.get(headerName)); + } + }); + // Merge the AppSec context with the extension context + AgentSpanContext mergedContext = mergeContexts(extensionContext, extractedContext); + return mergedContext; + } else { + log.debug("Extension call failed with status: {}", response.code()); } } catch (Throwable ignored) { log.error("could not reach the extension"); @@ -103,10 +143,22 @@ public static AgentSpanContext notifyStartInvocation(Object event, String lambda public static boolean notifyEndInvocation( AgentSpan span, Object result, boolean isError, String lambdaRequestId) { if (null == span || null == span.getSamplingPriority()) { - log.error( - "could not notify the extension as the lambda span is null or no sampling priority has been found"); + log.error("could not notify the extension as the lambda span is null or no sampling priority has been found"); return false; } + + // Call requestEnded event + RequestContext requestContext = span.getRequestContext(); + if (requestContext != null) { + AgentTracer.TracerAPI tracer = AgentTracer.get(); + BiFunction> requestEndedCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestEnded()); + if (requestEndedCallback != null) { + requestEndedCallback.apply(requestContext, span); + } + } + RequestBody body = RequestBody.create(jsonMediaType, writeValueAsString(result)); Request.Builder builder = new Request.Builder() @@ -166,4 +218,671 @@ public static String writeValueAsString(Object obj) { public static void setExtensionBaseUrl(String extensionBaseUrl) { EXTENSION_BASE_URL = extensionBaseUrl; } + + private static AgentSpanContext callIGCallbackStart( + AgentTracer.TracerAPI tracer, LambdaEventData eventData) { + Supplier> requestStartedCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestStarted()); + if (requestStartedCallback == null) { + log.debug("requestStarted callback is null"); + return null; + } + + TagContext tagContext = new TagContext(); + Object appSecRequestContext; + + // Call requestStarted + appSecRequestContext = requestStartedCallback.get().getResult(); + tagContext.withRequestContextDataAppSec(appSecRequestContext); + + if (appSecRequestContext != null) { + TemporaryRequestContext requestContext = new TemporaryRequestContext(appSecRequestContext); + + // Call requestMethodUriRaw + if (eventData.method != null && eventData.path != null) { + datadog.trace.api.function.TriFunction> methodUriCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); + if (methodUriCallback != null) { + LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(eventData.path); + methodUriCallback.apply(requestContext, eventData.method, uriAdapter); + } else { + log.debug("requestMethodUriRaw callback is null"); + } + } + + // Call requestHeader for each header + if (eventData.headers != null && !eventData.headers.isEmpty()) { + TriConsumer headerCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestHeader()); + if (headerCallback != null) { + for (Map.Entry header : eventData.headers.entrySet()) { + headerCallback.accept(requestContext, header.getKey(), header.getValue()); + } + } else { + log.debug("requestHeader callback is null"); + } + } + + // Call requestClientSocketAddress + if (eventData.sourceIp != null) { + datadog.trace.api.function.TriFunction> socketAddrCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestClientSocketAddress()); + if (socketAddrCallback != null) { + Integer port = eventData.sourcePort != null ? eventData.sourcePort : 0; + socketAddrCallback.apply(requestContext, eventData.sourceIp, port); + } else { + log.debug("requestClientSocketAddress callback is null"); + } + } + + // Call requestHeaderDone + Function> headerDoneCallback = + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestHeaderDone()); + if (headerDoneCallback != null) { + headerDoneCallback.apply(requestContext); + } else { + log.debug("requestHeaderDone callback is null"); + } + + // Call requestPathParams + if (eventData.pathParameters != null && !eventData.pathParameters.isEmpty()) { + BiFunction, Flow> pathParamsCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestPathParams()); + if (pathParamsCallback != null) { + pathParamsCallback.apply(requestContext, eventData.pathParameters); + } else { + log.debug("requestPathParams callback is null"); + } + } + + // Call requestBodyProcessed + if (eventData.body != null) { + BiFunction> bodyCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestBodyProcessed()); + if (bodyCallback != null) { + bodyCallback.apply(requestContext, eventData.body); + } else { + log.debug("requestBodyProcessed callback is null"); + } + } + } + return tagContext; + } + + private static AgentSpanContext mergeContexts( + AgentSpanContext extensionContext, AgentSpanContext extractedContext) { + if (extractedContext == null) { + return extensionContext; + } + if (extensionContext == null) { + return extractedContext; + } + + if (extractedContext instanceof TagContext) { + TagContext extracted = (TagContext) extractedContext; + Object appSecData = extracted.getRequestContextDataAppSec(); + Object iastData = extracted.getRequestContextDataIast(); + + if (extensionContext instanceof TagContext) { + TagContext merged = (TagContext) extensionContext; + if (appSecData != null) { + merged.withRequestContextDataAppSec(appSecData); + } + if (iastData != null) { + merged.withRequestContextDataIast(iastData); + } + return merged; + } + + log.warn( + "Cannot merge AppSec data: extension context is not a TagContext: {}", + extensionContext.getClass()); + } + return extensionContext; + } + + private static LambdaEventData extractEventData(ByteArrayInputStream inputStream) + throws IOException { + inputStream.mark(0); + + try { + StringBuilder jsonBuilder = new StringBuilder(inputStream.available()); + try (Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { + char[] buffer = new char[1024]; + int charsRead; + while ((charsRead = reader.read(buffer)) != -1) { + jsonBuilder.append(buffer, 0, charsRead); + } + } + return extractEventDataFromJson(jsonBuilder.toString()); + } finally { + inputStream.reset(); + } + } + + private static LambdaEventData extractEventDataFromJson(String json) { + try { + // Parse JSON into a Map + JsonAdapter adapter = + new Moshi.Builder().build().adapter(Map.class); + + Map event = adapter.fromJson(json); + log.debug("Event JSON parsed successfully"); + + if (event == null) { + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + } + + // Detect trigger type + LambdaTriggerType triggerType = detectTriggerType(event); + log.debug("Detected Lambda trigger type: {}", triggerType); + + // Extract data based on trigger type + switch (triggerType) { + case API_GATEWAY_V1_REST: + return extractApiGatewayV1Data(event); + case API_GATEWAY_V2_HTTP: + case LAMBDA_URL: + return extractApiGatewayV2HttpData(event, triggerType); + case API_GATEWAY_V2_WEBSOCKET: + return extractApiGatewayV2WebSocketData(event); + case ALB: + case ALB_MULTI_VALUE: + return extractAlbData(event, triggerType); + default: + log.debug("Unknown trigger type, attempting generic extraction"); + return extractGenericData(event); + } + } catch (Exception e) { + log.error("Failed to parse event data from JSON", e); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + } + } + + private static LambdaTriggerType detectTriggerType(Map event) { + Object requestContextObj = event.get("requestContext"); + + if (requestContextObj instanceof Map) { + Map requestContext = (Map) requestContextObj; + + // Check for ALB trigger (has elb object) + if (requestContext.containsKey("elb")) { + // Check if event has multiValueHeaders + if (event.containsKey("multiValueHeaders")) { + return LambdaTriggerType.ALB_MULTI_VALUE; + } + return LambdaTriggerType.ALB; + } + + // Check for WebSocket + if (requestContext.containsKey("connectionId") && + (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { + return LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET; + } + + // Check for API Gateway v2 format + Object httpObj = requestContext.get("http"); + if (httpObj instanceof Map) { + Object domainNameObj = requestContext.get("domainName"); + if (domainNameObj instanceof String) { + String domainName = (String) domainNameObj; + if (domainName.contains("lambda-url")) { + return LambdaTriggerType.LAMBDA_URL; + } else { + return LambdaTriggerType.API_GATEWAY_V2_HTTP; + } + } else { + return LambdaTriggerType.LAMBDA_URL; + } + } + + // Check for API Gateway v1 REST API + if (requestContext.containsKey("httpMethod") || requestContext.containsKey("requestId")) { + return LambdaTriggerType.API_GATEWAY_V1_REST; + } + } + return LambdaTriggerType.UNKNOWN; + } + + /** + * Extracts data from API Gateway v1 (REST API) event + */ + private static LambdaEventData extractApiGatewayV1Data(Map event) { + Map headers = extractHeaders(event.get("headers")); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + Map requestContext = (Map) event.get("requestContext"); + String method = (String) requestContext.get("httpMethod"); + String path = (String) event.get("path"); + + String sourceIp = null; + Object identityObj = requestContext.get("identity"); + if (identityObj instanceof Map) { + Map identity = (Map) identityObj; + sourceIp = (String) identity.get("sourceIp"); + } + + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, body); + } + + /** + * Extracts data from API Gateway v2 (HTTP API) or Lambda URL event + */ + private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { + Map headers = extractHeadersWithCookies(event); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + Map requestContext = (Map) event.get("requestContext"); + Map http = (Map) requestContext.get("http"); + + String method = (String) http.get("method"); + String path = (String) http.get("path"); + String sourceIp = (String) http.get("sourceIp"); + + // Extract port if available + Integer sourcePort = null; + Object portObj = http.get("sourcePort"); + if (portObj instanceof Number) { + sourcePort = ((Number) portObj).intValue(); + } + + return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, body); + } + + /** + * Extracts data from API Gateway v2 WebSocket event + */ + private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { + Map headers = extractHeadersWithCookies(event); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + Map requestContext = (Map) event.get("requestContext"); + + String method = "WEBSOCKET"; + String routeKey = (String) requestContext.get("routeKey"); + String path = routeKey != null ? routeKey : "/"; + + String sourceIp = null; + Object identityObj = requestContext.get("identity"); + if (identityObj instanceof Map) { + Map identity = (Map) identityObj; + sourceIp = (String) identity.get("sourceIp"); + } + + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, pathParameters, body); + } + + /** + * Extracts data from ALB event (with or without multi-value headers) + */ + private static LambdaEventData extractAlbData(Map event, LambdaTriggerType triggerType) { + Map headers; + + if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { + // Handle multi-value headers (combine multiple values with comma) + headers = new java.util.HashMap<>(); + Object multiValueHeadersObj = event.get("multiValueHeaders"); + if (multiValueHeadersObj instanceof Map) { + Map rawHeaders = (Map) multiValueHeadersObj; + for (Map.Entry entry : rawHeaders.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + if (entry.getValue() instanceof java.util.List) { + java.util.List values = (java.util.List) entry.getValue(); + // Join multiple values with comma + String joinedValue = values.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining(", ")); + headers.put(key, joinedValue); + } else { + headers.put(key, String.valueOf(entry.getValue())); + } + } + } + } + } else { + headers = extractHeaders(event.get("headers")); + } + + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + String method = (String) event.get("httpMethod"); + String path = (String) event.get("path"); + String sourceIp = headers.get("x-forwarded-for"); + + return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, body); + } + + /** + * Generic data extraction for unknown trigger types (fallback) + */ + private static LambdaEventData extractGenericData(Map event) { + Map headers = extractHeadersWithCookies(event); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + String method = null; + String path = null; + String sourceIp = null; + + // Try to extract from requestContext if available + Object requestContextObj = event.get("requestContext"); + if (requestContextObj instanceof Map) { + Map requestContext = (Map) requestContextObj; + + Object httpObj = requestContext.get("http"); + if (httpObj instanceof Map) { + Map http = (Map) httpObj; + method = (String) http.get("method"); + path = (String) http.get("path"); + sourceIp = (String) http.get("sourceIp"); + } else { + Object methodObj = requestContext.get("httpMethod"); + if (methodObj != null) { + method = String.valueOf(methodObj); + } + + Object identityObj = requestContext.get("identity"); + if (identityObj instanceof Map) { + Map identity = (Map) identityObj; + sourceIp = (String) identity.get("sourceIp"); + } + } + } + + // Try root level fields + if (method == null) { + Object methodObj = event.get("httpMethod"); + if (methodObj != null) { + method = String.valueOf(methodObj); + } + } + if (path == null) { + Object pathObj = event.get("path"); + if (pathObj != null) { + path = String.valueOf(pathObj); + } + } + + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, body); + } + + /** + * Generic helper method to extract string key-value pairs from an object. + * Converts all keys and values to strings, filtering out null entries. + */ + private static Map extractStringMap(Object mapObj) { + Map result = new java.util.HashMap<>(); + if (mapObj instanceof Map) { + Map rawMap = (Map) mapObj; + for (Map.Entry entry : rawMap.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + String value = String.valueOf(entry.getValue()); + result.put(key, value); + } + } + } + return result; + } + + /** + * Helper method to extract headers from event + */ + private static Map extractHeaders(Object headersObj) { + Map headers = extractStringMap(headersObj); + log.debug("Extracted {} headers", headers.size()); + if (headers.containsKey("cookie")) { + log.debug("Cookie header found with value length: {}", headers.get("cookie").length()); + } + return headers; + } + + /** + * Helper method to extract path parameters from event + */ + private static Map extractPathParameters(Object pathParamsObj) { + Map pathParams = extractStringMap(pathParamsObj); + log.debug("Extracted {} path parameters", pathParams.size()); + return pathParams; + } + + /** + * Helper method to extract and merge headers with cookies array from event. + * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. + */ + private static Map extractHeadersWithCookies(Map event) { + Map headers = extractHeaders(event.get("headers")); + + // API Gateway v2 provides a pre-parsed cookies array + Object cookiesObj = event.get("cookies"); + if (cookiesObj instanceof java.util.List) { + java.util.List cookiesList = (java.util.List) cookiesObj; + if (!cookiesList.isEmpty()) { + // Join cookies with "; " separator per RFC 6265 + String cookieValue = cookiesList.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining("; ")); + + // Merge with existing cookie header if present + String existingCookie = headers.get("cookie"); + if (existingCookie != null && !existingCookie.isEmpty()) { + headers.put("cookie", existingCookie + "; " + cookieValue); + } else { + headers.put("cookie", cookieValue); + } + } + } + + return headers; + } + + /** + * Helper method to extract and parse body from event + */ + private static Object extractBody(Map event) { + Object bodyObj = event.get("body"); + if (bodyObj == null) { + return null; + } + + String bodyString = String.valueOf(bodyObj); + + // Check if body is base64 encoded (API Gateway feature) + Boolean isBase64Encoded = (Boolean) event.get("isBase64Encoded"); + if (Boolean.TRUE.equals(isBase64Encoded)) { + try { + bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); + } catch (Exception e) { + log.debug("Failed to decode base64 body", e); + return null; + } + } + + // Try to parse as JSON + Object parsedBody = parseBodyAsJson(bodyString); + if (parsedBody != null) { + log.debug("Body parsed as JSON successfully"); + return parsedBody; + } + + // If not JSON, return the raw string + log.debug("Body is not JSON, returning raw string"); + return bodyString; + } + + /** + * Helper method to parse body as JSON + */ + private static Object parseBodyAsJson(String body) { + if (body == null || body.isEmpty() || "null".equals(body)) { + return null; + } + + try { + JsonAdapter adapter = new Moshi.Builder().build().adapter(Object.class); + Object parsed = adapter.fromJson(body); + return parsed; + } catch (Exception e) { + return null; + } + } + + /** + * Temporary RequestContext implementation to hold AppSecRequestContext + * before a span is created. + */ + private static class TemporaryRequestContext implements RequestContext { + private final Object appSecRequestContext; + + TemporaryRequestContext(Object appSecRequestContext) { + this.appSecRequestContext = appSecRequestContext; + } + + @Override + public T getData(RequestContextSlot slot) { + if (slot == RequestContextSlot.APPSEC) { + return (T) appSecRequestContext; + } + return null; + } + + @Override + public TraceSegment getTraceSegment() { + return TraceSegment.NoOp.INSTANCE; + } + + @Override + public void setBlockResponseFunction(BlockResponseFunction blockResponseFunction) { + // No-op for temporary context + } + + @Override + public BlockResponseFunction getBlockResponseFunction() { + return null; + } + + @Override + public T getOrCreateMetaStructTop(String key, Function defaultValue) { + return null; + } + + @Override + public void close() { + // No-op for temporary context + } + } + + /** + * Enum representing different AWS Lambda trigger types + */ + private enum LambdaTriggerType { + API_GATEWAY_V1_REST, // API Gateway REST API (v1) + API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) + API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket + ALB, // Application Load Balancer + ALB_MULTI_VALUE, // ALB with multi-value headers + LAMBDA_URL, // Lambda Function URL + UNKNOWN // Unknown or unsupported trigger + } + + /** + * Object for Lambda event data needed for AppSec processing + */ + private static class LambdaEventData { + final Map headers; + final String method; + final String path; + final String sourceIp; + final Integer sourcePort; + final LambdaTriggerType triggerType; + final Map pathParameters; + final Object body; + + LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Object body) { + this.headers = headers; + this.method = method; + this.path = path; + this.sourceIp = sourceIp; + this.sourcePort = sourcePort; + this.triggerType = triggerType; + this.pathParameters = pathParameters; + this.body = body; + } + } + + /** + * URIDataAdapter implementation for Lambda events. + */ + private static class LambdaURIDataAdapter extends URIDataAdapterBase { + private final String path; + private final String query; + + LambdaURIDataAdapter(String pathWithQuery) { + if (pathWithQuery != null) { + int queryIndex = pathWithQuery.indexOf('?'); + if (queryIndex != -1) { + this.path = pathWithQuery.substring(0, queryIndex); + this.query = pathWithQuery.substring(queryIndex + 1); + } else { + this.path = pathWithQuery; + this.query = null; + } + } else { + this.path = "/"; + this.query = null; + } + } + + @Override + public String scheme() { + return "https"; + } + + @Override + public String host() { + return null; + } + + @Override + public int port() { + return 443; + } + + @Override + public String path() { + return path; + } + + @Override + public String fragment() { + return null; + } + + @Override + public String query() { + return query; + } + + @Override + public boolean supportsRaw() { + return true; + } + + @Override + public String rawPath() { + return path; + } + + @Override + public String rawQuery() { + return query; + } + } } From ca445038cb2fd50e9f3320ed78a541a57e71679b Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Fri, 30 Jan 2026 16:47:07 +0100 Subject: [PATCH 2/5] refactor + add appsec data to span --- .../lambda/LambdaHandlerInstrumentation.java | 3 +- .../java/datadog/trace/core/CoreTracer.java | 17 +- .../trace/lambda/LambdaAppSecHandler.java | 765 ++++++++++++++++++ .../datadog/trace/lambda/LambdaHandler.java | 732 +---------------- .../instrumentation/api/AgentTracer.java | 9 +- 5 files changed, 796 insertions(+), 730 deletions(-) create mode 100644 dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java index 1f75e292327..0f020c6623a 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java @@ -89,7 +89,7 @@ static AgentScope enter( return null; } String lambdaRequestId = awsContext.getAwsRequestId(); - AgentSpanContext lambdaContext = AgentTracer.get().notifyExtensionStart(in, lambdaRequestId); + AgentSpanContext lambdaContext = AgentTracer.get().notifyLambdaStart(in, lambdaRequestId); final AgentSpan span; if (null == lambdaContext) { span = startSpan(INVOCATION_SPAN_NAME); @@ -123,6 +123,7 @@ static void exit( } String lambdaRequestId = awsContext.getAwsRequestId(); + AgentTracer.get().notifyAppSecEnd(span); span.finish(); AgentTracer.get().notifyExtensionEnd(span, result, null != throwable, lambdaRequestId); } finally { diff --git a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java index 8916e8d757b..c0ce30eafe1 100644 --- a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java +++ b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java @@ -104,6 +104,7 @@ import datadog.trace.core.taginterceptor.RuleFlags; import datadog.trace.core.taginterceptor.TagInterceptor; import datadog.trace.core.traceinterceptor.LatencyTraceInterceptor; +import datadog.trace.lambda.LambdaAppSecHandler; import datadog.trace.lambda.LambdaHandler; import datadog.trace.relocate.api.RatelimitedLogger; import datadog.trace.util.AgentTaskScheduler; @@ -1214,8 +1215,15 @@ public void closeActive() { } @Override - public AgentSpanContext notifyExtensionStart(Object event, String lambdaRequestId) { - return LambdaHandler.notifyStartInvocation(event, lambdaRequestId); + public AgentSpanContext notifyLambdaStart(Object event, String lambdaRequestId) { + // Get context from AppSec + AgentSpanContext appSecContext = LambdaAppSecHandler.processRequestStart(event); + + // Get context from extension + AgentSpanContext extensionContext = LambdaHandler.notifyStartInvocation(event, lambdaRequestId); + + // Merge contexts + return LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); } @Override @@ -1224,6 +1232,11 @@ public void notifyExtensionEnd( LambdaHandler.notifyEndInvocation(span, result, isError, lambdaRequestId); } + @Override + public void notifyAppSecEnd(AgentSpan span) { + LambdaAppSecHandler.processRequestEnd(span); + } + @Override public AgentDataStreamsMonitoring getDataStreamsMonitoring() { return dataStreamsMonitoring; diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java new file mode 100644 index 00000000000..08bf8950817 --- /dev/null +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -0,0 +1,765 @@ +package datadog.trace.lambda; + +import static datadog.trace.api.gateway.Events.EVENTS; + +import com.squareup.moshi.JsonAdapter; +import com.squareup.moshi.Moshi; +import datadog.trace.api.function.TriConsumer; +import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.IGSpanInfo; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.api.internal.TraceSegment; +import datadog.trace.bootstrap.ActiveSubsystems; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.TagContext; +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Collections; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Handles AppSec processing for AWS Lambda invocations. + * Extracts Lambda event data and invokes AppSec gateway callbacks. + */ +public class LambdaAppSecHandler { + + private static final Logger log = LoggerFactory.getLogger(LambdaAppSecHandler.class); + + /** + * Process AppSec request data at the start of a Lambda invocation. + * Extract event data and invokes all relevant AppSec gateway callbacks. + * + * @param event the Lambda event object + * @return AgentSpanContext containing AppSec data, or null if AppSec is disabled or processing fails + */ + public static AgentSpanContext processRequestStart(Object event) { + if (!ActiveSubsystems.APPSEC_ACTIVE) { + log.debug("AppSec is not active, skipping request start processing"); + return null; + } + + if (!(event instanceof ByteArrayInputStream)) { + log.debug("Event is not a ByteArrayInputStream, type: {}", event != null ? event.getClass().getName() : "null"); + return null; + } + + try { + LambdaEventData eventData = extractEventData((ByteArrayInputStream) event); + return processAppSecRequestData(eventData); + } catch (Exception e) { + log.error("Failed to process AppSec request data", e); + return null; + } + } + + /** + * Invokes the requestEnded gateway callback to add AppSec data to the span. + * + * @param span the current span + */ + public static void processRequestEnd(AgentSpan span) { + if (!ActiveSubsystems.APPSEC_ACTIVE || span == null) { + return; + } + + RequestContext requestContext = span.getRequestContext(); + if (requestContext != null) { + AgentTracer.TracerAPI tracer = AgentTracer.get(); + BiFunction> requestEndedCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestEnded()); + if (requestEndedCallback != null) { + requestEndedCallback.apply(requestContext, span); + } else { + log.warn("requestEnded callback is null"); + } + } + } + + /** + * Merge AppSec context data into extension context. + * + * @param extensionContext context from extension + * @param appSecContext context containing AppSec data + * @return merged context + */ + public static AgentSpanContext mergeContexts( + AgentSpanContext extensionContext, AgentSpanContext appSecContext) { + if (appSecContext == null) { + return extensionContext; + } + if (extensionContext == null) { + return appSecContext; + } + + if (appSecContext instanceof TagContext) { + TagContext extracted = (TagContext) appSecContext; + Object appSecData = extracted.getRequestContextDataAppSec(); + Object iastData = extracted.getRequestContextDataIast(); + + if (extensionContext instanceof TagContext) { + TagContext merged = (TagContext) extensionContext; + if (appSecData != null) { + merged.withRequestContextDataAppSec(appSecData); + } + if (iastData != null) { + merged.withRequestContextDataIast(iastData); + } + return merged; + } + + log.warn( + "Cannot merge AppSec data: extension context is not a TagContext: {}", + extensionContext.getClass()); + } + return extensionContext; + } + + private static AgentSpanContext processAppSecRequestData(LambdaEventData eventData) { + AgentTracer.TracerAPI tracer = AgentTracer.get(); + Supplier> requestStartedCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestStarted()); + if (requestStartedCallback == null) { + log.warn("requestStarted callback is null"); + return null; + } + + TagContext tagContext = new TagContext(); + Object appSecRequestContext; + + // Call requestStarted + appSecRequestContext = requestStartedCallback.get().getResult(); + tagContext.withRequestContextDataAppSec(appSecRequestContext); + + if (appSecRequestContext != null) { + TemporaryRequestContext requestContext = new TemporaryRequestContext(appSecRequestContext); + + // Call requestMethodUriRaw + if (eventData.method != null && eventData.path != null) { + datadog.trace.api.function.TriFunction> methodUriCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); + if (methodUriCallback != null) { + LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(eventData.path); + methodUriCallback.apply(requestContext, eventData.method, uriAdapter); + } else { + log.warn("requestMethodUriRaw callback is null"); + } + } + + // Call requestHeader for each header + if (eventData.headers != null && !eventData.headers.isEmpty()) { + TriConsumer headerCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestHeader()); + if (headerCallback != null) { + for (Map.Entry header : eventData.headers.entrySet()) { + headerCallback.accept(requestContext, header.getKey(), header.getValue()); + } + } else { + log.warn("requestHeader callback is null"); + } + } + + // Call requestClientSocketAddress + if (eventData.sourceIp != null) { + datadog.trace.api.function.TriFunction> socketAddrCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestClientSocketAddress()); + if (socketAddrCallback != null) { + Integer port = eventData.sourcePort != null ? eventData.sourcePort : 0; + socketAddrCallback.apply(requestContext, eventData.sourceIp, port); + } else { + log.warn("requestClientSocketAddress callback is null"); + } + } + + // Call requestHeaderDone + Function> headerDoneCallback = + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestHeaderDone()); + if (headerDoneCallback != null) { + headerDoneCallback.apply(requestContext); + } else { + log.warn("requestHeaderDone callback is null"); + } + + // Call requestPathParams + if (eventData.pathParameters != null && !eventData.pathParameters.isEmpty()) { + BiFunction, Flow> pathParamsCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestPathParams()); + if (pathParamsCallback != null) { + pathParamsCallback.apply(requestContext, eventData.pathParameters); + } else { + log.warn("requestPathParams callback is null"); + } + } + + // Call requestBodyProcessed + if (eventData.body != null) { + BiFunction> bodyCallback = + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestBodyProcessed()); + if (bodyCallback != null) { + bodyCallback.apply(requestContext, eventData.body); + } else { + log.warn("requestBodyProcessed callback is null"); + } + } + } + return tagContext; + } + + private static LambdaEventData extractEventData(ByteArrayInputStream inputStream) + throws IOException { + try { + StringBuilder jsonBuilder = new StringBuilder(inputStream.available()); + try (Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { + char[] buffer = new char[1024]; + int charsRead; + while ((charsRead = reader.read(buffer)) != -1) { + jsonBuilder.append(buffer, 0, charsRead); + } + } + return extractEventDataFromJson(jsonBuilder.toString()); + } finally { + inputStream.reset(); + } + } + + private static LambdaEventData extractEventDataFromJson(String json) { + try { + // Parse JSON into a Map + JsonAdapter adapter = + new Moshi.Builder().build().adapter(Map.class); + + Map event = adapter.fromJson(json); + log.debug("Event JSON parsed successfully"); + + if (event == null) { + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + } + + // Detect trigger type + LambdaTriggerType triggerType = detectTriggerType(event); + log.debug("Detected Lambda trigger type: {}", triggerType); + + // Extract data based on trigger type + switch (triggerType) { + case API_GATEWAY_V1_REST: + return extractApiGatewayV1Data(event); + case API_GATEWAY_V2_HTTP: + case LAMBDA_URL: + return extractApiGatewayV2HttpData(event, triggerType); + case API_GATEWAY_V2_WEBSOCKET: + return extractApiGatewayV2WebSocketData(event); + case ALB: + case ALB_MULTI_VALUE: + return extractAlbData(event, triggerType); + default: + log.debug("Unknown trigger type, attempting generic extraction"); + return extractGenericData(event); + } + } catch (Exception e) { + log.error("Failed to parse event data from JSON", e); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + } + } + + private static LambdaTriggerType detectTriggerType(Map event) { + Object requestContextObj = event.get("requestContext"); + + if (requestContextObj instanceof Map) { + Map requestContext = (Map) requestContextObj; + + // Check for ALB trigger (has elb object) + if (requestContext.containsKey("elb")) { + // Check if event has multiValueHeaders + if (event.containsKey("multiValueHeaders")) { + return LambdaTriggerType.ALB_MULTI_VALUE; + } + return LambdaTriggerType.ALB; + } + + // Check for WebSocket + if (requestContext.containsKey("connectionId") && + (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { + return LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET; + } + + // Check for API Gateway v2 format + Object httpObj = requestContext.get("http"); + if (httpObj instanceof Map) { + Object domainNameObj = requestContext.get("domainName"); + if (domainNameObj instanceof String) { + String domainName = (String) domainNameObj; + if (domainName.contains("lambda-url")) { + return LambdaTriggerType.LAMBDA_URL; + } else { + return LambdaTriggerType.API_GATEWAY_V2_HTTP; + } + } else { + return LambdaTriggerType.LAMBDA_URL; + } + } + + // Check for API Gateway v1 REST API + if (requestContext.containsKey("httpMethod") || requestContext.containsKey("requestId")) { + return LambdaTriggerType.API_GATEWAY_V1_REST; + } + } + return LambdaTriggerType.UNKNOWN; + } + + /** + * Extracts data from API Gateway v1 (REST API) event + */ + private static LambdaEventData extractApiGatewayV1Data(Map event) { + Map headers = extractHeaders(event.get("headers")); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + Map requestContext = (Map) event.get("requestContext"); + String method = (String) requestContext.get("httpMethod"); + String path = (String) event.get("path"); + + String sourceIp = null; + Object identityObj = requestContext.get("identity"); + if (identityObj instanceof Map) { + Map identity = (Map) identityObj; + sourceIp = (String) identity.get("sourceIp"); + } + + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, body); + } + + /** + * Extracts data from API Gateway v2 (HTTP API) or Lambda URL event + */ + private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { + Map headers = extractHeadersWithCookies(event); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + Map requestContext = (Map) event.get("requestContext"); + Map http = (Map) requestContext.get("http"); + + String method = (String) http.get("method"); + String path = (String) http.get("path"); + String sourceIp = (String) http.get("sourceIp"); + + // Extract port if available + Integer sourcePort = null; + Object portObj = http.get("sourcePort"); + if (portObj instanceof Number) { + sourcePort = ((Number) portObj).intValue(); + } + + return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, body); + } + + /** + * Extracts data from API Gateway v2 WebSocket event + */ + private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { + Map headers = extractHeadersWithCookies(event); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + Map requestContext = (Map) event.get("requestContext"); + + String method = "WEBSOCKET"; + String routeKey = (String) requestContext.get("routeKey"); + String path = routeKey != null ? routeKey : "/"; + + String sourceIp = null; + Object identityObj = requestContext.get("identity"); + if (identityObj instanceof Map) { + Map identity = (Map) identityObj; + sourceIp = (String) identity.get("sourceIp"); + } + + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, pathParameters, body); + } + + /** + * Extracts data from ALB event (with or without multi-value headers) + */ + private static LambdaEventData extractAlbData(Map event, LambdaTriggerType triggerType) { + Map headers; + + if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { + // Handle multi-value headers (combine multiple values with comma) + headers = new java.util.HashMap<>(); + Object multiValueHeadersObj = event.get("multiValueHeaders"); + if (multiValueHeadersObj instanceof Map) { + Map rawHeaders = (Map) multiValueHeadersObj; + for (Map.Entry entry : rawHeaders.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + if (entry.getValue() instanceof java.util.List) { + java.util.List values = (java.util.List) entry.getValue(); + // Join multiple values with comma + String joinedValue = values.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining(", ")); + headers.put(key, joinedValue); + } else { + headers.put(key, String.valueOf(entry.getValue())); + } + } + } + } + } else { + headers = extractHeaders(event.get("headers")); + } + + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + String method = (String) event.get("httpMethod"); + String path = (String) event.get("path"); + String sourceIp = headers.get("x-forwarded-for"); + + return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, body); + } + + /** + * Generic data extraction for unknown trigger types (fallback) + */ + private static LambdaEventData extractGenericData(Map event) { + Map headers = extractHeadersWithCookies(event); + Map pathParameters = extractPathParameters(event.get("pathParameters")); + Object body = extractBody(event); + + String method = null; + String path = null; + String sourceIp = null; + + // Try to extract from requestContext if available + Object requestContextObj = event.get("requestContext"); + if (requestContextObj instanceof Map) { + Map requestContext = (Map) requestContextObj; + + Object httpObj = requestContext.get("http"); + if (httpObj instanceof Map) { + Map http = (Map) httpObj; + method = (String) http.get("method"); + path = (String) http.get("path"); + sourceIp = (String) http.get("sourceIp"); + } else { + Object methodObj = requestContext.get("httpMethod"); + if (methodObj != null) { + method = String.valueOf(methodObj); + } + + Object identityObj = requestContext.get("identity"); + if (identityObj instanceof Map) { + Map identity = (Map) identityObj; + sourceIp = (String) identity.get("sourceIp"); + } + } + } + + // Try root level fields + if (method == null) { + Object methodObj = event.get("httpMethod"); + if (methodObj != null) { + method = String.valueOf(methodObj); + } + } + if (path == null) { + Object pathObj = event.get("path"); + if (pathObj != null) { + path = String.valueOf(pathObj); + } + } + + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, body); + } + + /** + * Generic helper method to extract string key-value pairs from an object. + * Converts all keys and values to strings, filtering out null entries. + */ + private static Map extractStringMap(Object mapObj) { + Map result = new java.util.HashMap<>(); + if (mapObj instanceof Map) { + Map rawMap = (Map) mapObj; + for (Map.Entry entry : rawMap.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + String value = String.valueOf(entry.getValue()); + result.put(key, value); + } + } + } + return result; + } + + /** + * Helper method to extract headers from event + */ + private static Map extractHeaders(Object headersObj) { + Map headers = extractStringMap(headersObj); + log.debug("Extracted {} headers", headers.size()); + if (headers.containsKey("cookie")) { + log.debug("Cookie header found with value length: {}", headers.get("cookie").length()); + } + return headers; + } + + /** + * Helper method to extract path parameters from event + */ + private static Map extractPathParameters(Object pathParamsObj) { + Map pathParams = extractStringMap(pathParamsObj); + log.debug("Extracted {} path parameters", pathParams.size()); + return pathParams; + } + + /** + * Helper method to extract and merge headers with cookies array from event. + * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. + */ + private static Map extractHeadersWithCookies(Map event) { + Map headers = extractHeaders(event.get("headers")); + + // API Gateway v2 provides a pre-parsed cookies array + Object cookiesObj = event.get("cookies"); + if (cookiesObj instanceof java.util.List) { + java.util.List cookiesList = (java.util.List) cookiesObj; + if (!cookiesList.isEmpty()) { + // Join cookies with "; " separator per RFC 6265 + String cookieValue = cookiesList.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining("; ")); + + // Merge with existing cookie header if present + String existingCookie = headers.get("cookie"); + if (existingCookie != null && !existingCookie.isEmpty()) { + headers.put("cookie", existingCookie + "; " + cookieValue); + } else { + headers.put("cookie", cookieValue); + } + } + } + + return headers; + } + + /** + * Helper method to extract and parse body from event + */ + private static Object extractBody(Map event) { + Object bodyObj = event.get("body"); + if (bodyObj == null) { + return null; + } + + String bodyString = String.valueOf(bodyObj); + + // Check if body is base64 encoded (API Gateway feature) + Boolean isBase64Encoded = (Boolean) event.get("isBase64Encoded"); + if (Boolean.TRUE.equals(isBase64Encoded)) { + try { + bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); + } catch (Exception e) { + log.debug("Failed to decode base64 body", e); + return null; + } + } + + // Try to parse as JSON + Object parsedBody = parseBodyAsJson(bodyString); + if (parsedBody != null) { + log.debug("Body parsed as JSON successfully"); + return parsedBody; + } + + // If not JSON, return the raw string + log.debug("Body is not JSON, returning raw string"); + return bodyString; + } + + /** + * Helper method to parse body as JSON + */ + private static Object parseBodyAsJson(String body) { + if (body == null || body.isEmpty() || "null".equals(body)) { + return null; + } + + try { + JsonAdapter adapter = new Moshi.Builder().build().adapter(Object.class); + Object parsed = adapter.fromJson(body); + return parsed; + } catch (Exception e) { + return null; + } + } + + /** + * Temporary RequestContext implementation to hold AppSecRequestContext + * before a span is created. + */ + private static class TemporaryRequestContext implements RequestContext { + private final Object appSecRequestContext; + + TemporaryRequestContext(Object appSecRequestContext) { + this.appSecRequestContext = appSecRequestContext; + } + + @Override + public T getData(RequestContextSlot slot) { + if (slot == RequestContextSlot.APPSEC) { + return (T) appSecRequestContext; + } + return null; + } + + @Override + public TraceSegment getTraceSegment() { + return TraceSegment.NoOp.INSTANCE; + } + + @Override + public void setBlockResponseFunction(BlockResponseFunction blockResponseFunction) { + // No-op for temporary context + } + + @Override + public BlockResponseFunction getBlockResponseFunction() { + return null; + } + + @Override + public T getOrCreateMetaStructTop(String key, Function defaultValue) { + return null; + } + + @Override + public void close() { + // No-op for temporary context + } + } + + /** + * Enum representing different AWS Lambda trigger types + */ + private enum LambdaTriggerType { + API_GATEWAY_V1_REST, // API Gateway REST API (v1) + API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) + API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket + ALB, // Application Load Balancer + ALB_MULTI_VALUE, // ALB with multi-value headers + LAMBDA_URL, // Lambda Function URL + UNKNOWN // Unknown or unsupported trigger + } + + /** + * Object for Lambda event data needed for AppSec processing + */ + private static class LambdaEventData { + final Map headers; + final String method; + final String path; + final String sourceIp; + final Integer sourcePort; + final LambdaTriggerType triggerType; + final Map pathParameters; + final Object body; + + LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Object body) { + this.headers = headers; + this.method = method; + this.path = path; + this.sourceIp = sourceIp; + this.sourcePort = sourcePort; + this.triggerType = triggerType; + this.pathParameters = pathParameters; + this.body = body; + } + } + + /** + * URIDataAdapter implementation for Lambda events. + */ + private static class LambdaURIDataAdapter extends URIDataAdapterBase { + private final String path; + private final String query; + + LambdaURIDataAdapter(String pathWithQuery) { + if (pathWithQuery != null) { + int queryIndex = pathWithQuery.indexOf('?'); + if (queryIndex != -1) { + this.path = pathWithQuery.substring(0, queryIndex); + this.query = pathWithQuery.substring(queryIndex + 1); + } else { + this.path = pathWithQuery; + this.query = null; + } + } else { + this.path = "/"; + this.query = null; + } + } + + @Override + public String scheme() { + return "https"; + } + + @Override + public String host() { + return null; + } + + @Override + public int port() { + return 443; + } + + @Override + public String path() { + return path; + } + + @Override + public String fragment() { + return null; + } + + @Override + public String query() { + return query; + } + + @Override + public boolean supportsRaw() { + return true; + } + + @Override + public String rawPath() { + return path; + } + + @Override + public String rawQuery() { + return query; + } + } +} diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java index 728910042fe..5e2eed69469 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java @@ -1,6 +1,5 @@ package datadog.trace.lambda; -import static datadog.trace.api.gateway.Events.EVENTS; import static datadog.trace.bootstrap.instrumentation.api.AgentPropagation.extractContextAndGetSpanContext; import static java.util.concurrent.TimeUnit.SECONDS; @@ -8,31 +7,12 @@ import com.squareup.moshi.Moshi; import datadog.trace.api.DDSpanId; import datadog.trace.api.DDTags; -import datadog.trace.api.function.TriConsumer; -import datadog.trace.api.gateway.BlockResponseFunction; -import datadog.trace.api.gateway.Flow; -import datadog.trace.api.gateway.IGSpanInfo; -import datadog.trace.api.gateway.RequestContext; -import datadog.trace.api.gateway.RequestContextSlot; -import datadog.trace.api.internal.TraceSegment; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; -import datadog.trace.bootstrap.instrumentation.api.TagContext; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.Reader; import java.nio.charset.StandardCharsets; import java.util.Base64; -import java.util.Collections; -import java.util.Map; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Supplier; -import datadog.trace.bootstrap.instrumentation.api.AgentTracer; -import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; -import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase; import okhttp3.ConnectionPool; import okhttp3.MediaType; import okhttp3.OkHttpClient; @@ -93,21 +73,6 @@ public class LambdaHandler { private static String EXTENSION_BASE_URL = "http://127.0.0.1:8124"; public static AgentSpanContext notifyStartInvocation(Object event, String lambdaRequestId) { - AgentTracer.TracerAPI tracer = AgentTracer.get(); - AgentSpanContext extractedContext = null; - - // Extract headers and call AppSec gateway events - if (event instanceof ByteArrayInputStream) { - try { - LambdaEventData eventData = extractEventData((ByteArrayInputStream) event); - extractedContext = callIGCallbackStart(tracer, eventData); - } catch (Exception e) { - log.error("Failed to extract data from event stream", e); - } - } else { - log.debug("Event is not a ByteArrayInputStream, type: {}", event != null ? event.getClass().getName() : "null"); - } - RequestBody body = RequestBody.create(jsonMediaType, writeValueAsString(event)); try (Response response = HTTP_CLIENT @@ -120,17 +85,13 @@ public static AgentSpanContext notifyStartInvocation(Object event, String lambda .build()) .execute()) { if (response.isSuccessful()) { - AgentSpanContext extensionContext = - extractContextAndGetSpanContext( - response.headers(), - (carrier, classifier) -> { - for (String headerName : carrier.names()) { - classifier.accept(headerName, carrier.get(headerName)); - } - }); - // Merge the AppSec context with the extension context - AgentSpanContext mergedContext = mergeContexts(extensionContext, extractedContext); - return mergedContext; + return extractContextAndGetSpanContext( + response.headers(), + (carrier, classifier) -> { + for (String headerName : carrier.names()) { + classifier.accept(headerName, carrier.get(headerName)); + } + }); } else { log.debug("Extension call failed with status: {}", response.code()); } @@ -147,18 +108,6 @@ public static boolean notifyEndInvocation( return false; } - // Call requestEnded event - RequestContext requestContext = span.getRequestContext(); - if (requestContext != null) { - AgentTracer.TracerAPI tracer = AgentTracer.get(); - BiFunction> requestEndedCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestEnded()); - if (requestEndedCallback != null) { - requestEndedCallback.apply(requestContext, span); - } - } - RequestBody body = RequestBody.create(jsonMediaType, writeValueAsString(result)); Request.Builder builder = new Request.Builder() @@ -218,671 +167,4 @@ public static String writeValueAsString(Object obj) { public static void setExtensionBaseUrl(String extensionBaseUrl) { EXTENSION_BASE_URL = extensionBaseUrl; } - - private static AgentSpanContext callIGCallbackStart( - AgentTracer.TracerAPI tracer, LambdaEventData eventData) { - Supplier> requestStartedCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestStarted()); - if (requestStartedCallback == null) { - log.debug("requestStarted callback is null"); - return null; - } - - TagContext tagContext = new TagContext(); - Object appSecRequestContext; - - // Call requestStarted - appSecRequestContext = requestStartedCallback.get().getResult(); - tagContext.withRequestContextDataAppSec(appSecRequestContext); - - if (appSecRequestContext != null) { - TemporaryRequestContext requestContext = new TemporaryRequestContext(appSecRequestContext); - - // Call requestMethodUriRaw - if (eventData.method != null && eventData.path != null) { - datadog.trace.api.function.TriFunction> methodUriCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); - if (methodUriCallback != null) { - LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(eventData.path); - methodUriCallback.apply(requestContext, eventData.method, uriAdapter); - } else { - log.debug("requestMethodUriRaw callback is null"); - } - } - - // Call requestHeader for each header - if (eventData.headers != null && !eventData.headers.isEmpty()) { - TriConsumer headerCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestHeader()); - if (headerCallback != null) { - for (Map.Entry header : eventData.headers.entrySet()) { - headerCallback.accept(requestContext, header.getKey(), header.getValue()); - } - } else { - log.debug("requestHeader callback is null"); - } - } - - // Call requestClientSocketAddress - if (eventData.sourceIp != null) { - datadog.trace.api.function.TriFunction> socketAddrCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestClientSocketAddress()); - if (socketAddrCallback != null) { - Integer port = eventData.sourcePort != null ? eventData.sourcePort : 0; - socketAddrCallback.apply(requestContext, eventData.sourceIp, port); - } else { - log.debug("requestClientSocketAddress callback is null"); - } - } - - // Call requestHeaderDone - Function> headerDoneCallback = - tracer - .getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestHeaderDone()); - if (headerDoneCallback != null) { - headerDoneCallback.apply(requestContext); - } else { - log.debug("requestHeaderDone callback is null"); - } - - // Call requestPathParams - if (eventData.pathParameters != null && !eventData.pathParameters.isEmpty()) { - BiFunction, Flow> pathParamsCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestPathParams()); - if (pathParamsCallback != null) { - pathParamsCallback.apply(requestContext, eventData.pathParameters); - } else { - log.debug("requestPathParams callback is null"); - } - } - - // Call requestBodyProcessed - if (eventData.body != null) { - BiFunction> bodyCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestBodyProcessed()); - if (bodyCallback != null) { - bodyCallback.apply(requestContext, eventData.body); - } else { - log.debug("requestBodyProcessed callback is null"); - } - } - } - return tagContext; - } - - private static AgentSpanContext mergeContexts( - AgentSpanContext extensionContext, AgentSpanContext extractedContext) { - if (extractedContext == null) { - return extensionContext; - } - if (extensionContext == null) { - return extractedContext; - } - - if (extractedContext instanceof TagContext) { - TagContext extracted = (TagContext) extractedContext; - Object appSecData = extracted.getRequestContextDataAppSec(); - Object iastData = extracted.getRequestContextDataIast(); - - if (extensionContext instanceof TagContext) { - TagContext merged = (TagContext) extensionContext; - if (appSecData != null) { - merged.withRequestContextDataAppSec(appSecData); - } - if (iastData != null) { - merged.withRequestContextDataIast(iastData); - } - return merged; - } - - log.warn( - "Cannot merge AppSec data: extension context is not a TagContext: {}", - extensionContext.getClass()); - } - return extensionContext; - } - - private static LambdaEventData extractEventData(ByteArrayInputStream inputStream) - throws IOException { - inputStream.mark(0); - - try { - StringBuilder jsonBuilder = new StringBuilder(inputStream.available()); - try (Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { - char[] buffer = new char[1024]; - int charsRead; - while ((charsRead = reader.read(buffer)) != -1) { - jsonBuilder.append(buffer, 0, charsRead); - } - } - return extractEventDataFromJson(jsonBuilder.toString()); - } finally { - inputStream.reset(); - } - } - - private static LambdaEventData extractEventDataFromJson(String json) { - try { - // Parse JSON into a Map - JsonAdapter adapter = - new Moshi.Builder().build().adapter(Map.class); - - Map event = adapter.fromJson(json); - log.debug("Event JSON parsed successfully"); - - if (event == null) { - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); - } - - // Detect trigger type - LambdaTriggerType triggerType = detectTriggerType(event); - log.debug("Detected Lambda trigger type: {}", triggerType); - - // Extract data based on trigger type - switch (triggerType) { - case API_GATEWAY_V1_REST: - return extractApiGatewayV1Data(event); - case API_GATEWAY_V2_HTTP: - case LAMBDA_URL: - return extractApiGatewayV2HttpData(event, triggerType); - case API_GATEWAY_V2_WEBSOCKET: - return extractApiGatewayV2WebSocketData(event); - case ALB: - case ALB_MULTI_VALUE: - return extractAlbData(event, triggerType); - default: - log.debug("Unknown trigger type, attempting generic extraction"); - return extractGenericData(event); - } - } catch (Exception e) { - log.error("Failed to parse event data from JSON", e); - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); - } - } - - private static LambdaTriggerType detectTriggerType(Map event) { - Object requestContextObj = event.get("requestContext"); - - if (requestContextObj instanceof Map) { - Map requestContext = (Map) requestContextObj; - - // Check for ALB trigger (has elb object) - if (requestContext.containsKey("elb")) { - // Check if event has multiValueHeaders - if (event.containsKey("multiValueHeaders")) { - return LambdaTriggerType.ALB_MULTI_VALUE; - } - return LambdaTriggerType.ALB; - } - - // Check for WebSocket - if (requestContext.containsKey("connectionId") && - (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { - return LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET; - } - - // Check for API Gateway v2 format - Object httpObj = requestContext.get("http"); - if (httpObj instanceof Map) { - Object domainNameObj = requestContext.get("domainName"); - if (domainNameObj instanceof String) { - String domainName = (String) domainNameObj; - if (domainName.contains("lambda-url")) { - return LambdaTriggerType.LAMBDA_URL; - } else { - return LambdaTriggerType.API_GATEWAY_V2_HTTP; - } - } else { - return LambdaTriggerType.LAMBDA_URL; - } - } - - // Check for API Gateway v1 REST API - if (requestContext.containsKey("httpMethod") || requestContext.containsKey("requestId")) { - return LambdaTriggerType.API_GATEWAY_V1_REST; - } - } - return LambdaTriggerType.UNKNOWN; - } - - /** - * Extracts data from API Gateway v1 (REST API) event - */ - private static LambdaEventData extractApiGatewayV1Data(Map event) { - Map headers = extractHeaders(event.get("headers")); - Map pathParameters = extractPathParameters(event.get("pathParameters")); - Object body = extractBody(event); - - Map requestContext = (Map) event.get("requestContext"); - String method = (String) requestContext.get("httpMethod"); - String path = (String) event.get("path"); - - String sourceIp = null; - Object identityObj = requestContext.get("identity"); - if (identityObj instanceof Map) { - Map identity = (Map) identityObj; - sourceIp = (String) identity.get("sourceIp"); - } - - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, body); - } - - /** - * Extracts data from API Gateway v2 (HTTP API) or Lambda URL event - */ - private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { - Map headers = extractHeadersWithCookies(event); - Map pathParameters = extractPathParameters(event.get("pathParameters")); - Object body = extractBody(event); - - Map requestContext = (Map) event.get("requestContext"); - Map http = (Map) requestContext.get("http"); - - String method = (String) http.get("method"); - String path = (String) http.get("path"); - String sourceIp = (String) http.get("sourceIp"); - - // Extract port if available - Integer sourcePort = null; - Object portObj = http.get("sourcePort"); - if (portObj instanceof Number) { - sourcePort = ((Number) portObj).intValue(); - } - - return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, body); - } - - /** - * Extracts data from API Gateway v2 WebSocket event - */ - private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { - Map headers = extractHeadersWithCookies(event); - Map pathParameters = extractPathParameters(event.get("pathParameters")); - Object body = extractBody(event); - - Map requestContext = (Map) event.get("requestContext"); - - String method = "WEBSOCKET"; - String routeKey = (String) requestContext.get("routeKey"); - String path = routeKey != null ? routeKey : "/"; - - String sourceIp = null; - Object identityObj = requestContext.get("identity"); - if (identityObj instanceof Map) { - Map identity = (Map) identityObj; - sourceIp = (String) identity.get("sourceIp"); - } - - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, pathParameters, body); - } - - /** - * Extracts data from ALB event (with or without multi-value headers) - */ - private static LambdaEventData extractAlbData(Map event, LambdaTriggerType triggerType) { - Map headers; - - if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { - // Handle multi-value headers (combine multiple values with comma) - headers = new java.util.HashMap<>(); - Object multiValueHeadersObj = event.get("multiValueHeaders"); - if (multiValueHeadersObj instanceof Map) { - Map rawHeaders = (Map) multiValueHeadersObj; - for (Map.Entry entry : rawHeaders.entrySet()) { - if (entry.getKey() != null && entry.getValue() != null) { - String key = String.valueOf(entry.getKey()); - if (entry.getValue() instanceof java.util.List) { - java.util.List values = (java.util.List) entry.getValue(); - // Join multiple values with comma - String joinedValue = values.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining(", ")); - headers.put(key, joinedValue); - } else { - headers.put(key, String.valueOf(entry.getValue())); - } - } - } - } - } else { - headers = extractHeaders(event.get("headers")); - } - - Map pathParameters = extractPathParameters(event.get("pathParameters")); - Object body = extractBody(event); - - String method = (String) event.get("httpMethod"); - String path = (String) event.get("path"); - String sourceIp = headers.get("x-forwarded-for"); - - return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, body); - } - - /** - * Generic data extraction for unknown trigger types (fallback) - */ - private static LambdaEventData extractGenericData(Map event) { - Map headers = extractHeadersWithCookies(event); - Map pathParameters = extractPathParameters(event.get("pathParameters")); - Object body = extractBody(event); - - String method = null; - String path = null; - String sourceIp = null; - - // Try to extract from requestContext if available - Object requestContextObj = event.get("requestContext"); - if (requestContextObj instanceof Map) { - Map requestContext = (Map) requestContextObj; - - Object httpObj = requestContext.get("http"); - if (httpObj instanceof Map) { - Map http = (Map) httpObj; - method = (String) http.get("method"); - path = (String) http.get("path"); - sourceIp = (String) http.get("sourceIp"); - } else { - Object methodObj = requestContext.get("httpMethod"); - if (methodObj != null) { - method = String.valueOf(methodObj); - } - - Object identityObj = requestContext.get("identity"); - if (identityObj instanceof Map) { - Map identity = (Map) identityObj; - sourceIp = (String) identity.get("sourceIp"); - } - } - } - - // Try root level fields - if (method == null) { - Object methodObj = event.get("httpMethod"); - if (methodObj != null) { - method = String.valueOf(methodObj); - } - } - if (path == null) { - Object pathObj = event.get("path"); - if (pathObj != null) { - path = String.valueOf(pathObj); - } - } - - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, body); - } - - /** - * Generic helper method to extract string key-value pairs from an object. - * Converts all keys and values to strings, filtering out null entries. - */ - private static Map extractStringMap(Object mapObj) { - Map result = new java.util.HashMap<>(); - if (mapObj instanceof Map) { - Map rawMap = (Map) mapObj; - for (Map.Entry entry : rawMap.entrySet()) { - if (entry.getKey() != null && entry.getValue() != null) { - String key = String.valueOf(entry.getKey()); - String value = String.valueOf(entry.getValue()); - result.put(key, value); - } - } - } - return result; - } - - /** - * Helper method to extract headers from event - */ - private static Map extractHeaders(Object headersObj) { - Map headers = extractStringMap(headersObj); - log.debug("Extracted {} headers", headers.size()); - if (headers.containsKey("cookie")) { - log.debug("Cookie header found with value length: {}", headers.get("cookie").length()); - } - return headers; - } - - /** - * Helper method to extract path parameters from event - */ - private static Map extractPathParameters(Object pathParamsObj) { - Map pathParams = extractStringMap(pathParamsObj); - log.debug("Extracted {} path parameters", pathParams.size()); - return pathParams; - } - - /** - * Helper method to extract and merge headers with cookies array from event. - * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. - */ - private static Map extractHeadersWithCookies(Map event) { - Map headers = extractHeaders(event.get("headers")); - - // API Gateway v2 provides a pre-parsed cookies array - Object cookiesObj = event.get("cookies"); - if (cookiesObj instanceof java.util.List) { - java.util.List cookiesList = (java.util.List) cookiesObj; - if (!cookiesList.isEmpty()) { - // Join cookies with "; " separator per RFC 6265 - String cookieValue = cookiesList.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining("; ")); - - // Merge with existing cookie header if present - String existingCookie = headers.get("cookie"); - if (existingCookie != null && !existingCookie.isEmpty()) { - headers.put("cookie", existingCookie + "; " + cookieValue); - } else { - headers.put("cookie", cookieValue); - } - } - } - - return headers; - } - - /** - * Helper method to extract and parse body from event - */ - private static Object extractBody(Map event) { - Object bodyObj = event.get("body"); - if (bodyObj == null) { - return null; - } - - String bodyString = String.valueOf(bodyObj); - - // Check if body is base64 encoded (API Gateway feature) - Boolean isBase64Encoded = (Boolean) event.get("isBase64Encoded"); - if (Boolean.TRUE.equals(isBase64Encoded)) { - try { - bodyString = new String(Base64.getDecoder().decode(bodyString), StandardCharsets.UTF_8); - } catch (Exception e) { - log.debug("Failed to decode base64 body", e); - return null; - } - } - - // Try to parse as JSON - Object parsedBody = parseBodyAsJson(bodyString); - if (parsedBody != null) { - log.debug("Body parsed as JSON successfully"); - return parsedBody; - } - - // If not JSON, return the raw string - log.debug("Body is not JSON, returning raw string"); - return bodyString; - } - - /** - * Helper method to parse body as JSON - */ - private static Object parseBodyAsJson(String body) { - if (body == null || body.isEmpty() || "null".equals(body)) { - return null; - } - - try { - JsonAdapter adapter = new Moshi.Builder().build().adapter(Object.class); - Object parsed = adapter.fromJson(body); - return parsed; - } catch (Exception e) { - return null; - } - } - - /** - * Temporary RequestContext implementation to hold AppSecRequestContext - * before a span is created. - */ - private static class TemporaryRequestContext implements RequestContext { - private final Object appSecRequestContext; - - TemporaryRequestContext(Object appSecRequestContext) { - this.appSecRequestContext = appSecRequestContext; - } - - @Override - public T getData(RequestContextSlot slot) { - if (slot == RequestContextSlot.APPSEC) { - return (T) appSecRequestContext; - } - return null; - } - - @Override - public TraceSegment getTraceSegment() { - return TraceSegment.NoOp.INSTANCE; - } - - @Override - public void setBlockResponseFunction(BlockResponseFunction blockResponseFunction) { - // No-op for temporary context - } - - @Override - public BlockResponseFunction getBlockResponseFunction() { - return null; - } - - @Override - public T getOrCreateMetaStructTop(String key, Function defaultValue) { - return null; - } - - @Override - public void close() { - // No-op for temporary context - } - } - - /** - * Enum representing different AWS Lambda trigger types - */ - private enum LambdaTriggerType { - API_GATEWAY_V1_REST, // API Gateway REST API (v1) - API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) - API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket - ALB, // Application Load Balancer - ALB_MULTI_VALUE, // ALB with multi-value headers - LAMBDA_URL, // Lambda Function URL - UNKNOWN // Unknown or unsupported trigger - } - - /** - * Object for Lambda event data needed for AppSec processing - */ - private static class LambdaEventData { - final Map headers; - final String method; - final String path; - final String sourceIp; - final Integer sourcePort; - final LambdaTriggerType triggerType; - final Map pathParameters; - final Object body; - - LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Object body) { - this.headers = headers; - this.method = method; - this.path = path; - this.sourceIp = sourceIp; - this.sourcePort = sourcePort; - this.triggerType = triggerType; - this.pathParameters = pathParameters; - this.body = body; - } - } - - /** - * URIDataAdapter implementation for Lambda events. - */ - private static class LambdaURIDataAdapter extends URIDataAdapterBase { - private final String path; - private final String query; - - LambdaURIDataAdapter(String pathWithQuery) { - if (pathWithQuery != null) { - int queryIndex = pathWithQuery.indexOf('?'); - if (queryIndex != -1) { - this.path = pathWithQuery.substring(0, queryIndex); - this.query = pathWithQuery.substring(queryIndex + 1); - } else { - this.path = pathWithQuery; - this.query = null; - } - } else { - this.path = "/"; - this.query = null; - } - } - - @Override - public String scheme() { - return "https"; - } - - @Override - public String host() { - return null; - } - - @Override - public int port() { - return 443; - } - - @Override - public String path() { - return path; - } - - @Override - public String fragment() { - return null; - } - - @Override - public String query() { - return query; - } - - @Override - public boolean supportsRaw() { - return true; - } - - @Override - public String rawPath() { - return path; - } - - @Override - public String rawQuery() { - return query; - } - } } diff --git a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java b/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java index fb3eb9f853b..23043f80e68 100644 --- a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java +++ b/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/AgentTracer.java @@ -415,10 +415,12 @@ default SpanBuilder singleSpanBuilder(CharSequence spanName) { CallbackProvider getUniversalCallbackProvider(); - AgentSpanContext notifyExtensionStart(Object event, String lambdaRequestId); + AgentSpanContext notifyLambdaStart(Object event, String lambdaRequestId); void notifyExtensionEnd(AgentSpan span, Object result, boolean isError, String lambdaRequestId); + void notifyAppSecEnd(AgentSpan span); + AgentDataStreamsMonitoring getDataStreamsMonitoring(); String getTraceId(AgentSpan span); @@ -662,7 +664,7 @@ public EndpointTracker onRootSpanStarted(AgentSpan root) { } @Override - public AgentSpanContext notifyExtensionStart(Object event, String lambdaRequestId) { + public AgentSpanContext notifyLambdaStart(Object event, String lambdaRequestId) { return null; } @@ -670,6 +672,9 @@ public AgentSpanContext notifyExtensionStart(Object event, String lambdaRequestI public void notifyExtensionEnd( AgentSpan span, Object result, boolean isError, String lambdaRequestId) {} + @Override + public void notifyAppSecEnd(AgentSpan span) {} + @Override public AgentDataStreamsMonitoring getDataStreamsMonitoring() { return NoopDataStreamsMonitoring.INSTANCE; From 8b3857fc8cfe6b67e660412c8195d0fc72be5935 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Thu, 5 Feb 2026 15:12:59 +0100 Subject: [PATCH 3/5] unit tests --- .../trace/lambda/LambdaAppSecHandler.java | 37 +- .../lambda/LambdaAppSecHandlerTest.groovy | 1339 +++++++++++++++++ 2 files changed, 1362 insertions(+), 14 deletions(-) create mode 100644 dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 08bf8950817..610f7b2c518 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -4,6 +4,7 @@ import com.squareup.moshi.JsonAdapter; import com.squareup.moshi.Moshi; +import datadog.trace.api.Config; import datadog.trace.api.function.TriConsumer; import datadog.trace.api.gateway.BlockResponseFunction; import datadog.trace.api.gateway.Flow; @@ -40,6 +41,12 @@ public class LambdaAppSecHandler { private static final Logger log = LoggerFactory.getLogger(LambdaAppSecHandler.class); + private static final Moshi MOSHI = new Moshi.Builder().build(); + private static final JsonAdapter MAP_ADAPTER = MOSHI.adapter(Map.class); + private static final JsonAdapter OBJECT_ADAPTER = MOSHI.adapter(Object.class); + + private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); + /** * Process AppSec request data at the start of a Lambda invocation. * Extract event data and invokes all relevant AppSec gateway callbacks. @@ -110,16 +117,12 @@ public static AgentSpanContext mergeContexts( if (appSecContext instanceof TagContext) { TagContext extracted = (TagContext) appSecContext; Object appSecData = extracted.getRequestContextDataAppSec(); - Object iastData = extracted.getRequestContextDataIast(); if (extensionContext instanceof TagContext) { TagContext merged = (TagContext) extensionContext; if (appSecData != null) { merged.withRequestContextDataAppSec(appSecData); } - if (iastData != null) { - merged.withRequestContextDataIast(iastData); - } return merged; } @@ -224,8 +227,18 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa private static LambdaEventData extractEventData(ByteArrayInputStream inputStream) throws IOException { + inputStream.mark(0); try { - StringBuilder jsonBuilder = new StringBuilder(inputStream.available()); + int availableBytes = inputStream.available(); + + if (availableBytes <= 0 || availableBytes > MAX_EVENT_SIZE) { + log.warn("Event size {} exceeds limit {} or is invalid, skipping AppSec processing", + availableBytes, MAX_EVENT_SIZE); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, + LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + } + + StringBuilder jsonBuilder = new StringBuilder(availableBytes); try (Reader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { char[] buffer = new char[1024]; int charsRead; @@ -242,10 +255,7 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream private static LambdaEventData extractEventDataFromJson(String json) { try { // Parse JSON into a Map - JsonAdapter adapter = - new Moshi.Builder().build().adapter(Map.class); - - Map event = adapter.fromJson(json); + Map event = MAP_ADAPTER.fromJson(json); log.debug("Event JSON parsed successfully"); if (event == null) { @@ -278,7 +288,7 @@ private static LambdaEventData extractEventDataFromJson(String json) { } } - private static LambdaTriggerType detectTriggerType(Map event) { + static LambdaTriggerType detectTriggerType(Map event) { Object requestContextObj = event.get("requestContext"); if (requestContextObj instanceof Map) { @@ -603,8 +613,7 @@ private static Object parseBodyAsJson(String body) { } try { - JsonAdapter adapter = new Moshi.Builder().build().adapter(Object.class); - Object parsed = adapter.fromJson(body); + Object parsed = OBJECT_ADAPTER.fromJson(body); return parsed; } catch (Exception e) { return null; @@ -659,7 +668,7 @@ public void close() { /** * Enum representing different AWS Lambda trigger types */ - private enum LambdaTriggerType { + enum LambdaTriggerType { API_GATEWAY_V1_REST, // API Gateway REST API (v1) API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket @@ -672,7 +681,7 @@ private enum LambdaTriggerType { /** * Object for Lambda event data needed for AppSec processing */ - private static class LambdaEventData { + static class LambdaEventData { final Map headers; final String method; final String path; diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy new file mode 100644 index 00000000000..eea29fe68c8 --- /dev/null +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -0,0 +1,1339 @@ +package datadog.trace.lambda + +import datadog.trace.api.Config +import datadog.trace.api.function.TriConsumer +import datadog.trace.api.gateway.CallbackProvider +import datadog.trace.api.gateway.Flow +import datadog.trace.api.gateway.RequestContext +import datadog.trace.api.gateway.RequestContextSlot +import datadog.trace.bootstrap.ActiveSubsystems +import datadog.trace.bootstrap.instrumentation.api.AgentSpan +import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext +import datadog.trace.bootstrap.instrumentation.api.AgentTracer +import datadog.trace.bootstrap.instrumentation.api.TagContext +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter +import datadog.trace.core.test.DDCoreSpecification +import spock.lang.Shared + +import java.nio.charset.StandardCharsets +import java.util.function.BiFunction +import java.util.function.Function +import java.util.function.Supplier + +import static datadog.trace.api.gateway.Events.EVENTS + +class LambdaAppSecHandlerTest extends DDCoreSpecification { + + @Shared + def originalAppSecActive + + def setupSpec() { + originalAppSecActive = ActiveSubsystems.APPSEC_ACTIVE + } + + def cleanupSpec() { + ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive + } + + def setup() { + ActiveSubsystems.APPSEC_ACTIVE = true + } + + def "processRequestStart returns null when AppSec is disabled"() { + given: + ActiveSubsystems.APPSEC_ACTIVE = false + def event = createInputStream('{"test": "data"}') + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "processRequestStart returns null for non-ByteArrayInputStream"() { + when: + def result = LambdaAppSecHandler.processRequestStart("not a stream") + + then: + result == null + } + + def "processRequestStart returns null for null event"() { + when: + def result = LambdaAppSecHandler.processRequestStart(null) + + then: + result == null + } + + def "processRequestStart returns null for oversized event"() { + given: + def maxSize = Config.get().getAppSecBodyParsingSizeLimit() + def largeBody = "x" * (maxSize + 1) + def event = createInputStream(largeBody) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "processRequestStart returns null for zero-size event"() { + given: + def event = createInputStream('') + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "processRequestStart returns null for malformed JSON"() { + given: + def event = createInputStream('{invalid json') + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null + } + + def "stream can be read multiple times after processing"() { + given: + def jsonData = '{"test": "data", "requestContext": {"httpMethod": "GET"}}' + def event = createInputStream(jsonData) + + when: + LambdaAppSecHandler.processRequestStart(event) + event.reset() + def content = new String(event.readAllBytes(), StandardCharsets.UTF_8) + + then: + content == jsonData + } + + + // ============================================================================ + // Trigger Type Detection Tests + // ============================================================================ + + def "detects API Gateway v1 REST trigger type"() { + given: + def event = [ + requestContext: [ + httpMethod: "GET", + requestId: "abc123" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST + } + + def "detects API Gateway v2 HTTP trigger type"() { + given: + def event = [ + requestContext: [ + http: [ + method: "POST", + path: "/api" + ], + domainName: "api.example.com" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP + } + + def "detects Lambda Function URL trigger type"() { + given: + def event = [ + requestContext: [ + http: [ + method: "GET", + path: "/" + ], + domainName: "xyz123.lambda-url.us-east-1.on.aws" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL + } + + def "detects ALB trigger type without multi-value headers"() { + given: + def event = [ + httpMethod: "GET", + path: "/", + requestContext: [ + elb: [ + targetGroupArn: "arn:aws:..." + ] + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB + } + + def "detects ALB trigger type with multi-value headers"() { + given: + def event = [ + httpMethod: "GET", + path: "/", + multiValueHeaders: [ + accept: ["text/html", "application/json"] + ], + requestContext: [ + elb: [ + targetGroupArn: "arn:aws:..." + ] + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE + } + + def "detects WebSocket trigger type with routeKey"() { + given: + def event = [ + requestContext: [ + connectionId: "conn-123", + routeKey: "\$connect" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET + } + + def "detects WebSocket trigger type with eventType"() { + given: + def event = [ + requestContext: [ + connectionId: "conn-456", + eventType: "CONNECT" + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET + } + + def "detects unknown trigger type for unrecognized events"() { + given: + def event = [ + someUnknownField: "value" + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN + } + + def "detects unknown trigger type for empty requestContext"() { + given: + def event = [ + requestContext: [:] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN + } + + def "detects Lambda URL when http present but no domainName"() { + given: + def event = [ + requestContext: [ + http: [ + method: "GET", + path: "/ambiguous" + ] + ] + ] + + when: + def triggerType = LambdaAppSecHandler.detectTriggerType(event) + + then: + triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL + } + + // ============================================================================ + // Data Extraction Tests with Mocked Callbacks + // ============================================================================ + + def "extracts API Gateway v1 REST data correctly"() { + given: + def eventJson = ''' + { + "path": "/api/users/123", + "httpMethod": "POST", + "headers": { + "Content-Type": "application/json", + "Authorization": "Bearer token123" + }, + "pathParameters": { + "userId": "123" + }, + "body": "{\\"name\\": \\"John\\"}", + "requestContext": { + "httpMethod": "POST", + "requestId": "req-123", + "identity": { + "sourceIp": "192.168.1.100" + } + } + } + ''' + def event = createInputStream(eventJson) + + // Track callback invocations + def capturedMethod = null + def capturedPath = null + def capturedHeaders = [:] + def capturedSourceIp = null + def capturedSourcePort = null + def capturedPathParams = null + def capturedBody = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + }, + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + result instanceof TagContext + + capturedMethod == "POST" + capturedPath == "/api/users/123" + capturedHeaders["Content-Type"] == "application/json" + capturedHeaders["Authorization"] == "Bearer token123" + capturedSourceIp == "192.168.1.100" + capturedSourcePort == 0 + capturedPathParams == ["userId": "123"] + capturedBody instanceof Map + capturedBody.name == "John" + } + + def "extracts API Gateway v2 HTTP data correctly"() { + given: + def eventJson = ''' + { + "version": "2.0", + "headers": { + "content-type": "application/json", + "x-custom-header": "custom-value" + }, + "cookies": ["session=abc123", "user=john"], + "pathParameters": { + "id": "456" + }, + "body": "test body", + "requestContext": { + "http": { + "method": "PUT", + "path": "/api/items/456", + "sourceIp": "10.0.0.50", + "sourcePort": 54321 + }, + "domainName": "api.example.com" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedHeaders = [:] + def capturedSourceIp = null + def capturedSourcePort = null + def capturedPathParams = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "PUT" + capturedPath == "/api/items/456" + capturedHeaders["content-type"] == "application/json" + capturedHeaders["x-custom-header"] == "custom-value" + capturedHeaders["cookie"] == "session=abc123; user=john" + capturedSourceIp == "10.0.0.50" + capturedSourcePort == 54321 + capturedPathParams == ["id": "456"] + } + + def "extracts Lambda Function URL data correctly"() { + given: + def eventJson = ''' + { + "version": "2.0", + "headers": { + "host": "xyz.lambda-url.us-east-1.on.aws" + }, + "requestContext": { + "http": { + "method": "GET", + "path": "/function/path", + "sourceIp": "1.2.3.4" + }, + "domainName": "xyz.lambda-url.us-east-1.on.aws" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "GET" + capturedPath == "/function/path" + } + + def "extracts ALB data correctly"() { + given: + def eventJson = ''' + { + "path": "/alb/test", + "httpMethod": "DELETE", + "headers": { + "x-forwarded-for": "203.0.113.42", + "user-agent": "curl/7.64.1" + }, + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-target-group/50dc6c495c0c9188" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "DELETE" + capturedPath == "/alb/test" + capturedSourceIp == "203.0.113.42" + } + + def "extracts ALB multi-value headers correctly"() { + given: + def eventJson = ''' + { + "path": "/test", + "httpMethod": "GET", + "multiValueHeaders": { + "accept": ["text/html", "application/json"], + "x-custom": ["value1", "value2"] + }, + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:..." + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedHeaders["accept"] == "text/html, application/json" + capturedHeaders["x-custom"] == "value1, value2" + } + + def "handles multi-value headers with empty list"() { + given: + def eventJson = ''' + { + "path": "/test", + "httpMethod": "GET", + "multiValueHeaders": { + "accept": [], + "x-custom": ["value1"] + }, + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:..." + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedHeaders["accept"] == "" // Empty list should result in empty string + capturedHeaders["x-custom"] == "value1" + } + + def "extracts WebSocket data correctly"() { + given: + def eventJson = ''' + { + "requestContext": { + "routeKey": "$connect", + "connectionId": "conn-abc123", + "identity": { + "sourceIp": "192.168.0.100" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "WEBSOCKET" + capturedPath == "\$connect" + capturedSourceIp == "192.168.0.100" + } + + def "handles base64 encoded body correctly"() { + given: + def originalBody = "This is test data" + def base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes()) + def eventJson = """ + { + "body": "${base64Body}", + "isBase64Encoded": true, + "requestContext": { + "httpMethod": "POST" + } + } + """ + def event = createInputStream(eventJson) + + def capturedBody = null + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == originalBody + } + + def "handles null body correctly"() { + given: + def event = createInputStream('{"body": null, "requestContext": {"httpMethod": "GET"}}') + + def capturedBody = "NOT_CALLED" + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "NOT_CALLED" // Callback should not be invoked for null body + } + + def "handles empty body correctly"() { + given: + def event = createInputStream('{"body": "", "requestContext": {"httpMethod": "POST"}}') + + def capturedBody = null + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "" // Empty body is passed as empty string to WAF + } + + def "handles path with query string correctly"() { + given: + def eventJson = ''' + { + "path": "/api/users?id=123&filter=active", + "requestContext": { + "httpMethod": "GET" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedPath = null + def capturedQuery = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedPath = uri.path() + capturedQuery = uri.query() + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedPath == "/api/users" + capturedQuery == "id=123&filter=active" + } + + def "handles invalid base64 body gracefully"() { + given: + def eventJson = ''' + { + "body": "not-valid-base64", + "isBase64Encoded": true, + "requestContext": { + "httpMethod": "POST" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedBody = "NOT_CALLED" + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "NOT_CALLED" // Should not call body callback when decode fails + } + + def "handles base64 decoded empty string body"() { + given: + def base64Empty = Base64.getEncoder().encodeToString("".getBytes()) + def eventJson = """ + { + "body": "${base64Empty}", + "isBase64Encoded": true, + "requestContext": { + "httpMethod": "POST" + } + } + """ + def event = createInputStream(eventJson) + + def capturedBody = "NOT_CALLED" + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody == "" // Should pass empty string after decoding + } + + def "handles body with special characters"() { + given: + def eventJson = ''' + { + "body": "{\\"text\\": \\"Hello 世界 🌍\\"}", + "requestContext": { + "httpMethod": "POST" + } + } + ''' + def event = createInputStream(eventJson) + + def capturedBody = null + + setupMockCallbacks( + onBody: { body -> + capturedBody = body + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedBody instanceof Map + capturedBody.text == "Hello 世界 🌍" + } + + // ============================================================================ + // Generic Data Extraction Tests + // ============================================================================ + + def "extracts data from unknown trigger type using generic extraction"() { + given: + def eventJson = ''' + { + "path": "/generic/path", + "httpMethod": "PATCH", + "headers": { + "x-custom-header": "generic-value" + }, + "unknownField": "should be ignored", + "requestContext": { + "identity": { + "sourceIp": "203.0.113.1" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedHeaders = [:] + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "PATCH" + capturedPath == "/generic/path" + capturedHeaders["x-custom-header"] == "generic-value" + capturedSourceIp == "203.0.113.1" + } + + def "extracts data from unknown trigger with http in requestContext"() { + given: + def eventJson = ''' + { + "requestContext": { + "http": { + "method": "OPTIONS", + "path": "/options/path", + "sourceIp": "198.51.100.50" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedMethod = null + def capturedPath = null + def capturedSourceIp = null + + setupMockCallbacks( + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedMethod == "OPTIONS" + capturedPath == "/options/path" + capturedSourceIp == "198.51.100.50" + } + + def "handles cookies merging with existing cookie header"() { + given: + def eventJson = ''' + { + "headers": { + "cookie": "existing=value" + }, + "cookies": ["new=cookie1", "another=cookie2"], + "requestContext": { + "http": { + "method": "GET", + "path": "/" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + capturedHeaders["cookie"] == "existing=value; new=cookie1; another=cookie2" + } + + def "handles empty cookies array correctly"() { + given: + def eventJson = ''' + { + "headers": { + "content-type": "application/json" + }, + "cookies": [], + "requestContext": { + "http": { + "method": "GET", + "path": "/" + } + } + } + ''' + def event = createInputStream(eventJson) + + def capturedHeaders = [:] + + setupMockCallbacks( + onHeader: { name, value -> + capturedHeaders[name] = value + } + ) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null + !capturedHeaders.containsKey("cookie") // Empty array should not add cookie header + } + + // ============================================================================ + // processRequestEnd Tests + // ============================================================================ + + def "processRequestEnd does nothing when span is null"() { + when: + LambdaAppSecHandler.processRequestEnd(null) + + then: + noExceptionThrown() + } + + def "processRequestEnd does nothing when AppSec is disabled"() { + given: + ActiveSubsystems.APPSEC_ACTIVE = false + def span = Mock(AgentSpan) + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + 0 * span._ + } + + def "processRequestEnd does nothing when span has no RequestContext"() { + given: + def span = Mock(AgentSpan) { + getRequestContext() >> null + } + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + noExceptionThrown() + } + + def "processRequestEnd invokes requestEnded callback with RequestContext"() { + given: + def mockAppSecContext = new Object() + def mockRequestContext = Mock(RequestContext) { + getData(RequestContextSlot.APPSEC) >> mockAppSecContext + } + def span = Mock(AgentSpan) { + getRequestContext() >> mockRequestContext + } + + def callbackInvoked = false + def capturedContext = null + def capturedSpan = null + + def mockRequestEndedCallback = Mock(BiFunction) { + apply(_ as RequestContext, _ as AgentSpan) >> { RequestContext ctx, AgentSpan s -> + callbackInvoked = true + capturedContext = ctx + capturedSpan = s + return new Flow.ResultFlow<>(null) + } + } + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestEnded()) >> mockRequestEndedCallback + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + callbackInvoked + capturedContext == mockRequestContext + capturedSpan == span + } + + def "processRequestEnd handles null requestEnded callback gracefully"() { + given: + def mockRequestContext = Mock(RequestContext) + def span = Mock(AgentSpan) { + getRequestContext() >> mockRequestContext + } + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestEnded()) >> null + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + LambdaAppSecHandler.processRequestEnd(span) + + then: + noExceptionThrown() // Should log warning but not throw + } + + // ============================================================================ + // mergeContexts Tests + // ============================================================================ + + def "mergeContexts returns null when both contexts are null"() { + when: + def result = LambdaAppSecHandler.mergeContexts(null, null) + + then: + result == null + } + + def "mergeContexts returns extensionContext when appSecContext is null"() { + given: + def extensionContext = Mock(TagContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, null) + + then: + result == extensionContext + } + + def "mergeContexts returns appSecContext when extensionContext is null"() { + given: + def appSecContext = Mock(TagContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(null, appSecContext) + + then: + result == appSecContext + } + + def "mergeContexts merges AppSec data into TagContext"() { + given: + def appSecData = new Object() + + // Create real TagContext instances since methods are final + def appSecContext = new TagContext() + appSecContext.withRequestContextDataAppSec(appSecData) + + def extensionContext = new TagContext() + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) + + then: + result == extensionContext + result.getRequestContextDataAppSec() == appSecData + } + + def "mergeContexts returns extensionContext when appSecContext is not TagContext"() { + given: + def extensionContext = Mock(TagContext) + def appSecContext = Mock(AgentSpanContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) + + then: + result == extensionContext + } + + def "mergeContexts returns extensionContext when it is not TagContext"() { + given: + def extensionContext = Mock(AgentSpanContext) + def appSecContext = Mock(TagContext) + + when: + def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) + + then: + result == extensionContext + } + + // ============================================================================ + // Error Handling and Null Callback Tests + // ============================================================================ + + def "processRequestStart handles null requestStarted callback gracefully"() { + given: + def eventJson = '{"requestContext": {"httpMethod": "GET"}}' + def event = createInputStream(eventJson) + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestStarted()) >> null + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null // Should return null when requestStarted callback is missing + } + + def "processRequestStart handles null methodUri callback gracefully"() { + given: + def eventJson = ''' + { + "path": "/test", + "requestContext": { + "httpMethod": "GET" + } + } + ''' + def event = createInputStream(eventJson) + + def mockAppSecContext = new Object() + + def mockRequestStartedCallback = Mock(Supplier) { + get() >> new Flow.ResultFlow<>(mockAppSecContext) + } + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback + getCallback(EVENTS.requestMethodUriRaw()) >> null // Null callback + getCallback(EVENTS.requestHeader()) >> null + getCallback(EVENTS.requestClientSocketAddress()) >> null + getCallback(EVENTS.requestHeaderDone()) >> Mock(Function) { + apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) + } + getCallback(EVENTS.requestPathParams()) >> null + getCallback(EVENTS.requestBodyProcessed()) >> null + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + AgentTracer.forceRegister(mockTracer) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result != null // Should continue processing even if methodUri callback is null + result instanceof TagContext + } + + def "processRequestStart handles exception during JSON parsing"() { + given: + def invalidJson = '{this is not valid JSON at all' + def event = createInputStream(invalidJson) + + when: + def result = LambdaAppSecHandler.processRequestStart(event) + + then: + result == null // Should return null on parse error + } + + def "processRequestStart handles exception during stream reading"() { + given: + def mockStream = Mock(ByteArrayInputStream) { + available() >> { throw new IOException("Stream error") } + } + + when: + def result = LambdaAppSecHandler.processRequestStart(mockStream) + + then: + result == null // Should return null on IO error + } + + // ============================================================================ + // Helper Methods + // ============================================================================ + + private ByteArrayInputStream createInputStream(String json) { + return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) + } + + /** + * Set up mock callbacks to capture invocations and verify data extraction. + * This mocks the AgentTracer and callback provider to intercept gateway calls. + */ + private void setupMockCallbacks(Map callbacks) { + def mockAppSecContext = new Object() + + def mockRequestStartedCallback = Mock(Supplier) { + get() >> new Flow.ResultFlow<>(mockAppSecContext) + } + + def mockMethodUriCallback = callbacks.onMethodUri ? Mock(datadog.trace.api.function.TriFunction) { + apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { RequestContext ctx, String method, URIDataAdapter uri -> + callbacks.onMethodUri(method, uri) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockHeaderCallback = callbacks.onHeader ? Mock(TriConsumer) { + accept(_ as RequestContext, _ as String, _ as String) >> { RequestContext ctx, String name, String value -> + callbacks.onHeader(name, value) + } + } : null + + def mockSocketAddressCallback = callbacks.onSocketAddress ? Mock(TriFunction) { + apply(_ as RequestContext, _ as String, _ as Integer) >> { RequestContext ctx, String ip, Integer port -> + callbacks.onSocketAddress(ip, port) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockHeaderDoneCallback = Mock(Function) { + apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) + } + + def mockPathParamsCallback = callbacks.onPathParams ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + callbacks.onPathParams(params) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockQueryParamsCallback = callbacks.onQueryParams ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + callbacks.onQueryParams(params) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockBodyCallback = callbacks.onBody ? Mock(BiFunction) { + apply(_ as RequestContext, _ as Object) >> { RequestContext ctx, Object body -> + callbacks.onBody(body) + return new Flow.ResultFlow<>(null) + } + } : null + + def mockCallbackProvider = Mock(CallbackProvider) { + getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback + getCallback(EVENTS.requestMethodUriRaw()) >> mockMethodUriCallback + getCallback(EVENTS.requestHeader()) >> mockHeaderCallback + getCallback(EVENTS.requestClientSocketAddress()) >> mockSocketAddressCallback + getCallback(EVENTS.requestHeaderDone()) >> mockHeaderDoneCallback + getCallback(EVENTS.requestPathParams()) >> mockPathParamsCallback + getCallback(EVENTS.requestBodyProcessed()) >> mockBodyCallback + } + + def mockTracer = Mock(AgentTracer.TracerAPI) { + getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider + } + + // Install the mock tracer + AgentTracer.forceRegister(mockTracer) + } + + def cleanup() { + // Reset tracer after each test + AgentTracer.forceRegister(null) + } +} From 452065a6e5e284b2b365916ce15e78b58f439ec3 Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Wed, 11 Feb 2026 16:59:13 +0100 Subject: [PATCH 4/5] add better support for query parameters --- .../lambda/LambdaHandlerInstrumentation.java | 2 + .../LambdaHandlerInstrumentationTest.groovy | 3 + .../trace/lambda/LambdaAppSecHandler.java | 119 ++++++++++++++++-- .../lambda/LambdaAppSecHandlerTest.groovy | 1 + 4 files changed, 115 insertions(+), 10 deletions(-) diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java index 0f020c6623a..208a1e5675a 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/main/java/datadog/trace/instrumentation/aws/v1/lambda/LambdaHandlerInstrumentation.java @@ -23,6 +23,7 @@ import datadog.trace.bootstrap.instrumentation.api.AgentSpan; import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes; import datadog.trace.config.inversion.ConfigHelper; import net.bytebuddy.asm.Advice; import net.bytebuddy.description.type.TypeDescription; @@ -96,6 +97,7 @@ static AgentScope enter( } else { span = startSpan(INVOCATION_SPAN_NAME, lambdaContext); } + span.setSpanType(InternalSpanTypes.SERVERLESS); span.setTag("request_id", lambdaRequestId); final AgentScope scope = activateSpan(span); diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy index 78314ada2da..6c705fede05 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy +++ b/dd-java-agent/instrumentation/aws-java/aws-java-lambda-handler-1.2/src/test/groovy/LambdaHandlerInstrumentationTest.groovy @@ -1,4 +1,5 @@ import datadog.trace.agent.test.naming.VersionedNamingTestBase +import datadog.trace.api.DDSpanTypes import java.nio.charset.StandardCharsets import com.amazonaws.services.lambda.runtime.Context @@ -30,6 +31,7 @@ abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase trace(1) { span { operationName operation() + spanType DDSpanTypes.SERVERLESS errored false } } @@ -51,6 +53,7 @@ abstract class LambdaHandlerInstrumentationTest extends VersionedNamingTestBase trace(1) { span { operationName operation() + spanType DDSpanTypes.SERVERLESS errored true tags { tag "request_id", requestId diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index 610f7b2c518..c4cfea9dd9c 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -26,6 +26,7 @@ import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -157,7 +158,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa datadog.trace.api.function.TriFunction> methodUriCallback = tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); if (methodUriCallback != null) { - LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(eventData.path); + // Reconstruct full path with query string for AppSec analysis + String fullPath = buildFullPath(eventData.path, eventData.queryParameters); + LambdaURIDataAdapter uriAdapter = new LambdaURIDataAdapter(fullPath); methodUriCallback.apply(requestContext, eventData.method, uriAdapter); } else { log.warn("requestMethodUriRaw callback is null"); @@ -235,7 +238,7 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream log.warn("Event size {} exceeds limit {} or is invalid, skipping AppSec processing", availableBytes, MAX_EVENT_SIZE); return new LambdaEventData(Collections.emptyMap(), null, null, null, null, - LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); } StringBuilder jsonBuilder = new StringBuilder(availableBytes); @@ -259,7 +262,7 @@ private static LambdaEventData extractEventDataFromJson(String json) { log.debug("Event JSON parsed successfully"); if (event == null) { - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); } // Detect trigger type @@ -284,7 +287,7 @@ private static LambdaEventData extractEventDataFromJson(String json) { } } catch (Exception e) { log.error("Failed to parse event data from JSON", e); - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), null); + return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); } } @@ -339,6 +342,7 @@ static LambdaTriggerType detectTriggerType(Map event) { private static LambdaEventData extractApiGatewayV1Data(Map event) { Map headers = extractHeaders(event.get("headers")); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -352,7 +356,7 @@ private static LambdaEventData extractApiGatewayV1Data(Map event sourceIp = (String) identity.get("sourceIp"); } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, queryParameters, body); } /** @@ -361,6 +365,7 @@ private static LambdaEventData extractApiGatewayV1Data(Map event private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -377,7 +382,7 @@ private static LambdaEventData extractApiGatewayV2HttpData(Map e sourcePort = ((Number) portObj).intValue(); } - return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, queryParameters, body); } /** @@ -386,6 +391,7 @@ private static LambdaEventData extractApiGatewayV2HttpData(Map e private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -401,7 +407,7 @@ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event, LambdaT } Map pathParameters = extractPathParameters(event.get("pathParameters")); + + // ALB can have both queryStringParameters and multiValueQueryStringParameters + Map> queryParameters; + if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { + queryParameters = extractMultiValueQueryParameters(event.get("multiValueQueryStringParameters")); + } else { + queryParameters = extractQueryParameters(event.get("queryStringParameters")); + } + Object body = extractBody(event); String method = (String) event.get("httpMethod"); String path = (String) event.get("path"); String sourceIp = headers.get("x-forwarded-for"); - return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); } /** @@ -452,6 +467,7 @@ private static LambdaEventData extractAlbData(Map event, LambdaT private static LambdaEventData extractGenericData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); + Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); String method = null; @@ -497,7 +513,7 @@ private static LambdaEventData extractGenericData(Map event) { } } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, body); + return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, queryParameters, body); } /** @@ -540,6 +556,87 @@ private static Map extractPathParameters(Object pathParamsObj) { return pathParams; } + /** + * Helper method to extract query parameters from event. + * Converts Map to Map> format expected by AppSec. + */ + private static Map> extractQueryParameters(Object queryParamsObj) { + Map> result = new java.util.HashMap<>(); + if (queryParamsObj instanceof Map) { + Map rawMap = (Map) queryParamsObj; + for (Map.Entry entry : rawMap.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + String value = String.valueOf(entry.getValue()); + result.put(key, java.util.Collections.singletonList(value)); + } + } + } + log.debug("Extracted {} query parameters", result.size()); + return result; + } + + /** + * Helper method to extract multi-value query parameters (used by ALB). + * Handles Map> format directly. + */ + private static Map> extractMultiValueQueryParameters(Object queryParamsObj) { + Map> result = new java.util.HashMap<>(); + if (queryParamsObj instanceof Map) { + Map rawMap = (Map) queryParamsObj; + for (Map.Entry entry : rawMap.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + String key = String.valueOf(entry.getKey()); + if (entry.getValue() instanceof java.util.List) { + java.util.List values = (java.util.List) entry.getValue(); + java.util.List stringValues = new java.util.ArrayList<>(); + for (Object value : values) { + if (value != null) { + stringValues.add(String.valueOf(value)); + } + } + result.put(key, stringValues); + } else { + result.put(key, java.util.Collections.singletonList(String.valueOf(entry.getValue()))); + } + } + } + } + log.debug("Extracted {} multi-value query parameters", result.size()); + return result; + } + + /** + * Helper method to build full path including query string. + * Lambda events provide path and query parameters separately, so we need to reconstruct + * the full URI for AppSec to parse. + */ + private static String buildFullPath(String path, Map> queryParameters) { + if (queryParameters == null || queryParameters.isEmpty()) { + return path; + } + + StringBuilder fullPath = new StringBuilder(path); + fullPath.append('?'); + + boolean first = true; + for (Map.Entry> entry : queryParameters.entrySet()) { + String key = entry.getKey(); + for (String value : entry.getValue()) { + if (!first) { + fullPath.append('&'); + } + first = false; + fullPath.append(key); + if (value != null) { + fullPath.append('=').append(value); + } + } + } + + return fullPath.toString(); + } + /** * Helper method to extract and merge headers with cookies array from event. * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. @@ -689,9 +786,10 @@ static class LambdaEventData { final Integer sourcePort; final LambdaTriggerType triggerType; final Map pathParameters; + final Map> queryParameters; final Object body; - LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Object body) { + LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Map> queryParameters, Object body) { this.headers = headers; this.method = method; this.path = path; @@ -699,6 +797,7 @@ static class LambdaEventData { this.sourcePort = sourcePort; this.triggerType = triggerType; this.pathParameters = pathParameters; + this.queryParameters = queryParameters; this.body = body; } } diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index eea29fe68c8..59f72c0a98e 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -2,6 +2,7 @@ package datadog.trace.lambda import datadog.trace.api.Config import datadog.trace.api.function.TriConsumer +import datadog.trace.api.function.TriFunction import datadog.trace.api.gateway.CallbackProvider import datadog.trace.api.gateway.Flow import datadog.trace.api.gateway.RequestContext From 305fba24a6c763a45c0447417e5c0b7372ba825c Mon Sep 17 00:00:00 2001 From: "clara.poncet" Date: Wed, 11 Feb 2026 17:16:00 +0100 Subject: [PATCH 5/5] apply spotless --- .../trace/lambda/LambdaAppSecHandler.java | 265 +++++++++++------- .../datadog/trace/lambda/LambdaHandler.java | 3 +- .../lambda/LambdaAppSecHandlerTest.groovy | 221 ++++++++------- 3 files changed, 284 insertions(+), 205 deletions(-) diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java index c4cfea9dd9c..e396655acb0 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaAppSecHandler.java @@ -35,8 +35,8 @@ import org.slf4j.LoggerFactory; /** - * Handles AppSec processing for AWS Lambda invocations. - * Extracts Lambda event data and invokes AppSec gateway callbacks. + * Handles AppSec processing for AWS Lambda invocations. Extracts Lambda event data and invokes + * AppSec gateway callbacks. */ public class LambdaAppSecHandler { @@ -49,11 +49,12 @@ public class LambdaAppSecHandler { private static final int MAX_EVENT_SIZE = Config.get().getAppSecBodyParsingSizeLimit(); /** - * Process AppSec request data at the start of a Lambda invocation. - * Extract event data and invokes all relevant AppSec gateway callbacks. + * Process AppSec request data at the start of a Lambda invocation. Extract event data and invokes + * all relevant AppSec gateway callbacks. * * @param event the Lambda event object - * @return AgentSpanContext containing AppSec data, or null if AppSec is disabled or processing fails + * @return AgentSpanContext containing AppSec data, or null if AppSec is disabled or processing + * fails */ public static AgentSpanContext processRequestStart(Object event) { if (!ActiveSubsystems.APPSEC_ACTIVE) { @@ -62,7 +63,9 @@ public static AgentSpanContext processRequestStart(Object event) { } if (!(event instanceof ByteArrayInputStream)) { - log.debug("Event is not a ByteArrayInputStream, type: {}", event != null ? event.getClass().getName() : "null"); + log.debug( + "Event is not a ByteArrayInputStream, type: {}", + event != null ? event.getClass().getName() : "null"); return null; } @@ -89,8 +92,7 @@ public static void processRequestEnd(AgentSpan span) { if (requestContext != null) { AgentTracer.TracerAPI tracer = AgentTracer.get(); BiFunction> requestEndedCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC) - .getCallback(EVENTS.requestEnded()); + tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestEnded()); if (requestEndedCallback != null) { requestEndedCallback.apply(requestContext, span); } else { @@ -155,8 +157,11 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestMethodUriRaw if (eventData.method != null && eventData.path != null) { - datadog.trace.api.function.TriFunction> methodUriCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestMethodUriRaw()); + datadog.trace.api.function.TriFunction> + methodUriCallback = + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestMethodUriRaw()); if (methodUriCallback != null) { // Reconstruct full path with query string for AppSec analysis String fullPath = buildFullPath(eventData.path, eventData.queryParameters); @@ -170,7 +175,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestHeader for each header if (eventData.headers != null && !eventData.headers.isEmpty()) { TriConsumer headerCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestHeader()); + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestHeader()); if (headerCallback != null) { for (Map.Entry header : eventData.headers.entrySet()) { headerCallback.accept(requestContext, header.getKey(), header.getValue()); @@ -182,8 +189,11 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestClientSocketAddress if (eventData.sourceIp != null) { - datadog.trace.api.function.TriFunction> socketAddrCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestClientSocketAddress()); + datadog.trace.api.function.TriFunction> + socketAddrCallback = + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestClientSocketAddress()); if (socketAddrCallback != null) { Integer port = eventData.sourcePort != null ? eventData.sourcePort : 0; socketAddrCallback.apply(requestContext, eventData.sourceIp, port); @@ -206,7 +216,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestPathParams if (eventData.pathParameters != null && !eventData.pathParameters.isEmpty()) { BiFunction, Flow> pathParamsCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestPathParams()); + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestPathParams()); if (pathParamsCallback != null) { pathParamsCallback.apply(requestContext, eventData.pathParameters); } else { @@ -217,7 +229,9 @@ private static AgentSpanContext processAppSecRequestData(LambdaEventData eventDa // Call requestBodyProcessed if (eventData.body != null) { BiFunction> bodyCallback = - tracer.getCallbackProvider(RequestContextSlot.APPSEC).getCallback(EVENTS.requestBodyProcessed()); + tracer + .getCallbackProvider(RequestContextSlot.APPSEC) + .getCallback(EVENTS.requestBodyProcessed()); if (bodyCallback != null) { bodyCallback.apply(requestContext, eventData.body); } else { @@ -235,10 +249,20 @@ private static LambdaEventData extractEventData(ByteArrayInputStream inputStream int availableBytes = inputStream.available(); if (availableBytes <= 0 || availableBytes > MAX_EVENT_SIZE) { - log.warn("Event size {} exceeds limit {} or is invalid, skipping AppSec processing", - availableBytes, MAX_EVENT_SIZE); - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, - LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); + log.warn( + "Event size {} exceeds limit {} or is invalid, skipping AppSec processing", + availableBytes, + MAX_EVENT_SIZE); + return new LambdaEventData( + Collections.emptyMap(), + null, + null, + null, + null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), + Collections.emptyMap(), + null); } StringBuilder jsonBuilder = new StringBuilder(availableBytes); @@ -262,7 +286,16 @@ private static LambdaEventData extractEventDataFromJson(String json) { log.debug("Event JSON parsed successfully"); if (event == null) { - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); + return new LambdaEventData( + Collections.emptyMap(), + null, + null, + null, + null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), + Collections.emptyMap(), + null); } // Detect trigger type @@ -287,7 +320,16 @@ private static LambdaEventData extractEventDataFromJson(String json) { } } catch (Exception e) { log.error("Failed to parse event data from JSON", e); - return new LambdaEventData(Collections.emptyMap(), null, null, null, null, LambdaTriggerType.UNKNOWN, Collections.emptyMap(), Collections.emptyMap(), null); + return new LambdaEventData( + Collections.emptyMap(), + null, + null, + null, + null, + LambdaTriggerType.UNKNOWN, + Collections.emptyMap(), + Collections.emptyMap(), + null); } } @@ -307,8 +349,8 @@ static LambdaTriggerType detectTriggerType(Map event) { } // Check for WebSocket - if (requestContext.containsKey("connectionId") && - (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { + if (requestContext.containsKey("connectionId") + && (requestContext.containsKey("eventType") || requestContext.containsKey("routeKey"))) { return LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET; } @@ -336,13 +378,12 @@ static LambdaTriggerType detectTriggerType(Map event) { return LambdaTriggerType.UNKNOWN; } - /** - * Extracts data from API Gateway v1 (REST API) event - */ + /** Extracts data from API Gateway v1 (REST API) event */ private static LambdaEventData extractApiGatewayV1Data(Map event) { Map headers = extractHeaders(event.get("headers")); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -356,16 +397,25 @@ private static LambdaEventData extractApiGatewayV1Data(Map event sourceIp = (String) identity.get("sourceIp"); } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.API_GATEWAY_V1_REST, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, + method, + path, + sourceIp, + null, + LambdaTriggerType.API_GATEWAY_V1_REST, + pathParameters, + queryParameters, + body); } - /** - * Extracts data from API Gateway v2 (HTTP API) or Lambda URL event - */ - private static LambdaEventData extractApiGatewayV2HttpData(Map event, LambdaTriggerType triggerType) { + /** Extracts data from API Gateway v2 (HTTP API) or Lambda URL event */ + private static LambdaEventData extractApiGatewayV2HttpData( + Map event, LambdaTriggerType triggerType) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -382,16 +432,24 @@ private static LambdaEventData extractApiGatewayV2HttpData(Map e sourcePort = ((Number) portObj).intValue(); } - return new LambdaEventData(headers, method, path, sourceIp, sourcePort, triggerType, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, + method, + path, + sourceIp, + sourcePort, + triggerType, + pathParameters, + queryParameters, + body); } - /** - * Extracts data from API Gateway v2 WebSocket event - */ + /** Extracts data from API Gateway v2 WebSocket event */ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); Map requestContext = (Map) event.get("requestContext"); @@ -407,13 +465,21 @@ private static LambdaEventData extractApiGatewayV2WebSocketData(Map event, LambdaTriggerType triggerType) { + /** Extracts data from ALB event (with or without multi-value headers) */ + private static LambdaEventData extractAlbData( + Map event, LambdaTriggerType triggerType) { Map headers; if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { @@ -428,9 +494,10 @@ private static LambdaEventData extractAlbData(Map event, LambdaT if (entry.getValue() instanceof java.util.List) { java.util.List values = (java.util.List) entry.getValue(); // Join multiple values with comma - String joinedValue = values.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining(", ")); + String joinedValue = + values.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining(", ")); headers.put(key, joinedValue); } else { headers.put(key, String.valueOf(entry.getValue())); @@ -447,7 +514,8 @@ private static LambdaEventData extractAlbData(Map event, LambdaT // ALB can have both queryStringParameters and multiValueQueryStringParameters Map> queryParameters; if (triggerType == LambdaTriggerType.ALB_MULTI_VALUE) { - queryParameters = extractMultiValueQueryParameters(event.get("multiValueQueryStringParameters")); + queryParameters = + extractMultiValueQueryParameters(event.get("multiValueQueryStringParameters")); } else { queryParameters = extractQueryParameters(event.get("queryStringParameters")); } @@ -458,16 +526,16 @@ private static LambdaEventData extractAlbData(Map event, LambdaT String path = (String) event.get("path"); String sourceIp = headers.get("x-forwarded-for"); - return new LambdaEventData(headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, method, path, sourceIp, null, triggerType, pathParameters, queryParameters, body); } - /** - * Generic data extraction for unknown trigger types (fallback) - */ + /** Generic data extraction for unknown trigger types (fallback) */ private static LambdaEventData extractGenericData(Map event) { Map headers = extractHeadersWithCookies(event); Map pathParameters = extractPathParameters(event.get("pathParameters")); - Map> queryParameters = extractQueryParameters(event.get("queryStringParameters")); + Map> queryParameters = + extractQueryParameters(event.get("queryStringParameters")); Object body = extractBody(event); String method = null; @@ -513,12 +581,21 @@ private static LambdaEventData extractGenericData(Map event) { } } - return new LambdaEventData(headers, method, path, sourceIp, null, LambdaTriggerType.UNKNOWN, pathParameters, queryParameters, body); + return new LambdaEventData( + headers, + method, + path, + sourceIp, + null, + LambdaTriggerType.UNKNOWN, + pathParameters, + queryParameters, + body); } /** - * Generic helper method to extract string key-value pairs from an object. - * Converts all keys and values to strings, filtering out null entries. + * Generic helper method to extract string key-value pairs from an object. Converts all keys and + * values to strings, filtering out null entries. */ private static Map extractStringMap(Object mapObj) { Map result = new java.util.HashMap<>(); @@ -535,9 +612,7 @@ private static Map extractStringMap(Object mapObj) { return result; } - /** - * Helper method to extract headers from event - */ + /** Helper method to extract headers from event */ private static Map extractHeaders(Object headersObj) { Map headers = extractStringMap(headersObj); log.debug("Extracted {} headers", headers.size()); @@ -547,9 +622,7 @@ private static Map extractHeaders(Object headersObj) { return headers; } - /** - * Helper method to extract path parameters from event - */ + /** Helper method to extract path parameters from event */ private static Map extractPathParameters(Object pathParamsObj) { Map pathParams = extractStringMap(pathParamsObj); log.debug("Extracted {} path parameters", pathParams.size()); @@ -557,8 +630,8 @@ private static Map extractPathParameters(Object pathParamsObj) { } /** - * Helper method to extract query parameters from event. - * Converts Map to Map> format expected by AppSec. + * Helper method to extract query parameters from event. Converts Map to + * Map> format expected by AppSec. */ private static Map> extractQueryParameters(Object queryParamsObj) { Map> result = new java.util.HashMap<>(); @@ -577,8 +650,8 @@ private static Map> extractQueryParameters(Object queryPara } /** - * Helper method to extract multi-value query parameters (used by ALB). - * Handles Map> format directly. + * Helper method to extract multi-value query parameters (used by ALB). Handles Map> format directly. */ private static Map> extractMultiValueQueryParameters(Object queryParamsObj) { Map> result = new java.util.HashMap<>(); @@ -607,9 +680,8 @@ private static Map> extractMultiValueQueryParameters(Object } /** - * Helper method to build full path including query string. - * Lambda events provide path and query parameters separately, so we need to reconstruct - * the full URI for AppSec to parse. + * Helper method to build full path including query string. Lambda events provide path and query + * parameters separately, so we need to reconstruct the full URI for AppSec to parse. */ private static String buildFullPath(String path, Map> queryParameters) { if (queryParameters == null || queryParameters.isEmpty()) { @@ -638,8 +710,8 @@ private static String buildFullPath(String path, Map> query } /** - * Helper method to extract and merge headers with cookies array from event. - * API Gateway v2 provides a separate 'cookies' array that should be merged with headers. + * Helper method to extract and merge headers with cookies array from event. API Gateway v2 + * provides a separate 'cookies' array that should be merged with headers. */ private static Map extractHeadersWithCookies(Map event) { Map headers = extractHeaders(event.get("headers")); @@ -650,9 +722,10 @@ private static Map extractHeadersWithCookies(Map java.util.List cookiesList = (java.util.List) cookiesObj; if (!cookiesList.isEmpty()) { // Join cookies with "; " separator per RFC 6265 - String cookieValue = cookiesList.stream() - .map(String::valueOf) - .collect(java.util.stream.Collectors.joining("; ")); + String cookieValue = + cookiesList.stream() + .map(String::valueOf) + .collect(java.util.stream.Collectors.joining("; ")); // Merge with existing cookie header if present String existingCookie = headers.get("cookie"); @@ -667,9 +740,7 @@ private static Map extractHeadersWithCookies(Map return headers; } - /** - * Helper method to extract and parse body from event - */ + /** Helper method to extract and parse body from event */ private static Object extractBody(Map event) { Object bodyObj = event.get("body"); if (bodyObj == null) { @@ -701,9 +772,7 @@ private static Object extractBody(Map event) { return bodyString; } - /** - * Helper method to parse body as JSON - */ + /** Helper method to parse body as JSON */ private static Object parseBodyAsJson(String body) { if (body == null || body.isEmpty() || "null".equals(body)) { return null; @@ -718,8 +787,7 @@ private static Object parseBodyAsJson(String body) { } /** - * Temporary RequestContext implementation to hold AppSecRequestContext - * before a span is created. + * Temporary RequestContext implementation to hold AppSecRequestContext before a span is created. */ private static class TemporaryRequestContext implements RequestContext { private final Object appSecRequestContext; @@ -762,22 +830,18 @@ public void close() { } } - /** - * Enum representing different AWS Lambda trigger types - */ + /** Enum representing different AWS Lambda trigger types */ enum LambdaTriggerType { - API_GATEWAY_V1_REST, // API Gateway REST API (v1) - API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) + API_GATEWAY_V1_REST, // API Gateway REST API (v1) + API_GATEWAY_V2_HTTP, // API Gateway HTTP API (v2) API_GATEWAY_V2_WEBSOCKET, // API Gateway WebSocket - ALB, // Application Load Balancer - ALB_MULTI_VALUE, // ALB with multi-value headers - LAMBDA_URL, // Lambda Function URL - UNKNOWN // Unknown or unsupported trigger + ALB, // Application Load Balancer + ALB_MULTI_VALUE, // ALB with multi-value headers + LAMBDA_URL, // Lambda Function URL + UNKNOWN // Unknown or unsupported trigger } - /** - * Object for Lambda event data needed for AppSec processing - */ + /** Object for Lambda event data needed for AppSec processing */ static class LambdaEventData { final Map headers; final String method; @@ -789,7 +853,16 @@ static class LambdaEventData { final Map> queryParameters; final Object body; - LambdaEventData(Map headers, String method, String path, String sourceIp, Integer sourcePort, LambdaTriggerType triggerType, Map pathParameters, Map> queryParameters, Object body) { + LambdaEventData( + Map headers, + String method, + String path, + String sourceIp, + Integer sourcePort, + LambdaTriggerType triggerType, + Map pathParameters, + Map> queryParameters, + Object body) { this.headers = headers; this.method = method; this.path = path; @@ -802,9 +875,7 @@ static class LambdaEventData { } } - /** - * URIDataAdapter implementation for Lambda events. - */ + /** URIDataAdapter implementation for Lambda events. */ private static class LambdaURIDataAdapter extends URIDataAdapterBase { private final String path; private final String query; diff --git a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java index 5e2eed69469..c5f47f6554e 100644 --- a/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java +++ b/dd-trace-core/src/main/java/datadog/trace/lambda/LambdaHandler.java @@ -104,7 +104,8 @@ public static AgentSpanContext notifyStartInvocation(Object event, String lambda public static boolean notifyEndInvocation( AgentSpan span, Object result, boolean isError, String lambdaRequestId) { if (null == span || null == span.getSamplingPriority()) { - log.error("could not notify the extension as the lambda span is null or no sampling priority has been found"); + log.error( + "could not notify the extension as the lambda span is null or no sampling priority has been found"); return false; } diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy index 59f72c0a98e..00f58c7ef36 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy @@ -332,23 +332,23 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - }, - onBody: { body -> - capturedBody = body - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + }, + onBody: { body -> + capturedBody = body + } ) when: @@ -404,20 +404,20 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedPathParams = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + capturedSourcePort = port + }, + onPathParams: { params -> + capturedPathParams = params + } ) when: @@ -459,10 +459,10 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedPath = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + } ) when: @@ -498,13 +498,13 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -539,9 +539,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -575,9 +575,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -609,13 +609,13 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -646,9 +646,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -666,9 +666,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = "NOT_CALLED" setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -686,9 +686,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -715,10 +715,10 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedQuery = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedPath = uri.path() - capturedQuery = uri.query() - } + onMethodUri: { method, uri -> + capturedPath = uri.path() + capturedQuery = uri.query() + } ) when: @@ -746,9 +746,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = "NOT_CALLED" setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -776,9 +776,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = "NOT_CALLED" setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -804,9 +804,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedBody = null setupMockCallbacks( - onBody: { body -> - capturedBody = body - } + onBody: { body -> + capturedBody = body + } ) when: @@ -847,16 +847,16 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onHeader: { name, value -> + capturedHeaders[name] = value + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -890,13 +890,13 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSourceIp = null setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } + onMethodUri: { method, uri -> + capturedMethod = method + capturedPath = uri.path() + }, + onSocketAddress: { ip, port -> + capturedSourceIp = ip + } ) when: @@ -930,9 +930,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -964,9 +964,9 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedHeaders = [:] setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } + onHeader: { name, value -> + capturedHeaders[name] = value + } ) when: @@ -1029,7 +1029,8 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { def capturedSpan = null def mockRequestEndedCallback = Mock(BiFunction) { - apply(_ as RequestContext, _ as AgentSpan) >> { RequestContext ctx, AgentSpan s -> + apply(_ as RequestContext, _ as AgentSpan) >> { + RequestContext ctx, AgentSpan s -> callbackInvoked = true capturedContext = ctx capturedSpan = s @@ -1271,20 +1272,23 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { } def mockMethodUriCallback = callbacks.onMethodUri ? Mock(datadog.trace.api.function.TriFunction) { - apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { RequestContext ctx, String method, URIDataAdapter uri -> + apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { + RequestContext ctx, String method, URIDataAdapter uri -> callbacks.onMethodUri(method, uri) return new Flow.ResultFlow<>(null) } } : null def mockHeaderCallback = callbacks.onHeader ? Mock(TriConsumer) { - accept(_ as RequestContext, _ as String, _ as String) >> { RequestContext ctx, String name, String value -> + accept(_ as RequestContext, _ as String, _ as String) >> { + RequestContext ctx, String name, String value -> callbacks.onHeader(name, value) } } : null def mockSocketAddressCallback = callbacks.onSocketAddress ? Mock(TriFunction) { - apply(_ as RequestContext, _ as String, _ as Integer) >> { RequestContext ctx, String ip, Integer port -> + apply(_ as RequestContext, _ as String, _ as Integer) >> { + RequestContext ctx, String ip, Integer port -> callbacks.onSocketAddress(ip, port) return new Flow.ResultFlow<>(null) } @@ -1295,21 +1299,24 @@ class LambdaAppSecHandlerTest extends DDCoreSpecification { } def mockPathParamsCallback = callbacks.onPathParams ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + apply(_ as RequestContext, _ as Map) >> { + RequestContext ctx, Map params -> callbacks.onPathParams(params) return new Flow.ResultFlow<>(null) } } : null def mockQueryParamsCallback = callbacks.onQueryParams ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Map) >> { RequestContext ctx, Map params -> + apply(_ as RequestContext, _ as Map) >> { + RequestContext ctx, Map params -> callbacks.onQueryParams(params) return new Flow.ResultFlow<>(null) } } : null def mockBodyCallback = callbacks.onBody ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Object) >> { RequestContext ctx, Object body -> + apply(_ as RequestContext, _ as Object) >> { + RequestContext ctx, Object body -> callbacks.onBody(body) return new Flow.ResultFlow<>(null) }