Skip to content

Commit 651296b

Browse files
committed
caterpillar, download: improve SIGINT and SIGTERM handling
1 parent cd36378 commit 651296b

File tree

2 files changed

+129
-92
lines changed

2 files changed

+129
-92
lines changed

src/caterpillar/caterpillar.py

+73-57
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import time
99
import urllib.parse
10-
from typing import List, Optional
10+
from typing import Any, List, Optional
1111

1212
import peewee
1313

@@ -255,6 +255,7 @@ def rmdir_p(path: pathlib.Path, *, root: pathlib.Path = None) -> None:
255255
def process_entry(
256256
m3u8_url: str,
257257
output: pathlib.Path,
258+
*,
258259
force: bool = False,
259260
exist_ok: bool = False,
260261
workdir: pathlib.Path = None,
@@ -376,6 +377,65 @@ def process_entry(
376377
return 0
377378

378379

380+
def process_batch(
381+
manifest: pathlib.Path,
382+
remove_manifest_on_success: bool = False,
383+
debug: bool = False,
384+
**processing_kwargs: Any,
385+
) -> int:
386+
target_dir = manifest.parent
387+
try:
388+
entries = []
389+
with manifest.open(encoding="utf-8") as fp:
390+
manifest_content = fp.read()
391+
except OSError:
392+
logger.critical("cannot open batch mode manifest", exc_info=debug)
393+
if debug:
394+
raise
395+
return 1
396+
except UnicodeDecodeError:
397+
logger.critical(
398+
"cannot decode batch mode manifest as utf-8; "
399+
"see https://git.io/caterpillar-encoding",
400+
exc_info=debug,
401+
)
402+
if debug:
403+
raise
404+
return 1
405+
406+
for line in manifest_content.splitlines():
407+
try:
408+
m3u8_url, filename = line.strip().split("\t")
409+
output = target_dir.joinpath(filename)
410+
entries.append((m3u8_url, output))
411+
except Exception:
412+
logger.critical(
413+
"malformed line in batch mode manifest: %s", line, exc_info=debug
414+
)
415+
if debug:
416+
raise
417+
return 1
418+
419+
retvals = []
420+
count = len(entries)
421+
for i, (m3u8_url, output) in enumerate(entries):
422+
sys.stderr.write(
423+
f'[{i + 1}/{count}] Downloading {m3u8_url} into "{output}"...\n'
424+
)
425+
retvals.append(process_entry(m3u8_url, output, **processing_kwargs))
426+
sys.stderr.write("\n")
427+
retval = int(any(retvals))
428+
if retval == 0 and remove_manifest_on_success:
429+
try:
430+
manifest.unlink()
431+
except OSError:
432+
logger.error("cannot remove batch mode manifest")
433+
if debug:
434+
raise
435+
return 1
436+
return retval
437+
438+
379439
def main() -> int:
380440
user_config_options = [] if USER_CONFIG_DISABLED else load_user_config()
381441

@@ -545,63 +605,19 @@ def main() -> int:
545605
logger.critical("ffmpeg not found")
546606
return 1
547607

548-
if not args.batch:
549-
return process_entry(args.m3u8_url, args.output, **kwargs)
550-
else:
551-
manifest = pathlib.Path(args.m3u8_url).resolve()
552-
target_dir = manifest.parent
553-
try:
554-
entries = []
555-
with manifest.open(encoding="utf-8") as fp:
556-
manifest_content = fp.read()
557-
except OSError:
558-
logger.critical("cannot open batch mode manifest", exc_info=args.debug)
559-
if args.debug:
560-
raise
561-
return 1
562-
except UnicodeDecodeError:
563-
logger.critical(
564-
"cannot decode batch mode manifest as utf-8; "
565-
"see https://git.io/caterpillar-encoding",
566-
exc_info=args.debug,
567-
)
568-
if args.debug:
569-
raise
570-
return 1
571-
572-
for line in manifest_content.splitlines():
573-
try:
574-
m3u8_url, filename = line.strip().split("\t")
575-
output = target_dir.joinpath(filename)
576-
entries.append((m3u8_url, output))
577-
except Exception:
578-
logger.critical(
579-
"malformed line in batch mode manifest: %s",
580-
line,
581-
exc_info=args.debug,
582-
)
583-
if args.debug:
584-
raise
585-
return 1
586-
587-
retvals = []
588-
count = len(entries)
589-
for i, (m3u8_url, output) in enumerate(entries):
590-
sys.stderr.write(
591-
f'[{i + 1}/{count}] Downloading {m3u8_url} into "{output}"...\n'
608+
try:
609+
if not args.batch:
610+
return process_entry(args.m3u8_url, args.output, **kwargs)
611+
else:
612+
manifest = pathlib.Path(args.m3u8_url).resolve()
613+
return process_batch(
614+
manifest,
615+
remove_manifest_on_success=args.remove_manifest_on_success,
616+
debug=args.debug,
617+
**kwargs,
592618
)
593-
retvals.append(process_entry(m3u8_url, output, **kwargs))
594-
sys.stderr.write("\n")
595-
retval = int(any(retvals))
596-
if retval == 0 and args.remove_manifest_on_success:
597-
try:
598-
manifest.unlink()
599-
except OSError:
600-
logger.error("cannot remove batch mode manifest")
601-
if args.debug:
602-
raise
603-
return 1
604-
return retval
619+
except KeyboardInterrupt:
620+
return 1
605621

606622

607623
if __name__ == "__main__":

src/caterpillar/download.py

+56-35
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,20 @@ def download_segment(
128128

129129
# download_segment wrapper that takes all arguments as a single tuple,
130130
# so that we can use it with multiprocessing.pool.Pool.map and company.
131+
# It also gracefully consumes KeyboardInterrupt.
131132
def _download_segment_mappable(args: Tuple[str, int, pathlib.Path]) -> bool:
132-
return download_segment(*args)
133+
try:
134+
return download_segment(*args)
135+
except KeyboardInterrupt:
136+
url, *_ = args
137+
logger.debug(f"download of {url} has been interrupted")
138+
return False
133139

134140

135-
def _init_worker():
136-
# Ignore SIGINT in worker processes to disable traceback from
137-
# each worker on keyboard interrupt.
138-
signal.signal(signal.SIGINT, signal.SIG_IGN)
141+
def _raise_keyboard_interrupt(signum, _):
142+
pid = os.getpid()
143+
logger.debug(f"pid {pid} received signal {signum}; transforming into SIGINT")
144+
raise KeyboardInterrupt
139145

140146

141147
# Download all segments in remote_m3u8_file (downloaded from
@@ -179,33 +185,48 @@ def download_m3u8_segments(
179185
logger.error(f"{remote_m3u8_file}: empty playlist")
180186
return False
181187
jobs = min(jobs, total)
182-
with multiprocessing.Pool(jobs, _init_worker) as pool:
183-
num_success = 0
184-
num_failure = 0
185-
logger.info(f"downloading {total} segments with {jobs} workers...")
186-
progress_bar_generator = (
187-
click.progressbar if should_log_warning() else stub_context_manager
188-
)
189-
progress_bar_props = dict(
190-
width=0, # Full width
191-
bar_template="[%(bar)s] %(info)s",
192-
show_pos=True,
193-
length=total,
194-
)
195-
with progress_bar_generator(**progress_bar_props) as bar: # type: ignore
196-
for success in pool.imap_unordered(
197-
_download_segment_mappable, download_args
198-
):
199-
if success:
200-
num_success += 1
201-
else:
202-
num_failure += 1
203-
logger.debug(f"progress: {num_success}/{num_failure}/{total}")
204-
bar.update(1)
205-
206-
if num_failure > 0:
207-
logger.error(f"failed to download {num_failure} segments")
208-
return False
209-
else:
210-
logger.info(f"finished downloading all {total} segments")
211-
return True
188+
with multiprocessing.Pool(jobs) as pool:
189+
# For the duration of the worker pool, map SIGTERM to SIGINT on
190+
# the main process. We only do this after the fork, and restore
191+
# the original SIGTERM handler (usually SIG_DFL) at the end of
192+
# the pool, because using _raise_keyboard_interrupt as the
193+
# SIGTERM handler on workers could somehow lead to dead locks.
194+
old_sigterm_handler = signal.signal(signal.SIGTERM, _raise_keyboard_interrupt)
195+
try:
196+
num_success = 0
197+
num_failure = 0
198+
logger.info(f"downloading {total} segments with {jobs} workers...")
199+
progress_bar_generator = (
200+
click.progressbar if should_log_warning() else stub_context_manager
201+
)
202+
progress_bar_props = dict(
203+
width=0, # Full width
204+
bar_template="[%(bar)s] %(info)s",
205+
show_pos=True,
206+
length=total,
207+
)
208+
with progress_bar_generator(**progress_bar_props) as bar: # type: ignore
209+
for success in pool.imap_unordered(
210+
_download_segment_mappable, download_args
211+
):
212+
if success:
213+
num_success += 1
214+
else:
215+
num_failure += 1
216+
logger.debug(f"progress: {num_success}/{num_failure}/{total}")
217+
bar.update(1)
218+
219+
if num_failure > 0:
220+
logger.error(f"failed to download {num_failure} segments")
221+
return False
222+
else:
223+
logger.info(f"finished downloading all {total} segments")
224+
return True
225+
except KeyboardInterrupt:
226+
pool.terminate()
227+
pool.join()
228+
logger.critical("interrupted")
229+
# Bubble KeyboardInterrupt to stop retries.
230+
raise KeyboardInterrupt
231+
finally:
232+
signal.signal(signal.SIGTERM, old_sigterm_handler)

0 commit comments

Comments
 (0)