diff --git a/tests/unit/accounts/test_views.py b/tests/unit/accounts/test_views.py index 80182a3a48d5..e21befcba4c3 100644 --- a/tests/unit/accounts/test_views.py +++ b/tests/unit/accounts/test_views.py @@ -3423,12 +3423,14 @@ def test_reauth(self, monkeypatch, pyramid_request, pyramid_services, next_route pyramid_request.matched_route = pretend.stub(name=pretend.stub()) pyramid_request.matchdict = {"foo": "bar"} pyramid_request.GET = pretend.stub(mixed=lambda: {"baz": "bar"}) + pyramid_request.params = {} form_obj = pretend.stub( next_route=pretend.stub(data=next_route), next_route_matchdict=pretend.stub(data="{}"), next_route_query=pretend.stub(data="{}"), validate=lambda: True, + password=pretend.stub(errors=[]), ) form_class = pretend.call_recorder(lambda d, **kw: form_obj) @@ -3460,6 +3462,65 @@ def test_reauth(self, monkeypatch, pyramid_request, pyramid_services, next_route ) ] + @pytest.mark.parametrize("next_route", [None, "/manage/accounts", "/projects/"]) + def test_reauth_with_password_error_in_query( + self, monkeypatch, pyramid_request, pyramid_services, next_route + ): + user_service = pretend.stub(get_password_timestamp=lambda uid: 0) + response = pretend.stub() + + monkeypatch.setattr(views, "HTTPSeeOther", lambda url: response) + + pyramid_services.register_service(user_service, IUserService, None) + + pyramid_request.route_path = lambda *args, **kwargs: pretend.stub() + pyramid_request.session.record_auth_timestamp = pretend.call_recorder( + lambda *args: None + ) + pyramid_request.session.record_password_timestamp = lambda ts: None + pyramid_request.user = pretend.stub(id=pretend.stub(), username=pretend.stub()) + pyramid_request.matched_route = pretend.stub(name=pretend.stub()) + pyramid_request.matchdict = {"foo": "bar"} + pyramid_request.GET = pretend.stub(mixed=lambda: {"baz": "bar"}) + + # Inject password error through query params + password_errors = ["The password is invalid. Try again."] + + form_obj = pretend.stub( + next_route=pretend.stub(data=next_route), + next_route_matchdict=pretend.stub(data="{}"), + next_route_query=pretend.stub(data="{}"), + validate=lambda: False, # Simulate form validation failure + password=pretend.stub(errors=password_errors), + ) + + form_class = pretend.call_recorder(lambda d, **kw: form_obj) + + if next_route is not None: + pyramid_request.method = "POST" + pyramid_request.POST["next_route"] = next_route + pyramid_request.POST["next_route_matchdict"] = "{}" + pyramid_request.POST["next_route_query"] = "{}" + + _ = views.reauthenticate(pyramid_request, _form_class=form_class) + + assert form_class.calls == [ + pretend.call( + pyramid_request.POST, + request=pyramid_request, + username=pyramid_request.user.username, + next_route=pyramid_request.matched_route.name, + next_route_matchdict=json.dumps(pyramid_request.matchdict), + next_route_query=json.dumps(pyramid_request.GET.mixed()), + action="reauthenticate", + user_service=user_service, + check_password_metrics_tags=[ + "method:reauth", + "auth_method:reauthenticate_form", + ], + ) + ] + def test_reauth_no_user(self, monkeypatch, pyramid_request): pyramid_request.user = None pyramid_request.route_path = pretend.call_recorder(lambda a: "/the-redirect") diff --git a/tests/unit/manage/test_init.py b/tests/unit/manage/test_init.py index 18e0395b618e..e478704173ef 100644 --- a/tests/unit/manage/test_init.py +++ b/tests/unit/manage/test_init.py @@ -10,6 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + import pretend import pytest @@ -66,6 +68,7 @@ def test_reauth(self, monkeypatch, require_reauth, needs_reauth_calls): session=pretend.stub( needs_reauthentication=pretend.call_recorder(lambda *args: True) ), + params={}, user=pretend.stub(username=pretend.stub()), matched_route=pretend.stub(name=pretend.stub()), matchdict={"foo": "bar"}, @@ -94,6 +97,136 @@ def view(context, request): assert view.calls == [] assert request.session.needs_reauthentication.calls == needs_reauth_calls + @pytest.mark.parametrize( + ("require_reauth", "needs_reauth_calls"), + [ + (True, [pretend.call(manage.DEFAULT_TIME_TO_REAUTH)]), + (666, [pretend.call(666)]), + ], + ) + def test_reauth_view_with_malformed_errors( + self, monkeypatch, require_reauth, needs_reauth_calls + ): + mock_user_service = pretend.stub() + response = pretend.stub() + + def mock_response(*args, **kwargs): + return {"mock_key": "mock_response"} + + def mock_form(*args, **kwargs): + return pretend.stub(password=pretend.stub(errors=[])) + + monkeypatch.setattr(manage, "render_to_response", mock_response) + monkeypatch.setattr(manage, "ReAuthenticateForm", mock_form) + + context = pretend.stub() + dummy_request = pretend.stub( + session=pretend.stub( + needs_reauthentication=pretend.call_recorder(lambda *a: True) + ), + params={"errors": "{this is not: valid json"}, + POST={}, + user=pretend.stub(username="fakeuser"), + matched_route=pretend.stub(name="fake.route"), + matchdict={"foo": "bar"}, + GET=pretend.stub(mixed=lambda: {"baz": "qux"}), + find_service=lambda service, context=None: mock_user_service, + ) + + @pretend.call_recorder + def view(context, request): + return response + + info = pretend.stub(options={"require_reauth": True}, exception_only=False) + + derived_view = manage.reauth_view(view, info) + + assert derived_view(context, dummy_request) is not response + assert mock_form().password.errors == [] + + def test_reauth_view_sets_errors(self, monkeypatch): + mock_field = pretend.stub(errors=[]) + form = pretend.stub(password=mock_field) + response = pretend.stub() + + monkeypatch.setattr(manage, "ReAuthenticateForm", lambda *a, **kw: form) + monkeypatch.setattr(manage, "render_to_response", lambda *a, **kw: {}) + + request = pretend.stub( + session=pretend.stub(needs_reauthentication=lambda *a: True), + params={ + "errors": json.dumps({"password": ["Invalid password"]}) + }, # mock errors + POST={}, + GET=pretend.stub(mixed=lambda: {}), + matched_route=pretend.stub(name="reauth"), + matchdict={}, + user=pretend.stub(username="tester"), + find_service=lambda *a, **kw: pretend.stub(), + ) + + context = pretend.stub() + info = pretend.stub(options={"require_reauth": True}) + + @pretend.call_recorder + def view(context, request): + return response + + wrapped = manage.reauth_view(view, info) + + wrapped(context, request) + + assert mock_field.errors == [ + "Invalid password" + ], f"Expected errors to be ['Invalid password'], but got {mock_field.errors}" + + def test_reauth_view_field_missing_or_no_errors(self, monkeypatch): + mock_user_service = pretend.stub() + response = pretend.stub() + + def mock_response(*args, **kwargs): + return {"mock_key": "mock_response"} + + class DummyField: + pass # No `errors` attribute + + class DummyForm: + def __init__(self, *args, **kwargs): + self.existing_field = DummyField() # Has no `.errors` + + monkeypatch.setattr(manage, "render_to_response", mock_response) + monkeypatch.setattr(manage, "ReAuthenticateForm", DummyForm) + + context = pretend.stub() + dummy_request = pretend.stub( + session=pretend.stub( + needs_reauthentication=pretend.call_recorder(lambda *a: True) + ), + params={ + "errors": json.dumps( + {"non_existing_field": ["err1"], "existing_field": ["err2"]} + ) + }, + POST={}, + user=pretend.stub(username="fakeuser"), + matched_route=pretend.stub(name="fake.route"), + matchdict={"foo": "bar"}, + GET=pretend.stub(mixed=lambda: {"baz": "qux"}), + find_service=lambda service, context=None: mock_user_service, + ) + + @pretend.call_recorder + def view(context, request): + return response + + info = pretend.stub(options={"require_reauth": True}, exception_only=False) + + derived_view = manage.reauth_view(view, info) + result = derived_view(context, dummy_request) + + assert isinstance(result, dict) + assert result["mock_key"] == "mock_response" + def test_includeme(monkeypatch): settings = { diff --git a/warehouse/accounts/views.py b/warehouse/accounts/views.py index 38c186dc9c62..b6acc023c913 100644 --- a/warehouse/accounts/views.py +++ b/warehouse/accounts/views.py @@ -1541,24 +1541,34 @@ def reauthenticate(request, _form_class=ReAuthenticateForm): ], ) - if form.next_route.data and form.next_route_matchdict.data: - redirect_to = request.route_path( - form.next_route.data, - **json.loads(form.next_route_matchdict.data) - | dict(_query=json.loads(form.next_route_query.data)), - ) - else: - redirect_to = request.route_path("manage.projects") + next_route = form.next_route.data or "manage.projects" + next_route_matchdict = json.loads(form.next_route_matchdict.data or "{}") + next_route_query = json.loads(form.next_route_query.data or "{}") - resp = HTTPSeeOther(redirect_to) + is_valid = form.validate() - if request.method == "POST" and form.validate(): + # Ensure errors don't persist across successful validations + next_route_query.pop("errors", None) + + if request.method == "POST" and is_valid: request.session.record_auth_timestamp() request.session.record_password_timestamp( user_service.get_password_timestamp(request.user.id) ) + else: + # Inject password errors into query if validation failed + if form.password.errors: + next_route_query["errors"] = json.dumps( + {"password": [str(e) for e in form.password.errors]} + ) - return resp + redirect_to = request.route_path( + next_route, + **next_route_matchdict, + _query=next_route_query, + ) + + return HTTPSeeOther(redirect_to) @view_defaults( diff --git a/warehouse/locale/messages.pot b/warehouse/locale/messages.pot index d374e0a1ec0c..9f1015f0d745 100644 --- a/warehouse/locale/messages.pot +++ b/warehouse/locale/messages.pot @@ -301,28 +301,28 @@ msgstr "" msgid "Please review our updated Terms of Service." msgstr "" -#: warehouse/accounts/views.py:1663 warehouse/accounts/views.py:1905 +#: warehouse/accounts/views.py:1673 warehouse/accounts/views.py:1915 #: warehouse/manage/views/__init__.py:1419 msgid "" "Trusted publishing is temporarily disabled. See https://pypi.org/help" "#admin-intervention for details." msgstr "" -#: warehouse/accounts/views.py:1684 +#: warehouse/accounts/views.py:1694 msgid "disabled. See https://pypi.org/help#admin-intervention for details." msgstr "" -#: warehouse/accounts/views.py:1700 +#: warehouse/accounts/views.py:1710 msgid "" "You must have a verified email in order to register a pending trusted " "publisher. See https://pypi.org/help#openid-connect for details." msgstr "" -#: warehouse/accounts/views.py:1713 +#: warehouse/accounts/views.py:1723 msgid "You can't register more than 3 pending trusted publishers at once." msgstr "" -#: warehouse/accounts/views.py:1728 warehouse/manage/views/__init__.py:1600 +#: warehouse/accounts/views.py:1738 warehouse/manage/views/__init__.py:1600 #: warehouse/manage/views/__init__.py:1715 #: warehouse/manage/views/__init__.py:1829 #: warehouse/manage/views/__init__.py:1941 @@ -331,29 +331,29 @@ msgid "" "again later." msgstr "" -#: warehouse/accounts/views.py:1738 warehouse/manage/views/__init__.py:1613 +#: warehouse/accounts/views.py:1748 warehouse/manage/views/__init__.py:1613 #: warehouse/manage/views/__init__.py:1728 #: warehouse/manage/views/__init__.py:1842 #: warehouse/manage/views/__init__.py:1954 msgid "The trusted publisher could not be registered" msgstr "" -#: warehouse/accounts/views.py:1753 +#: warehouse/accounts/views.py:1763 msgid "" "This trusted publisher has already been registered. Please contact PyPI's" " admins if this wasn't intentional." msgstr "" -#: warehouse/accounts/views.py:1780 +#: warehouse/accounts/views.py:1790 msgid "Registered a new pending publisher to create " msgstr "" -#: warehouse/accounts/views.py:1918 warehouse/accounts/views.py:1931 -#: warehouse/accounts/views.py:1938 +#: warehouse/accounts/views.py:1928 warehouse/accounts/views.py:1941 +#: warehouse/accounts/views.py:1948 msgid "Invalid publisher ID" msgstr "" -#: warehouse/accounts/views.py:1945 +#: warehouse/accounts/views.py:1955 msgid "Removed trusted publisher for project " msgstr "" diff --git a/warehouse/manage/__init__.py b/warehouse/manage/__init__.py index a06201ee7891..ea6353c53663 100644 --- a/warehouse/manage/__init__.py +++ b/warehouse/manage/__init__.py @@ -35,7 +35,6 @@ def reauth_view(view, info): def wrapped(context, request): if request.session.needs_reauthentication(time_to_reauth): user_service = request.find_service(IUserService, context=None) - form = ReAuthenticateForm( request.POST, request=request, @@ -45,6 +44,17 @@ def wrapped(context, request): next_route_query=json.dumps(request.GET.mixed()), user_service=user_service, ) + errors_param = request.params.get("errors") + if errors_param: + try: + parsed_errors = json.loads(errors_param) + for field_name, messages in parsed_errors.items(): + field = getattr(form, field_name, None) + if field is not None and hasattr(field, "errors"): + field.errors = list(messages) + except (ValueError, TypeError): + # log or ignore bad JSON + pass return render_to_response( "re-auth.html",