Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/mdformat/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,16 @@ def run(cli_args: Sequence[str], cache_toml: bool = True) -> int: # noqa: C901
renderer_warning_printer = RendererWarningPrinter()
for path in file_paths:
read_toml = read_toml_opts if cache_toml else read_toml_opts.__wrapped__

conf_dir = Path.cwd()
if path:
conf_dir = path.parent
if cli_opts.get("toml_file"):
cli_toml_file = resolve_cli_toml_file(cli_opts["toml_file"])
conf_dir = cli_toml_file.parent

try:
toml_opts, toml_path = read_toml(path.parent if path else Path.cwd())
toml_opts, toml_path = read_toml(conf_dir)
except InvalidConfError as e:
print_error(str(e))
return 1
Expand Down Expand Up @@ -249,6 +257,7 @@ def make_arg_parser(
help="exclude files that match the Unix-style glob pattern "
"(multiple allowed)",
)
parser.add_argument("--toml_file", help="path to desired toml config file")
extensions_group = parser.add_mutually_exclusive_group()
extensions_group.add_argument(
"--extensions",
Expand Down Expand Up @@ -503,3 +512,17 @@ def get_source_file_and_line(obj: object) -> tuple[str, int]:
except (OSError, TypeError): # pragma: no cover
lineno = 0
return filename, lineno


def resolve_cli_toml_file(cli_toml_arg: str) -> Path:
toml_file = Path(cli_toml_arg).absolute()

if cli_toml_arg[0].isalnum() or (cli_toml_arg[0] in ["."]):
toml_file = (Path.cwd() / cli_toml_arg).absolute()

if cli_toml_arg[0] == "~":
toml_file = (Path.home() / cli_toml_arg.replace("~", "")).absolute()

if not toml_file.is_file():
raise InvalidPath(toml_file)
return toml_file
56 changes: 55 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from pathlib import Path
import sys
from unittest.mock import patch

import pytest

import mdformat
from mdformat._cli import get_plugin_info_str, run, wrap_paragraphs
from mdformat._cli import InvalidPath, get_plugin_info_str, run, wrap_paragraphs
from mdformat.plugins import CODEFORMATTERS, PARSER_EXTENSIONS
from tests.utils import (
FORMATTED_MARKDOWN,
Expand Down Expand Up @@ -512,3 +513,56 @@ def test_no_extensions(tmp_path, monkeypatch):
file_path.write_text(original_md)
assert run((str(file_path), "--no-extensions")) == 0
assert file_path.read_text() == original_md


def test_cli_toml(tmp_path):
_wrap_num = 5
config_path = tmp_path / ".mdformat.toml"
config_path.write_text(f"wrap = {_wrap_num}")

file_path = tmp_path / "test_markdown.md"
file_path.write_text(
" ".join(["x" * _wrap_num, "o" * _wrap_num, "w" * _wrap_num, "p" * _wrap_num])
)

assert run([str(file_path), f"--toml_file={config_path}"]) == 0
assert file_path.read_text() == "xxxxx\nooooo\nwwwww\nppppp\n"


def test_cli_toml_alphanum(tmp_path):
config_path = "1234/.mdformat.toml"

file_path = tmp_path / "test_markdown.md"
file_path.write_text("text")

with pytest.raises(InvalidPath) as except_info:
run([str(file_path), f"--toml_file={config_path}"])
assert except_info.typename == "InvalidPath"

err_value = except_info.value
assert config_path in err_value.path.__str__()


def test_cli_toml_home(tmp_path):
file_path = tmp_path / "test_markdown.md"
file_path.write_text("text")

with pytest.raises(InvalidPath) as except_info:
run([str(file_path), "--toml_file=~"])
assert except_info.typename == "InvalidPath"

err_value = except_info.value
assert Path.home() == err_value.path


def test_cli_toml_not_exists(tmp_path, capsys):
config_path = tmp_path / ".mdformat.toml"

file_path = tmp_path / "test_markdown.md"

with pytest.raises(SystemExit) as exc_info:
run([str(file_path), f"--toml_file={config_path}"])
assert exc_info.value.code == 2

captured = capsys.readouterr()
assert "does not exist" in captured.err