diff --git a/tests/test_injection.py b/tests/test_injection.py index 696dadd5..88ffa4bd 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -7,8 +7,8 @@ import parsel import pytest from andi.typeutils import strip_annotated -from pytest_twisted import inlineCallbacks -from scrapy import Request +from pytest_twisted import ensureDeferred, inlineCallbacks +from scrapy import Request, Spider, signals from scrapy.http import Response from url_matcher import Patterns from url_matcher.util import get_domain @@ -36,6 +36,7 @@ NonCallableProviderError, UndeclaredProvidedTypeError, ) +from scrapy_poet.utils.testing import make_crawler from .test_providers import Name, Price @@ -1005,3 +1006,162 @@ def test_dynamic_deps_factory_bad_input(): match=re.escape(r"Expected a dynamic dependency type, got (,)"), ): Injector._get_dynamic_deps_factory([(int,)]) + + +class BaseCbSpider(Spider): + + def start_requests(self): + kwargs = {} + if cb_kwargs := getattr(self, "cb_kwargs", None): + kwargs["cb_kwargs"] = cb_kwargs + yield Request("data:,", **kwargs) + + +class CbSpider1(BaseCbSpider): + + def parse(self, response): + yield {"success": True} + + +class CbSpider2(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, foo): + yield {"success": foo == "bar"} + + +class CbSpider3(BaseCbSpider): + + def parse(self, response, foo=None): + yield {"success": foo is None} + + +class CbSpider4(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, foo=None): + yield {"success": foo == "bar"} + + +VALUE = object() + + +class CbSpider5(BaseCbSpider): + + def parse(self, response, foo=VALUE): + yield {"success": foo is VALUE} + + +class CbSpider6(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, foo=VALUE): + yield {"success": foo == "bar"} + + +class Injected(str): + pass + + +INJECTED = Injected("baz") + + +def expected_injected(injected): + return isinstance(injected, Injected) and str(injected) == str(INJECTED) + + +class CbSpider7(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, injected: Injected, foo): + yield {"success": expected_injected(injected) and foo == "bar"} + + +class CbSpider8(BaseCbSpider): + + def parse(self, response, injected: Injected, foo=None): + yield {"success": expected_injected(injected) and foo is None} + + +class CbSpider9(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, injected: Injected, foo=None): + yield {"success": expected_injected(injected) and foo == "bar"} + + +class CbSpider10(BaseCbSpider): + + def parse(self, response, injected: Injected, foo=VALUE): + yield {"success": expected_injected(injected) and foo is VALUE} + + +class CbSpider11(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, injected: Injected, foo=VALUE): + yield {"success": expected_injected(injected) and foo == "bar"} + + +class CbSpider12(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, foo, injected: Injected): + yield {"success": expected_injected(injected) and foo == "bar"} + + +class CbSpider13(BaseCbSpider): + + def parse(self, response, foo=None, injected: Injected = None): + yield {"success": expected_injected(injected) and foo is None} + + +class CbSpider14(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, foo=None, injected: Injected = None): + yield {"success": expected_injected(injected) and foo == "bar"} + + +class CbSpider15(BaseCbSpider): + + def parse(self, response, foo=VALUE, injected: Injected = None): + yield {"success": expected_injected(injected) and foo is VALUE} + + +class CbSpider16(BaseCbSpider): + + cb_kwargs = {"foo": "bar"} + + def parse(self, response, foo=VALUE, injected: Injected = None): + yield {"success": expected_injected(injected) and foo == "bar"} + + +@pytest.mark.parametrize( + ("spider_cls",), + ((cls,) for cls in BaseCbSpider.__subclasses__()), +) +@ensureDeferred +async def test_callback_arg_mapping(spider_cls): + provider = get_provider({Injected}, str(INJECTED)) + settings = {"SCRAPY_POET_PROVIDERS": {provider: 500}} + crawler = make_crawler(spider_cls, settings) + + success_list = [] + + def track_item(item, response, spider): + success_list.append(item["success"]) + + crawler.signals.connect(track_item, signal=signals.item_scraped) + await crawler.crawl() + + assert len(success_list) == 1 + assert success_list[0]