diff --git a/src/main/java/bootstrap/WebServer.java b/src/main/java/bootstrap/WebServer.java index 9c9926c84..95f524e1f 100644 --- a/src/main/java/bootstrap/WebServer.java +++ b/src/main/java/bootstrap/WebServer.java @@ -5,7 +5,9 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import config.DependencyLoader; +import config.AppConfig; +import config.FilterConfig; +import config.SecurityConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import web.dispatch.ConnectionHandler; @@ -13,7 +15,9 @@ public class WebServer { private static final Logger logger = LoggerFactory.getLogger(WebServer.class); private static final int DEFAULT_PORT = 8080; - private static final DependencyLoader LOADER = new DependencyLoader(); + private static final AppConfig LOADER = new AppConfig(); + private static final SecurityConfig securityConfig = new SecurityConfig(); + private static final FilterConfig filterConfig = new FilterConfig(); private static final ExecutorService executor = Executors.newFixedThreadPool(32); public static void main(String args[]) throws Exception { @@ -23,6 +27,7 @@ public static void main(String args[]) throws Exception { } else { port = Integer.parseInt(args[0]); } + config(); // 서버소켓을 생성한다. 웹서버는 기본적으로 8080번 포트를 사용한다. try (ServerSocket listenSocket = new ServerSocket(port)) { @@ -33,14 +38,20 @@ public static void main(String args[]) throws Exception { while ((connection = listenSocket.accept()) != null) { Socket singleConnection = connection; executor.submit(() -> { - ConnectionHandler connectionHandler = new ConnectionHandler(LOADER.dispatcher, - LOADER.exceptionHandlerMapping, - LOADER.httpResponseConverter, - LOADER.httpRequestConverter, + ConnectionHandler connectionHandler = new ConnectionHandler( + LOADER.filterChainContainer(), + LOADER.exceptionHandlerMapping(), + LOADER.httpResponseConverter(), + LOADER.httpRequestConverter(), singleConnection); connectionHandler.run(); }); } } } + + private static void config(){ + securityConfig.config(); + filterConfig.config(); + } } diff --git a/src/main/java/config/AppConfig.java b/src/main/java/config/AppConfig.java index 1541f42b7..540260975 100644 --- a/src/main/java/config/AppConfig.java +++ b/src/main/java/config/AppConfig.java @@ -18,6 +18,9 @@ import web.dispatch.argument.ArgumentResolver; import web.dispatch.argument.resolver.HttpRequestResolver; import web.dispatch.argument.resolver.QueryParamsResolver; +import web.filter.AccessLogFilter; +import web.filter.FilterChainContainer; +import web.filter.RestrictedFilter; import web.handler.StaticContentHandler; import web.handler.WebHandler; import web.renderer.HttpResponseRenderer; @@ -209,4 +212,20 @@ public ErrorExceptionHandler errorExceptionHandler() { ErrorExceptionHandler::new ); } + + /** + * ===== Filter ===== + */ + public FilterChainContainer filterChainContainer(){ + return getOrCreate("filterChainContainer", + () -> new FilterChainContainer(dispatcher())); + } + + public AccessLogFilter accessLogFilter(){ + return getOrCreate("accessLogFilter", AccessLogFilter::new); + } + + public RestrictedFilter restrictedFilter(){ + return getOrCreate("restrictedFilter", RestrictedFilter::new); + } } diff --git a/src/main/java/config/DependencyLoader.java b/src/main/java/config/DependencyLoader.java deleted file mode 100644 index 5db9b30bf..000000000 --- a/src/main/java/config/DependencyLoader.java +++ /dev/null @@ -1,23 +0,0 @@ -package config; - -import exception.ExceptionHandlerMapping; -import http.request.HttpRequestConverter; -import http.response.HttpResponseConverter; -import web.dispatch.Dispatcher; - -public class DependencyLoader { - private final AppConfig appConfig; - - public final HttpRequestConverter httpRequestConverter; - public final HttpResponseConverter httpResponseConverter; - public final ExceptionHandlerMapping exceptionHandlerMapping; - public final Dispatcher dispatcher; - - public DependencyLoader(){ - this.appConfig = new AppConfig(); - this.httpRequestConverter = appConfig.httpRequestConverter(); - this.httpResponseConverter = appConfig.httpResponseConverter(); - this.exceptionHandlerMapping = appConfig.exceptionHandlerMapping(); - this.dispatcher = appConfig.dispatcher(); - } -} diff --git a/src/main/java/config/FilterConfig.java b/src/main/java/config/FilterConfig.java new file mode 100644 index 000000000..4f7005056 --- /dev/null +++ b/src/main/java/config/FilterConfig.java @@ -0,0 +1,56 @@ +package config; + +import exception.ErrorException; +import web.filter.ServletFilter; + +import java.util.ArrayList; +import java.util.List; + +public class FilterConfig extends SingletonContainer { + private final AppConfig appConfig = new AppConfig(); + private int callCount = 0; + + public void config(){ + if(callCount>0) throw new ErrorException("FilterConfig::set: Duplicated call"); + setFilterChains(); + callCount++; + } + + private void setFilterChains(){ + appConfig.filterChainContainer() + .addFilterList(FilterType.ALL, getFilterListByAuthorityType(FilterType.ALL)) + .addFilterList(FilterType.PUBLIC, getFilterListByAuthorityType(FilterType.PUBLIC)) + .addFilterList(FilterType.AUTHENTICATED, getFilterListByAuthorityType(FilterType.AUTHENTICATED)) + .addFilterList(FilterType.RESTRICT, getFilterListByAuthorityType(FilterType.RESTRICT)); + } + + private List commonFrontFilter(){ + return getOrCreate("commonFrontFilter", + () -> List.of( + appConfig.accessLogFilter() + )); + } + + private List commonBackFilter(){ + return getOrCreate("commonBackFilter", + () -> List.of()); + } + + private List getFilterListByAuthorityType(FilterType type) { + List servletFilterList = new ArrayList<>(); + servletFilterList.addAll(commonFrontFilter()); + servletFilterList.addAll(authorizedFilterList(type)); + servletFilterList.addAll(commonBackFilter()); + return servletFilterList; + } + + private List authorizedFilterList(FilterType type) { + return switch (type) { + case ALL -> List.of(); + case PUBLIC -> List.of(); + case AUTHENTICATED -> List.of(); + case RESTRICT -> List.of(appConfig.restrictedFilter()); + case LOG_IN -> List.of(); + }; + } +} diff --git a/src/main/java/config/FilterType.java b/src/main/java/config/FilterType.java new file mode 100644 index 000000000..49c433b25 --- /dev/null +++ b/src/main/java/config/FilterType.java @@ -0,0 +1,9 @@ +package config; + +public enum FilterType { + ALL, + PUBLIC, + AUTHENTICATED, + RESTRICT, + LOG_IN +} diff --git a/src/main/java/config/SecurityConfig.java b/src/main/java/config/SecurityConfig.java new file mode 100644 index 000000000..514218be2 --- /dev/null +++ b/src/main/java/config/SecurityConfig.java @@ -0,0 +1,23 @@ +package config; + +import exception.ErrorException; + +public class SecurityConfig extends SingletonContainer { + private final AppConfig appConfig = new AppConfig(); + private int callCount; + + public void config(){ + if(callCount>0) throw new ErrorException("SecurityConfig::setPaths: Duplicated call"); + setPaths(); + callCount++; + } + + public void setPaths(){ + appConfig.filterChainContainer() + .addPath(FilterType.AUTHENTICATED, "/mypage/**") + .addPath(FilterType.ALL, "/user/**") + .addPath(FilterType.PUBLIC, "/**"); + } + + +} diff --git a/src/main/java/http/HttpStatus.java b/src/main/java/http/HttpStatus.java index af8b95af0..78836466f 100644 --- a/src/main/java/http/HttpStatus.java +++ b/src/main/java/http/HttpStatus.java @@ -1,6 +1,8 @@ package http; public enum HttpStatus { + NONE(0), + OK(200), CREATED(201), ACCEPTED(202), diff --git a/src/main/java/http/request/HttpRequest.java b/src/main/java/http/request/HttpRequest.java index 824551de3..cf0c025bf 100644 --- a/src/main/java/http/request/HttpRequest.java +++ b/src/main/java/http/request/HttpRequest.java @@ -7,10 +7,7 @@ import java.net.InetAddress; import java.net.URI; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; public class HttpRequest { private final HttpMethod method; @@ -20,6 +17,7 @@ public class HttpRequest { private final URI uri; private String contentType; private byte[] body; + private UUID rid; private InetAddress requestAddress; @@ -102,4 +100,14 @@ public InetAddress getRequestAddress() { public void setBody(byte[] body){ this.body = body; } + + public String getOrGenerateRid(){ + if(this.rid == null) + this.rid = UUID.randomUUID(); + return this.rid.toString(); + } + + public UUID getRid() { + return rid; + } } diff --git a/src/main/java/http/response/HttpResponse.java b/src/main/java/http/response/HttpResponse.java index 9e11f9620..cd088ab0f 100644 --- a/src/main/java/http/response/HttpResponse.java +++ b/src/main/java/http/response/HttpResponse.java @@ -12,8 +12,8 @@ import java.util.Map; public class HttpResponse { - private final HttpStatus status; private final Map> headers; + private HttpStatus status; private byte[] body; private HttpResponse (HttpStatus status){ @@ -28,6 +28,14 @@ public static HttpResponse of (HttpStatus status){ return new HttpResponse(status); } + public static HttpResponse of (){ + return new HttpResponse(HttpStatus.NONE); + } + + public void setStatus(HttpStatus status) { + this.status = status; + } + public HttpStatus getStatus() { return status; } diff --git a/src/main/java/web/dispatch/ConnectionHandler.java b/src/main/java/web/dispatch/ConnectionHandler.java index 0ea2ae576..c29558c8b 100644 --- a/src/main/java/web/dispatch/ConnectionHandler.java +++ b/src/main/java/web/dispatch/ConnectionHandler.java @@ -5,22 +5,23 @@ import http.request.HttpRequestConverter; import http.request.HttpRequest; import http.response.HttpResponse; +import web.filter.FilterChainContainer; import java.net.Socket; public class ConnectionHandler implements Runnable{ private final Socket connection; + private final FilterChainContainer filterChainContainer; private final HttpRequestConverter requestConverter; private final HttpResponseConverter responseConverter; private final ExceptionHandlerMapping exceptionHandlerMapping; - private final Dispatcher dispatcher; - public ConnectionHandler(Dispatcher dispatcher, + public ConnectionHandler(FilterChainContainer filterChainContainer, ExceptionHandlerMapping exceptionHandlerMapping, HttpResponseConverter responseConverter, HttpRequestConverter requestConverter, Socket connection) { - this.dispatcher = dispatcher; + this.filterChainContainer = filterChainContainer; this.exceptionHandlerMapping = exceptionHandlerMapping; this.responseConverter = responseConverter; this.requestConverter = requestConverter; @@ -30,22 +31,23 @@ public ConnectionHandler(Dispatcher dispatcher, @Override public void run() { + HttpResponse response = HttpResponse.of(); try { HttpRequest request = requestConverter.parseRequest(connection); - HttpResponse response = dispatcher.handle(request); + filterChainContainer.runFilterChain(request,response); responseConverter.sendResponse(response, connection); - } catch (Exception e){ + } catch (Throwable t){ /** * TODO: * ExceptionHandler 또한 HttpResponse를 반환하게 하고 * finally에 `responseConverter.sendResponse(response, connection);` 를 넣어 * socket에 write를 하는 포인트를 단일 포인트로 관리 */ - exceptionHandlerMapping.handle(e, connection); + exceptionHandlerMapping.handle(t, connection); } finally { - try { connection.close(); } catch (Exception ignore) {} + try { connection.close(); } catch (Throwable ignore) {} } } } diff --git a/src/main/java/web/dispatch/Dispatcher.java b/src/main/java/web/dispatch/Dispatcher.java index 0d55440b0..1a6c136ae 100644 --- a/src/main/java/web/dispatch/Dispatcher.java +++ b/src/main/java/web/dispatch/Dispatcher.java @@ -28,7 +28,7 @@ public Dispatcher(List handlerMapping, List adapterL handlerMapping.forEach(hm -> this.handlerMapping.get(hm.getMethod()).add(hm)); } - public HttpResponse handle(HttpRequest request){ + public HttpResponse handle(HttpRequest request, HttpResponse response){ logger.debug("{}: {} - {} from {}", request.getMethod(), request.getPath(), request.getQueryString(), request.getRequestAddress()); @@ -41,11 +41,11 @@ public HttpResponse handle(HttpRequest request){ HandlerAdapter adapter = adapterList.stream().filter(ha -> ha.support(handler)) .findFirst().orElseThrow(() -> new ErrorException("DispatcherError: No adapter matched")); - HandlerResponse response = adapter.handle(request, handler); + HandlerResponse handlerResponse = adapter.handle(request, handler); HttpResponseRenderer responseHandler = responseHandlerList.stream() - .filter(rh -> rh.supports(response)) - .findFirst().orElseThrow(()-> new ErrorException("Post handler not exists")); - return responseHandler.handle(response); + .filter(rh -> rh.supports(handlerResponse)) + .findFirst().orElseThrow(()-> new ErrorException("Renderer not exists")); + return responseHandler.handle(response, handlerResponse); } } diff --git a/src/main/java/web/filter/AccessLogFilter.java b/src/main/java/web/filter/AccessLogFilter.java new file mode 100644 index 000000000..2cd4e2922 --- /dev/null +++ b/src/main/java/web/filter/AccessLogFilter.java @@ -0,0 +1,21 @@ +package web.filter; + +import http.request.HttpRequest; +import http.response.HttpResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class AccessLogFilter implements ServletFilter { + + private static final Logger log = LoggerFactory.getLogger(AccessLogFilter.class); + + @Override + public void runFilter(HttpRequest request, HttpResponse response, FilterChainContainer.FilterChainEngine chain) { + chain.doFilter(); + log.info("rid-{}: {} {} from {}", + request.getOrGenerateRid(), + request.getMethod(), + request.getPath(), + request.getRequestAddress()); + } +} diff --git a/src/main/java/web/filter/FilterChainContainer.java b/src/main/java/web/filter/FilterChainContainer.java new file mode 100644 index 000000000..826a9fc3c --- /dev/null +++ b/src/main/java/web/filter/FilterChainContainer.java @@ -0,0 +1,131 @@ +package web.filter; + +import config.FilterType; +import exception.ErrorCode; +import exception.ErrorException; +import exception.ServiceException; +import http.request.HttpRequest; +import http.response.HttpResponse; +import web.dispatch.Dispatcher; + +import java.util.*; + +public class FilterChainContainer { + private final List registeredPaths; + private final Map> filterChainMap; + private final Dispatcher dispatcher; + + public FilterChainContainer(Dispatcher dispatcher) { + this.registeredPaths = new ArrayList<>(); + this.filterChainMap = new HashMap<>(); + this.dispatcher = dispatcher; + } + + public void runFilterChain(HttpRequest request, HttpResponse response) { + String requestedPath = request.getPath(); + + FilterTypePaths matched = findMatchedChain(requestedPath).orElseThrow(()-> new ServiceException(ErrorCode.FORBIDDEN)); + List filters = filterChainMap.get(matched.type); + + FilterChainEngine engine = new FilterChainEngine(dispatcher, request, response, filters); + engine.doFilter(); + } + + public FilterChainContainer addFilterList(FilterType type, List filterList) { + if(filterChainMap.containsKey(type)) + throw new ErrorException("FilterChain Construction: Duplicate filter list per type"); + filterChainMap.put(type, filterList); + return this; + } + + public FilterChainContainer addPaths(FilterType type, List paths) { + registeredPaths.add(FilterTypePaths.of(type, paths)); + return this; + } + + public FilterChainContainer addPath(FilterType type, String path) { + registeredPaths.add(FilterTypePaths.of(type, List.of(path))); + return this; + } + + private Optional findMatchedChain(String requestedPath) { + for (FilterTypePaths filterTypePaths : registeredPaths) { + for (String path : filterTypePaths.paths) { + if (isMatched(requestedPath, path)) { + return Optional.of(filterTypePaths); + } + } + } + return Optional.empty(); + } + + private boolean isMatched(String requestedPath, String filterPath) { + if (filterPath == null || filterPath.isBlank()) return false; + if (requestedPath == null) return false; + + if (!requestedPath.startsWith("/")) requestedPath = "/" + requestedPath; + if (!filterPath.startsWith("/")) filterPath = "/" + filterPath; + + if (!filterPath.contains("*")) { + return requestedPath.equals(filterPath); + } + + if (filterPath.endsWith("/**")) { + String prefix = filterPath.substring(0, filterPath.length() - 3); + if (prefix.isEmpty()) return true; + return requestedPath.equals(prefix) || requestedPath.startsWith(prefix + "/"); + } + + if (filterPath.endsWith("/*")) { + String prefix = filterPath.substring(0, filterPath.length() - 2); + if (!(requestedPath.equals(prefix) || requestedPath.startsWith(prefix + "/"))) return false; + + String rest = requestedPath.substring(prefix.length()); + if (rest.isEmpty()) return false; + if (!rest.startsWith("/")) return false; + + String afterSlash = rest.substring(1); + return !afterSlash.isEmpty() && !afterSlash.contains("/"); + } + + return false; + } + + private static class FilterTypePaths { + private final FilterType type; + private final List paths; + + public FilterTypePaths(FilterType type, List paths) { + this.type = type; + this.paths = paths; + } + + public static FilterTypePaths of(FilterType type, List paths) { + return new FilterTypePaths(type, paths); + } + } + + public static class FilterChainEngine { + private final Dispatcher dispatcher; + private final HttpRequest request; + private final HttpResponse response; + private final List filterList; + private int position = 0; + + public FilterChainEngine(Dispatcher dispatcher, HttpRequest request, HttpResponse response, List filterList) { + this.dispatcher = dispatcher; + this.request = request; + this.response = response; + this.filterList = filterList; + } + + public void doFilter() { + if (position >= filterList.size()) { + dispatcher.handle(request, response); + return; + } + ServletFilter next = filterList.get(position++); + next.runFilter(request, response, this); + } + } +} diff --git a/src/main/java/web/filter/RestrictedFilter.java b/src/main/java/web/filter/RestrictedFilter.java new file mode 100644 index 000000000..ec5b07420 --- /dev/null +++ b/src/main/java/web/filter/RestrictedFilter.java @@ -0,0 +1,18 @@ +package web.filter; + +import exception.ErrorCode; +import exception.ServiceException; +import http.request.HttpRequest; +import http.response.HttpResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RestrictedFilter implements ServletFilter{ + private static final Logger log = LoggerFactory.getLogger(RestrictedFilter.class); + + @Override + public void runFilter(HttpRequest request, HttpResponse response, FilterChainContainer.FilterChainEngine chain) { + log.info("rid:{} - Request to restricted path:{}", request.getRid(), request.getPath()); + throw new ServiceException(ErrorCode.FORBIDDEN); + } +} diff --git a/src/main/java/web/filter/ServletFilter.java b/src/main/java/web/filter/ServletFilter.java new file mode 100644 index 000000000..0b5ba3641 --- /dev/null +++ b/src/main/java/web/filter/ServletFilter.java @@ -0,0 +1,8 @@ +package web.filter; + +import http.request.HttpRequest; +import http.response.HttpResponse; + +public interface ServletFilter { + void runFilter(HttpRequest request, HttpResponse response, FilterChainContainer.FilterChainEngine chain); +} diff --git a/src/main/java/web/renderer/HttpResponseRenderer.java b/src/main/java/web/renderer/HttpResponseRenderer.java index 5ff61fe80..dd6926827 100644 --- a/src/main/java/web/renderer/HttpResponseRenderer.java +++ b/src/main/java/web/renderer/HttpResponseRenderer.java @@ -5,5 +5,5 @@ public interface HttpResponseRenderer { boolean supports(HandlerResponse response); - HttpResponse handle(HandlerResponse response); + HttpResponse handle(HttpResponse httpResponse, HandlerResponse handlerResponse); } diff --git a/src/main/java/web/renderer/StaticViewRenderer.java b/src/main/java/web/renderer/StaticViewRenderer.java index d61dc0d01..299dba15b 100644 --- a/src/main/java/web/renderer/StaticViewRenderer.java +++ b/src/main/java/web/renderer/StaticViewRenderer.java @@ -2,6 +2,7 @@ import config.VariableConfig; import exception.ErrorException; +import http.HttpStatus; import http.response.HttpResponse; import web.response.HandlerResponse; import web.response.StaticViewResponse; @@ -18,7 +19,7 @@ public boolean supports(HandlerResponse response) { } @Override - public HttpResponse handle(HandlerResponse handlerResponse) { + public HttpResponse handle(HttpResponse httpResponse, HandlerResponse handlerResponse) { StaticViewResponse staticResponse = (StaticViewResponse) handlerResponse; String path = staticResponse.getPath(); @@ -28,7 +29,7 @@ public HttpResponse handle(HandlerResponse handlerResponse) { try (BufferedInputStream in = new BufferedInputStream(new FileInputStream(file))) { byte[] body = in.readAllBytes(); - HttpResponse httpResponse = HttpResponse.of(handlerResponse.getStatus()); + httpResponse.setStatus(handlerResponse.getStatus()); httpResponse.setBody(file, body); handlerResponse.getCookies().forEach(cookie->httpResponse.addHeader("Set-Cookie", cookie)); return httpResponse;