diff --git a/src/requests/sessions.py b/src/requests/sessions.py index 731550de88..4d3188696a 100644 --- a/src/requests/sessions.py +++ b/src/requests/sessions.py @@ -16,7 +16,7 @@ from .auth import _basic_auth_str from .compat import Mapping, cookielib, urljoin, urlparse from .cookies import ( - RequestsCookieJar, + _copy_cookie_jar, cookiejar_from_dict, extract_cookies_to_jar, merge_cookies, @@ -471,9 +471,8 @@ def prepare_request(self, request): cookies = cookiejar_from_dict(cookies) # Merge with session cookies - merged_cookies = merge_cookies( - merge_cookies(RequestsCookieJar(), self.cookies), cookies - ) + session_cookies = _copy_cookie_jar(self.cookies) + merged_cookies = merge_cookies(session_cookies, cookies) # Set environment's basic authentication if not explicitly set. auth = request.auth diff --git a/tests/test_requests.py b/tests/test_requests.py index 75d2deff2e..5aa9ee8c7f 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -469,6 +469,33 @@ def test_cookielib_cookiejar_on_redirect(self, httpbin): assert cookies["foo"] == "bar" assert cookies["cookie"] == "tasty" + @pytest.mark.parametrize("jar", ( + requests.cookies.RequestsCookieJar(), + cookielib.CookieJar() + )) + def test_custom_cookie_policy_is_persisted(self, httpbin, jar): + """Verify a custom CookiePolicy is propagated on each session request.""" + + class TestCookiePolicy(cookielib.DefaultCookiePolicy): + """Policy to restrict all cookies from localhost (127.0.0.1).""" + def __init__(self): + cookielib.DefaultCookiePolicy.__init__(self, blocked_domains=['127.0.0.1']) + + # Establish session with jar and set some cookies. + s = requests.session() + s.cookies = jar + s.get(httpbin("cookies/set?k1=v1&k2=v2")) + assert len(s.cookies) == 2 + + # Set policy + s.cookies.set_policy(TestCookiePolicy()) + + # No cookies were sent to our blocked domain and none were set. + resp = s.get(httpbin("cookies/set?k3=v3")) + assert "Cookie" not in resp.request.headers + assert len(s.cookies) == 2 + assert "k3" not in s.cookies + def test_requests_in_history_are_not_overridden(self, httpbin): resp = requests.get(httpbin("redirect/3")) urls = [r.url for r in resp.history]