Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/portkeydrop/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def __init__(self) -> None:
self.build_tag = os.environ.get("PORTKEYDROP_BUILD_TAG")
self._auto_update_check_timer: wx.Timer | None = None
self._site_manager = SiteManager()
self._transfer_service = TransferService(notify_window=self)
self._transfer_service = TransferService(
notify_window=self,
max_workers=self._settings.transfer.concurrent_transfers,
)
self._transfer_state_by_id: dict[str, str] = {}
self._last_failed_transfer: str | None = None
self._announcer = ScreenReaderAnnouncer()
Expand Down Expand Up @@ -1574,6 +1577,9 @@ def _on_settings(self, event: wx.CommandEvent) -> None:
self._settings = dlg.get_settings()
update_last_local_folder(self._settings, self._local_cwd)
save_settings(self._settings)
self._transfer_service.set_max_workers(
self._settings.transfer.concurrent_transfers,
)
self.update_check_updates_menu_label()
self._start_auto_update_checks()
self._populate_file_list(
Expand Down
37 changes: 32 additions & 5 deletions src/portkeydrop/services/transfer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,19 @@ def from_dict(cls, data: dict) -> TransferJob:


class TransferService:
"""Owns the transfer queue and a single daemon worker thread."""
"""Owns the transfer queue and a pool of daemon worker threads."""

def __init__(self, notify_window: Any | None = None) -> None:
def __init__(self, notify_window: Any | None = None, max_workers: int = 1) -> None:
self._notify_window = notify_window
self._queue: queue.Queue[TransferJob] = queue.Queue()
self._queue: queue.Queue[TransferJob | None] = queue.Queue()
self._jobs: list[TransferJob] = []
self._lock = threading.Lock()
self._worker = threading.Thread(target=self._worker_loop, daemon=True)
self._worker.start()
self._max_workers = max(1, max_workers)
self._workers: list[threading.Thread] = []
for _ in range(self._max_workers):
t = threading.Thread(target=self._worker_loop, daemon=True)
t.start()
self._workers.append(t)

# ------------------------------------------------------------------
# Public API
Expand Down Expand Up @@ -205,6 +209,27 @@ def cancel(self, job_id: str) -> None:
break
self._post_event()

def set_max_workers(self, n: int) -> None:
"""Resize the worker pool to *n* threads.

Extra workers are drained via a ``None`` sentinel on the queue;
missing workers are spawned immediately.
"""
n = max(1, n)
with self._lock:
# Prune threads that have already exited
self._workers = [t for t in self._workers if t.is_alive()]
current = len(self._workers)
if n > current:
for _ in range(n - current):
t = threading.Thread(target=self._worker_loop, daemon=True)
t.start()
self._workers.append(t)
elif n < current:
for _ in range(current - n):
self._queue.put(None) # sentinel to stop one worker
self._max_workers = n

# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
Expand All @@ -218,6 +243,8 @@ def _enqueue(self, job: TransferJob) -> None:
def _worker_loop(self) -> None:
while True:
job = self._queue.get()
if job is None: # shutdown sentinel
break
if job.cancel_event.is_set():
job.status = TransferStatus.CANCELLED
self._post_event()
Expand Down
8 changes: 6 additions & 2 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def _build_frame(module, tmp_path):
sort_by="name",
sort_ascending=True,
)
settings = SimpleNamespace(display=display)
transfer = SimpleNamespace(concurrent_transfers=2)
settings = SimpleNamespace(display=display, transfer=transfer)
fake_manager = MagicMock(jobs=[])
fake_site_manager = MagicMock()

Expand Down Expand Up @@ -85,7 +86,7 @@ def _hydrate_frame(module):
def test_main_frame_init_sets_transfer_state(tmp_path, app_module):
frame, _, transfer_service_cls = _build_frame(app_module, tmp_path)
assert frame._transfer_state_by_id == {}
transfer_service_cls.assert_called_once_with(notify_window=frame)
transfer_service_cls.assert_called_once_with(notify_window=frame, max_workers=2)


def test_bind_events_hooks_transfer_update(app_module):
Expand Down Expand Up @@ -1177,6 +1178,7 @@ def test_on_settings_reconfigures_update_menu_and_timer(app_module):
frame._settings = SimpleNamespace(
app=SimpleNamespace(update_channel="stable"),
display=SimpleNamespace(show_hidden_files=True),
transfer=SimpleNamespace(concurrent_transfers=2),
)
frame._local_cwd = "/tmp"
frame.remote_file_list = MagicMock()
Expand All @@ -1192,6 +1194,7 @@ def test_on_settings_reconfigures_update_menu_and_timer(app_module):
updated_settings = SimpleNamespace(
app=SimpleNamespace(update_channel="nightly"),
display=SimpleNamespace(show_hidden_files=True),
transfer=SimpleNamespace(concurrent_transfers=4),
)
dialog = MagicMock(
ShowModal=MagicMock(return_value=fake_wx.ID_OK),
Expand All @@ -1215,6 +1218,7 @@ def test_on_settings_passes_check_updates_callback(app_module):
frame._settings = SimpleNamespace(
app=SimpleNamespace(update_channel="stable"),
display=SimpleNamespace(show_hidden_files=True),
transfer=SimpleNamespace(concurrent_transfers=2),
)
frame._local_cwd = "/tmp"
frame.remote_file_list = MagicMock()
Expand Down
74 changes: 71 additions & 3 deletions tests/test_transfer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,20 @@ def test_cancel_event_independent(self):


class TestTransferServiceInit:
def test_starts_daemon_worker_thread(self):
def test_starts_daemon_worker_threads(self):
svc = TransferService(notify_window=None)
assert svc._worker.is_alive()
assert svc._worker.daemon is True
assert len(svc._workers) == 1
assert all(t.is_alive() for t in svc._workers)
assert all(t.daemon for t in svc._workers)

def test_starts_multiple_workers(self):
svc = TransferService(notify_window=None, max_workers=3)
assert len(svc._workers) == 3
assert all(t.is_alive() for t in svc._workers)

def test_max_workers_clamped_to_one(self):
svc = TransferService(notify_window=None, max_workers=0)
assert len(svc._workers) == 1

def test_jobs_returns_snapshot(self):
svc = TransferService(notify_window=None)
Expand Down Expand Up @@ -446,3 +456,61 @@ def test_status_values(self):
assert TransferStatus.COMPLETE.value == "complete"
assert TransferStatus.FAILED.value == "failed"
assert TransferStatus.CANCELLED.value == "cancelled"


# ---------------------------------------------------------------------------
# Concurrent worker pool
# ---------------------------------------------------------------------------


class TestConcurrentWorkers:
def test_jobs_run_concurrently_with_multiple_workers(self):
"""Two slow jobs should overlap when max_workers >= 2."""
barrier = threading.Barrier(2, timeout=5)
completed_order: list[str] = []
lock = threading.Lock()

mock_client = MagicMock()

def slow_download(src, fh, callback=None, offset=0):
name = PurePosixPath(src).name
barrier.wait() # both workers must reach here before either proceeds
with lock:
completed_order.append(name)

mock_client.download.side_effect = slow_download

svc = TransferService(notify_window=None, max_workers=2)
with patch("builtins.open", return_value=MagicMock(spec=io.BufferedWriter)):
j1 = svc.submit_download(mock_client, "/r/a.txt", "/tmp/a.txt")
j2 = svc.submit_download(mock_client, "/r/b.txt", "/tmp/b.txt")
_wait_for_terminal(j1)
_wait_for_terminal(j2)

assert j1.status == TransferStatus.COMPLETE
assert j2.status == TransferStatus.COMPLETE
assert len(completed_order) == 2

def test_set_max_workers_increases_pool(self):
svc = TransferService(notify_window=None, max_workers=1)
assert len([t for t in svc._workers if t.is_alive()]) == 1

svc.set_max_workers(3)
time.sleep(0.1)
alive = [t for t in svc._workers if t.is_alive()]
assert len(alive) == 3

def test_set_max_workers_decreases_pool(self):
svc = TransferService(notify_window=None, max_workers=3)
assert len(svc._workers) == 3

svc.set_max_workers(1)
# Give sentinels time to be consumed
time.sleep(0.5)
alive = [t for t in svc._workers if t.is_alive()]
assert len(alive) == 1

def test_set_max_workers_clamps_to_one(self):
svc = TransferService(notify_window=None, max_workers=2)
svc.set_max_workers(0)
assert svc._max_workers == 1
Loading