diff --git a/scrapy_poet/overrides.py b/scrapy_poet/overrides.py index 72893e51..27d3170d 100644 --- a/scrapy_poet/overrides.py +++ b/scrapy_poet/overrides.py @@ -1,8 +1,12 @@ +import re +from re import Pattern from abc import ABC, abstractmethod -from typing import Dict, Mapping, Callable, Optional, List +import bisect +from collections import defaultdict +from typing import Dict, Mapping, Callable, Optional, List, Union, Tuple +import attr from marisa_trie import Trie - from scrapy import Request from scrapy.crawler import Crawler from scrapy_poet.utils import get_domain, url_hierarchical_str @@ -142,3 +146,94 @@ def overrides_for(self, request: Request) -> Mapping[Callable, Callable]: return self.overrides[self.trie[max_prefix]].overrides else: return {} + + + +@attr.s(auto_attribs=True, order=False) +class RegexOverridesRecord: + """ + Keep a reverse ordering on hurl. This is required to prioritize the more + especific rules over the less especific ones using the hierarchy determined + by the hierarchical url. + """ + hurl: str = attr.ib(eq=False) + regex: str + overrides: Mapping[Callable, Callable] = attr.ib(eq=False) + re: Pattern = attr.ib(init=False, eq=False) + + def __attrs_post_init__(self): + self.re = re.compile(self.regex) + + def __gt__(self, other): + return self.hurl < other.hurl + + def __lt__(self, other): + return self.hurl > other.hurl + + def __ge__(self, other): + return self.hurl <= other.hurl + + def __le__(self, other): + return self.hurl >= other.hurl + + +RuleType = Union[str, Tuple[str, str]] + + +class RegexOverridesRegistry(OverridesRegistryBase): + def __init__(self, all_overrides: Optional[Mapping[RuleType, Mapping[Callable, Callable]]] = None) -> None: + super().__init__() + self.rules = defaultdict(list) + for rule, overrides in (all_overrides or {}).items(): + if isinstance(rule, tuple): + domain, regex = rule + self.register_regex(domain, regex, overrides) + else: + self.register(rule, overrides) + + def register_regex(self, domain: str, regex: str, overrides: Mapping[Callable, Callable]): + record = RegexOverridesRecord("\ue83a", regex, overrides) + self._insert(domain, record) + + def register(self, domain_or_more: str, overrides: Mapping[Callable, Callable]): + if domain_or_more.strip() == "": + self.register_regex("", r".*", overrides) + return + + url = f"http://{domain_or_more}" + domain = get_domain(url) + hurl = url_hierarchical_str(url) + regex = domain_or_more_regex(domain_or_more) + record = RegexOverridesRecord(hurl, regex, overrides) + self._insert(domain, record) + + def _insert(self, domain: str, record: RegexOverridesRecord): + records = self.rules[domain] + try: + del records[records.index(record)] + except ValueError: + ... + bisect.insort(records, record) + + @classmethod + def from_crawler(cls, crawler: Crawler): + return cls(crawler.settings.getdict("SCRAPY_POET_OVERRIDES", {})) + + def overrides_for(self, request: Request) -> Mapping[Callable, Callable]: + rules = self.rules.get(get_domain(request.url)) or self.rules.get("", {}) + for record in rules: + if record.re.match(request.url): + return record.overrides + return {} + + +def domain_or_more_regex(domain_or_more: str) -> str: + """ + Return a regex that matches urls belonging to the set represented by + the given `domain_or_more` rule + """ + if domain_or_more.endswith("/"): + domain_or_more = domain_or_more[:-1] + if domain_or_more.strip() == "": + return r"https?://.*" + return r"https?://(?:.+\.)?" + re.escape(domain_or_more) + r".*" \ No newline at end of file diff --git a/tests/test_regex_overrides.py b/tests/test_regex_overrides.py new file mode 100644 index 00000000..bc71f950 --- /dev/null +++ b/tests/test_regex_overrides.py @@ -0,0 +1,150 @@ +import re +from typing import Mapping + +import pytest + +from scrapy import Request, Spider +from scrapy.utils.test import get_crawler +from scrapy_poet.overrides import RegexOverridesRegistry, \ + PerDomainOverridesRegistry, domain_or_more_regex + + +class _str(str, Mapping): # type: ignore + """Trick to use strings as overrides dicts for testing""" + ... + + +def _r(url: str): + return Request(url) + + +@pytest.fixture +def reg(): + return RegexOverridesRegistry() + + +class TestRegexOverridesRegistry: + + def test_replace(self, reg): + reg.register("toscrape.com", _str("ORIGINAL")) + assert reg.overrides_for(_r("http://toscrape.com:442/path")) == "ORIGINAL" + reg.register("toscrape.com", _str("REPLACED")) + assert reg.overrides_for(_r("http://www.toscrape.com/path")) == "REPLACED" + assert len(reg.rules) == 1 + + def test_init_and_global(self): + overrides = { + "": _str("GLOBAL"), + "toscrape.com": _str("TOSCRAPE") + } + reg = RegexOverridesRegistry(overrides) + assert reg.overrides_for(_r("http://example.com/blabla")) == "GLOBAL" + assert reg.overrides_for(_r("http://toscrape.com/blabla")) == "TOSCRAPE" + + def test_register(self, reg): + assert reg.overrides_for(_r("http://books.toscrape.com/")) == {} + + reg.register("books.toscrape.com", _str("BOOKS_TO_SCRAPE")) + assert reg.overrides_for(_r("http://books.toscrape.com/")) == "BOOKS_TO_SCRAPE" + assert reg.overrides_for(_r("http://books.toscrape.com/path")) == "BOOKS_TO_SCRAPE" + assert reg.overrides_for(_r("http://toscrape.com/")) == {} + + reg.register("toscrape.com", _str("TO_SCRAPE")) + assert reg.overrides_for(_r("http://books.toscrape.com/")) == "BOOKS_TO_SCRAPE" + assert reg.overrides_for(_r("http://books.toscrape.com/path")) == "BOOKS_TO_SCRAPE" + assert reg.overrides_for(_r("http://toscrape.com/")) == "TO_SCRAPE" + assert reg.overrides_for(_r("http://www.toscrape.com/")) == "TO_SCRAPE" + assert reg.overrides_for(_r("http://toscrape.com/path")) == "TO_SCRAPE" + assert reg.overrides_for(_r("http://zz.com")) == {} + + reg.register("books.toscrape.com/category/books/classics_6/", _str("CLASSICS")) + assert reg.overrides_for(_r("http://books.toscrape.com/path?arg=1")) == "BOOKS_TO_SCRAPE" + assert reg.overrides_for(_r("http://toscrape.com")) == "TO_SCRAPE" + assert reg.overrides_for(_r("http://aa.com")) == {} + assert reg.overrides_for( + _r("https://books.toscrape.com/category/books/classics_6")) == "CLASSICS" + assert reg.overrides_for( + _r("http://books.toscrape.com/category/books/classics_6/path")) == "CLASSICS" + assert reg.overrides_for( + _r("http://books.toscrape.com/category/books/")) == "BOOKS_TO_SCRAPE" + + def test_from_crawler(self): + crawler = get_crawler(Spider) + reg = RegexOverridesRegistry.from_crawler(crawler) + assert len(reg.rules) == 0 + + settings = { + "SCRAPY_POET_OVERRIDES": { + "toscrape.com": _str("TOSCRAPE") + } + } + crawler = get_crawler(Spider, settings) + reg = RegexOverridesRegistry.from_crawler(crawler) + assert len(reg.rules) == 1 + assert reg.overrides_for(_r("http://toscrape.com/path")) == "TOSCRAPE" + + def test_domain_subdomain_case(self, reg): + reg.register("toscrape.com", _str("DOMAIN")) + reg.register("books.toscrape.com", _str("SUBDOMAIN")) + assert reg.overrides_for(_r("http://toscrape.com/blabla")) == "DOMAIN" + assert reg.overrides_for(_r("http://cars.toscrape.com/")) == "DOMAIN" + assert reg.overrides_for(_r("http://books2.toscrape.com:123/blabla")) == "DOMAIN" + assert reg.overrides_for(_r("https://mybooks.toscrape.com/blabla")) == "DOMAIN" + assert reg.overrides_for(_r("http://books.toscrape.com/blabla")) == "SUBDOMAIN" + assert reg.overrides_for(_r("http://www.books.toscrape.com")) == "SUBDOMAIN" + assert reg.overrides_for(_r("http://uk.books.toscrape.com/blabla")) == "SUBDOMAIN" + + def test_common_prefix_domains(self, reg): + reg.register("toscrape.com", _str("TOSCRAPE")) + reg.register("toscrape2.com", _str("TOSCRAPE2")) + assert reg.overrides_for(_r("http://toscrape.com/blabla")) == "TOSCRAPE" + assert reg.overrides_for(_r("http://toscrape2.com")) == "TOSCRAPE2" + + +class TestPerDomainOverridesRegistry: + + def test(self): + settings = { + "SCRAPY_POET_OVERRIDES": { + "toscrape.com": _str("TOSCRAPE") + } + } + crawler = get_crawler(Spider, settings) + reg = PerDomainOverridesRegistry.from_crawler(crawler) + assert reg.overrides_for(_r("http://toscrape.com/path")) == "TOSCRAPE" + assert reg.overrides_for(_r("http://books.toscrape.com/path")) == "TOSCRAPE" + assert reg.overrides_for(_r("http://toscrape2.com/path")) == {} + + +@pytest.mark.parametrize("domain_or_more", + [ + "", "example.com:343", "example.com:343/", + "WWW.example.com:343/", + "www.EXAMPLE.com:343/?id=23", + "www.example.com:343/page?id=23", + "www.example.com:343/page?id=23;params#fragment", + "127.0.0.1:80/page?id=23;params#fragment", + "127.0.0.1:443/page?id=23;params#fragment", + "127.0.0.1:333/page?id=23;params#fragment" + ]) +def test_domain_or_more_regex(domain_or_more): + url = f"http://{domain_or_more}" + regex = domain_or_more_regex(domain_or_more) + + assert re.match(regex, url) + assert re.match(regex, f"https://{domain_or_more}") + assert re.match(regex, url + "a") + assert re.match(regex, url + "/") + assert re.match(regex, url + "/some_text") + assert re.match(regex, url + "some_other_text") + + if url[-1] == '/' and domain_or_more: + assert re.match(regex, url[:-1]) + assert not re.match(regex, url[:-2]) + url = url[:-1] # The polluted test requires the url without a last slash + else: + assert not re.match(regex, url[:-1]) + for i in range(len(url)): + # Modify a single character + polluted = url[:i] + chr(ord(url[i]) - 1) + url[i+1:] + assert not re.match(regex, polluted) \ No newline at end of file