diff --git a/services/press-fault-data-simulator-service/app/config/settings.py b/services/press-fault-data-simulator-service/app/config/settings.py index 8a4fce7..ad99853 100644 --- a/services/press-fault-data-simulator-service/app/config/settings.py +++ b/services/press-fault-data-simulator-service/app/config/settings.py @@ -22,6 +22,8 @@ class Settings(BaseSettings): # 모델 API 서버 설정 PRESS_FAULT_MODEL_BASE_URL: str = "http://127.0.0.1:8004" PREDICT_API_ENDPOINT: str = "/predict" + # Spring Boot 서버 설정 + SPRING_BOOT_BASE_URL: str = "http://localhost:8088" SIMULATOR_INTERVAL_MINUTES: int = 1 diff --git a/services/press-fault-data-simulator-service/app/services/azure_storage_service.py b/services/press-fault-data-simulator-service/app/services/azure_storage_service.py index ebc4ca1..1b70078 100644 --- a/services/press-fault-data-simulator-service/app/services/azure_storage_service.py +++ b/services/press-fault-data-simulator-service/app/services/azure_storage_service.py @@ -205,16 +205,29 @@ def _clear_cache(self): self.current_file_name = None self.current_row_index = 0 - async def close(self): + async def reset_connection(self): + """연결 상태 초기화 (재시작 시 사용)""" + try: + if hasattr(self, '_current_client'): + await self._current_client.close() + delattr(self, '_current_client') + + self.current_file_name = None + self.current_row_index = 0 + + system_log.info("Azure Storage 연결 상태 초기화 완료") + except Exception as e: + system_log.warning(f"연결 초기화 중 오류 (무시됨): {str(e)}") + + async def close(self): """연결 종료 (선택적)""" try: - if hasattr(self, "blob_service_client") and self.blob_service_client: - await self.blob_service_client.close() - system_log.info("Azure Storage 연결 종료 완료") + if hasattr(self, '_current_client'): + await self._current_client.close() + delattr(self, '_current_client') + system_log.info("Azure Storage 연결 종료 완료") except Exception as e: - system_log.error(f"Azure Storage 연결 종료 중 오류: {str(e)}") - finally: - self.is_connected = False + system_log.warning(f"연결 종료 중 오류 (무시됨): {str(e)}") azure_storage = AzureStorageService() diff --git a/services/press-fault-data-simulator-service/app/services/prediction_api_service.py b/services/press-fault-data-simulator-service/app/services/prediction_api_service.py deleted file mode 100644 index beb049d..0000000 --- a/services/press-fault-data-simulator-service/app/services/prediction_api_service.py +++ /dev/null @@ -1,142 +0,0 @@ -import time -from typing import Optional -import aiohttp -import asyncio - -from app.config.settings import settings -from app.models.data_models import PredictionRequest, PredictionResult -from app.utils.logger import system_log - - -class PredictAPIService: - """Press Fault Detection Model API 호출 서비스""" - - def __init__(self): - self.api_url = settings.PREDICTION_API_FULL_URL - self.timeout = 30 - self.max_retries = 3 - - system_log.info(f"Predict API Service 초기화 완료 - URL: {self.api_url}") - - async def call_predict_api( - self, request: PredictionRequest - ) -> Optional[PredictionResult]: - """ - /predict API를 호출하여 예측 결과를 받아옴 - Args: - request 예측 요청 데이터 (PredictionRequest) - Returns: - Optional[PredictionResult]: 예측 결과 또는 None (실패 시) - """ - - request_data = request.model_dump() - system_log.debug( - f"API 요청 데이터 크기: AI0({len(request_data['AI0_Vibration'])}) " - f"AI1({len(request_data['AI1_Vibration'])}) " - f"AI2({len(request_data['AI2_Current'])})" - ) - - for attempt in range(1, self.max_retries + 1): - try: - start_time = time.time() - - timeout = aiohttp.ClientTimeout(total=self.timeout) - async with aiohttp.ClientSession(timeout=timeout) as session: - # POST 요청 전송 - async with session.post( - url=self.api_url, - json=request_data, - headers={"Content-Type": "application/json"}, - ) as response: - - response_time = time.time() - start_time - - # 응답 상태 확인 - if response.status == 200: - response_json = await response.json() - try: - result = PredictionResult(**response_json) - system_log.info( - f"API 호출 성공 ({response_time:.3f}s) - " - f"Prediction: {result.prediction}, " - f"Is_fault: {result.is_fault}" - ) - - return result - except Exception as e: - system_log.error(f"응답 데이터 검증 실패: {str(e)}") - return None - else: - response_text = await response.text() - system_log.error( - f"API 호출 실패 (시도 {attempt}/{self.max_retries}) - " - f"Status: {response.status}, " - f"Response: {response_text[:200]}" - ) - # 마지막 시도가 아니면 재시도 - if attempt < self.max_retries: - await asyncio.sleep(2**attempt) - continue - else: - return None - - except asyncio.TimeoutError: - system_log.error( - f"API 호출 타임아웃 (시도 {attempt}/{self.max_retries}) - " - f"URL: {self.api_url}" - ) - if attempt < self.max_retries: - await asyncio.sleep(2**attempt) - continue - else: - return None - - except aiohttp.ClientConnectionError: - system_log.error( - f"API 서버 연결 실패 (시도 {attempt}/{self.max_retries}) - " - f"URL: {self.api_url}" - ) - if attempt < self.max_retries: - await asyncio.sleep(2**attempt) - continue - else: - return None - - except Exception as e: - system_log.error( - f"예상치 못한 오류 (시도 {attempt}/{self.max_retries}): {str(e)}" - ) - if attempt < self.max_retries: - await asyncio.sleep(2**attempt) - continue - else: - return None - return None - - async def health_check(self) -> bool: - """ - API 서버 상태 확인 - - Returns: - bool: 서버 상태 (True: 정상, False: 비정상) - """ - try: - base_url = str(settings.PRESS_FAULT_MODEL_BASE_URL) - health_url = f"{base_url}/health" - - timeout = aiohttp.ClientTimeout(total=10) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(health_url) as response: - - if response.status == 200: - system_log.info("API 서버 상태: 정상") - return True - else: - system_log.warning( - f"API 서버 상태: 비정상 (Status: {response.status})" - ) - return False - - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - system_log.error(f"API 서버 상태 확인 실패: {str(e)}") - return False diff --git a/services/press-fault-data-simulator-service/app/services/scheduler_service.py b/services/press-fault-data-simulator-service/app/services/scheduler_service.py index 10e9ffb..6962c3e 100644 --- a/services/press-fault-data-simulator-service/app/services/scheduler_service.py +++ b/services/press-fault-data-simulator-service/app/services/scheduler_service.py @@ -3,10 +3,11 @@ from datetime import datetime from app.services.azure_storage_service import AzureStorageService -from app.services.prediction_api_service import PredictAPIService +# from app.services.prediction_api_service import PredictAPIService +from app.services.spring_boot_service import SpringBootService from app.models.data_models import PredictionRequest from app.config.settings import settings -from app.utils.logger import system_log, prediction_log +from app.utils.logger import system_log class SchedulerService: @@ -18,13 +19,17 @@ def __init__(self): # 서비스 인스턴스들 self.storage_service = AzureStorageService() - self.api_service = PredictAPIService() + # self.api_service = PredictAPIService() + self.spring_boot_service = SpringBootService() # 통계 - self.total_predictions = 0 - self.fault_detections = 0 + # self.total_predictions = 0 + # self.fault_detections = 0 self.start_time: Optional[datetime] = None + self.total_transmissions = 0 + self.successful_transmissions = 0 + async def start_simulation(self) -> bool: """시뮬레이션 시작!""" @@ -34,11 +39,12 @@ async def start_simulation(self) -> bool: system_log.info("🚀 시뮬레이션 시작 중...") - # API 서버 상태 확인 - if not await self.api_service.health_check(): - system_log.error("API 서버 연결 실패. 시뮬레이션을 시작할 수 없습니다.") + await self.storage_service.reset_connection() + # Spring Boot 서버 상태 확인 + if not await self.spring_boot_service.health_check(): + system_log.error("Spring Boot 서버 연결 실패. 시뮬레이션을 시작할 수 없습니다.") return False - + # Azure Storage 연결 확인 if not await self.storage_service.connect(): system_log.error("Azure Storage 연결 실패.") @@ -46,8 +52,10 @@ async def start_simulation(self) -> bool: self.is_running = True self.start_time = datetime.now() - self.total_predictions = 0 - self.fault_detections = 0 + # self.total_predictions = 0 + # self.fault_detections = 0 + self.total_transmissions = 0 + self.successful_transmissions = 0 self.loop = asyncio.get_event_loop() self.task = self.loop.create_task(self._run_simulation_loop()) @@ -55,7 +63,6 @@ async def start_simulation(self) -> bool: system_log.info( f"✅ 시뮬레이션 시작됨 - 간격: {settings.SIMULATOR_INTERVAL_MINUTES}분" ) - system_log.info(f"📊 API URL: {settings.PREDICTION_API_FULL_URL}") return True @@ -83,7 +90,7 @@ async def stop_simulation(self) -> bool: if self.start_time: duration = datetime.now() - self.start_time system_log.info(f"실행 시간: {duration}") - system_log.info(f"총 예측 횟수: {self.total_predictions}") + system_log.info(f" └─ 총 전송 횟수: {self.total_transmissions}") system_log.info("✅ 시뮬레이션 종료 완료") return True @@ -131,24 +138,25 @@ async def _run_single_simulation(self) -> bool: minute_data, file_name, is_end_of_file = data_result - # 2. 예측 요청 데이터 생성 - prediction_request = PredictionRequest.from_csv_data(minute_data) - # 3. API 호출 - prediction_result = await self.api_service.call_predict_api( - prediction_request + # 2. Spring Boot 전송용 데이터 생성 + sensor_data_request = PredictionRequest.from_csv_data(minute_data) + # 3. Spring Boot로 데이터 전송 + transmission_success = await self.spring_boot_service.send_sensor_data( + sensor_data_request, + data_source=file_name ) - if prediction_result is None: - system_log.error("API 호출 실패") + if not transmission_success: + system_log.error("Spring Boot 데이터 전송 실패") return False - # 4. 결과 로그 처리 - self._handle_prediction_result(prediction_result, file_name) - + # 4. 전송 성공 로그 + system_log.info(f"✅ 데이터 전송 성공 - Source: {file_name}, " + f"Size: {len(minute_data)}행") # 5. 통계 업데이트 - self.total_predictions += 1 - if prediction_result.is_fault: - self.fault_detections += 1 + self.total_transmissions += 1 + if transmission_success: + self.successful_transmissions += 1 if is_end_of_file: system_log.info(f"파일 '{file_name}' 처리 완료") @@ -159,7 +167,7 @@ async def _run_single_simulation(self) -> bool: system_log.error(f"시뮬레이션 실행 오류: {str(e)}") return False - def _handle_prediction_result(self, result, data_source: str): + # def _handle_prediction_result(self, result, data_source: str): """예측 결과 처리 (로그 기록)""" status = "FAULT DETECTED" if result.is_fault else "✅ NORMAL" @@ -200,10 +208,9 @@ def get_simulation_status(self) -> dict: "is_running": True, "start_time": self.start_time.isoformat() if self.start_time else None, "runtime": runtime, - "total_predictions": self.total_predictions, - "fault_detections": self.fault_detections, + "total_transmissions": self.total_transmissions, + "successful_transmissions": self.successful_transmissions, "interval_minutes": settings.SIMULATOR_INTERVAL_MINUTES, - "api_url": settings.PREDICTION_API_FULL_URL, "storage_status": storage_status, } diff --git a/services/press-fault-data-simulator-service/app/services/spring_boot_service.py b/services/press-fault-data-simulator-service/app/services/spring_boot_service.py new file mode 100644 index 0000000..4850e48 --- /dev/null +++ b/services/press-fault-data-simulator-service/app/services/spring_boot_service.py @@ -0,0 +1,113 @@ +import aiohttp +import asyncio +import time +from datetime import datetime +from typing import Optional + +from app.config.settings import settings +from app.models.data_models import PredictionRequest +from app.utils.logger import system_log + + +class SpringBootService: + """Spring Boot API 호출 서비스 (비동기)""" + + def __init__(self): + self.api_url = f"{settings.SPRING_BOOT_BASE_URL}/pressFaultDetectionLogs/data" + self.timeout = 30 # 30초 타임아웃 + self.max_retries = 3 # 최대 재시도 횟수 + + system_log.info(f"Spring Boot Service 초기화 완료 - URL: {self.api_url}") + + async def send_sensor_data(self, request: PredictionRequest, data_source: str = None) -> bool: + """ + Spring Boot로 센서 데이터를 전송 + + Args: + request: 센서 데이터 (PredictionRequest 객체) + data_source: 데이터 소스 (파일명 등) + + Returns: + bool: 전송 성공 여부 + """ + + # 요청 데이터를 JSON 형태로 변환 + request_data = { + "AI0_Vibration": request.AI0_Vibration, + "AI1_Vibration": request.AI1_Vibration, + "AI2_Current": request.AI2_Current, + "timestamp": datetime.now().isoformat(), + "source": data_source or "simulator", + "data_length": len(request.AI0_Vibration) + } + + system_log.debug(f"Spring Boot 전송 데이터 크기: AI0({len(request_data['AI0_Vibration'])}) " + f"AI1({len(request_data['AI1_Vibration'])}) " + f"AI2({len(request_data['AI2_Current'])})") + + # max_retries 횟수(3)만큼 시도 + for attempt in range(1, self.max_retries + 1): + try: + start_time = time.time() + timeout = aiohttp.ClientTimeout(total=self.timeout) + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + url=self.api_url, + json=request_data, + headers={'Content-Type': 'application/json'} + ) as response: + response_time = time.time() - start_time + + if response.status in [200, 201, 202]: + system_log.info(f"Spring Boot 전송 성공 ({response_time:.3f}s) - " + f"Status: {response.status}, Source: {data_source}") + return True + + response_text = await response.text() + system_log.error(f"Spring Boot 전송 실패 (시도 {attempt}/{self.max_retries}) - " + f"Status: {response.status}, Response: {response_text[:200]}") + + except asyncio.TimeoutError: + system_log.error(f"Spring Boot 전송 타임아웃 (시도 {attempt}/{self.max_retries}) - URL: {self.api_url}") + + except aiohttp.ClientConnectionError: + system_log.error(f"Spring Boot 서버 연결 실패 (시도 {attempt}/{self.max_retries}) - URL: {self.api_url}") + + except Exception as e: + system_log.error(f"예상치 못한 오류 (시도 {attempt}/{self.max_retries}): {str(e)}") + + # 마지막 시도가 아니면 지수 백오프 후 재시도 + if attempt < self.max_retries: + await asyncio.sleep(2 ** attempt) + + return False + + return False + + async def health_check(self) -> bool: + """ + Spring Boot 서버 상태 확인 + + Returns: + bool: 서버 상태 (True: 정상, False: 비정상) + """ + try: + # Spring Boot 기본 루트 경로로 간단한 연결 테스트 + base_url = f"{settings.SPRING_BOOT_BASE_URL}/pressFaultDetectionLogs" + + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(base_url) as response: + + # 연결만 되면 OK (404여도 서버는 살아있음) + if response.status in [200, 404]: + system_log.info("Spring Boot 서버 상태: 정상") + return True + else: + system_log.warning(f"Spring Boot 서버 상태: 비정상 (Status: {response.status})") + return False + + except Exception as e: + system_log.error(f"Spring Boot 서버 상태 확인 실패: {str(e)}") + return False \ No newline at end of file diff --git a/services/press-fault-data-simulator-service/tests/services/test_prediction_api_service.py b/services/press-fault-data-simulator-service/tests/services/test_prediction_api_service.py deleted file mode 100644 index 25d42ba..0000000 --- a/services/press-fault-data-simulator-service/tests/services/test_prediction_api_service.py +++ /dev/null @@ -1,414 +0,0 @@ -import asyncio -from typing import Any, Dict -import aiohttp -import pytest -# Note on test framework: -# These tests use pytest with pytest.mark.asyncio to validate async behavior without adding new dependencies. - -# Import the service under test and its dependencies -from app.services.prediction_api_service import ( - PredictAPIService, -) # assuming service path; adjust if different -from app.models.data_models import PredictionRequest, PredictionResult -from app.config import settings as settings_module - - -def make_prediction_request( - n_ai0: int = 10, n_ai1: int = 12, n_ai2: int = 8 -) -> PredictionRequest: - # Build a minimal valid PredictionRequest object. - # The exact model fields are unknown from the diff; we infer from debug log keys used: - # - AI0_Vibration: sequence-like - # - AI1_Vibration: sequence-like - # - AI2_Current: sequence-like - # If additional required fields exist, please update accordingly to match the model. - data = { - "AI0_Vibration": [0.1] * n_ai0, - "AI1_Vibration": [0.2] * n_ai1, - "AI2_Current": [0.3] * n_ai2, - } - # PredictionRequest is likely a Pydantic model; construct it via its constructor. - return PredictionRequest(**data) - - -class DummyResponse: - def __init__( - self, - status: int, - json_payload: Dict[str, Any] | None = None, - text_body: str = "", - ): - self.status = status - self._json_payload = json_payload - self._text_body = text_body - - async def json(self): - # Simulate aiohttp response.json() - if self._json_payload is None: - raise ValueError("No JSON payload") - return self._json_payload - - async def text(self): - return self._text_body - - # Context manager protocol for "async with" - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - -class DummySession: - def __init__(self, timeout=None, post_side_effect=None, get_side_effect=None): - self._timeout = timeout - self._post_side_effect = post_side_effect - self._get_side_effect = get_side_effect - self.post_calls = [] - self.get_calls = [] - - # Context manager protocol for "async with aiohttp.ClientSession(...) as session" - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - def post(self, url, json=None, headers=None): - # aiohttp returns an async context manager; we emulate any side effect or return a DummyResponse - self.post_calls.append({"url": url, "json": json, "headers": headers}) - if callable(self._post_side_effect): - result = self._post_side_effect(url, json, headers) - return result - # Default: return 200 ok with reasonable payload - return DummyResponse( - 200, - { - "prediction": "OK", - "is_fault": False, - }, - ) - - def get(self, url): - self.get_calls.append({"url": url}) - if callable(self._get_side_effect): - return self._get_side_effect(url) - return DummyResponse(200, None, "healthy") - - -class DummyTimeoutError(Exception): - pass - - -@pytest.fixture -def patch_sleep(monkeypatch): - # Speed up retries by replacing asyncio.sleep with a no-op that still yields control - async def fast_sleep(_): - await asyncio.sleep(0) # one loop to maintain async semantics - # To avoid recursion, patch to a coroutine that just returns None without calling itself - async def no_sleep(_): - return None - - monkeypatch.setattr(asyncio, "sleep", no_sleep) - yield - # No explicit unpatch needed as monkeypatch reverts automatically - - -@pytest.fixture -def patch_aiohttp_session(monkeypatch): - # Patch aiohttp.ClientSession to return our DummySession - - def _apply(session_factory): - class FactorySession: - def __init__(self, timeout=None): - self._timeout = timeout - self._session = session_factory(timeout) - - async def __aenter__(self): - return self._session - - async def __aexit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr(aiohttp, "ClientSession", FactorySession) - return FactorySession - - return _apply - - -@pytest.fixture -def patch_settings(monkeypatch): - # Patch URLs used by the service - - monkeypatch.setattr( - settings_module.settings, - "PRESS_FAULT_MODEL_BASE_URL", - "http://example.test", - raising=False, - ) - - -@pytest.mark.asyncio -async def test_call_predict_api_success_happy_path( - patch_settings, patch_aiohttp_session -): - # Arrange: 200 response with valid JSON that matches PredictionResult - def session_factory(_timeout): - def post_side_effect(url, payload, headers): - # Validate that JSON matches request.model_dump() keys - assert ( - "AI0_Vibration" in payload - and "AI1_Vibration" in payload - and "AI2_Current" in payload - ) - # Return a context manager which yields a DummyResponse - return DummyResponse( - 200, - { - # Minimal plausible fields; adjust if PredictionResult requires more - "prediction": "NORMAL", - "is_fault": False, - "reconstruction_error": 0.0, - }, - ) - - return DummySession(timeout=_timeout, post_side_effect=post_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - req = make_prediction_request() - - # Act - result = await svc.call_predict_api(req) - - # Assert - assert isinstance(result, PredictionResult) - assert result.prediction in ("NORMAL", "OK") - assert result.is_fault is False - assert isinstance(result.reconstruction_error, float) - - -@pytest.mark.asyncio -async def test_call_predict_api_invalid_response_returns_none( - patch_settings, patch_aiohttp_session -): - # Arrange: 200 but invalid JSON payload that breaks PredictionResult construction - def session_factory(_timeout): - def post_side_effect(url, payload, headers): - # Missing required fields to force a validation error - return DummyResponse(200, {"unexpected": "data"}) - return DummySession(timeout=_timeout, post_side_effect=post_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - req = make_prediction_request() - - # Act - result = await svc.call_predict_api(req) - - # Assert - assert result is None - - -@pytest.mark.asyncio -async def test_call_predict_api_non_200_retries_then_returns_none( - patch_settings, patch_aiohttp_session, patch_sleep, monkeypatch -): - attempts = [] - - def session_factory(_timeout): - def post_side_effect(url, payload, headers): - attempts.append(1) - # Always return 500 with some text body - return DummyResponse(500, None, "Internal Server Error") - return DummySession(timeout=_timeout, post_side_effect=post_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - # Reduce max_retries to 2 for faster test while still exercising retry logic - monkeypatch.setattr(svc, "max_retries", 2) - - req = make_prediction_request() - - result = await svc.call_predict_api(req) - - assert result is None - # Should have attempted exactly max_retries times - assert len(attempts) == 2 - - -@pytest.mark.asyncio -async def test_call_predict_api_timeout_error_retries_then_none( - patch_settings, patch_aiohttp_session, patch_sleep, monkeypatch -): - attempts = [] - - class RaisingPost: - def __init__(self): - self._count = 0 - - # This is returned by session.post and acts as async context manager - async def __aenter__(self): - raise asyncio.TimeoutError("timed out") - - async def __aexit__(self, exc_type, exc, tb): - return False - - def session_factory(_timeout): - def post_side_effect(url, payload, headers): - attempts.append(1) - return RaisingPost() - return DummySession(timeout=_timeout, post_side_effect=post_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - monkeypatch.setattr(svc, "max_retries", 3) - - req = make_prediction_request() - - result = await svc.call_predict_api(req) - - assert result is None - assert len(attempts) == 3 - - -@pytest.mark.asyncio -async def test_call_predict_api_client_connection_error_retries_then_none( - patch_settings, patch_aiohttp_session, patch_sleep, monkeypatch -): - attempts = [] - - class RaisingPost: - async def __aenter__(self): - raise aiohttp.ClientConnectionError("connection failed") - - async def __aexit__(self, exc_type, exc, tb): - return False - - def session_factory(_timeout): - def post_side_effect(url, payload, headers): - attempts.append(1) - return RaisingPost() - return DummySession(timeout=_timeout, post_side_effect=post_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - monkeypatch.setattr(svc, "max_retries", 2) - - req = make_prediction_request() - - result = await svc.call_predict_api(req) - - assert result is None - assert len(attempts) == 2 - - -@pytest.mark.asyncio -async def test_call_predict_api_unexpected_exception_retries_then_none( - patch_settings, patch_aiohttp_session, patch_sleep, monkeypatch -): - attempts = [] - - class RaisingPost: - async def __aenter__(self): - raise RuntimeError("unexpected") - - async def __aexit__(self, exc_type, exc, tb): - return False - - def session_factory(_timeout): - def post_side_effect(url, payload, headers): - attempts.append(1) - return RaisingPost() - return DummySession(timeout=_timeout, post_side_effect=post_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - monkeypatch.setattr(svc, "max_retries", 2) - - req = make_prediction_request() - - result = await svc.call_predict_api(req) - - assert result is None - assert len(attempts) == 2 - - -@pytest.mark.asyncio -async def test_health_check_success_returns_true(patch_settings, patch_aiohttp_session): - # Arrange: GET /health returns 200 - def session_factory(_timeout): - def get_side_effect(url): - assert url.endswith("/health") - return DummyResponse(200, None, "healthy") - return DummySession(timeout=_timeout, get_side_effect=get_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - - ok = await svc.health_check() - assert ok is True - - -@pytest.mark.asyncio -async def test_health_check_non_200_returns_false(patch_settings, patch_aiohttp_session): - # Arrange: GET /health returns 503 - def session_factory(_timeout): - def get_side_effect(url): - return DummyResponse(503, None, "unhealthy") - return DummySession(timeout=_timeout, get_side_effect=get_side_effect) - - patch_aiohttp_session(session_factory) - - svc = PredictAPIService() - - ok = await svc.health_check() - assert ok is False - - -@pytest.mark.asyncio -async def test_health_check_requests_exception_caught_returns_false( - patch_settings, monkeypatch -): - # The code catches requests.exceptions.RequestException despite using aiohttp. - # We simulate that specific exception to ensure the except block is executed. - - class RaisingGetSession: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - def get(self, url): - class Ctx: - async def __aenter__(self_inner): - raise aiohttp.ClientError("simulated aiohttp exception") - - async def __aexit__(self_inner, exc_type, exc, tb): - return False - - return Ctx() - - class SessionFactory: - def __init__(self, timeout=None): - self.timeout = timeout - - async def __aenter__(self): - return RaisingGetSession() - - async def __aexit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr(aiohttp, "ClientSession", SessionFactory) - - svc = PredictAPIService() - ok = await svc.health_check() - assert ok is False diff --git a/services/press-fault-detection-model-service/app/services/predict_service.py b/services/press-fault-detection-model-service/app/services/predict_service.py index 40578ed..7d095da 100644 --- a/services/press-fault-detection-model-service/app/services/predict_service.py +++ b/services/press-fault-detection-model-service/app/services/predict_service.py @@ -74,7 +74,7 @@ def predict_press_fault(data: SensorData) -> dict: is_fault = fault_probability > 0.05 prediction_result = "고장" if is_fault else "정상" - max_error = np.max(errors) if total_sequences > 0 else 0.0 + mean_error = np.mean(errors) if total_sequences > 0 else 0.0 # 9. 원인 분석 (확률이 일정 수준 이상일 때만) attribute_errors_dict = None @@ -88,7 +88,7 @@ def predict_press_fault(data: SensorData) -> dict: # 10. 최종 응답 데이터 구성 response_data = { "prediction": prediction_result, - "reconstruction_error": float(max_error), + "reconstruction_error": float(mean_error), "is_fault": is_fault, "fault_probability": float(fault_probability), # 계산된 확률 추가 "attribute_errors": attribute_errors_dict,