diff --git a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java index e332dd0510..cf47180523 100644 --- a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java +++ b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java @@ -45,6 +45,12 @@ */ public class InvalidRequestFilter extends AccessControlFilter { + static enum PathTraversalBlockMode { + STRICT, + NORMAL, + NO_BLOCK; + } + private static final List SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B")); private static final List BACKSLASH = Collections.unmodifiableList(Arrays.asList("\\", "%5c", "%5C")); @@ -59,7 +65,7 @@ public class InvalidRequestFilter extends AccessControlFilter { private boolean blockNonAscii = true; - private boolean blockTraversal = true; + private PathTraversalBlockMode pathTraversalBlockMode = PathTraversalBlockMode.NORMAL; @Override protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception { @@ -117,7 +123,10 @@ private static boolean containsOnlyPrintableAsciiCharacters(String uri) { } private boolean containsTraversal(String uri) { - if (isBlockTraversal()) { + if (isBlockTraversalNormal()) { + return !(isNormalized(uri)); + } + if (isBlockTraversalStrict()) { return !(isNormalized(uri) && PERIOD.stream().noneMatch(uri::contains) && FORWARDSLASH.stream().noneMatch(uri::contains)); @@ -173,11 +182,24 @@ public void setBlockNonAscii(boolean blockNonAscii) { this.blockNonAscii = blockNonAscii; } - public boolean isBlockTraversal() { - return blockTraversal; + public boolean isBlockTraversalNormal() { + return pathTraversalBlockMode == PathTraversalBlockMode.NORMAL; + } + + public boolean isBlockTraversalStrict() { + return pathTraversalBlockMode == PathTraversalBlockMode.STRICT; } + public void setPathTraversalBlockMode(PathTraversalBlockMode mode) { + this.pathTraversalBlockMode = mode; + } + + /** + * + * @deprecated Use {@link #setPathTraversalBlockMode(PathTraversalBlockMode)} + */ + @Deprecated public void setBlockTraversal(boolean blockTraversal) { - this.blockTraversal = blockTraversal; + this.pathTraversalBlockMode = blockTraversal ? PathTraversalBlockMode.NORMAL : PathTraversalBlockMode.NO_BLOCK; } } diff --git a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy index fc61a7230f..19f7ab1a57 100644 --- a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy +++ b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy @@ -37,7 +37,7 @@ class InvalidRequestFilterTest { assertThat "filter.blockBackslash expected to be true", filter.isBlockBackslash() assertThat "filter.blockNonAscii expected to be true", filter.isBlockNonAscii() assertThat "filter.blockSemicolon expected to be true", filter.isBlockSemicolon() - assertThat "filter.blockTraversal expected to be true", filter.isBlockTraversal() + assertThat "filter.blockTraversal expected to be NORMAL", filter.isBlockTraversalNormal() } @Test @@ -75,29 +75,63 @@ class InvalidRequestFilterTest { } @Test - void testBlocksTraversal() { + void testBlocksTraversalNormal() { InvalidRequestFilter filter = new InvalidRequestFilter() assertPathBlocked(filter, "/something/../") assertPathBlocked(filter, "/something/../bar") assertPathBlocked(filter, "/something/../bar/") - assertPathBlocked(filter, "/something/%2e%2E/bar/") assertPathBlocked(filter, "/something/..") assertPathBlocked(filter, "/..") assertPathBlocked(filter, "..") assertPathBlocked(filter, "../") - assertPathBlocked(filter, "%2E./") - assertPathBlocked(filter, "%2F./") assertPathBlocked(filter, "/something/./") assertPathBlocked(filter, "/something/./bar") assertPathBlocked(filter, "/something/\u002e/bar") assertPathBlocked(filter, "/something/./bar/") - assertPathBlocked(filter, "/something/%2e/bar/") - assertPathBlocked(filter, "/something/%2f/bar/") assertPathBlocked(filter, "/something/.") assertPathBlocked(filter, "/.") assertPathBlocked(filter, "/something/../something/.") assertPathBlocked(filter, "/something/../something/.") + + assertPathAllowed(filter, "%2E./") + assertPathAllowed(filter, "%2F./") + assertPathAllowed(filter, "/something/%2e/bar/") + assertPathAllowed(filter, "/something/%2f/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathAllowed(filter, "/something/%2e%2E/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/") } + + @Test + void testBlocksTraversalStrict() { + InvalidRequestFilter filter = new InvalidRequestFilter() + filter.setPathTraversalBlockMode(InvalidRequestFilter.PathTraversalBlockMode.STRICT) + assertPathBlocked(filter, "/something/../") + assertPathBlocked(filter, "/something/../bar") + assertPathBlocked(filter, "/something/../bar/") + assertPathBlocked(filter, "/something/..") + assertPathBlocked(filter, "/..") + assertPathBlocked(filter, "..") + assertPathBlocked(filter, "../") + assertPathBlocked(filter, "/something/./") + assertPathBlocked(filter, "/something/./bar") + assertPathBlocked(filter, "/something/\u002e/bar") + assertPathBlocked(filter, "/something/./bar/") + assertPathBlocked(filter, "/something/.") + assertPathBlocked(filter, "/.") + assertPathBlocked(filter, "/something/../something/.") + assertPathBlocked(filter, "/something/../something/.") + + assertPathBlocked(filter, "%2E./") + assertPathBlocked(filter, "%2F./") + assertPathBlocked(filter, "/something/%2e/bar/") + assertPathBlocked(filter, "/something/%2f/bar/") + assertPathBlocked(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathBlocked(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathBlocked(filter, "/something/%2e%2E/bar/") + assertPathBlocked(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/") + } @Test void testFilterAllowsBackslash() { @@ -149,7 +183,7 @@ class InvalidRequestFilterTest { @Test void testAllowTraversal() { InvalidRequestFilter filter = new InvalidRequestFilter() - filter.setBlockTraversal(false) + filter.setPathTraversalBlockMode(InvalidRequestFilter.PathTraversalBlockMode.NO_BLOCK); assertPathAllowed(filter, "/something/../") assertPathAllowed(filter, "/something/../bar") @@ -158,18 +192,23 @@ class InvalidRequestFilterTest { assertPathAllowed(filter, "/..") assertPathAllowed(filter, "..") assertPathAllowed(filter, "../") - assertPathAllowed(filter, "%2E./") - assertPathAllowed(filter, "%2F./") assertPathAllowed(filter, "/something/./") assertPathAllowed(filter, "/something/./bar") assertPathAllowed(filter, "/something/\u002e/bar") assertPathAllowed(filter, "/something\u002fbar") assertPathAllowed(filter, "/something/./bar/") - assertPathAllowed(filter, "/something/%2e/bar/") - assertPathAllowed(filter, "/something/%2f/bar/") assertPathAllowed(filter, "/something/.") assertPathAllowed(filter, "/.") assertPathAllowed(filter, "/something/../something/.") + + assertPathAllowed(filter, "%2E./") + assertPathAllowed(filter, "%2F./") + assertPathAllowed(filter, "/something/%2e/bar/") + assertPathAllowed(filter, "/something/%2f/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain.example.com%2foidc/bar/") + assertPathAllowed(filter, "/something/%2e%2E/bar/") + assertPathAllowed(filter, "/something/http:%2f%2fmydomain%2eexample%2ecom%2foidc/bar/") } static void assertPathBlocked(InvalidRequestFilter filter, String requestUri, String servletPath = requestUri, String pathInfo = null) {