Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
import com.umc.yeogi_gal_lae.api.aiCourse.domain.AICourse;
import com.umc.yeogi_gal_lae.api.aiCourse.repository.AICourseRepository;
import com.umc.yeogi_gal_lae.api.place.domain.Place;
import com.umc.yeogi_gal_lae.api.place.repository.PlaceRepository;
import com.umc.yeogi_gal_lae.api.tripPlan.domain.TripPlan;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
Expand All @@ -24,17 +29,20 @@

@Service
public class AICourseService {
private static final Logger logger = LoggerFactory.getLogger(AICourseService.class);

private final AICourseRepository aiCourseRepository;
private final PlaceRepository placeRepository;
private final WebClient webClient;
private final ObjectMapper objectMapper;

@Value("${openai.api.key}")
private String apiKey;

public AICourseService(AICourseRepository aiCourseRepository,
public AICourseService(AICourseRepository aiCourseRepository, PlaceRepository placeRepository,
WebClient.Builder webClientBuilder) {
this.aiCourseRepository = aiCourseRepository;
this.placeRepository = placeRepository;
this.webClient = webClientBuilder.baseUrl("https://api.openai.com/v1").build();
this.objectMapper = new ObjectMapper();
}
Expand All @@ -47,8 +55,17 @@ public AICourseService(AICourseRepository aiCourseRepository,
*/
@Transactional
public AICourse generateAndStoreAICourse(TripPlan tripPlan) {
// TripPlan에 직접 연결된 Place들을 사용
List<Place> places = tripPlan.getPlaces(); // TripPlan에 places 컬렉션이 있다고 가정
// TripPlan에 직접 연결된 Place들을 DB에서 명시적으로 조회
List<Place> places = placeRepository.findAllByTripPlanId(tripPlan.getId());

if (places.isEmpty()) {
logger.warn("TripPlan id {}에 등록된 장소가 없습니다.", tripPlan.getId());
} else {
List<String> placeNames = places.stream()
.map(Place::getPlaceName)
.collect(Collectors.toList());
logger.info("TripPlan id {}에 등록된 장소 목록: {}", tripPlan.getId(), placeNames);
}
if (places.isEmpty()) {
return null;
}
Expand All @@ -60,13 +77,25 @@ public AICourse generateAndStoreAICourse(TripPlan tripPlan) {
Map<String, List<String>> courseByDay = parseGptResponse(gptApiResponse);
// GPT 결과를 실제 Place 객체와 매핑
Map<String, List<Place>> course = new LinkedHashMap<>();
Set<Long> usedPlaceIds = new HashSet<>();

for (Map.Entry<String, List<String>> entry : courseByDay.entrySet()) {
String dayLabel = entry.getKey();
List<Place> dayPlaces = entry.getValue().stream()
// 해당 일차의 추천 장소들을 실제 Place 객체로 매핑
List<Place> recommendedPlaces = entry.getValue().stream()
.map(name -> findPlaceByName(places, name))
.filter(Objects::nonNull)
.collect(Collectors.toList());
course.put(dayLabel, dayPlaces);

// 이미 사용된 장소를 제외하여 유니크한 장소만 할당 (글로벌 중복 제거)
List<Place> uniqueForDay = recommendedPlaces.stream()
.filter(place -> !usedPlaceIds.contains(place.getId()))
.collect(Collectors.toList());

// 선택된 장소들을 전역 사용 목록에 추가
uniqueForDay.forEach(place -> usedPlaceIds.add(place.getId()));

course.put(dayLabel, uniqueForDay);
}
// 저장할 데이터를 위해 각 일차별 Place 이름 목록으로 변환
Map<String, List<String>> courseByName = new LinkedHashMap<>();
Expand Down Expand Up @@ -132,7 +161,8 @@ private String buildPrompt(TripPlan tripPlan, List<Place> places) {
.append("여행 종료일: ").append(tripPlan.getEndDate()).append("\n")
.append("총 여행 일수: ").append(totalDays).append("일\n")
.append("여행 지역: ").append(tripPlan.getLocation()).append("\n\n")
.append("다음은 방문 가능한 장소 목록 (이름, 주소, 위도, 경도)입니다:\n");
.append("다음은 해당 여행 계획에 등록된 방문 가능한 장소 목록입니다.\n")
.append("※ 일정 생성 시 반드시 아래 목록에 있는 장소 이름만 사용하며, 동일한 장소가 중복되지 않도록 해 주세요.\n");
for (Place p : places) {
promptBuilder.append("- ").append(p.getPlaceName())
.append(" (주소: ").append(p.getAddress())
Expand All @@ -143,27 +173,30 @@ private String buildPrompt(TripPlan tripPlan, List<Place> places) {
.append("위 정보를 바탕으로, 총 ").append(totalDays)
.append("일의 여행 일정(각 일차에 방문할 장소 추천)을 생성해줘.\n")
.append("일정은 반드시 '1일차', '2일차', ... '").append(totalDays)
.append("일차' 형식의 키를 가지며, 각 키의 값은 해당 일차에 추천할 장소들의 이름 목록이어야 합니다.\n")
.append("일차' 형식의 키를 가지며, 각 키의 값은 위 목록에 있는 장소들의 이름만 포함해야 합니다.\n")
.append("예시:\n")
.append("{\"1일차\": [\"장소 A\", \"장소 B\"], \"2일차\": [\"장소 C\", \"장소 D\"], ...}");
return promptBuilder.toString();
}


private String callGptApi(String prompt) {
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", "gpt-4o-mini");
Map<String, String> message = new HashMap<>();
message.put("role", "user");
message.put("content", prompt);
requestBody.put("messages", List.of(message));
return webClient.post()
String response = webClient.post()
.uri("/chat/completions")
.header("Authorization", "Bearer " + apiKey)
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(requestBody)
.retrieve()
.bodyToMono(String.class)
.block();
logger.info("GPT API 응답: {}", response);
return response;
}

private Map<String, List<String>> parseGptResponse(String gptResponse) {
Expand All @@ -183,8 +216,9 @@ private Map<String, List<String>> parseGptResponse(String gptResponse) {
}

private Place findPlaceByName(List<Place> places, String name) {
String normalizedName = name.trim().toLowerCase();
return places.stream()
.filter(p -> p.getPlaceName().equalsIgnoreCase(name))
.filter(p -> p.getPlaceName().trim().toLowerCase().equals(normalizedName))
.findFirst()
.orElse(null);
}
Expand Down