Skip to content

Commit 1febcd5

Browse files
committed
Test reapplied updates
1 parent 092ac9c commit 1febcd5

File tree

1 file changed

+202
-0
lines changed

1 file changed

+202
-0
lines changed

reset/reapply_updates.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import asyncio
2+
import socket
3+
from datetime import datetime, timedelta
4+
from typing import Iterator, List
5+
from uuid import uuid4
6+
7+
import temporalio.api.common.v1
8+
import temporalio.api.enums.v1
9+
import temporalio.api.history.v1
10+
import temporalio.api.workflowservice.v1
11+
import temporalio.common
12+
from temporalio import activity, workflow
13+
from temporalio.client import Client
14+
from temporalio.worker import Worker
15+
16+
try:
17+
from rich import print
18+
except ImportError:
19+
pass
20+
21+
RunId = str
22+
23+
WORKFLOW_ID = uuid4().hex
24+
TASK_QUEUE = __file__
25+
26+
N_UPDATES = 1
27+
REPLAY = True
28+
29+
30+
@activity.defn
31+
async def my_activity(arg: int) -> str:
32+
return f"activity-result-{arg}"
33+
34+
35+
@workflow.defn(sandboxed=False)
36+
class WorkflowWithUpdateHandler:
37+
def __init__(self) -> None:
38+
self.update_results = []
39+
40+
@workflow.update
41+
async def my_update(self, arg: int):
42+
r = await workflow.execute_activity(
43+
my_activity, arg, start_to_close_timeout=timedelta(seconds=10)
44+
)
45+
self.update_results.append(r)
46+
return self.update_results
47+
48+
@workflow.run
49+
async def run(self):
50+
await workflow.wait_condition(lambda: len(self.update_results) == N_UPDATES)
51+
return {"update_results": self.update_results}
52+
53+
54+
async def app(client: Client):
55+
handle = await client.start_workflow(
56+
WorkflowWithUpdateHandler.run,
57+
id=WORKFLOW_ID,
58+
task_queue=TASK_QUEUE,
59+
id_reuse_policy=temporalio.common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING,
60+
)
61+
62+
log(
63+
f"sent start workflow request http://{server()}/namespaces/default/workflows/{WORKFLOW_ID}"
64+
)
65+
66+
for i in range(N_UPDATES):
67+
if True or input("execute update?") in ["y", ""]:
68+
log("sending update...")
69+
result = await handle.execute_update(
70+
WorkflowWithUpdateHandler.my_update, arg=i
71+
)
72+
log(f"received update result: {result}")
73+
74+
if True or input("reset?") in ["y", ""]:
75+
history = [e async for e in handle.fetch_history_events()]
76+
reset_to = next_event(
77+
history,
78+
temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_TASK_COMPLETED,
79+
)
80+
81+
log(f"sending reset to event {reset_to.event_id}...")
82+
run_id = get_first_execution_run_id(history)
83+
new_run_id = await reset_workflow(run_id, reset_to, client)
84+
log(
85+
f"did reset: http://localhost:8080/namespaces/default/workflows/{WORKFLOW_ID}/{new_run_id}"
86+
)
87+
88+
new_handle = client.get_workflow_handle(WORKFLOW_ID, run_id=new_run_id)
89+
90+
history = [e async for e in new_handle.fetch_history_events()]
91+
92+
log("new history")
93+
for e in history:
94+
log(f"{e.event_id} {e.event_type}")
95+
96+
wf_result = await new_handle.result()
97+
print(f"reset wf result: {wf_result}")
98+
log(f"reset wf result: {wf_result}")
99+
else:
100+
wf_result = await handle.result()
101+
print(f"wf result: {wf_result}")
102+
log(f"wf result: {wf_result}")
103+
104+
105+
async def reset_workflow(
106+
run_id: str,
107+
event: temporalio.api.history.v1.HistoryEvent,
108+
client: Client,
109+
) -> RunId:
110+
resp = await client.workflow_service.reset_workflow_execution(
111+
temporalio.api.workflowservice.v1.ResetWorkflowExecutionRequest(
112+
namespace="default",
113+
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
114+
workflow_id=WORKFLOW_ID,
115+
run_id=run_id,
116+
),
117+
reason="Reset to test update reapply",
118+
request_id="1",
119+
reset_reapply_type=temporalio.api.enums.v1.ResetReapplyType.RESET_REAPPLY_TYPE_UNSPECIFIED, # TODO
120+
workflow_task_finish_event_id=event.event_id,
121+
)
122+
)
123+
assert resp.run_id
124+
return resp.run_id
125+
126+
127+
def next_event(
128+
history: List[temporalio.api.history.v1.HistoryEvent],
129+
event_type: temporalio.api.enums.v1.EventType.ValueType,
130+
) -> temporalio.api.history.v1.HistoryEvent:
131+
return next(e for e in history if e.event_type == event_type)
132+
133+
134+
def get_first_execution_run_id(
135+
history: List[temporalio.api.history.v1.HistoryEvent],
136+
) -> str:
137+
# TODO: correct way to obtain run_id
138+
wf_started_event = next_event(
139+
history, temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED
140+
)
141+
run_id = (
142+
wf_started_event.workflow_execution_started_event_attributes.first_execution_run_id
143+
)
144+
assert run_id
145+
return run_id
146+
147+
148+
async def main():
149+
client = await Client.connect("localhost:7233")
150+
async with Worker(
151+
client,
152+
task_queue=TASK_QUEUE,
153+
workflows=[WorkflowWithUpdateHandler],
154+
activities=[my_activity],
155+
sticky_queue_schedule_to_start_timeout=timedelta(hours=1),
156+
max_cached_workflows=0 if REPLAY else 100,
157+
):
158+
await app(client)
159+
160+
161+
def only(it: Iterator):
162+
t = next(it)
163+
assert (t2 := next(it, it)) == it, f"iterator had multiple items: [{t}, {t2}]"
164+
return t
165+
166+
167+
def is_listening(addr: str) -> bool:
168+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
169+
h, p = addr.split(":")
170+
try:
171+
s.connect((h, int(p)))
172+
return True
173+
except socket.error:
174+
return False
175+
finally:
176+
s.close()
177+
178+
179+
def server() -> str:
180+
return only(
181+
filter(
182+
is_listening,
183+
["localhost:8080", "localhost:8081", "localhost:8233"],
184+
)
185+
)
186+
187+
188+
def log(s: str):
189+
log_to_file(s, "client", "red")
190+
191+
192+
def log_to_file(msg: str, prefix: str, color: str):
193+
with open("/tmp/log", "a") as f:
194+
time = datetime.now().strftime("%H:%M:%S.%f")[:-3]
195+
print(
196+
f"\n\n======================\n[{color}]{time} : {prefix} : {msg}[/{color}]\n\n",
197+
file=f,
198+
)
199+
200+
201+
if __name__ == "__main__":
202+
asyncio.run(main())

0 commit comments

Comments
 (0)