diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..54a79a8 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[run] +omit = + reddit_get/types/__init__.py + diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..428d292 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,15 @@ +coverage: + status: + project: + default: + target: 90% + threshold: 5% + patch: + target: 90% + threshold: 5% + +ignore: + - "reddit_get/__init__.py" + - "reddit_get/types/__init__.py" + - "tests/" + diff --git a/poetry.lock b/poetry.lock index 8e9fc8d..1c82b4c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -597,11 +597,19 @@ category = "dev" optional = false python-versions = ">=3.6" +[[package]] +name = "types-toml" +version = "0.10.7" +description = "Typing stubs for toml" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "typing-extensions" version = "4.2.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -681,7 +689,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = '>=3.7.0, <4.0.0' -content-hash = "fd06cbea5c3b712878e7123ea3fe1ad8097792cdc0a0a3fd980dc992b4739a63" +content-hash = "af390b04932c1eaf0673fb37fe48574525c2c3557d15f762ae72d650e5819823" [metadata.files] atomicwrites = [ @@ -1003,6 +1011,10 @@ typed-ast = [ {file = "typed_ast-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:20d5118e494478ef2d3a2702d964dae830aedd7b4d3b626d003eea526be18718"}, {file = "typed_ast-1.5.3.tar.gz", hash = "sha256:27f25232e2dd0edfe1f019d6bfaaf11e86e657d9bdb7b0956db95f560cceb2b3"}, ] +types-toml = [ + {file = "types-toml-0.10.7.tar.gz", hash = "sha256:a567fe2614b177d537ad99a661adc9bfc8c55a46f95e66370a4ed2dd171335f9"}, + {file = "types_toml-0.10.7-py3-none-any.whl", hash = "sha256:05a8da4bfde2f1ee60e90c7071c063b461f74c63a9c3c1099470c08d6fa58615"}, +] typing-extensions = [ {file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"}, {file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"}, diff --git a/pyproject.toml b/pyproject.toml index ce25c08..a1a00b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ fire = '>=0.3.1,<0.5.0' praw = "^7.6.0" toml = '^0.10.2' titlecase = "^2.3.0" +typing-extensions = "^4.2.0" [tool.poetry.dev-dependencies] black = { version = '*', allow-prereleases = true } @@ -42,6 +43,7 @@ isort = "*" pytest-isort = "*" pydantic = "^1.9.0" pytest-mypy = {version = "*", allow-prereleases = true} +types-toml = "^0.10.7" [tool.pytest.ini_options] minversion = '6.0' @@ -73,6 +75,7 @@ exclude = ''' [tool.isort] profile = "black" multi_line_output = 3 +force_grid_wrap = 2 [build-system] requires = ['poetry-core>=1.0.0'] diff --git a/reddit_get/cli.py b/reddit_get/cli.py index 3d5b13b..1262653 100644 --- a/reddit_get/cli.py +++ b/reddit_get/cli.py @@ -1,17 +1,27 @@ -import functools import sys -from pathlib import Path -from string import Formatter -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Union +from typing import ( + Dict, + List, + Union, +) import fire import praw -import toml from praw.exceptions import MissingRequiredAttributeException -from praw.models import Submission -from praw.models.reddit.subreddit import Subreddit -from reddit_get.types import SortingOption, TimeFilterOption +from .types import ( + SortingOption, + TimeFilterOption, +) +from .utils import ( + create_post_output, + get_post_sorting_option, + get_reddit_query_function, + get_response, + get_template_keys, + get_time_filter_option, + load_configs, +) class RedditCli: @@ -38,19 +48,9 @@ class RedditCli: """ def __init__(self, config: str = '~/.redditgetrc'): - self.config_path: Path = Path(config).expanduser() - try: - self.configs = toml.load(self.config_path) - except (FileNotFoundError, toml.TomlDecodeError): - raise fire.core.FireError(f'No valid TOML config found at {self.config_path}') - try: - self.reddit = praw.Reddit(**self.configs['reddit-get']) - except MissingRequiredAttributeException as e: # pragma: no cover - fire.core.FireError(e) - if not self.reddit.user.me(): - raise fire.core.FireError( # pragma: no cover - 'Failed to authenticate with Reddit. Did you remember your username and password?' - ) + self.config_path, self.configs = load_configs(config) + self.reddit = self.get_authenticated_reddit_instance() + self.valid_header_variables: Dict[str, Dict[Union[SortingOption, TimeFilterOption], str]] = { 'sorting': { SortingOption.CONTROVERSIAL: 'Most Controversial', @@ -71,6 +71,17 @@ def __init__(self, config: str = '~/.redditgetrc'): }, } + def get_authenticated_reddit_instance(self): + try: + reddit = praw.Reddit(**self.configs['reddit-get']) + if not reddit.user.me(): + raise fire.core.FireError( # pragma: no cover + 'Failed to authenticate with Reddit. Did you remember your username and password?' + ) + return reddit + except MissingRequiredAttributeException as e: # pragma: no cover + fire.core.FireError(e) + def config_location(self): """Get the path of the reddit-get config. @@ -81,11 +92,11 @@ def config_location(self): else: raise fire.core.FireError(f'No config_path has been set!') - def _create_header( + def create_header( self, template: str, sorting: SortingOption, time: TimeFilterOption, subreddit: str ) -> str: valid_keys = {'sorting', 'time', 'subreddit'} - keys = self._get_template_keys(template) + keys = get_template_keys(template) if keys and not keys.issubset(valid_keys): raise fire.core.FireError( f'Invalid keys passed into header template: {", ".join(keys - valid_keys)}' @@ -97,24 +108,6 @@ def _create_header( } return template.format(**format_params) - def _create_post_output(self, template: str, posts: Iterator[Submission]) -> List[str]: - template_vars = self._get_template_keys(template) - if not template_vars: - raise fire.core.FireError('Your post output template did not have any items to be printed') - results = [] - for post in posts: - try: - format_params = {key: getattr(post, key) for key in template_vars} - results.append(template.format(**format_params)) - except AttributeError as e: - raise fire.core.FireError(e) - return results - - @staticmethod - def _get_template_keys(template: str) -> Optional[Set[str]]: - template_vars = {tup[1] for tup in Formatter().parse(template) if tup[1] and isinstance(tup[1], str)} - return template_vars or None - def post( self, subreddit: str, @@ -178,45 +171,20 @@ def post( The number of post titles from the specified subreddit formatted as specified """ - try: - post_sorting = SortingOption(post_sorting) - except ValueError: - raise fire.core.FireError(f'{post_sorting} is not a valid sorting option.') - try: - time_filter = TimeFilterOption(time_filter) - except ValueError: - raise fire.core.FireError(f'{time_filter} is not a valid time filter option') if not 0 < limit <= 25: raise fire.core.FireError('You may only get between 1 and 25 submissions') - - praw_subreddit: Subreddit = self.reddit.subreddit(subreddit) - - call_map: Dict[SortingOption, Callable[[Optional[int]], Iterator[Any]]] = { - SortingOption.CONTROVERSIAL: functools.partial( - praw_subreddit.controversial, time_filter=time_filter + sorting = get_post_sorting_option(post_sorting) + query_fn = get_reddit_query_function(self.reddit.subreddit(subreddit), time_filter, sorting) + return get_response( + self.create_header( + template=custom_header, + sorting=sorting, + time=get_time_filter_option(time_filter), + subreddit=subreddit, ), - SortingOption.GILDED: praw_subreddit.gilded, - SortingOption.HOT: praw_subreddit.hot, - SortingOption.NEW: praw_subreddit.new, - SortingOption.RANDOM_RISING: praw_subreddit.random_rising, - SortingOption.RISING: praw_subreddit.rising, - SortingOption.TOP: functools.partial(praw_subreddit.top, time_filter=time_filter), - } - - response_header = ( - [ - self._create_header( - template=custom_header, sorting=post_sorting, time=time_filter, subreddit=subreddit - ) - ] - if header - else [] + create_post_output(output_format, query_fn(limit=limit)), ) - posts: List[str] = self._create_post_output(output_format, call_map[post_sorting](limit=limit)) # type: ignore - - return response_header + posts - def main(): # pragma: no cover try: diff --git a/reddit_get/types/__init__.py b/reddit_get/types/__init__.py index b02681c..5b733f3 100644 --- a/reddit_get/types/__init__.py +++ b/reddit_get/types/__init__.py @@ -1 +1,23 @@ +from typing import ( + Any, + Dict, + Iterator, + List, + Optional, +) + +try: + from typing import Protocol +except ImportError: # pragma: no cover + from typing_extensions import Protocol # type: ignore + from .enums import * + + +class PrawQuery(Protocol): # pragma: no cover + def __call__(self, limit: Optional[int]) -> Iterator[Any]: + ... + + +CallMap = Dict[SortingOption, PrawQuery] +Posts = List[str] diff --git a/reddit_get/types/enums.py b/reddit_get/types/enums.py index a980ec7..ff7f644 100644 --- a/reddit_get/types/enums.py +++ b/reddit_get/types/enums.py @@ -1,4 +1,7 @@ -from enum import Enum, EnumMeta +from enum import ( + Enum, + EnumMeta, +) class MetaEnum(EnumMeta): diff --git a/reddit_get/utils.py b/reddit_get/utils.py new file mode 100644 index 0000000..9a7444a --- /dev/null +++ b/reddit_get/utils.py @@ -0,0 +1,89 @@ +import functools +from pathlib import Path +from string import Formatter +from typing import ( + Iterator, + List, + Optional, + Set, +) + +import fire +import toml +from praw.models import ( + Submission, + Subreddit, +) + +from .types import ( + CallMap, + PrawQuery, + SortingOption, + TimeFilterOption, +) + + +def load_configs(config): + config_path: Path = Path(config).expanduser() + try: + configs = toml.load(config_path) + except (FileNotFoundError, toml.TomlDecodeError): + raise fire.core.FireError(f'No valid TOML config found at {config_path}') + return config_path, configs + + +def get_reddit_query_function( + subreddit: Subreddit, time_filter: str = 'all', post_sorting: SortingOption = SortingOption.TOP +) -> PrawQuery: + call_map: CallMap = { + SortingOption.CONTROVERSIAL: functools.partial(subreddit.controversial, time_filter=time_filter), + SortingOption.GILDED: subreddit.gilded, + SortingOption.HOT: subreddit.hot, + SortingOption.NEW: subreddit.new, + SortingOption.RANDOM_RISING: subreddit.random_rising, + SortingOption.RISING: subreddit.rising, + SortingOption.TOP: functools.partial(subreddit.top, time_filter=time_filter), + } + try: + return call_map[post_sorting] + except KeyError: + raise fire.core.FireError(f'Invalid sorting option: {post_sorting}') + + +def get_response(header: str, posts: List[str]) -> List[str]: + response_header = [header] if header else [] + return response_header + posts + + +def get_time_filter_option(time_filter): + try: + time_filter = TimeFilterOption(time_filter) + except ValueError: + raise fire.core.FireError(f'{time_filter} is not a valid time filter option') + return time_filter + + +def get_post_sorting_option(post_sorting: str) -> SortingOption: + try: + return SortingOption(post_sorting) + except ValueError: + raise fire.core.FireError(f'{post_sorting} is not a valid sorting option.') + + +def get_template_keys(template: str) -> Optional[Set[str]]: + template_vars = {tup[1] for tup in Formatter().parse(template) if tup[1] and isinstance(tup[1], str)} + return template_vars or None + + +def create_post_output(template: str, posts: Iterator[Submission]) -> List[str]: + template_vars = get_template_keys(template) + if not template_vars: + raise fire.core.FireError('Your post output template did not have any items to be printed') + results = [] + for post in posts: + try: + format_params = {key: getattr(post, key) for key in template_vars} + results.append(template.format(**format_params)) + except AttributeError as e: + raise fire.core.FireError(e) + return results diff --git a/tests/conftest.py b/tests/conftest.py index 26b7414..28b84a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ class MockSubreddit: def __init__(self, display_name: str, *args, **kwargs): self.display_name = display_name - def __repr__(self): + def __repr__(self): # pragma: nocover return self.display_name def controversial(self, *args, **kwargs): diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..ffe502b --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,22 @@ +import fire +import pytest + +from reddit_get import ( + RedditCli, + get_post_sorting_option, + get_reddit_query_function, +) + + +class TestUtils: + class TestErrors: + class TestGetPostSortingOption: + def it_raises_a_fireerror_with_invalid_post_sorting(self): + with pytest.raises(fire.core.FireError): + get_post_sorting_option('invalid') + + class TestGetRedditQueryFunction: + def it_raises_a_fireerror_with_invalid_post_sorting(self, mock_reddit): + with pytest.raises(fire.core.FireError): + cli = RedditCli('tests/.exampleconfig') + get_reddit_query_function(subreddit=cli.reddit.subreddit, post_sorting='invalid')