Skip to content

Commit 1d6fb56

Browse files
committed
Fix requirements.txt unit tests
1 parent bc35d5d commit 1d6fb56

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

src/datacustomcode/scan.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,10 @@ def write_requirements_file(file_path: str) -> str:
243243
"""
244244
imports = scan_file_for_imports(file_path)
245245

246-
# Use the parent directory rather than same directory as the file
246+
# Write requirements.txt in the parent directory of the Python file
247247
file_dir = os.path.dirname(file_path)
248-
output_dir = os.path.dirname(file_dir) if file_dir else "."
249-
250-
requirements_path = os.path.join(output_dir, "requirements.txt")
248+
parent_dir = os.path.dirname(file_dir) if file_dir else "."
249+
requirements_path = os.path.join(parent_dir, "requirements.txt")
251250

252251
# If the file exists, read existing requirements and merge with new ones
253252
existing_requirements = set()

tests/test_scan.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,31 +453,50 @@ def test_scan_file_for_imports(self):
453453

454454
def test_write_requirements_file_new(self):
455455
"""Test writing a new requirements.txt file."""
456+
# Create a temporary directory structure
457+
temp_dir = tempfile.mkdtemp()
458+
script_dir = os.path.join(temp_dir, "script_dir")
459+
os.makedirs(script_dir)
460+
456461
content = textwrap.dedent(
457462
"""
458463
import pandas as pd
459464
import numpy as np
460465
"""
461466
)
462-
temp_path = create_test_script(content)
467+
temp_path = os.path.join(script_dir, "test_script.py")
468+
with open(temp_path, "w") as f:
469+
f.write(content)
470+
471+
requirements_path = None
463472
try:
464473
requirements_path = write_requirements_file(temp_path)
465474
assert os.path.exists(requirements_path)
475+
assert (
476+
os.path.dirname(requirements_path) == temp_dir
477+
) # Should be in parent directory
466478

467479
with open(requirements_path, "r") as f:
468480
requirements = {line.strip() for line in f}
469481

470482
assert "pandas" in requirements
471483
assert "numpy" in requirements
472484
finally:
473-
os.unlink(temp_path)
474-
if os.path.exists(requirements_path):
485+
if os.path.exists(temp_path):
486+
os.unlink(temp_path)
487+
if requirements_path and os.path.exists(requirements_path):
475488
os.unlink(requirements_path)
489+
os.rmdir(script_dir)
490+
os.rmdir(temp_dir)
476491

477492
def test_write_requirements_file_merge(self):
478493
"""Test merging with existing requirements.txt file."""
479-
# First create an existing requirements.txt
494+
# Create a temporary directory structure
480495
temp_dir = tempfile.mkdtemp()
496+
script_dir = os.path.join(temp_dir, "script_dir")
497+
os.makedirs(script_dir)
498+
499+
# Create existing requirements.txt in parent directory
481500
existing_requirements = os.path.join(temp_dir, "requirements.txt")
482501
with open(existing_requirements, "w") as f:
483502
f.write("pandas\nnumpy\n")
@@ -491,10 +510,17 @@ def test_write_requirements_file_merge(self):
491510
import matplotlib
492511
"""
493512
)
494-
temp_path = create_test_script(content)
513+
temp_path = os.path.join(script_dir, "test_script.py")
514+
with open(temp_path, "w") as f:
515+
f.write(content)
516+
517+
requirements_path = None
495518
try:
496519
requirements_path = write_requirements_file(temp_path)
497520
assert os.path.exists(requirements_path)
521+
assert (
522+
os.path.dirname(requirements_path) == temp_dir
523+
) # Should be in parent directory
498524

499525
with open(requirements_path, "r") as f:
500526
requirements = {line.strip() for line in f}
@@ -505,11 +531,13 @@ def test_write_requirements_file_merge(self):
505531
assert "scipy" in requirements
506532
assert "matplotlib" in requirements
507533
finally:
508-
os.unlink(temp_path)
509-
if os.path.exists(requirements_path):
534+
if os.path.exists(temp_path):
535+
os.unlink(temp_path)
536+
if requirements_path and os.path.exists(requirements_path):
510537
os.unlink(requirements_path)
511538
if os.path.exists(existing_requirements):
512539
os.unlink(existing_requirements)
540+
os.rmdir(script_dir)
513541
os.rmdir(temp_dir)
514542

515543
def test_standard_library_exclusion(self):

0 commit comments

Comments
 (0)