Skip to content

Commit b3fad1d

Browse files
authored
InboundActivityInterceptor now has an access to header (temporalio#375)
1 parent 6267972 commit b3fad1d

File tree

5 files changed

+125
-22
lines changed

5 files changed

+125
-22
lines changed

temporal-sdk/src/main/java/io/temporal/common/interceptors/ActivityInboundCallsInterceptor.java

+36-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,42 @@
2222
import io.temporal.activity.ActivityExecutionContext;
2323

2424
public interface ActivityInboundCallsInterceptor {
25+
final class ActivityInput {
26+
private final Header header;
27+
private final Object[] arguments;
28+
29+
public ActivityInput(Header header, Object[] arguments) {
30+
this.header = header;
31+
this.arguments = arguments;
32+
}
33+
34+
public Header getHeader() {
35+
return header;
36+
}
37+
38+
public Object[] getArguments() {
39+
return arguments;
40+
}
41+
}
42+
43+
final class ActivityOutput {
44+
private final Object result;
45+
46+
public ActivityOutput(Object result) {
47+
this.result = result;
48+
}
49+
50+
public Object getResult() {
51+
return result;
52+
}
53+
}
54+
2555
void init(ActivityExecutionContext context);
2656

27-
Object execute(Object[] arguments);
57+
/**
58+
* Intercepts a call to the main activity entry method.
59+
*
60+
* @return result of the activity execution.
61+
*/
62+
ActivityOutput execute(ActivityInput input);
2863
}

temporal-sdk/src/main/java/io/temporal/internal/sync/ActivityInfoImpl.java

+10-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package io.temporal.internal.sync;
2121

2222
import com.google.protobuf.util.Timestamps;
23-
import io.temporal.activity.ActivityInfo;
23+
import io.temporal.api.common.v1.Header;
2424
import io.temporal.api.common.v1.Payloads;
2525
import io.temporal.api.workflowservice.v1.PollActivityTaskQueueResponse;
2626
import io.temporal.internal.common.ProtobufTimeUtils;
@@ -29,7 +29,7 @@
2929
import java.util.Objects;
3030
import java.util.Optional;
3131

32-
final class ActivityInfoImpl implements ActivityInfo {
32+
final class ActivityInfoImpl implements ActivityInfoInternal {
3333
private final PollActivityTaskQueueResponse response;
3434
private final String activityNamespace;
3535
private final boolean local;
@@ -46,6 +46,7 @@ final class ActivityInfoImpl implements ActivityInfo {
4646
this.completionHandle = completionHandle;
4747
}
4848

49+
@Override
4950
public byte[] getTaskToken() {
5051
return response.getTaskToken().toByteArray();
5152
}
@@ -124,14 +125,21 @@ public boolean isLocal() {
124125
return local;
125126
}
126127

128+
@Override
127129
public Functions.Proc getCompletionHandle() {
128130
return completionHandle;
129131
}
130132

133+
@Override
131134
public Optional<Payloads> getInput() {
132135
if (response.hasInput()) {
133136
return Optional.of(response.getInput());
134137
}
135138
return Optional.empty();
136139
}
140+
141+
@Override
142+
public Header getHeader() {
143+
return response.getHeader();
144+
}
137145
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
3+
*
4+
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
5+
*
6+
* Modifications copyright (C) 2017 Uber Technologies, Inc.
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
9+
* use this file except in compliance with the License. A copy of the License is
10+
* located at
11+
*
12+
* http://aws.amazon.com/apache2.0
13+
*
14+
* or in the "license" file accompanying this file. This file is distributed on
15+
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
16+
* express or implied. See the License for the specific language governing
17+
* permissions and limitations under the License.
18+
*/
19+
20+
package io.temporal.internal.sync;
21+
22+
import io.temporal.activity.ActivityInfo;
23+
import io.temporal.api.common.v1.Header;
24+
import io.temporal.api.common.v1.Payloads;
25+
import io.temporal.workflow.Functions;
26+
import java.util.Optional;
27+
28+
/**
29+
* An extension for {@link ActivityInfo} with information about the activity task that the current
30+
* activity is handling that should be available for Temporal SDK, but shouldn't be available or
31+
* exposed to Activity implementation code.
32+
*/
33+
public interface ActivityInfoInternal extends ActivityInfo {
34+
/**
35+
* @return function shat should be triggered after activity completion with any outcome (success,
36+
* failure, cancelling)
37+
*/
38+
Functions.Proc getCompletionHandle();
39+
40+
/** @return input parameters of the activity execution */
41+
Optional<Payloads> getInput();
42+
43+
/** @return header that is passed with the activity execution */
44+
Header getHeader();
45+
}

temporal-sdk/src/main/java/io/temporal/internal/sync/POJOActivityTaskHandler.java

+32-17
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import io.temporal.common.converter.DataConverter;
3737
import io.temporal.common.converter.EncodedValues;
3838
import io.temporal.common.interceptors.ActivityInboundCallsInterceptor;
39+
import io.temporal.common.interceptors.Header;
3940
import io.temporal.common.interceptors.WorkerInterceptor;
4041
import io.temporal.common.metadata.POJOActivityImplMetadata;
4142
import io.temporal.common.metadata.POJOActivityInterfaceMetadata;
@@ -189,7 +190,7 @@ void registerLocalActivityImplementations(Object[] activitiesImplementation) {
189190
public Result handle(ActivityTask activityTask, Scope metricsScope, boolean localActivity) {
190191
PollActivityTaskQueueResponse pollResponse = activityTask.getResponse();
191192
String activityType = pollResponse.getActivityType().getName();
192-
ActivityInfoImpl activityInfo =
193+
ActivityInfoInternal activityInfo =
193194
new ActivityInfoImpl(
194195
pollResponse, this.namespace, localActivity, activityTask.getCompletionHandle());
195196
ActivityTaskExecutor activity = activities.get(activityType);
@@ -212,7 +213,7 @@ public Result handle(ActivityTask activityTask, Scope metricsScope, boolean loca
212213
}
213214

214215
private interface ActivityTaskExecutor {
215-
ActivityTaskHandler.Result execute(ActivityInfoImpl task, Scope metricsScope);
216+
ActivityTaskHandler.Result execute(ActivityInfoInternal task, Scope metricsScope);
216217
}
217218

218219
private class POJOActivityImplementation implements ActivityTaskExecutor {
@@ -225,7 +226,7 @@ private class POJOActivityImplementation implements ActivityTaskExecutor {
225226
}
226227

227228
@Override
228-
public ActivityTaskHandler.Result execute(ActivityInfoImpl info, Scope metricsScope) {
229+
public ActivityTaskHandler.Result execute(ActivityInfoInternal info, Scope metricsScope) {
229230
ActivityExecutionContext context =
230231
new ActivityExecutionContextImpl(
231232
service,
@@ -249,15 +250,18 @@ public ActivityTaskHandler.Result execute(ActivityInfoImpl info, Scope metricsSc
249250
input,
250251
method.getParameterTypes(),
251252
method.getGenericParameterTypes());
252-
Object result = inboundCallsInterceptor.execute(args);
253+
ActivityInboundCallsInterceptor.ActivityOutput result =
254+
inboundCallsInterceptor.execute(
255+
new ActivityInboundCallsInterceptor.ActivityInput(
256+
new Header(info.getHeader()), args));
253257
if (context.isDoNotCompleteOnReturn()) {
254258
return new ActivityTaskHandler.Result(
255259
info.getActivityId(), null, null, null, null, context.isUseLocalManualCompletion());
256260
}
257261
RespondActivityTaskCompletedRequest.Builder request =
258262
RespondActivityTaskCompletedRequest.newBuilder();
259263
if (method.getReturnType() != Void.TYPE) {
260-
Optional<Payloads> serialized = dataConverter.toPayloads(result);
264+
Optional<Payloads> serialized = dataConverter.toPayloads(result.getResult());
261265
if (serialized.isPresent()) {
262266
request.setResult(serialized.get());
263267
}
@@ -287,10 +291,11 @@ public void init(ActivityExecutionContext context) {
287291
}
288292

289293
@Override
290-
public Object execute(Object[] arguments) {
294+
public ActivityOutput execute(ActivityInput input) {
291295
CurrentActivityExecutionContext.set(context);
292296
try {
293-
return method.invoke(activity, arguments);
297+
Object result = method.invoke(activity, input.getArguments());
298+
return new ActivityOutput(result);
294299
} catch (InvocationTargetException e) {
295300
throw Activity.wrap(e.getTargetException());
296301
} catch (Exception e) {
@@ -310,7 +315,7 @@ public DynamicActivityImplementation(DynamicActivity activity) {
310315
}
311316

312317
@Override
313-
public ActivityTaskHandler.Result execute(ActivityInfoImpl info, Scope metricsScope) {
318+
public ActivityTaskHandler.Result execute(ActivityInfoInternal info, Scope metricsScope) {
314319
ActivityExecutionContext context =
315320
new ActivityExecutionContextImpl(
316321
service,
@@ -328,15 +333,20 @@ public ActivityTaskHandler.Result execute(ActivityInfoImpl info, Scope metricsSc
328333
}
329334
inboundCallsInterceptor.init(context);
330335
try {
331-
EncodedValues args = new EncodedValues(input, dataConverter);
332-
Object result = inboundCallsInterceptor.execute(new Object[] {args});
336+
EncodedValues encodedValues = new EncodedValues(input, dataConverter);
337+
Object[] args = new Object[] {encodedValues};
338+
339+
ActivityInboundCallsInterceptor.ActivityOutput result =
340+
inboundCallsInterceptor.execute(
341+
new ActivityInboundCallsInterceptor.ActivityInput(
342+
new Header(info.getHeader()), args));
333343
if (context.isDoNotCompleteOnReturn()) {
334344
return new ActivityTaskHandler.Result(
335345
info.getActivityId(), null, null, null, null, context.isUseLocalManualCompletion());
336346
}
337347
RespondActivityTaskCompletedRequest.Builder request =
338348
RespondActivityTaskCompletedRequest.newBuilder();
339-
Optional<Payloads> serialized = dataConverter.toPayloads(result);
349+
Optional<Payloads> serialized = dataConverter.toPayloads(result.getResult());
340350
if (serialized.isPresent()) {
341351
request.setResult(serialized.get());
342352
}
@@ -348,7 +358,8 @@ public ActivityTaskHandler.Result execute(ActivityInfoImpl info, Scope metricsSc
348358
}
349359
}
350360

351-
private Result activityFailureToResult(ActivityInfoImpl info, Scope metricsScope, Throwable e) {
361+
private Result activityFailureToResult(
362+
ActivityInfoInternal info, Scope metricsScope, Throwable e) {
352363
e = CheckedExceptionWrapper.unwrap(e);
353364
if (e instanceof ActivityCanceledException) {
354365
if (log.isInfoEnabled()) {
@@ -388,10 +399,11 @@ public void init(ActivityExecutionContext context) {
388399
}
389400

390401
@Override
391-
public Object execute(Object[] arguments) {
402+
public ActivityOutput execute(ActivityInput input) {
392403
CurrentActivityExecutionContext.set(context);
393404
try {
394-
return activity.execute((EncodedValues) arguments[0]);
405+
Object result = activity.execute((EncodedValues) input.getArguments()[0]);
406+
return new ActivityOutput(result);
395407
} catch (Exception e) {
396408
throw Activity.wrap(e);
397409
} finally {
@@ -410,7 +422,7 @@ private class POJOLocalActivityImplementation implements ActivityTaskExecutor {
410422
}
411423

412424
@Override
413-
public ActivityTaskHandler.Result execute(ActivityInfoImpl info, Scope metricsScope) {
425+
public ActivityTaskHandler.Result execute(ActivityInfoInternal info, Scope metricsScope) {
414426
ActivityExecutionContext context = new LocalActivityExecutionContextImpl(info, metricsScope);
415427
Optional<Payloads> input = info.getInput();
416428
ActivityInboundCallsInterceptor inboundCallsInterceptor =
@@ -426,11 +438,14 @@ public ActivityTaskHandler.Result execute(ActivityInfoImpl info, Scope metricsSc
426438
input,
427439
method.getParameterTypes(),
428440
method.getGenericParameterTypes());
429-
Object result = inboundCallsInterceptor.execute(args);
441+
ActivityInboundCallsInterceptor.ActivityOutput result =
442+
inboundCallsInterceptor.execute(
443+
new ActivityInboundCallsInterceptor.ActivityInput(
444+
new Header(info.getHeader()), args));
430445
RespondActivityTaskCompletedRequest.Builder request =
431446
RespondActivityTaskCompletedRequest.newBuilder();
432447
if (method.getReturnType() != Void.TYPE) {
433-
Optional<Payloads> serialized = dataConverter.toPayloads(result);
448+
Optional<Payloads> serialized = dataConverter.toPayloads(result.getResult());
434449
if (serialized.isPresent()) {
435450
request.setResult(serialized.get());
436451
}

temporal-testing-junit4/src/main/java/io/temporal/testing/TracingWorkerInterceptor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,9 @@ public <V> void heartbeat(V details) throws ActivityCompletionException {
382382
}
383383

384384
@Override
385-
public Object execute(Object[] arguments) {
385+
public ActivityOutput execute(ActivityInput input) {
386386
trace.add((local ? "local " : "") + "activity " + type);
387-
return next.execute(arguments);
387+
return next.execute(input);
388388
}
389389
}
390390
}

0 commit comments

Comments
 (0)