Skip to content

Commit a60065e

Browse files
fix(arkitect): trace span add context support
1 parent 79b3321 commit a60065e

File tree

1 file changed

+39
-6
lines changed

1 file changed

+39
-6
lines changed

arkitect/telemetry/trace/wrapper.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextvars
1516
import inspect
1617
import time
1718
from functools import wraps
@@ -31,6 +32,9 @@
3132

3233
T = TypeVar("T", covariant=True)
3334
tracer = trace.get_tracer(__name__)
35+
_current_span_context: contextvars.ContextVar = contextvars.ContextVar(
36+
"current_span_context"
37+
)
3438

3539

3640
def get_remote_func(func): # type: ignore
@@ -74,7 +78,12 @@ def task(
7478

7579
def task_wrapper(func): # type: ignore
7680
async def async_exec(*args: Any, **kwargs: Any) -> Any:
77-
with tracer.start_as_current_span(name=func.__qualname__) as span:
81+
parent_ctx = _current_span_context.get(None)
82+
with tracer.start_as_current_span(
83+
name=func.__qualname__, context=parent_ctx
84+
) as span:
85+
_current_span_context.set(trace.set_span_in_context(span))
86+
7887
input = _update_kwargs(args, kwargs, func)
7988
try:
8089
result = await (get_remote_func(func) if distributed else func)(
@@ -98,7 +107,12 @@ async def async_exec(*args: Any, **kwargs: Any) -> Any:
98107
raise e
99108

100109
def sync_exec(*args: Any, **kwargs: Any) -> Any:
101-
with tracer.start_as_current_span(name=func.__qualname__) as span:
110+
parent_ctx = _current_span_context.get(None)
111+
with tracer.start_as_current_span(
112+
name=func.__qualname__, context=parent_ctx
113+
) as span:
114+
_current_span_context.set(trace.set_span_in_context(span))
115+
102116
input = _update_kwargs(args, kwargs, func)
103117
try:
104118
result = func(*args, **kwargs)
@@ -121,9 +135,13 @@ def sync_exec(*args: Any, **kwargs: Any) -> Any:
121135

122136
@wraps(func)
123137
async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
138+
parent_ctx = _current_span_context.get(None)
124139
span = tracer.start_span(
125-
name=func.__qualname__ + ".first_iter", start_time=time.time_ns()
140+
name=func.__qualname__ + ".first_iter",
141+
start_time=time.time_ns(),
142+
context=parent_ctx,
126143
)
144+
_current_span_context.set(trace.set_span_in_context(span))
127145
input = _update_kwargs(args, kwargs, func)
128146
try:
129147
async for i, resp in aenumerate(func(*args, **kwargs)): # type: ignore
@@ -142,12 +160,17 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
142160
custom_attributes=custom_attributes,
143161
)
144162
span.end(end_time=time.time_ns())
163+
_current_span_context.set(parent_ctx)
145164
yield resp
146165

147166
if trace_all:
167+
parent_ctx = _current_span_context.get()
148168
span = tracer.start_span(
149-
name=func.__qualname__, start_time=time.time_ns()
169+
name=func.__qualname__,
170+
start_time=time.time_ns(),
171+
context=parent_ctx,
150172
)
173+
_current_span_context.set(trace.set_span_in_context(span))
151174
except Exception as e:
152175
if not trace_all:
153176
span = tracer.start_span(
@@ -160,7 +183,12 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
160183

161184
@wraps(func)
162185
def iter_task(*args: Any, **kwargs: Any) -> Iterable[T]:
163-
span = tracer.start_span(name=func.__qualname__, start_time=time.time_ns())
186+
parent_ctx = _current_span_context.get(None)
187+
span = tracer.start_span(
188+
name=func.__qualname__, start_time=time.time_ns(), context=parent_ctx
189+
)
190+
_current_span_context.set(trace.set_span_in_context(span))
191+
164192
input = _update_kwargs(args, kwargs, func)
165193
try:
166194
for i, resp in enumerate(func(*args, **kwargs)):
@@ -179,11 +207,16 @@ def iter_task(*args: Any, **kwargs: Any) -> Iterable[T]:
179207
custom_attributes=custom_attributes,
180208
)
181209
span.end(end_time=time.time_ns())
210+
_current_span_context.set(parent_ctx)
182211
yield resp
183212
if trace_all:
213+
parent_ctx = _current_span_context.get()
184214
span = tracer.start_span(
185-
name=func.__qualname__, start_time=time.time_ns()
215+
name=func.__qualname__,
216+
start_time=time.time_ns(),
217+
context=parent_ctx,
186218
)
219+
_current_span_context.set(trace.set_span_in_context(span))
187220
except Exception as e:
188221
if not trace_all:
189222
span = tracer.start_span(

0 commit comments

Comments
 (0)