Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/wxflow/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from logging import getLogger
from multiprocessing import Pool
from pathlib import Path

from .fsutils import cp, mkdir
Expand All @@ -9,6 +10,28 @@
logger = getLogger(__name__.split('.')[-1])


def _copy_single_file(src, dest):
"""Helper function to copy a single file. Used by multiprocessing.Pool.

Parameters
----------
src : str
Source file path
dest : str
Destination file path

Returns
-------
tuple
(success: bool, src: str, dest: str, error: Exception or None)
"""
try:
cp(src, dest)
return (True, src, dest, None)
except Exception as ee:
return (False, src, dest, ee)


class FileHandler:
"""Class to manipulate files in bulk for a given configuration

Expand Down Expand Up @@ -106,6 +129,72 @@ def _copy_files(filelist, required=True):
else:
logger.warning(f"Source file '{src}' does not exist, skipping!")

@staticmethod
def copy_parallel(filelist, num_processes=None):
"""Function to copy files in parallel using multiprocessing.Pool

Parameters
----------
filelist : list
List of lists of [src, dest]
num_processes : int, optional
Number of processes to use for parallel copying.
If None, uses the number of CPUs on the machine.
"""
FileHandler._copy_files_parallel(filelist, required=True, num_processes=num_processes)

@staticmethod
def _copy_files_parallel(filelist, required=True, num_processes=None):
"""Function to copy files in parallel using multiprocessing.Pool

`filelist` should be in the form:
- [src, dest]

Parameters
----------
filelist : list
List of lists of [src, dest]
required : bool, optional
Flag to indicate if the src file is required to exist. Default is True
num_processes : int, optional
Number of processes to use for parallel copying.
If None, uses the number of CPUs on the machine.
"""
# Validate filelist format
for sublist in filelist:
if len(sublist) != 2:
raise IndexError(
f"List must be of the form ['src', 'dest'], not {sublist}")

# Check that all required source files exist before starting any copies
for sublist in filelist:
src = sublist[0]
if not os.path.exists(src):
if required:
logger.exception(f"Source file '{src}' does not exist and is required, ABORT!")
raise FileNotFoundError(f"Source file '{src}' does not exist")
else:
logger.warning(f"Source file '{src}' does not exist, skipping!")

# Filter out files where source doesn't exist (for optional copies)
valid_files = [sublist for sublist in filelist if os.path.exists(sublist[0])]

if not valid_files:
logger.warning("No valid files to copy")
return

# Use multiprocessing.Pool to copy files in parallel
with Pool(processes=num_processes) as pool:
results = pool.starmap(_copy_single_file, valid_files)

# Check if any copies failed
for success, src, dest, error in results:
if not success:
logger.exception(f"Error copying {src} to {dest}: {error}")
raise error
else:
logger.info(f'Copied {src} to {dest}')

@staticmethod
def _make_dirs(dirlist):
"""Function to make all directories specified in the list
Expand Down
214 changes: 214 additions & 0 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
import os

Expand Down Expand Up @@ -263,3 +264,216 @@ def test_link_file_bad(tmp_path, create_dirs_and_files_for_test_link):
bad_config = {'link_req': bad_link_list}
with pytest.raises(FileNotFoundError):
FileHandler(bad_config).sync()


def test_copy_parallel_basic(tmp_path):
"""
Test basic parallel copy functionality
Parameters
----------
tmp_path - pytest fixture
"""
# Create input directory and files
input_dir_path = tmp_path / 'parallel_input'
config = {'mkdir': [input_dir_path]}
FileHandler(config).sync()

# Create multiple test files with some content
src_files = []
for i in range(10):
src_file = input_dir_path / f'file_{i}.txt'
src_file.write_text(f'Content of file {i}\n' * 100)
src_files.append(src_file)

# Create output directory
output_dir_path = tmp_path / 'parallel_output'
config = {'mkdir': [output_dir_path]}
FileHandler(config).sync()

# Create copy list
copy_list = []
dest_files = []
for i, src in enumerate(src_files):
dest = output_dir_path / f'file_{i}.txt'
copy_list.append([src, dest])
dest_files.append(dest)

# Perform parallel copy
FileHandler.copy_parallel(copy_list)

# Verify all files were copied
for src, dest in zip(src_files, dest_files):
assert os.path.isfile(dest), f"Destination file {dest} does not exist"
# Verify content matches
assert src.read_text() == dest.read_text(), f"Content mismatch for {dest}"


def test_copy_parallel_with_num_processes(tmp_path):
"""
Test parallel copy with specific number of processes
Parameters
----------
tmp_path - pytest fixture
"""
# Create input directory and files
input_dir_path = tmp_path / 'parallel_input'
config = {'mkdir': [input_dir_path]}
FileHandler(config).sync()

# Create test files
src_files = []
for i in range(5):
src_file = input_dir_path / f'file_{i}.txt'
src_file.write_text(f'Content {i}')
src_files.append(src_file)

# Create output directory
output_dir_path = tmp_path / 'parallel_output'
config = {'mkdir': [output_dir_path]}
FileHandler(config).sync()

# Create copy list
copy_list = []
for i, src in enumerate(src_files):
dest = output_dir_path / f'file_{i}.txt'
copy_list.append([src, dest])

# Perform parallel copy with 2 processes
FileHandler.copy_parallel(copy_list, num_processes=2)

# Verify all files were copied
for i, src in enumerate(src_files):
dest = output_dir_path / f'file_{i}.txt'
assert os.path.isfile(dest)
assert src.read_text() == dest.read_text()


def test_copy_parallel_error_propagation(tmp_path):
"""
Test that errors in one copy cause the parent call to fail
Parameters
----------
tmp_path - pytest fixture
"""
# Create input directory and files
input_dir_path = tmp_path / 'parallel_input'
config = {'mkdir': [input_dir_path]}
FileHandler(config).sync()

# Create valid source files
src_files = []
for i in range(3):
src_file = input_dir_path / f'file_{i}.txt'
src_file.write_text(f'Content {i}')
src_files.append(src_file)

# Create copy list with bad destination (unwritable directory)
copy_list = []
for i, src in enumerate(src_files):
# Try to copy to an invalid location
dest = "/dev/null/invalid_path.txt"
copy_list.append([src, dest])

# Attempt parallel copy - should fail
with pytest.raises(OSError):
FileHandler.copy_parallel(copy_list)


def test_copy_parallel_missing_required_file(tmp_path):
"""
Test that missing required source files cause the parent call to fail
Parameters
----------
tmp_path - pytest fixture
"""
# Create input directory
input_dir_path = tmp_path / 'parallel_input'
config = {'mkdir': [input_dir_path]}
FileHandler(config).sync()

# Create one valid file and reference one that doesn't exist
valid_file = input_dir_path / 'valid.txt'
valid_file.write_text('Valid content')
missing_file = input_dir_path / 'missing.txt'

# Create output directory
output_dir_path = tmp_path / 'parallel_output'
config = {'mkdir': [output_dir_path]}
FileHandler(config).sync()

# Create copy list with missing file
copy_list = [
[valid_file, output_dir_path / 'valid.txt'],
[missing_file, output_dir_path / 'missing.txt']
]

# Attempt parallel copy - should fail due to missing file
with pytest.raises(FileNotFoundError, match=f"Source file '{missing_file}' does not exist"):
FileHandler.copy_parallel(copy_list)


def test_copy_parallel_file_integrity(tmp_path):
"""
Test that parallel copies are identical to their sources
Parameters
----------
tmp_path - pytest fixture
"""
# Create input directory
input_dir_path = tmp_path / 'parallel_input'
config = {'mkdir': [input_dir_path]}
FileHandler(config).sync()

# Create files with larger content to ensure integrity
src_files = []
src_hashes = []
for i in range(5):
src_file = input_dir_path / f'file_{i}.txt'
# Create larger content
content = f'Line {i}\n' * 10000
src_file.write_text(content)
src_files.append(src_file)
# Calculate hash
hash_obj = hashlib.sha256()
hash_obj.update(content.encode())
src_hashes.append(hash_obj.hexdigest())

# Create output directory
output_dir_path = tmp_path / 'parallel_output'
config = {'mkdir': [output_dir_path]}
FileHandler(config).sync()

# Create copy list
copy_list = []
dest_files = []
for i, src in enumerate(src_files):
dest = output_dir_path / f'file_{i}.txt'
copy_list.append([src, dest])
dest_files.append(dest)

# Perform parallel copy
FileHandler.copy_parallel(copy_list)

# Verify file integrity using hashes
for i, dest in enumerate(dest_files):
assert os.path.isfile(dest)
content = dest.read_text()
hash_obj = hashlib.sha256()
hash_obj.update(content.encode())
dest_hash = hash_obj.hexdigest()
assert dest_hash == src_hashes[i], f"Hash mismatch for {dest}"


def test_copy_parallel_invalid_format(tmp_path):
"""
Test that invalid copy list format raises appropriate error
Parameters
----------
tmp_path - pytest fixture
"""
# Create a copy list with invalid format
bad_copy_list = [['only_one_item']]

# Attempt parallel copy with bad format - should fail
with pytest.raises(IndexError, match="List must be of the form"):
FileHandler.copy_parallel(bad_copy_list)