diff --git a/cmd/api/handlers/settings.go b/cmd/api/handlers/settings.go index 0fd942ec..5c9569f1 100644 --- a/cmd/api/handlers/settings.go +++ b/cmd/api/handlers/settings.go @@ -49,7 +49,7 @@ func (h *HandlersApi) SettingsServiceHandler(w http.ResponseWriter, r *http.Requ return } // Make sure service is valid - if !h.Settings.VerifyType(service) { + if !h.Settings.VerifyService(service) { apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil) return } @@ -84,7 +84,7 @@ func (h *HandlersApi) SettingsServiceEnvHandler(w http.ResponseWriter, r *http.R return } // Make sure service is valid - if !h.Settings.VerifyType(service) { + if !h.Settings.VerifyService(service) { apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil) return } @@ -135,7 +135,7 @@ func (h *HandlersApi) SettingsServiceJSONHandler(w http.ResponseWriter, r *http. return } // Make sure service is valid - if !h.Settings.VerifyType(service) { + if !h.Settings.VerifyService(service) { apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil) return } @@ -170,7 +170,7 @@ func (h *HandlersApi) SettingsServiceEnvJSONHandler(w http.ResponseWriter, r *ht return } // Make sure service is valid - if !h.Settings.VerifyType(service) { + if !h.Settings.VerifyService(service) { apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil) return } diff --git a/cmd/api/main.go b/cmd/api/main.go index ed69c274..d31b94de 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -270,8 +270,9 @@ func osctrlAPIService() { "POST "+_apiPath(apiLoginPath)+"/{env}", handlerAuthCheck(http.HandlerFunc(handlersApi.LoginHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) // ///////////////////////// AUTHENTICATED - // API: check status - muxAPI.HandleFunc("GET "+_apiPath(checksAuthPath), handlersApi.CheckHandlerAuth) + // API: check auth + muxAPI.Handle( + "GET "+_apiPath(checksAuthPath), handlerAuthCheck(http.HandlerFunc(handlersApi.CheckHandlerAuth), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) // API: nodes by environment muxAPI.Handle( "GET "+_apiPath(apiNodesPath)+"/{env}/all", @@ -322,7 +323,7 @@ func osctrlAPIService() { if flagParams.Osquery.Carve { muxAPI.Handle( "GET "+_apiPath(apiCarvesPath)+"/{env}", - handlerAuthCheck(http.HandlerFunc(handlersApi.CarveShowHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + handlerAuthCheck(http.HandlerFunc(handlersApi.CarveListHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "GET "+_apiPath(apiCarvesPath)+"/{env}/queries/{target}", handlerAuthCheck(http.HandlerFunc(handlersApi.CarveQueriesHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) diff --git a/tools/api_tester.py b/tools/api_tester.py index 997d89b4..3daf0821 100755 --- a/tools/api_tester.py +++ b/tools/api_tester.py @@ -23,7 +23,7 @@ import json import argparse import requests -from typing import Optional, Dict, Any, Tuple +from typing import Optional, Dict, Any, Tuple, Union, List from urllib.parse import urljoin # Disable SSL warnings if insecure flag is used @@ -77,9 +77,20 @@ def log_verbose(self, message: str): def make_request(self, method: str, endpoint: str, headers: Optional[Dict] = None, data: Optional[Dict] = None, - expected_status: Optional[int] = None) -> Tuple[bool, Optional[requests.Response], str]: + expected_status: Optional[Union[int, List[int]]] = None) -> Tuple[bool, Optional[requests.Response], str]: """ Make HTTP request and return (success, response, message) + + Args: + method: HTTP method (GET, POST, etc.) + endpoint: API endpoint path + headers: Optional HTTP headers + data: Optional request data (for POST requests) + expected_status: Expected HTTP status code(s). Can be a single int or list of ints. + If None, any status code is considered valid. + + Returns: + Tuple of (success, response, message) """ url = urljoin(self.base_url, endpoint) if headers is None: @@ -106,9 +117,19 @@ def make_request(self, method: str, endpoint: str, success = True message = f"HTTP {response.status_code}" - if expected_status and response.status_code != expected_status: - success = False - message = f"Expected {expected_status}, got {response.status_code}" + # Normalize expected_status to a list for easier handling + if expected_status is not None: + if isinstance(expected_status, int): + expected_statuses = [expected_status] + else: + expected_statuses = expected_status + + if response.status_code not in expected_statuses: + success = False + if len(expected_statuses) == 1: + message = f"Expected {expected_statuses[0]}, got {response.status_code}" + else: + message = f"Expected one of {expected_statuses}, got {response.status_code}" if self.verbose: try: @@ -124,9 +145,22 @@ def make_request(self, method: str, endpoint: str, def test(self, name: str, method: str, endpoint: str, headers: Optional[Dict] = None, data: Optional[Dict] = None, - expected_status: int = 200, skip_if_no_token: bool = False) -> bool: + expected_status: Union[int, List[int]] = 200, skip_if_no_token: bool = False) -> bool: """ Run a single test and record results + + Args: + name: Test name/description + method: HTTP method (GET, POST, etc.) + endpoint: API endpoint path + headers: Optional HTTP headers + data: Optional request data (for POST requests) + expected_status: Expected HTTP status code(s). Can be a single int or list of ints. + Defaults to 200. Test passes if response status matches any of the expected statuses. + skip_if_no_token: If True, skip test when no token is available + + Returns: + True if test passed, False otherwise """ # Print request type and URI full_url = urljoin(self.base_url, endpoint) @@ -289,8 +323,7 @@ def run_all_tests(self, skip_auth: bool = False): self.test("Root endpoint", "GET", "/") self.test("Health check", "GET", "/health") self.test("Check (no auth)", "GET", f"{API_PREFIX}/checks-no-auth") - self.test("Check (auth)", "GET", f"{API_PREFIX}/checks-auth", - skip_if_no_token=True) + self.test("Check (auth)", "GET", f"{API_PREFIX}/checks-auth", skip_if_no_token=True) print() # Environments @@ -322,17 +355,21 @@ def run_all_tests(self, skip_auth: bool = False): if self.env_uuid: self.test("Get all nodes", "GET", f"{API_PREFIX}/nodes/{self.env_uuid}/all", + expected_status=[200, 404], skip_if_no_token=True) self.test("Get active nodes", "GET", f"{API_PREFIX}/nodes/{self.env_uuid}/active", + expected_status=[200, 404], skip_if_no_token=True) self.test("Get inactive nodes", "GET", f"{API_PREFIX}/nodes/{self.env_uuid}/inactive", + expected_status=[200, 404], skip_if_no_token=True) # Note: These require actual node identifiers, so they may fail self.test("Lookup node (test)", "POST", f"{API_PREFIX}/nodes/lookup", - data={"target": "test-node-identifier"}, + data={"identifier": "test-node-identifier"}, + expected_status=[200, 404], skip_if_no_token=True) print() @@ -391,9 +428,11 @@ def run_all_tests(self, skip_auth: bool = False): if self.env_uuid: self.test("Get carves", "GET", f"{API_PREFIX}/carves/{self.env_uuid}", + expected_status=[200, 404], skip_if_no_token=True) self.test("List carves", "GET", f"{API_PREFIX}/carves/{self.env_uuid}/list", + expected_status=[200, 404], skip_if_no_token=True) print()