diff --git a/tests/unit/rate_limiting/test_core.py b/tests/unit/rate_limiting/test_core.py index 658a5e6ce8cc..7b4422ea39c6 100644 --- a/tests/unit/rate_limiting/test_core.py +++ b/tests/unit/rate_limiting/test_core.py @@ -16,9 +16,17 @@ import redis from limits import storage +from pyramid.httpexceptions import HTTPTooManyRequests from warehouse import rate_limiting -from warehouse.rate_limiting import DummyRateLimiter, RateLimit, RateLimiter +from warehouse.metrics.interfaces import IMetricsService +from warehouse.rate_limiting import ( + DummyRateLimiter, + IRateLimiter, + RateLimit, + RateLimiter, + ratelimit_tween_factory, +) class TestRateLimiter: @@ -175,14 +183,117 @@ def test_eq(self): assert RateLimit("1 per 5 minutes", identifiers=["foo"]) != object() +class TestRateLimiterTween: + def test_ratelimit_tween(self): + response = pretend.stub(headers={}) + handler = pretend.call_recorder(lambda request: response) + registry = pretend.stub() + tween = ratelimit_tween_factory(handler, registry) + + metrics_service = pretend.stub( + increment=pretend.call_recorder(lambda *a, **kw: None) + ) + ratelimiter_service = pretend.stub( + test=pretend.call_recorder(lambda a: True), + hit=pretend.call_recorder(lambda a: None), + ) + + request = pretend.stub( + remote_addr="192.0.2.1", + path="/project/foobar/", + find_service=pretend.call_recorder( + lambda *a, **kw: { + IMetricsService: metrics_service, + IRateLimiter: ratelimiter_service, + }[a[0]] + ), + ) + + assert tween(request) is response + + assert metrics_service.increment.calls == [] + assert ratelimiter_service.hit.calls == [pretend.call("192.0.2.1")] + assert ratelimiter_service.test.calls == [pretend.call("192.0.2.1")] + + def test_ratelimiter_tween_blocking(self): + response = pretend.stub(headers={}) + handler = pretend.call_recorder(lambda request: response) + registry = pretend.stub() + tween = ratelimit_tween_factory(handler, registry) + + metrics_service = pretend.stub( + increment=pretend.call_recorder(lambda *a, **kw: None) + ) + ratelimiter_service = pretend.stub( + test=pretend.call_recorder(lambda a: False), + hit=pretend.call_recorder(lambda a: None), + ) + + request = pretend.stub( + remote_addr="192.0.2.1", + path="/project/foobar/", + find_service=pretend.call_recorder( + lambda *a, **kw: { + IMetricsService: metrics_service, + IRateLimiter: ratelimiter_service, + }[a[0]] + ), + ) + + response = tween(request) + assert isinstance(response, HTTPTooManyRequests) + assert ( + response.message + == "Your IP has issued too many requests reaching the PyPI backends." + ) + + assert metrics_service.increment.calls == [ + pretend.call("warehouse.ratelimited", tags=["ratelimiter:ip.requests"]) + ] + assert ratelimiter_service.test.calls == [pretend.call("192.0.2.1")] + assert ratelimiter_service.hit.calls == [] + + def test_includeme(): registry = {} config = pretend.stub( registry=pretend.stub( - settings={"ratelimit.url": "memory://"}, __setitem__=registry.__setitem__ - ) + settings={ + "ratelimit.url": "memory://", + "warehouse.ip_requests_ratelimit_string": "1000 per second", + }, + __setitem__=registry.__setitem__, + ), + register_service_factory=pretend.call_recorder(lambda *a, **kw: None), + add_tween=pretend.call_recorder(lambda *a, **kw: None), + ) + + rate_limiting.includeme(config) + + assert config.register_service_factory.calls == [ + pretend.call(RateLimit("1000 per second"), IRateLimiter, name="ip.requests") + ] + assert config.add_tween.calls == [ + pretend.call("warehouse.rate_limiting.ratelimit_tween_factory") + ] + assert isinstance(registry["ratelimiter.storage"], storage.MemoryStorage) + + +def test_includeme_no_ip_requests_ratelimit(): + registry = {} + config = pretend.stub( + registry=pretend.stub( + settings={ + "ratelimit.url": "memory://", + }, + __setitem__=registry.__setitem__, + ), + register_service_factory=pretend.call_recorder(lambda *a, **kw: None), + add_tween=pretend.call_recorder(lambda *a, **kw: None), ) rate_limiting.includeme(config) + assert config.register_service_factory.calls == [] + assert config.add_tween.calls == [] assert isinstance(registry["ratelimiter.storage"], storage.MemoryStorage) diff --git a/warehouse/config.py b/warehouse/config.py index 1014a1ca6dbe..13baae49ebc4 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -482,6 +482,11 @@ def configure(settings=None): maybe_set(settings, "helpscout.mailbox_id", "HELPSCOUT_WAREHOUSE_MAILBOX_ID") # Configure our ratelimiters + maybe_set( + settings, + "warehouse.ip_requests_ratelimit_string", + "IP_REQUESTS_RATE_LIMIT_STRING", + ) maybe_set( settings, "warehouse.account.user_login_ratelimit_string", diff --git a/warehouse/rate_limiting/__init__.py b/warehouse/rate_limiting/__init__.py index d40dbe4bac1f..5b38270af3a4 100644 --- a/warehouse/rate_limiting/__init__.py +++ b/warehouse/rate_limiting/__init__.py @@ -21,6 +21,7 @@ from limits.storage import storage_from_string from limits.strategies import MovingWindowRateLimiter from more_itertools import first_true +from pyramid.httpexceptions import HTTPTooManyRequests from zope.interface import implementer from warehouse.metrics import IMetricsService @@ -162,7 +163,37 @@ def __eq__(self, other): ) +def ratelimit_tween_factory(handler, registry): + def ratelimit_tween(request): + ratelimiter = request.find_service( + IRateLimiter, name="ip.requests", context=None + ) + metrics = request.find_service(IMetricsService, context=None) + + if not ratelimiter.test(request.remote_addr): + metrics.increment("warehouse.ratelimited", tags=["ratelimiter:ip.requests"]) + return HTTPTooManyRequests( + "Your IP has issued too many requests reaching the PyPI backends." + ) + + ratelimiter.hit(request.remote_addr) + + response = handler(request) + return response + + return ratelimit_tween + + def includeme(config): config.registry["ratelimiter.storage"] = storage_from_string( config.registry.settings["ratelimit.url"] ) + + ip_requests_ratelimit_string = config.registry.settings.get( + "warehouse.ip_requests_ratelimit_string" + ) + if ip_requests_ratelimit_string is not None: + config.register_service_factory( + RateLimit(ip_requests_ratelimit_string), IRateLimiter, name="ip.requests" + ) + config.add_tween("warehouse.rate_limiting.ratelimit_tween_factory")