Skip to content

Commit

Permalink
Improve cloud environment handling and test output validation
Browse files Browse the repository at this point in the history
- Add explicit pyautogui disabling in cloud environments
- Update test_act_train_e2e.py to use more robust output parsing
- Add detailed error messages for policy output normalization checks
  • Loading branch information
beduffy committed Feb 23, 2025
1 parent 5a9089f commit 32d6345
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 4 additions & 0 deletions imitate_mouse/run_mouse_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def run_policy_eval(args, num_steps=100):
# Skip GUI interaction in cloud environments
in_cloud = os.environ.get('CI') or not os.environ.get('DISPLAY')

if in_cloud and pyautogui is not None:
pyautogui = None # Disable GUI functionality
print("Running in cloud environment, disabling mouse control")

device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"

# Load checkpoint with proper error handling
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/test_act_train_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ class Args:
with patch('builtins.print') as mock_print:
run_policy_eval(Args(), num_steps=5)
# Verify policy is producing normalized outputs
outputs = [ast.literal_eval(call[0][0].split(": ")[1]) for call in mock_print.call_args_list if 'Policy output' in call[0][0]]
assert len(outputs) > 3, "Insufficient policy predictions"
outputs = [
ast.literal_eval(call.args[0].split(": ")[1])
for call in mock_print.call_args_list
if 'Policy output' in call.args[0]
]
assert len(outputs) > 3, f"Insufficient policy predictions, got {len(outputs)}"
for out in outputs:
assert 0 <= out[0] <= 1, "X output not normalized"
assert 0 <= out[1] <= 1, "Y output not normalized"
assert 0 <= out[0] <= 1, f"X output {out[0]} not normalized"
assert 0 <= out[1] <= 1, f"Y output {out[1]} not normalized"


@pytest.mark.integration
Expand Down

0 comments on commit 32d6345

Please sign in to comment.