Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cartopy = "^0.25.0"
matplotlib = "^3.10.7"
numpy = "^2.3.5"
pandas = "^2.3.3"
boto3 = "^1.42.65"

[tool.poetry.group.dev.dependencies]
mypy = "^1.18.1"
Expand Down
108 changes: 107 additions & 1 deletion pytrajplot/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,103 @@
"""Command line interface of pytrajplot."""
from typing import Tuple, Dict
from typing import Tuple, Dict, Optional
import logging
import os
from pathlib import Path

# Third-party
import click
import boto3

# First-party
from pytrajplot import __version__
from pytrajplot.generate_pdf import generate_pdf
from pytrajplot.parse_data import check_input_dir
from pytrajplot.utils import count_to_log_level

# Setup logging
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)

def print_version(ctx: click.Context, _param: click.Parameter, value: bool) -> None:
"""Print the version number and exit."""
if value:
click.echo(__version__)
ctx.exit(0)

def replace_variables(template_content: str) -> str:
"""
Replace $VAR with actual environment variable values.
Args:
template_content: Template string with $VARIABLE placeholders
Returns:
String with variables replaced by environment values
"""
result = template_content
# Get all environment variables as dict
env_vars = dict(os.environ)

# Replace variables found in the template
for env_key, env_value in env_vars.items():
placeholder = f'${env_key}'
if placeholder in result:
result = result.replace(placeholder, env_value)
logger.info(f"Replaced {placeholder} with {env_value}")
return result


def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str | None = None) -> bool:
"""
Check if plot_info file exists in input directory.
If not found, fetch from SSM parameter and create it replacing variables.
Args:
input_dir: Input directory path
info_name: Name of the plot info file
ssm_parameter_path: SSM parameter path (optional, uses env var if not provided)
Returns:
bool: True if file exists or was created successfully, False otherwise
"""
input_path = Path(input_dir)
plot_info_file = input_path / info_name

# Check if plot_info file already exists
if plot_info_file.exists():
logger.info(f"Plot info file already exists: {plot_info_file}")
return True

# File doesn't exist, try to create it from SSM parameter
logger.info(f"Plot info file not found: {plot_info_file}")

try:
# Get SSM parameter path from argument or environment
ssm_param_path = ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', '/pytrajplot/icon/plot_info')
logger.info(f"Fetching SSM parameter: {ssm_param_path}")

# Fetch template from SSM Parameter
ssm_client = boto3.client('ssm')
response = ssm_client.get_parameter(
Name=ssm_param_path,
WithDecryption=True
)

# Get the template content
template_content = response['Parameter']['Value']
logger.info(f"Template content length: {len(template_content)} chars")

# Replace variables with environment variable values
substituted_content = replace_variables(template_content)

# Create the plot_info file
with open(plot_info_file, 'w') as f:
f.write(substituted_content)

logger.info(f"Successfully created plot info file: {plot_info_file}")
return True

except Exception as e:
logger.error(f"Failed to create plot info file from SSM parameter: {str(e)}")
logger.error(f"SSM parameter path: {ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', 'not_set')}")
return False

def interpret_options(start_prefix: str, traj_prefix: str, info_name: str, language: str) -> Tuple[Dict[str, str], str]:
"""Reformat command line inputs.
Expand Down Expand Up @@ -124,6 +205,17 @@ def interpret_options(start_prefix: str, traj_prefix: str, info_name: str, langu
default=["pdf"],
help="Choose data type(s) of final result. Default: pdf",
)
@click.option(
"--ssm-parameter-path",
type=str,
help="SSM parameter path for plot_info template. Uses SSM_PARAMETER_PATH env var if not specified.",
)
@click.option(
"--skip-ssm-fallback",
is_flag=True,
default=False,
help="Skip SSM parameter fallback if plot_info file is missing.",
)
@click.option(
"--version",
"-V",
Expand All @@ -143,7 +235,21 @@ def cli(
language: str,
domain: str,
datatype: str,
ssm_parameter_path: str | None = None,
skip_ssm_fallback: bool = False,
) -> None:
# Check if plot_info file exists (create from SSM if needed)
if not skip_ssm_fallback:
plot_info_created = check_plot_info_file(
input_dir=input_dir,
info_name=info_name,
ssm_parameter_path=ssm_parameter_path
)

if not plot_info_created:
logger.error("Failed to check if plot_info file exists. Use --skip-ssm-fallback to continue anyway.")
raise click.ClickException("Missing plot_info file and failed to create from SSM parameter.")

prefix_dict, language = interpret_options(
start_prefix=start_prefix,
traj_prefix=traj_prefix,
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def create_args(input_dir: str, output_dir: str, opts: dict) -> list:
# Positional arguments
args.append(input_dir)
args.append(output_dir)
args.append("--skip-ssm-fallback")

# Keyword arguments
for key, value in opts.items():
Expand Down Expand Up @@ -208,4 +209,3 @@ def test_pytrajplot(input_args, input_dir, output_dir):
for rel in expected:
expected_file = Path(output_path) / Path(rel).name
assert expected_file.exists(), f"Expected output not found: {expected_file}"

Loading