diff --git a/spring-social-security/src/main/java/org/springframework/social/security/provider/OAuth1AuthenticationService.java b/spring-social-security/src/main/java/org/springframework/social/security/provider/OAuth1AuthenticationService.java index 6527785e3..b68d0d9a4 100644 --- a/spring-social-security/src/main/java/org/springframework/social/security/provider/OAuth1AuthenticationService.java +++ b/spring-social-security/src/main/java/org/springframework/social/security/provider/OAuth1AuthenticationService.java @@ -115,7 +115,7 @@ public SocialAuthenticationToken getAuthToken(HttpServletRequest request, HttpSe } protected String buildReturnToUrl(HttpServletRequest request) { - StringBuffer sb = request.getRequestURL(); + StringBuffer sb = getProxyHeaderAwareRequestURL(request); sb.append("?"); for (String name : getReturnToUrlParameters()) { @@ -134,6 +134,33 @@ protected String buildReturnToUrl(HttpServletRequest request) { return sb.toString(); } + protected StringBuffer getProxyHeaderAwareRequestURL(HttpServletRequest request) { + String host = request.getHeader("Host"); + if (StringUtils.isEmpty(host)) { + return request.getRequestURL(); + } + StringBuffer sb = new StringBuffer(); + String schemeHeader = request.getHeader("X-Forwarded-Proto"); + String portHeader = request.getHeader("X-Forwarded-Port"); + String scheme = StringUtils.isEmpty(schemeHeader) ? "http" : schemeHeader; + String port = StringUtils.isEmpty(portHeader) ? "80" : portHeader; + if (scheme.equals("http") && port.equals("80")){ + port = ""; + } + if (scheme.equals("https") && port.equals("443")){ + port = ""; + } + sb.append(scheme); + sb.append("://"); + sb.append(host); + if (StringUtils.hasLength(port)){ + sb.append(":"); + sb.append(port); + } + sb.append(request.getRequestURI()); + return sb; + } + private OAuthToken extractCachedRequestToken(HttpServletRequest request) { OAuthToken requestToken = (OAuthToken) request.getSession().getAttribute(OAUTH_TOKEN_ATTRIBUTE); request.getSession().removeAttribute(OAUTH_TOKEN_ATTRIBUTE);