diff --git a/src/mdformat/_cli.py b/src/mdformat/_cli.py index 74a6e28..ad1ab98 100644 --- a/src/mdformat/_cli.py +++ b/src/mdformat/_cli.py @@ -47,8 +47,16 @@ def run(cli_args: Sequence[str], cache_toml: bool = True) -> int: # noqa: C901 renderer_warning_printer = RendererWarningPrinter() for path in file_paths: read_toml = read_toml_opts if cache_toml else read_toml_opts.__wrapped__ + + conf_dir = Path.cwd() + if path: + conf_dir = path.parent + if cli_opts.get("toml_file"): + cli_toml_file = resolve_cli_toml_file(cli_opts["toml_file"]) + conf_dir = cli_toml_file.parent + try: - toml_opts, toml_path = read_toml(path.parent if path else Path.cwd()) + toml_opts, toml_path = read_toml(conf_dir) except InvalidConfError as e: print_error(str(e)) return 1 @@ -249,6 +257,7 @@ def make_arg_parser( help="exclude files that match the Unix-style glob pattern " "(multiple allowed)", ) + parser.add_argument("--toml_file", help="path to desired toml config file") extensions_group = parser.add_mutually_exclusive_group() extensions_group.add_argument( "--extensions", @@ -503,3 +512,17 @@ def get_source_file_and_line(obj: object) -> tuple[str, int]: except (OSError, TypeError): # pragma: no cover lineno = 0 return filename, lineno + + +def resolve_cli_toml_file(cli_toml_arg: str) -> Path: + toml_file = Path(cli_toml_arg).absolute() + + if cli_toml_arg[0].isalnum() or (cli_toml_arg[0] in ["."]): + toml_file = (Path.cwd() / cli_toml_arg).absolute() + + if cli_toml_arg[0] == "~": + toml_file = (Path.home() / cli_toml_arg.replace("~", "")).absolute() + + if not toml_file.is_file(): + raise InvalidPath(toml_file) + return toml_file diff --git a/tests/test_cli.py b/tests/test_cli.py index 28c3622..61b8164 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,11 +1,12 @@ import os +from pathlib import Path import sys from unittest.mock import patch import pytest import mdformat -from mdformat._cli import get_plugin_info_str, run, wrap_paragraphs +from mdformat._cli import InvalidPath, get_plugin_info_str, run, wrap_paragraphs from mdformat.plugins import CODEFORMATTERS, PARSER_EXTENSIONS from tests.utils import ( FORMATTED_MARKDOWN, @@ -512,3 +513,56 @@ def test_no_extensions(tmp_path, monkeypatch): file_path.write_text(original_md) assert run((str(file_path), "--no-extensions")) == 0 assert file_path.read_text() == original_md + + +def test_cli_toml(tmp_path): + _wrap_num = 5 + config_path = tmp_path / ".mdformat.toml" + config_path.write_text(f"wrap = {_wrap_num}") + + file_path = tmp_path / "test_markdown.md" + file_path.write_text( + " ".join(["x" * _wrap_num, "o" * _wrap_num, "w" * _wrap_num, "p" * _wrap_num]) + ) + + assert run([str(file_path), f"--toml_file={config_path}"]) == 0 + assert file_path.read_text() == "xxxxx\nooooo\nwwwww\nppppp\n" + + +def test_cli_toml_alphanum(tmp_path): + config_path = "1234/.mdformat.toml" + + file_path = tmp_path / "test_markdown.md" + file_path.write_text("text") + + with pytest.raises(InvalidPath) as except_info: + run([str(file_path), f"--toml_file={config_path}"]) + assert except_info.typename == "InvalidPath" + + err_value = except_info.value + assert config_path in err_value.path.__str__() + + +def test_cli_toml_home(tmp_path): + file_path = tmp_path / "test_markdown.md" + file_path.write_text("text") + + with pytest.raises(InvalidPath) as except_info: + run([str(file_path), "--toml_file=~"]) + assert except_info.typename == "InvalidPath" + + err_value = except_info.value + assert Path.home() == err_value.path + + +def test_cli_toml_not_exists(tmp_path, capsys): + config_path = tmp_path / ".mdformat.toml" + + file_path = tmp_path / "test_markdown.md" + + with pytest.raises(SystemExit) as exc_info: + run([str(file_path), f"--toml_file={config_path}"]) + assert exc_info.value.code == 2 + + captured = capsys.readouterr() + assert "does not exist" in captured.err