Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmd/api/handlers/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (h *HandlersApi) SettingsServiceHandler(w http.ResponseWriter, r *http.Requ
return
}
// Make sure service is valid
if !h.Settings.VerifyType(service) {
if !h.Settings.VerifyService(service) {
apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil)
return
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func (h *HandlersApi) SettingsServiceEnvHandler(w http.ResponseWriter, r *http.R
return
}
// Make sure service is valid
if !h.Settings.VerifyType(service) {
if !h.Settings.VerifyService(service) {
apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil)
return
}
Expand Down Expand Up @@ -135,7 +135,7 @@ func (h *HandlersApi) SettingsServiceJSONHandler(w http.ResponseWriter, r *http.
return
}
// Make sure service is valid
if !h.Settings.VerifyType(service) {
if !h.Settings.VerifyService(service) {
apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil)
return
}
Expand Down Expand Up @@ -170,7 +170,7 @@ func (h *HandlersApi) SettingsServiceEnvJSONHandler(w http.ResponseWriter, r *ht
return
}
// Make sure service is valid
if !h.Settings.VerifyType(service) {
if !h.Settings.VerifyService(service) {
apiErrorResponse(w, "invalid service", http.StatusInternalServerError, nil)
return
}
Expand Down
7 changes: 4 additions & 3 deletions cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ func osctrlAPIService() {
"POST "+_apiPath(apiLoginPath)+"/{env}",
handlerAuthCheck(http.HandlerFunc(handlersApi.LoginHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret))
// ///////////////////////// AUTHENTICATED
// API: check status
muxAPI.HandleFunc("GET "+_apiPath(checksAuthPath), handlersApi.CheckHandlerAuth)
// API: check auth
muxAPI.Handle(
"GET "+_apiPath(checksAuthPath), handlerAuthCheck(http.HandlerFunc(handlersApi.CheckHandlerAuth), flagParams.Service.Auth, flagParams.JWT.JWTSecret))
// API: nodes by environment
muxAPI.Handle(
"GET "+_apiPath(apiNodesPath)+"/{env}/all",
Expand Down Expand Up @@ -322,7 +323,7 @@ func osctrlAPIService() {
if flagParams.Osquery.Carve {
muxAPI.Handle(
"GET "+_apiPath(apiCarvesPath)+"/{env}",
handlerAuthCheck(http.HandlerFunc(handlersApi.CarveShowHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret))
handlerAuthCheck(http.HandlerFunc(handlersApi.CarveListHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret))
muxAPI.Handle(
"GET "+_apiPath(apiCarvesPath)+"/{env}/queries/{target}",
handlerAuthCheck(http.HandlerFunc(handlersApi.CarveQueriesHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret))
Expand Down
57 changes: 48 additions & 9 deletions tools/api_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import json
import argparse
import requests
from typing import Optional, Dict, Any, Tuple
from typing import Optional, Dict, Any, Tuple, Union, List
from urllib.parse import urljoin

# Disable SSL warnings if insecure flag is used
Expand Down Expand Up @@ -77,9 +77,20 @@ def log_verbose(self, message: str):
def make_request(self, method: str, endpoint: str,
headers: Optional[Dict] = None,
data: Optional[Dict] = None,
expected_status: Optional[int] = None) -> Tuple[bool, Optional[requests.Response], str]:
expected_status: Optional[Union[int, List[int]]] = None) -> Tuple[bool, Optional[requests.Response], str]:
"""
Make HTTP request and return (success, response, message)

Args:
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint path
headers: Optional HTTP headers
data: Optional request data (for POST requests)
expected_status: Expected HTTP status code(s). Can be a single int or list of ints.
If None, any status code is considered valid.

Returns:
Tuple of (success, response, message)
"""
url = urljoin(self.base_url, endpoint)
if headers is None:
Expand All @@ -106,9 +117,19 @@ def make_request(self, method: str, endpoint: str,
success = True
message = f"HTTP {response.status_code}"

if expected_status and response.status_code != expected_status:
success = False
message = f"Expected {expected_status}, got {response.status_code}"
# Normalize expected_status to a list for easier handling
if expected_status is not None:
if isinstance(expected_status, int):
expected_statuses = [expected_status]
else:
expected_statuses = expected_status

if response.status_code not in expected_statuses:
success = False
if len(expected_statuses) == 1:
message = f"Expected {expected_statuses[0]}, got {response.status_code}"
else:
message = f"Expected one of {expected_statuses}, got {response.status_code}"

if self.verbose:
try:
Expand All @@ -124,9 +145,22 @@ def make_request(self, method: str, endpoint: str,

def test(self, name: str, method: str, endpoint: str,
headers: Optional[Dict] = None, data: Optional[Dict] = None,
expected_status: int = 200, skip_if_no_token: bool = False) -> bool:
expected_status: Union[int, List[int]] = 200, skip_if_no_token: bool = False) -> bool:
"""
Run a single test and record results

Args:
name: Test name/description
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint path
headers: Optional HTTP headers
data: Optional request data (for POST requests)
expected_status: Expected HTTP status code(s). Can be a single int or list of ints.
Defaults to 200. Test passes if response status matches any of the expected statuses.
skip_if_no_token: If True, skip test when no token is available

Returns:
True if test passed, False otherwise
"""
# Print request type and URI
full_url = urljoin(self.base_url, endpoint)
Expand Down Expand Up @@ -289,8 +323,7 @@ def run_all_tests(self, skip_auth: bool = False):
self.test("Root endpoint", "GET", "/")
self.test("Health check", "GET", "/health")
self.test("Check (no auth)", "GET", f"{API_PREFIX}/checks-no-auth")
self.test("Check (auth)", "GET", f"{API_PREFIX}/checks-auth",
skip_if_no_token=True)
self.test("Check (auth)", "GET", f"{API_PREFIX}/checks-auth", skip_if_no_token=True)
print()

# Environments
Expand Down Expand Up @@ -322,17 +355,21 @@ def run_all_tests(self, skip_auth: bool = False):
if self.env_uuid:
self.test("Get all nodes", "GET",
f"{API_PREFIX}/nodes/{self.env_uuid}/all",
expected_status=[200, 404],
skip_if_no_token=True)
self.test("Get active nodes", "GET",
f"{API_PREFIX}/nodes/{self.env_uuid}/active",
expected_status=[200, 404],
skip_if_no_token=True)
self.test("Get inactive nodes", "GET",
f"{API_PREFIX}/nodes/{self.env_uuid}/inactive",
expected_status=[200, 404],
skip_if_no_token=True)
# Note: These require actual node identifiers, so they may fail
self.test("Lookup node (test)", "POST",
f"{API_PREFIX}/nodes/lookup",
data={"target": "test-node-identifier"},
data={"identifier": "test-node-identifier"},
expected_status=[200, 404],
skip_if_no_token=True)
print()

Expand Down Expand Up @@ -391,9 +428,11 @@ def run_all_tests(self, skip_auth: bool = False):
if self.env_uuid:
self.test("Get carves", "GET",
f"{API_PREFIX}/carves/{self.env_uuid}",
expected_status=[200, 404],
skip_if_no_token=True)
self.test("List carves", "GET",
f"{API_PREFIX}/carves/{self.env_uuid}/list",
expected_status=[200, 404],
skip_if_no_token=True)
print()

Expand Down
Loading