diff --git a/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java b/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java index 703df9f65..fb82edd10 100644 --- a/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java +++ b/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java @@ -32,6 +32,7 @@ import org.springframework.social.connect.ConnectionKey; import org.springframework.social.connect.ConnectionRepository; import org.springframework.social.connect.DuplicateConnectionException; +import org.springframework.social.connect.NoSuchConnectionException; import org.springframework.social.connect.support.OAuth1ConnectionFactory; import org.springframework.social.connect.support.OAuth2ConnectionFactory; import org.springframework.stereotype.Controller; @@ -216,7 +217,7 @@ public RedirectView oauth1Callback(@PathVariable String providerId, NativeWebReq try { OAuth1ConnectionFactory connectionFactory = (OAuth1ConnectionFactory) connectionFactoryLocator.getConnectionFactory(providerId); Connection connection = webSupport.completeConnection(connectionFactory, request); - addConnection(connection, connectionFactory, request); + addOrUpdateConnection(connection, connectionFactory, request); } catch (Exception e) { request.setAttribute(PROVIDER_ERROR_ATTRIBUTE, e, RequestAttributes.SCOPE_SESSION); logger.warn("Exception while handling OAuth1 callback (" + e.getMessage() + "). Redirecting to " + providerId +" connection status page."); @@ -234,7 +235,7 @@ public RedirectView oauth2Callback(@PathVariable String providerId, NativeWebReq try { OAuth2ConnectionFactory connectionFactory = (OAuth2ConnectionFactory) connectionFactoryLocator.getConnectionFactory(providerId); Connection connection = webSupport.completeConnection(connectionFactory, request); - addConnection(connection, connectionFactory, request); + addOrUpdateConnection(connection, connectionFactory, request); } catch (Exception e) { request.setAttribute(PROVIDER_ERROR_ATTRIBUTE, e, RequestAttributes.SCOPE_SESSION); logger.warn("Exception while handling OAuth2 callback (" + e.getMessage() + "). Redirecting to " + providerId +" connection status page."); @@ -336,13 +337,19 @@ private String getViewPath() { return "connect/"; } - private void addConnection(Connection connection, ConnectionFactory connectionFactory, WebRequest request) { - try { - connectionRepository.addConnection(connection); - postConnect(connectionFactory, connection, request); - } catch (DuplicateConnectionException e) { - request.setAttribute(DUPLICATE_CONNECTION_ATTRIBUTE, e, RequestAttributes.SCOPE_SESSION); - } + private void addOrUpdateConnection(Connection connection, ConnectionFactory connectionFactory, WebRequest request) { + try { + connectionRepository.getConnection(connection.getKey()); + connectionRepository.updateConnection(connection); + postConnect(connectionFactory, connection, request); + } catch (NoSuchConnectionException ex) { + try { + connectionRepository.addConnection(connection); + postConnect(connectionFactory, connection, request); + } catch (DuplicateConnectionException e) { + request.setAttribute(DUPLICATE_CONNECTION_ATTRIBUTE, e, RequestAttributes.SCOPE_SESSION); + } + } } @SuppressWarnings({ "rawtypes", "unchecked" }) diff --git a/spring-social-web/src/test/java/org/springframework/social/connect/web/test/StubConnectionRepository.java b/spring-social-web/src/test/java/org/springframework/social/connect/web/test/StubConnectionRepository.java index 8ee3ed1b9..7aeab3c16 100644 --- a/spring-social-web/src/test/java/org/springframework/social/connect/web/test/StubConnectionRepository.java +++ b/spring-social-web/src/test/java/org/springframework/social/connect/web/test/StubConnectionRepository.java @@ -21,6 +21,7 @@ import org.springframework.social.connect.Connection; import org.springframework.social.connect.ConnectionKey; import org.springframework.social.connect.ConnectionRepository; +import org.springframework.social.connect.NoSuchConnectionException; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -47,7 +48,7 @@ public MultiValueMap> findConnectionsToUsers(MultiValueMap } public Connection getConnection(ConnectionKey connectionKey) { - return null; + throw new NoSuchConnectionException(connectionKey); } public Connection getConnection(Class apiType, String providerUserId) {