Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import dataclasses
import functools
import inspect
import io
import json
import multiprocessing
import platform
import sys
import tempfile
import traceback
from datetime import datetime
from pathlib import Path
Expand All @@ -29,6 +29,7 @@
List,
MutableMapping,
Optional,
Sequence,
Set,
Type,
Union,
Expand Down Expand Up @@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
return JSONResponse(content=health_status, status_code=status_code)


def _handle_temporary_upload_file(
upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
) -> tempfile.SpooledTemporaryFile:
temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
temp_file.write(upload_file.file.read())
temp_file.seek(0)
return temp_file


async def temporary_upload_tree(
token: str, files: List[UploadFile]
) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
"""Write the uploaded files to a temporary directory structure.

Args:
token: The token to use for the temporary directory.
files: The files to write to the temporary directory.

Yields:
A list of the temporary files.
"""
upload_dir = get_upload_dir()
upload_dir.mkdir(parents=True, exist_ok=True)
temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
temp_files = []
loop = asyncio.get_running_loop()
temp_files = [
await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
for f in files
]
try:
yield temp_files
finally:

def _cleanup():
for temp_file in temp_files:
temp_file.close()
temp_root.cleanup()

await loop.run_in_executor(None, _cleanup)


def upload(app: App):
"""Upload a file.

Expand Down Expand Up @@ -1563,24 +1606,21 @@ async def upload_file(request: Request, files: List[UploadFile]):
# AsyncExitStack was removed from the request scope and is now
# part of the routing function which closes this before the
# event is handled.
file_copies = []
for file in files:
content_copy = io.BytesIO()
content_copy.write(await file.read())
content_copy.seek(0)
file_copies.append(
UploadFile(
file=content_copy,
filename=file.filename,
size=file.size,
headers=file.headers,
)
file_ctx = temporary_upload_tree(token, files)
temp_files = [
UploadFile(
file=tmp, # pyright: ignore[reportArgumentType]
filename=file.filename,
size=file.size,
headers=file.headers,
)
for file, tmp in zip(files, await anext(file_ctx), strict=True)
]

event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: file_copies},
payload={handler_upload_param[0]: temp_files},
)

async def _ndjson_updates():
Expand All @@ -1595,6 +1635,9 @@ async def _ndjson_updates():
# Postprocess the event.
update = await app._postprocess(state, event, update)
yield update.json() + "\n"
# Clean up the temporary files.
async for _ in file_ctx:
pass

# Stream updates to client
return StreamingResponse(
Expand Down