diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..8f2da54 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,3 @@ +# Configuration: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository + +blank_issues_enabled: false diff --git "a/.github/ISSUE_TEMPLATE/\342\235\223-question.md" "b/.github/ISSUE_TEMPLATE/\342\235\223-question.md" new file mode 100644 index 0000000..9741509 --- /dev/null +++ "b/.github/ISSUE_TEMPLATE/\342\235\223-question.md" @@ -0,0 +1,26 @@ +--- +name: "โ“ Question" +about: "Ask a question about this project \U0001F393" +title: '' +labels: '' +assignees: '' + +--- + +## Checklist + + + +- [ ] I've searched the project's [`issues`] + +## โ“ Question + + + +How can I [...]? + +Is it possible to [...]? + +## ๐Ÿ“Ž Additional context + + diff --git "a/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" "b/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" new file mode 100644 index 0000000..e858c6e --- /dev/null +++ "b/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" @@ -0,0 +1,31 @@ +--- +name: "\U0001F41B Bug Report" +about: "If something isn't working \U0001F527" +title: '' +labels: '' +assignees: '' + +--- + +## ๐Ÿ› Bug Report + + + +## ๐Ÿ”ฌ How To Reproduce + +Steps to reproduce the behavior: + +1. ... + +### Environment + +- OS: [e.g. Linux / Windows / macOS] +- Python version, get it with: + +```bash +python --version +``` + +## ๐Ÿ“Ž Additional context + + diff --git "a/.github/ISSUE_TEMPLATE/\360\237\232\200-feature-request.md" "b/.github/ISSUE_TEMPLATE/\360\237\232\200-feature-request.md" new file mode 100644 index 0000000..ba330a3 --- /dev/null +++ "b/.github/ISSUE_TEMPLATE/\360\237\232\200-feature-request.md" @@ -0,0 +1,16 @@ +--- +name: "\U0001F680 Feature Request" +about: "Suggest an idea for this project \U0001F3D6" +title: '' +labels: '' +assignees: '' + +--- + +## ๐Ÿš€ Feature Request + + + +## ๐Ÿ“Ž Additional context + + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..4dab74c --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,7 @@ +## Description + + + +## Related Issue + + diff --git a/.github/workflows/check-lint.yml b/.github/workflows/check-lint.yml new file mode 100644 index 0000000..ac24a0b --- /dev/null +++ b/.github/workflows/check-lint.yml @@ -0,0 +1,32 @@ +name: check-lint + +on: [pull_request] + +jobs: + check-lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Cache pip + uses: actions/cache@v2 + with: + # This path is specific to Ubuntu + path: ~/.cache/pip + # Look to see if there is a cache hit for the corresponding requirements file + key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + ${{ runner.os }}- + - name: Install dependencies + run: | + python3 -m pip install --upgrade pip + - name: Check Lint (black, flake8, isort) + run: | + make quality diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..345ab82 --- /dev/null +++ b/.gitignore @@ -0,0 +1,40 @@ +.env +*.log +*.csv +*.parquet +.ipynb_checkpoints/ +.deepeval +.deepeval_telemtry.txt + + +.DS_Store +*.json +*.csv +*.pdf +*.png +*.jpg +*.jpeg +*.gif +*.bmp +*.tiff +*.ico +*.webp +*.pdf + +/vector_db +*.bin +*.sqlite3 +__pycache__/ + +.hydra/ +*.pickle + +test.sh + +outputs/ +dist/ +passage_encoder/ +query_encoder/ + +!images/ +!images/**/* \ No newline at end of file diff --git a/.gitmessage b/.gitmessage new file mode 100644 index 0000000..98db672 --- /dev/null +++ b/.gitmessage @@ -0,0 +1,15 @@ +# Title: Summary, imperative, start upper case, don't end with a period +# No more than 50 chars. #### 50 chars is here: # + +# Remember blank line between title and body. + +# Body: Explain *what* and *why* (not *how*). Include task ID (Jira issue). +# Wrap at 72 chars. ################################## which is here: # + +# feat : ๊ธฐ๋Šฅ (์ƒˆ๋กœ์šด ๊ธฐ๋Šฅ) +# fix : ๋ฒ„๊ทธ (๋ฒ„๊ทธ ์ˆ˜์ •) +# refactor: ๋ฆฌํŒฉํ† ๋ง +# style : ์Šคํƒ€์ผ (์ฝ”๋“œ ํ˜•์‹, ์„ธ๋ฏธ์ฝœ๋ก  ์ถ”๊ฐ€: ๋น„์ฆˆ๋‹ˆ์Šค ๋กœ์ง์— ๋ณ€๊ฒฝ ์—†์Œ) +# docs : ๋ฌธ์„œ (๋ฌธ์„œ ์ถ”๊ฐ€, ์ˆ˜์ •, ์‚ญ์ œ) +# test : ํ…Œ์ŠคํŠธ (ํ…Œ์ŠคํŠธ ์ฝ”๋“œ ์ถ”๊ฐ€, ์ˆ˜์ •, ์‚ญ์ œ: ๋น„์ฆˆ๋‹ˆ์Šค ๋กœ์ง์— ๋ณ€๊ฒฝ ์—†์Œ) +# chore : ๊ธฐํƒ€ ๋ณ€๊ฒฝ์‚ฌํ•ญ (๋นŒ๋“œ ์Šคํฌ๋ฆฝํŠธ ์ˆ˜์ •,ํŒจํ‚ค์ง€ ๋งค๋‹ˆ์ € ์ˆ˜์ •์‚ฌํ•ญ ๋“ฑ) diff --git a/FE/.gitignore b/FE/.gitignore new file mode 100644 index 0000000..ac3b64f --- /dev/null +++ b/FE/.gitignore @@ -0,0 +1,25 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +.env +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/FE/README.md b/FE/README.md new file mode 100644 index 0000000..7344166 --- /dev/null +++ b/FE/README.md @@ -0,0 +1,57 @@ +# Front-End + +## ์‹คํ–‰ ๋ฐฉ๋ฒ• + +### ๐Ÿ’ฝ  ํŒจํ‚ค์ง€ ์„ค์น˜ +``` bash +npm install +``` + +### ๐Ÿƒ๐Ÿป  ์‹คํ–‰ +``` bash +npm run dev +``` + +### โš’๏ธ  ๋นŒ๋“œ +``` bash +npm run build +``` + +## ๐Ÿ”ข  Version +React `18.3.1` , react-router-dom `7.1.5` , mui `6.4.2` , axios `1.7.9` + +## ๐Ÿ“‚  ํŒŒ์ผ ๊ตฌ์กฐ +``` +FE/ +โ”œโ”€โ”€ index.html +โ”œโ”€โ”€ proxy-server.js +โ”œโ”€โ”€ tailwind.config.js +โ”œโ”€โ”€ vite.config.js +โ”œโ”€โ”€ package-lock.json +โ”œโ”€โ”€ package.json +โ”œโ”€โ”€ README.MD +โ”œโ”€โ”€ node_modules/ +โ”œโ”€โ”€ public/ +| โ”œโ”€โ”€ favicon.ico +โ”œโ”€โ”€ src/ +โ”‚ โ”œโ”€โ”€ App.css +โ”‚ โ”œโ”€โ”€ index.css +โ”‚ โ”œโ”€โ”€ App.jsx +โ”‚ โ”œโ”€โ”€ main.jsx +โ”‚ โ”œโ”€โ”€ api/ +โ”‚ โ”‚ โ”œโ”€โ”€ api.jsx +โ”‚ โ”‚ โ”œโ”€โ”€ query.jsx +โ”‚ โ”œโ”€โ”€ assets/ +โ”‚ โ”‚ โ”œโ”€โ”€ icon/ +โ”‚ โ”œโ”€โ”€ components/ +โ”‚ โ”‚ โ”œโ”€โ”€ atom/ +โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ ... +โ”‚ โ”‚ โ”œโ”€โ”€ module/ +โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ ... +โ”‚ โ”‚ โ”œโ”€โ”€ page/ +โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ chat.jsx +โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ home.jsx +โ”‚ โ”œโ”€โ”€ configs/ +โ”‚ โ”‚ โ”œโ”€โ”€ router.jsx +โ”‚ โ”‚ โ”œโ”€โ”€ theme.jsx +``` diff --git a/FE/eslint.config.js b/FE/eslint.config.js new file mode 100644 index 0000000..bb72d11 --- /dev/null +++ b/FE/eslint.config.js @@ -0,0 +1,12 @@ +import globals from "globals"; +import pluginJs from "@eslint/js"; +import pluginReact from "eslint-plugin-react"; + + +/** @type {import('eslint').Linter.Config[]} */ +export default [ + {files: ["**/*.{js,mjs,cjs,jsx}"]}, + {languageOptions: { globals: globals.browser }}, + pluginJs.configs.recommended, + pluginReact.configs.flat.recommended, +]; \ No newline at end of file diff --git a/FE/index.html b/FE/index.html new file mode 100644 index 0000000..467c3e4 --- /dev/null +++ b/FE/index.html @@ -0,0 +1,13 @@ + + + + + + + FRAG + + +
+ + + diff --git a/FE/postcss.config.js b/FE/postcss.config.js new file mode 100644 index 0000000..2aa7205 --- /dev/null +++ b/FE/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +}; diff --git a/FE/proxy-server.js b/FE/proxy-server.js new file mode 100644 index 0000000..21e483c --- /dev/null +++ b/FE/proxy-server.js @@ -0,0 +1,33 @@ +import express from 'express'; +import axios from 'axios'; +import cors from 'cors'; + +const app = express(); +const PORT = 3001; + +const URL = 'https://api-v2.deepsearch.com/v1/articles/economy'; +const apiKey = import.meta.env.VITE_NEWS_API_KEY; + +app.use(cors()); + +app.get('/api/news', async (req, res) => { + const { keyword, date_from, date_to, page_size } = req.query; + + try { + const response = await axios.get(URL, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + params: { keyword, date_from, date_to, page_size}, + }); + + res.json(response.data); + } catch (error) { + console.error('Error fetching news:', error); + res.status(500).json({ error: 'Failed to fetch news' }); + } +}); + +app.listen(PORT, () => { + console.log(`Proxy server is running on http://localhost:${PORT}`); +}); diff --git a/FE/src/App.css b/FE/src/App.css new file mode 100644 index 0000000..a6e3cec --- /dev/null +++ b/FE/src/App.css @@ -0,0 +1,9 @@ +.body-style { + background: #101010; + background-size: cover; + width: 100vw; + height: 100vh; + position: fixed; + top: 0; + left: 0; +} \ No newline at end of file diff --git a/FE/src/App.jsx b/FE/src/App.jsx new file mode 100644 index 0000000..b7d0bb8 --- /dev/null +++ b/FE/src/App.jsx @@ -0,0 +1,14 @@ +import React from 'react'; +import RouterConfiguration from './configs/router'; + +import './App.css'; + +function App() { + return ( +
+ +
+ ); +} + +export default App; diff --git a/FE/src/api/api.jsx b/FE/src/api/api.jsx new file mode 100644 index 0000000..b61953e --- /dev/null +++ b/FE/src/api/api.jsx @@ -0,0 +1,13 @@ +import axios from 'axios' + +const authorization = btoa('test@email.com:1234'); + +const api = axios.create({ + baseURL: 'http://10.28.224.90:30685/api/', + headers: { + 'Authorization': `Basic ${authorization}`, + 'Content-Type': 'application/json', + } +}); + +export default api; \ No newline at end of file diff --git a/FE/src/api/query.jsx b/FE/src/api/query.jsx new file mode 100644 index 0000000..59ec70a --- /dev/null +++ b/FE/src/api/query.jsx @@ -0,0 +1,31 @@ +import api from './api' + +function requestQuery(sessionId, query, company, model, max_tokens, temperature, chatHistory, requestQuerySuccess, requestQueryFail) { + api + .post('v1/chatting', { + session_id: sessionId, + query: query, + // llm_model: model, + max_tokens: max_tokens, + temperature: temperature, + company: company, + chat_history: chatHistory + }) + .then(requestQuerySuccess) + .catch(requestQueryFail) +} + +function uploadFile(file, uploadFileSuccess, uploadFileFail) { + api + .post('v1/documents/upload', + file, { + headers: { + 'Content-Type': 'multipart/form-data', + }, + }, + ) + .then(uploadFileSuccess) + .catch(uploadFileFail) +} + +export { requestQuery, uploadFile }; \ No newline at end of file diff --git a/FE/src/components/atom/CustomContainer.jsx b/FE/src/components/atom/CustomContainer.jsx new file mode 100644 index 0000000..c6a4132 --- /dev/null +++ b/FE/src/components/atom/CustomContainer.jsx @@ -0,0 +1,36 @@ +import React from 'react' +import { styled, Box } from '@mui/system' + +const CustomBox = styled(Box)( + ({ color, radius, width, height, flexDirection, justifyContent, border, padding, my }) => ` + background-color: #${color}; + color: #ffffff; + font-family: Pretendard-Regular; + width: ${width}; + height: ${height}px; + border-radius: ${radius}px; + display: flex; + flex-direction: ${flexDirection}; + justify-content: ${getJustifyContent(justifyContent)}; + alight-items: center; + margin: ${my} 0 0 0; + padding: 0 ${padding}px; + border: ${getBorder(border)}; + ` +) + +function getBorder(border) { + if (!!border) return '1px solid #4a4a4a'; + else return 'none'; +} + +function getJustifyContent(justifyContent) { + if (!!justifyContent) return justifyContent; + else return 'center'; +} + +export default function CustomContainer({ children, type, placeholder, onChange, color, radius, width, height, flexDirection, justifyContent, border, padding, my }) { + return ( + {children} + ) +} \ No newline at end of file diff --git a/FE/src/components/atom/CustomIcon.jsx b/FE/src/components/atom/CustomIcon.jsx new file mode 100644 index 0000000..6f29986 --- /dev/null +++ b/FE/src/components/atom/CustomIcon.jsx @@ -0,0 +1,20 @@ +import React from 'react' +import { Box } from '@mui/system' + +const CustomBox = styled(Box)( + ({ size }) => ` + width: ${size}px; + height: ${size}px; + display: flex; + justify-content: center; + align-items: center; + ` +) + +export default function CustomIcon({ src, size }) { + return ( + + + + ) +} \ No newline at end of file diff --git a/FE/src/components/atom/IconBox.jsx b/FE/src/components/atom/IconBox.jsx new file mode 100644 index 0000000..4d9736b --- /dev/null +++ b/FE/src/components/atom/IconBox.jsx @@ -0,0 +1,9 @@ +import React from 'react'; + +import { Box } from '@mui/material'; + +export default function IconBox({ children, onClick }) { + return ( + {children} + ) +} \ No newline at end of file diff --git a/FE/src/components/atom/InputText.jsx b/FE/src/components/atom/InputText.jsx new file mode 100644 index 0000000..d08255a --- /dev/null +++ b/FE/src/components/atom/InputText.jsx @@ -0,0 +1,23 @@ +import React from 'react' +import { styled } from '@mui/system' + +const InputBox = styled('input')( + () => ` + background-color: #212222; + color: #ffffff; + font-family: Pretendard-Regular; + font-size: 18px; + outline: none; + border: none; + border-radius: 25px; + display: flex; + width: 500%; + padding: 0 20px; + ` +) + +export default function InputText({ placeholder, onChange, autoFocus, onKeyUp, value }) { + return ( + + ) +} \ No newline at end of file diff --git a/FE/src/components/atom/SideBar.jsx b/FE/src/components/atom/SideBar.jsx new file mode 100644 index 0000000..959d01a --- /dev/null +++ b/FE/src/components/atom/SideBar.jsx @@ -0,0 +1,18 @@ +import React from 'react'; +import { styled, Box } from '@mui/system' + +const SideBarBox = styled(Box)( + () => ` + background-color: #212222; + width: 290px; + height: 100vh; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + ` +) + +export default function SideBar({ children }) { + return {children}; +} \ No newline at end of file diff --git a/FE/src/components/atom/customText.jsx b/FE/src/components/atom/customText.jsx new file mode 100644 index 0000000..e33a731 --- /dev/null +++ b/FE/src/components/atom/customText.jsx @@ -0,0 +1,90 @@ +import React from 'react'; +import { Box, styled } from '@mui/system'; + +const CustomBox = styled(Box)( + ({ size, color, weight, my, mx, justifyContent, height, textAlign }) => ` + color: ${getColor(color)}; + font-family: ${getWeight(weight)}; + font-size: ${getSize(size)}; + height: ${height}; + margin: ${getMargin(my)}px ${getMargin(mx)}px; + display: flex; + align-items: center; + justify-content: ${getJustifyContent(justifyContent)}; + text-align: ${getTextAlign(textAlign)}; + ` +); + +function getSize(size) { + switch (size) { + case 'xxl': + return '35px'; + case 'xl': + return '33px'; + case 'l': + return '30px'; + case 'm': + return '23px'; + case 's': + return '18px'; + case 'xs': + return '15px'; + case 'xxs': + return '13px'; + default: + return '20px'; + } +} + +function getColor(color) { + switch (color) { + case 'primary': + return '#ffffff'; + case 'second': + return '#7A7A7C'; + case 'blur': + return '#A1A1A1'; + case 'up': + return '#E43332'; + case 'down': + return '#3871CA'; + case 'green': + return '#37824A'; + default: + return '#ffffff'; + } +} + +function getWeight(weight) { + switch (weight) { + case 'bold': + return 'Pretendard-Bold'; + case 'light': + return 'Pretendard-Light'; + default: + return 'Pretendard-Regular'; + } +} + +function getMargin(px) { + if (!!px) return px; + else return 0; +} + +function getJustifyContent(justifyContent) { + if (!!justifyContent) return justifyContent; + else return 'center'; +} + +function getTextAlign(textAlign) { + if (!!textAlign) return textAlign; + else return 'center'; +} + +export default function CustomText({ children, size, color, weight, my, mx, justifyContent, textAlign, height }) { + return ( + + {children} + + ) +} \ No newline at end of file diff --git a/FE/src/components/module/ExchangeRateBox.jsx b/FE/src/components/module/ExchangeRateBox.jsx new file mode 100644 index 0000000..42186b5 --- /dev/null +++ b/FE/src/components/module/ExchangeRateBox.jsx @@ -0,0 +1,24 @@ +import React from 'react'; + +import CustomContainer from '../atom/CustomContainer'; +import CustomText from '../atom/CustomText'; + +export default function ExchangeRateBox({ children, rate, yesterdayRate }) { + const currentRate = parseFloat(rate); + const previousRate = parseFloat(yesterdayRate); + const vs = currentRate - previousRate; + const fltRt = ((currentRate / previousRate) - 1) * 100; + + return ( + + {children} + {currentRate?.toFixed(2)} + + {vs < 0 ? `โ–ผ${Math.abs(vs).toFixed(2)}` : `โ–ฒ${Math.abs(vs).toFixed(2)}`} + + + {`${fltRt.toFixed(2)}%`} + + + ); +} diff --git a/FE/src/components/module/ExchangeRateWidget.jsx b/FE/src/components/module/ExchangeRateWidget.jsx new file mode 100644 index 0000000..986e4df --- /dev/null +++ b/FE/src/components/module/ExchangeRateWidget.jsx @@ -0,0 +1,114 @@ +import React, { useState, useEffect } from 'react'; +import axios from 'axios'; +import moment from 'moment'; + +import { Box } from '@mui/system'; +import CustomContainer from '../atom/CustomContainer'; +import ExchangeRateBox from './ExchangeRateBox'; +import LoadingIcon from '../../assets/icon/spinner_widget.gif' + +const LIVE_URL = 'https://api.currencylayer.com/live'; +const HISTORICAL_URL = 'https://api.currencylayer.com/historical'; +const apiKey = import.meta.env.VITE_EXCHANGERATE_API_KEY; +const historicalApiKey = import.meta.env.VITE_HISTORICAL_EXCHANGERATE_API_KEY; + +export default function ExchangeRateWidget() { + const [rateData, setRateData] = useState({}); + const [yesterdayRateData, setYesterdayRateData] = useState({}); + const [error, setError] = useState(null); + const [todayLoading, setTodayLoading] = useState(true); + const [yesterdayloading, setYesterdayLoading] = useState(true); + + async function fetchTodayRate() { + try { + const res = await axios.get(LIVE_URL, { + params: { + access_key: apiKey, + currencies: 'KRW,JPY,EUR,CNY', + source: 'USD', + } + }) + + if (res.data.success) { + const usdToKrw = res.data.quotes.USDKRW; + const jpyToKrw = res.data.quotes.USDJPY ? usdToKrw / res.data.quotes.USDJPY : null; + const eurToKrw = res.data.quotes.USDEUR ? usdToKrw / res.data.quotes.USDEUR : null; + const cnyToKrw = res.data.quotes.USDCNY ? usdToKrw / res.data.quotes.USDCNY : null; + + setRateData({ + USD: usdToKrw, + JPY: jpyToKrw * 100, + EUR: eurToKrw, + CNY: cnyToKrw, + }); + } else { + throw new Error(res.data.error.info); + } + } catch (err) { + setError(err.message); + } finally { + setTodayLoading(false); + } + } + + async function fetchYesterdayRate() { + try { + const res = await axios.get(HISTORICAL_URL, { + params: { + access_key: historicalApiKey, + currencies: 'KRW,JPY,EUR,CNY', + source: 'USD', + date: moment().subtract(1, 'days').format('YYYY-MM-DD'), + } + }) + + if (res.data.success) { + const usdToKrw = res.data.quotes.USDKRW; + const jpyToKrw = res.data.quotes.USDJPY ? usdToKrw / res.data.quotes.USDJPY : null; + const eurToKrw = res.data.quotes.USDEUR ? usdToKrw / res.data.quotes.USDEUR : null; + const cnyToKrw = res.data.quotes.USDCNY ? usdToKrw / res.data.quotes.USDCNY : null; + + setYesterdayRateData({ + USD: usdToKrw, + JPY: jpyToKrw * 100, + EUR: eurToKrw, + CNY: cnyToKrw, + }); + } else { + throw new Error(res.data.error.info); + } + } catch (err) { + setError(err.message); + } finally { + setYesterdayLoading(false); + } + } + + useEffect(() => { + fetchTodayRate(); + fetchYesterdayRate(); + }, []); + + if (error) { + return ( + + + + + + ); + } + + return ( + + + USD + JPY 100 + + + EUR + CNY + + + ); +} diff --git a/FE/src/components/module/IndexWidget.jsx b/FE/src/components/module/IndexWidget.jsx new file mode 100644 index 0000000..bc6f004 --- /dev/null +++ b/FE/src/components/module/IndexWidget.jsx @@ -0,0 +1,82 @@ +import React, { useState, useEffect } from 'react' +import axios from 'axios' +import moment from 'moment' + +import { Box } from '@mui/system' +import CustomText from '../atom/CustomText' +import CustomContainer from '../atom/CustomContainer' +import LoadingIcon from '../../assets/icon/spinner_widget.gif' + +const URL = 'https://apis.data.go.kr/1160100/service/GetMarketIndexInfoService/getStockMarketIndex'; +const apiKey = import.meta.env.VITE_INDEX_API_KEY; +const currentDate = moment().format('YYYYMMDD'); + +export default function IndexWidget() { + const [indexData, setIndexData] = useState(); + + function getIndexSuccess(res) { + const data = res.data.response.body.items.item[0]; + + if (!data || data.length === 0) { + console.error('No data found in API response'); + return; + } + + setIndexData({ + basDt: data.basDt, // ๋‚ ์งœ + clpr: data.clpr, // ์ข…๊ฐ€ + hipr: data.hipr, // ๊ณ ๊ฐ€ + lopr: data.lopr, // ์ €๊ฐ€ + vs: parseFloat(data.vs), // ์ „์ผ ๋Œ€๋น„ + fltRt: parseFloat(data.fltRt), // ๋“ฑ๋ฝ๋ฅ  + }); + } + + function getIndex() { + axios + .get( + URL, + { + params: { + serviceKey: apiKey, + resultType: 'json', + endBasDt: currentDate, + idxNm: '์ฝ”์Šคํ”ผ' + } + } + ) + .then(getIndexSuccess); + } + + useEffect(() => { + const timer = setTimeout(() => { + getIndex(); + }, 100); + + return () => clearTimeout(timer); + }, []) + + return ( + + {indexData ? ( + + + KOSPI + {`${indexData.basDt.slice(4, 6)}.${indexData.basDt.slice(6)}`} + + {`${indexData.clpr}`} + + {indexData.vs < 0 ? `โ–ผ ${-indexData.vs}` : `โ–ฒ ${indexData.vs}`} + {`${indexData.fltRt}%`} + + + ) : ( + + + + + + )} + + ); +} \ No newline at end of file diff --git a/FE/src/components/module/NewsWidget.jsx b/FE/src/components/module/NewsWidget.jsx new file mode 100644 index 0000000..eddc85f --- /dev/null +++ b/FE/src/components/module/NewsWidget.jsx @@ -0,0 +1,57 @@ +import React, { useState, useEffect } from 'react'; +import axios from 'axios'; +import moment from 'moment'; + +import CustomText from '../atom/CustomText'; +import CustomContainer from '../atom/CustomContainer'; + +const URL = import.meta.env.VITE_NEWS_API_URL; +const currentDate = moment().format('YYYY-MM-DD'); +const oneWeekAgoDate = moment().subtract(7, 'days').format('YYYY-MM-DD'); + +export default function NewsWidget() { + const [newsData, setNewsData] = useState([]); + const [currentIndex, setCurrentIndex] = useState(0); + + function getNewsSuccess(res) { + if (res.data.data) { + setNewsData(res.data.data); + } + } + + function getNews() { + axios + .get(URL, { + params: { + date_from: oneWeekAgoDate, + date_to: currentDate, + page_size: '30', + }, + }) + .then(getNewsSuccess); + } + + useEffect(() => { + getNews(); + }, []); + + useEffect(() => { + if (newsData.length > 0) { + const interval = setInterval(() => { + setCurrentIndex((prevIndex) => (prevIndex + 1) % newsData.length); + }, 4000); + + return () => clearInterval(interval); + } + }, [newsData]); + + return ( + + {newsData.length > 0 ? ( + + {'๐Ÿ“ข' + ' ' + newsData[currentIndex].title} + + ) : null} + + ); +} diff --git a/FE/src/components/module/QueryInput.jsx b/FE/src/components/module/QueryInput.jsx new file mode 100644 index 0000000..c95ff10 --- /dev/null +++ b/FE/src/components/module/QueryInput.jsx @@ -0,0 +1,115 @@ +import React, { useState, useRef, useEffect } from 'react' +import { useNavigate } from 'react-router-dom'; + +import { styled, Box } from '@mui/system' +import CustomContainer from '../atom/CustomContainer' +import IconBox from '../atom/IconBox'; +import InputText from '../atom/InputText' + +import FileIcon from '../../assets/icon/addFile.png' +import SearchIcon from '../../assets/icon/search.png' + +import { uploadFile } from '../../api/query'; + +export default function QueryInput({ height, model, mode, onQuerySubmit, onCompanySubmit, onFileUpload, uploadMessage }) { + const navigate = useNavigate(); + const fileInputRef = useRef(null); + + const [file, setFile] = useState(); + const [query, setQuery] = useState(''); + + // Company Mapping + + const dictionary = [ + { keywords: ['๋„ค์ด๋ฒ„', 'naver', 'NAVER'], mappedValue: 'NAVER' }, + { keywords: ['๋กฏ๋ฐ๋ Œํƒˆ', '๋กฏ๋ฐ ๋ Œํƒˆ'], mappedValue: '๋กฏ๋ฐ๋ Œํƒˆ' }, + { keywords: ['์—˜์•ค์—ํ”„'], mappedValue: '์—˜์•ค์—ํ”„' }, + { keywords: ['์นด๋ฑ…', '์นด์นด์˜ค ๋ฑ…ํฌ', '์นด์นด์˜ค๋ฑ…ํฌ'], mappedValue: '์นด์นด์˜ค๋ฑ…ํฌ' }, + { keywords: ['ํฌ๋ž˜ํ”„ํ†ค'], mappedValue: 'ํฌ๋ž˜ํ”„ํ†ค' }, + { keywords: ['ํ•œํ™”์†”๋ฃจ์…˜', 'ํ•œํ™” ์†”๋ฃจ์…˜'], mappedValue: 'ํ•œํ™”์†”๋ฃจ์…˜' }, + { keywords: ['์ œ์ผ์ œ๋‹น', 'CJ์ œ์ผ์ œ๋‹น', 'CJ ์ œ์ผ์ œ๋‹น', 'cj์ œ์ผ์ œ๋‹น', 'cj ์ œ์ผ์ œ๋‹น'], mappedValue: 'CJ์ œ์ผ์ œ๋‹น' }, + { keywords: ['LGํ™”ํ•™', 'LG ํ™”ํ•™', 'lgํ™”ํ•™', 'lg ํ™”ํ•™', '์—˜์ง€ํ™”ํ•™', '์—˜์ง€ ํ™”ํ•™'], mappedValue: 'LGํ™”ํ•™' }, + { keywords: ['SK์ผ€๋ฏธ์นผ', 'SK ์ผ€๋ฏธ์นผ', 'sk์ผ€๋ฏธ์นผ', 'sk ์ผ€๋ฏธ์นผ', '์ผ€๋ฏธ์นผ'], mappedValue: 'SK์ผ€๋ฏธ์นผ' }, + { keywords: ['SKํ•˜์ด๋‹‰์Šค', 'SK ํ•˜์ด๋‹‰์Šค', 'skํ•˜์ด๋‹‰์Šค', 'sk ํ•˜์ด๋‹‰์Šค', 'ํ•˜์ด๋‹‰์Šค'], mappedValue: 'SKํ•˜์ด๋‹‰์Šค' }, + ] + + function mapCompany(input) { + for (const mapping of dictionary) { + if (mapping.keywords.some((keyword) => input.includes(keyword))) { + return mapping.mappedValue; + } + } + return ''; + } + + // Query Input + + function onKeyUp(e) { + if (e.key == 'Enter' && query.trim()) { + onClickSearch(); + } + } + + function onClickSearch() { + if (query.trim()) { + const company = mapCompany(query); + if (onCompanySubmit) onCompanySubmit(company); + + if (mode === 'main') { + navigate('/chat', { state: { query, model, company } }); + } + else if (onQuerySubmit) { + onCompanySubmit(company); + onQuerySubmit(query); + setQuery(''); + } + } + } + + // File Upload + + function onClickFile() { + if (fileInputRef.current) { + fileInputRef.current.click(); + } + } + + function handleFileChange(e) { + setFile(e.target.files[0]); + // onFileUpload('PDF ๋ฐ›์•„๋ผ ~'); + } + + function uploadFileSuccess(res) { + onFileUpload('PDF ์ „์†ก ์™„๋ฃŒ !'); + } + + function uploadFileFail(res) { + onFileUpload('์‹คํŒจ !'); + } + + useEffect(() => { + if (file) { + const formData = new FormData(); + formData.append('file', file); + + uploadFile(formData, uploadFileSuccess, uploadFileFail); + } + }, [file]) + + return ( + + setQuery(e.target.value)} value={query}/> + + +
+ + +
+
+ +
+
+ ) +} \ No newline at end of file diff --git a/FE/src/components/module/QueryOutput.jsx b/FE/src/components/module/QueryOutput.jsx new file mode 100644 index 0000000..c9407d4 --- /dev/null +++ b/FE/src/components/module/QueryOutput.jsx @@ -0,0 +1,67 @@ +import React, { useState, useEffect } from 'react'; + +import { styled, Box } from '@mui/system' +import CustomContainer from '../atom/CustomContainer'; +import CustomText from '../atom/CustomText'; +import IconBox from '../atom/IconBox'; + +import CopyIcon from '../../assets/icon/copy.png' +import ReloadIcon from '../../assets/icon/reload.png' +import LoadingIcon from '../../assets/icon/spinner.gif' + +export default function QueryOutput({ children, answer }) { + const [displayedText, setDisplayedText] = useState(''); + + function onClickCopy() { + navigator.clipboard.writeText(answer); + } + + function onClickReload() { + window.location.reload(); + } + + useEffect(() => { + let index = 0; + + const interval = setInterval(() => { + if (index < answer?.length) { + const char = answer[index]; + + setDisplayedText((prev) => prev + char); + index++; + } else { + clearInterval(interval); + } + }, 20); + + return () => clearInterval(interval); + }, [answer]); + + return ( + + + {children} + + + {displayedText ? + + ๐Ÿ’ฌ + ๋‹ต๋ณ€ + + : + + + hmm .. ์ƒ๊ฐ์ค‘ + + } + + + {displayedText} + + + + + + + ); +} diff --git a/FE/src/components/module/SelectModel.jsx b/FE/src/components/module/SelectModel.jsx new file mode 100644 index 0000000..41c32ec --- /dev/null +++ b/FE/src/components/module/SelectModel.jsx @@ -0,0 +1,24 @@ +import React from 'react'; + +import { Box } from '@mui/material'; + +export default function SelectModel({ onModelChange, selectedValue }) { + function handleChange(e) { + if (onModelChange) { + onModelChange(e.target.value); + } + } + + return ( + + + + + + ); +} diff --git a/FE/src/components/module/StockInfoBox.jsx b/FE/src/components/module/StockInfoBox.jsx new file mode 100644 index 0000000..09d2cb9 --- /dev/null +++ b/FE/src/components/module/StockInfoBox.jsx @@ -0,0 +1,13 @@ +import React from 'react' + +import { Box } from '@mui/material' +import CustomText from '../atom/CustomText' + +export default function StockInfoBox({ text, value, color }) { + return ( + + {text} + {value} + + ) +} \ No newline at end of file diff --git a/FE/src/components/module/StockNewsWidget.jsx b/FE/src/components/module/StockNewsWidget.jsx new file mode 100644 index 0000000..b027c50 --- /dev/null +++ b/FE/src/components/module/StockNewsWidget.jsx @@ -0,0 +1,58 @@ +import React, { useState, useEffect } from 'react'; +import axios from 'axios'; +import moment from 'moment'; + +import CustomText from '../atom/CustomText'; +import CustomContainer from '../atom/CustomContainer'; + +const URL = import.meta.env.VITE_NEWS_API_URL; +const currentDate = moment().format('YYYY-MM-DD'); +const oneWeekAgoDate = moment().subtract(7, 'days').format('YYYY-MM-DD'); + +export default function StockNewsWidget({ company }) { + const [newsData, setNewsData] = useState([]); + const [currentIndex, setCurrentIndex] = useState(0); + + function getNewsSuccess(res) { + if (res.data.data) { + setNewsData(res.data.data); + } + } + + function getNews() { + axios + .get(URL, { + params: { + keyword: `title:${company}`, + date_from: oneWeekAgoDate, + date_to: currentDate, + page_size: '30', + }, + }) + .then(getNewsSuccess); + } + + useEffect(() => { + getNews(); + }, [company]); + + useEffect(() => { + if (newsData.length > 0) { + const interval = setInterval(() => { + setCurrentIndex((prevIndex) => (prevIndex + 1) % newsData.length); + }, 4000); + + return () => clearInterval(interval); + } + }, [newsData]); + + return ( + + {newsData.length > 0 ? ( + + {'๐Ÿ“ข' + ' ' + newsData[currentIndex].title} + + ) : null} + + ); +} diff --git a/FE/src/components/module/StockWidget.jsx b/FE/src/components/module/StockWidget.jsx new file mode 100644 index 0000000..723d9f4 --- /dev/null +++ b/FE/src/components/module/StockWidget.jsx @@ -0,0 +1,104 @@ +import React, { useState, useEffect } from 'react' +import axios from 'axios' +import moment from 'moment' + +import { Box } from '@mui/system' +import CustomText from '../atom/CustomText' +import CustomContainer from '../atom/CustomContainer' +import StockInfoBox from './StockInfoBox' +import LoadingIcon from '../../assets/icon/spinner_widget.gif' + +const URL = 'https://apis.data.go.kr/1160100/service/GetStockSecuritiesInfoService/getStockPriceInfo'; +const apiKey = import.meta.env.VITE_STOCK_API_KEY; +const currentDate = moment().format('YYYYMMDD'); + +export default function StockWidget( { company }) { + const [stockData, setStockData] = useState(); + + function getStockSuccess(res) { + const data = res.data.response.body.items.item[0]; + + if (!data || data.length === 0) { + console.error('No data found in API response'); + return; + } + + setStockData({ + itmsNm: data.itmsNm, // ์ข…๋ชฉ๋ช… + basDt: data.basDt, // ๋‚ ์งœ + srtnCd: data.srtnCd, // ์ข…๋ชฉ์ฝ”๋“œ + mrktCtg: data.mrktCtg, // ์‹œ์žฅ๊ตฌ๋ถ„ + mkp: data.mkp, // ์‹œ๊ฐ€ + clpr: data.clpr, // ์ข…๊ฐ€ + hipr: data.hipr, // ๊ณ ๊ฐ€ + lopr: data.lopr, // ์ €๊ฐ€ + vs: data.vs, // ์ „์ผ ๋Œ€๋น„ + fltRt: data.fltRt, // ๋“ฑ๋ฝ๋ฅ  + mrktTotAmt: data.mrktTotAmt, // ์‹œ๊ฐ€์ด์•ก + trqu: data.trqu, // ๊ฑฐ๋ž˜๋Ÿ‰ + yesterdayClpr: res.data.response.body.items.item[1].clpr, + }); + } + + function getStock() { + axios.get(URL, { + params: { + serviceKey: apiKey, + resultType: 'json', + endBasDt: currentDate, + likeItmsNm: company, + numOfRows: 2, + pageNo: 1, + } + }) + .then(getStockSuccess); + } + + useEffect(() => { + getStock(); + }, [company]) + + return ( + + {stockData ? ( + + + {stockData.itmsNm} + + {stockData.srtnCd} + {stockData.mrktCtg} + + {`${new Intl.NumberFormat().format(stockData.clpr)}`} + + {parseFloat(stockData.vs) < 0 ? `โ–ผ ${new Intl.NumberFormat().format(stockData.vs.slice(1,))}` : `โ–ฒ ${new Intl.NumberFormat().format(stockData.vs)}`} + {`${stockData.fltRt}%`} + + + + + ์‹œ์„ธ์ •๋ณด + + + + ์ข…๋ชฉ์ •๋ณด + + + + + ) : ( + + + + + + + + + + + + + )} + + ); +} \ No newline at end of file diff --git a/FE/src/components/page/ChatPage.jsx b/FE/src/components/page/ChatPage.jsx new file mode 100644 index 0000000..d3a06fe --- /dev/null +++ b/FE/src/components/page/ChatPage.jsx @@ -0,0 +1,98 @@ +import React, { useState, useEffect, useRef } from 'react'; +import { useNavigate, useLocation } from 'react-router-dom'; + +import { styled, Box } from '@mui/system'; +import SideBar from '../atom/SideBar'; +import IconBox from '../atom/IconBox'; + +import StockWidget from '../module/StockWidget'; +import StockNewsWidget from '../module/StockNewsWidget'; +import QueryInput from '../module/QueryInput'; +import QueryOutput from '../module/QueryOutput'; + +import HomeIcon from '../../assets/icon/home.png'; +import Logo from '../../assets/logo.png' + +import { requestQuery } from '../../api/query'; + +export default function ChatPage() { + const navigate = useNavigate(); + const location = useLocation(); + const hasFetched = useRef(false); + + const query = location.state?.query; + const model = location.state?.model || 'GPT-4o-mini'; + + const [sessionId, setSessionId] = useState(''); + const [queries, setQueries] = useState([query]); + const [answers, setAnswers] = useState([]); + const [company, setCompany] = useState(location.state?.company); + const [chatHistory, setChatHistory] = useState([]); + + + const max_tokens = 1000; + const temperature = 0.7; + + function onClickHome() { + navigate('/'); + } + + function handleQuerySubmit(newQuery) { + setQueries((prev) => [...prev, newQuery]); + requestApi(newQuery); + } + + function handleCompanySubmit(newCompany) { + setCompany(newCompany); + } + + function requestApi(query) { + requestQuery( + sessionId, + query, + company === 'NAVER' ? '๋„ค์ด๋ฒ„' : company, + model, + max_tokens, + temperature, + chatHistory, + (res) => { + setSessionId(res.data.session_id); + if (res.data.company === '') setCompany(res.data.company === '๋„ค์ด๋ฒ„' ? 'NAVER' : res.data.company); + setAnswers((prev) => [...prev, res.data.answer]); + setChatHistory(res.data.chat_history); + }, + (err) => console.log('requestQueryFail:', err) + ); + } + + useEffect(() => { + if (!hasFetched.current && query) { + hasFetched.current = true; + requestApi(query); + } + }, [query]); + + return ( + + + + + + + {company && } + + {company && } + + + + + + {queries.map((q, idx) => ( + {q} + ))} + + + + + ); +} \ No newline at end of file diff --git a/FE/src/components/page/MainPage.jsx b/FE/src/components/page/MainPage.jsx new file mode 100644 index 0000000..a53d154 --- /dev/null +++ b/FE/src/components/page/MainPage.jsx @@ -0,0 +1,70 @@ +import React, { useEffect, useState } from 'react'; + +import { styled, Box } from '@mui/system' +import CustomText from '../atom/CustomText'; +import SideBar from '../atom/SideBar' + +import IndexWidget from '../module/IndexWidget'; +import ExchangeRateWidget from '../module/ExchangeRateWidget'; +import NewsWidget from '../module/NewsWidget'; +import QueryInput from '../module/QueryInput'; +import SelectModel from '../module/SelectModel'; + +import Logo from '../../assets/logo.png' + +export default function MainPage() { + const [model, setModel] = useState(''); + const [message, setMessage] = useState(''); + const [visibleMessage, setVisibleMessage] = useState(''); + const [visibleIcon, setVisibleIcon] = useState(''); + + function handleChange(value) { + setModel(value); + }; + + function handleUpload(value) { + setMessage(value); + } + + useEffect(() => { + if (message) { + setVisibleMessage(message); + + if (message === 'PDF ๋ฐ›์•„๋ผ ~') { + setVisibleIcon(LoadingIcon); + } + else if (message === 'PDF ์ „์†ก ์™„๋ฃŒ !') { + setVisibleIcon(''); + const timer = setTimeout(() => { + setVisibleMessage(''); + }, 3000); + + return () => clearTimeout(timer); + } + + } + }, [message]) + + return ( + + + + + + + + + + + + {'์›ํ•˜๋Š” ๊ธˆ์œต์ •๋ณด๋ฅผ ๊ฒ€์ƒ‰ํ•ด๋ณด์„ธ์š” ' + '๐Ÿ”Ž'} + + + + + {visibleMessage} + + + + ); +} diff --git a/FE/src/configs/router.jsx b/FE/src/configs/router.jsx new file mode 100644 index 0000000..e0e2cc4 --- /dev/null +++ b/FE/src/configs/router.jsx @@ -0,0 +1,14 @@ +import React from 'react'; +import { BrowserRouter, Routes, Route } from 'react-router-dom'; + +import MainPage from '../components/page/MainPage'; +import ChatPage from '../components/page/ChatPage'; + +export default function RouterConfiguration() { + return ( + + } /> + } /> + + ); +} \ No newline at end of file diff --git a/FE/src/configs/theme.jsx b/FE/src/configs/theme.jsx new file mode 100644 index 0000000..3a564cf --- /dev/null +++ b/FE/src/configs/theme.jsx @@ -0,0 +1,9 @@ +import { createTheme } from '@mui/material/styles'; + +const theme = createTheme({ + typography: { + fontFamily: "'Pretendard-Regular', sans-serif", + }, +}); + +export default theme; \ No newline at end of file diff --git a/FE/src/index.css b/FE/src/index.css new file mode 100644 index 0000000..58bb9ee --- /dev/null +++ b/FE/src/index.css @@ -0,0 +1,45 @@ +@font-face { + font-family: 'GmarketSansLight'; + src: url('static/fonts/GmarketSansTTFLight.woff') format('woff'); + font-weight: normal; + font-style: normal; +} + +@font-face { + font-family: 'GmarketSansMedium'; + src: url('static/fonts/GmarketSansTTFMedium.woff') format('woff'); + font-weight: normal; + font-style: normal; +} + +@font-face { + font-family: 'GmarketSansBold'; + src: url('static/fonts/GmarketSansTTFBold.woff') format('woff'); + font-weight: normal; + font-style: normal; +} + +@font-face { + font-family: 'Pretendard-Light'; + src: url('https://fastly.jsdelivr.net/gh/Project-Noonnu/noonfonts_2107@1.1/Pretendard-Light.woff') format('woff'); + font-weight: 400; + font-style: normal; +} + +@font-face { + font-family: 'Pretendard-Regular'; + src: url('https://fastly.jsdelivr.net/gh/Project-Noonnu/noonfonts_2107@1.1/Pretendard-Regular.woff') format('woff'); + font-weight: 400; + font-style: normal; +} + +@font-face { + font-family: 'Pretendard-Bold'; + src: url('https://fastly.jsdelivr.net/gh/Project-Noonnu/noonfonts_2107@1.1/Pretendard-SemiBold.woff') format('woff'); + font-weight: 400; + font-style: normal; +} + +@tailwind base; +@tailwind components; +@tailwind utilities; diff --git a/FE/src/main.jsx b/FE/src/main.jsx new file mode 100644 index 0000000..4f47a9c --- /dev/null +++ b/FE/src/main.jsx @@ -0,0 +1,21 @@ +import { StrictMode } from 'react'; +import { createRoot } from 'react-dom/client'; +import { BrowserRouter } from 'react-router-dom'; + +import App from './App.jsx'; +import theme from './configs/theme'; +import './index.css'; + +import { ThemeProvider } from '@mui/material/styles'; +import CssBaseline from '@mui/material/CssBaseline'; + +createRoot(document.getElementById('root')).render( + + + + + + + + , +); diff --git a/FE/tailwind.config.js b/FE/tailwind.config.js new file mode 100644 index 0000000..2051635 --- /dev/null +++ b/FE/tailwind.config.js @@ -0,0 +1,9 @@ +/** @type {import('tailwindcss').Config} */ + +module.exports = { + content: ["./src/**/*.{js,jsx,ts,tsx}"], + theme: { + extend: {}, + }, + plugins: [require("daisyui")], +}; diff --git a/FE/vite.config.js b/FE/vite.config.js new file mode 100644 index 0000000..8b0f57b --- /dev/null +++ b/FE/vite.config.js @@ -0,0 +1,7 @@ +import { defineConfig } from 'vite' +import react from '@vitejs/plugin-react' + +// https://vite.dev/config/ +export default defineConfig({ + plugins: [react()], +}) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9cf8224 --- /dev/null +++ b/Makefile @@ -0,0 +1,39 @@ +clean: clean-pyc clean-test +quality: set-style-dep check-quality +style: set-style-dep set-style + +##### basic ##### +set-git: + git config --local commit.template .gitmessage + git update-index --skip-worktree ./config/config.yaml + +set-style-dep: + pip3 install click==8.0.4 isort==5.13.2 black==24.8.0 flake8==7.1.1 + +set-style: + black --config pyproject.toml . + isort --settings-path pyproject.toml . + flake8 . --max-line-length=120 + +check-quality: + black --config pyproject.toml --check . + isort --settings-path pyproject.toml --check-only . + +##### clean ##### +clean-pyc: + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + +clean-test: + rm -f .coverage + rm -f .coverage.* + rm -rf .pytest_cache + rm -rf .mypy_cache + +clean-all: clean-pyc clean-test clean-build + +clean-build: + rm -rf build/ + rm -rf dist/ \ No newline at end of file diff --git a/PDF_OCR/README.MD b/PDF_OCR/README.MD new file mode 100644 index 0000000..40275b7 --- /dev/null +++ b/PDF_OCR/README.MD @@ -0,0 +1,57 @@ +# PDF_OCR ํŒŒ์ดํ”„๋ผ์ธ + +## ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ ๋ฐฉ๋ฒ• + +### ํŒจํ‚ค์ง€ ์„ค์น˜ +``` bash +pip install -r requirements.txt +``` + +### ์‹คํ–‰ +``` bash +python pdf_parser.py -i "./pdf/input_pdf_folder" -r +python data_postprocessor.py +``` + +--- + +## ํŒŒ์ผ ๊ตฌ์กฐ + +``` +PDF_OCR/ +โ”œโ”€โ”€ config.py +โ”œโ”€โ”€ ocr_api.py +โ”œโ”€โ”€ pdf_parser.py +โ”œโ”€โ”€ ocr_processor.py +โ”œโ”€โ”€ table_converter.py +โ”œโ”€โ”€ data_postprocessor.py +โ”œโ”€โ”€ requirements.txt +โ”œโ”€โ”€ README.MD +โ”œโ”€โ”€ pdf/ +โ”‚ โ”œโ”€โ”€ input_pdf_folder/ +โ”‚ โ”‚ โ”œโ”€โ”€ pdf_file1.pdf +โ”‚ โ”‚ โ”œโ”€โ”€ pdf_file2.pdf +โ”‚ โ”‚ โ””โ”€โ”€ ... +โ”œโ”€โ”€ ocr_results/ +โ”‚ โ”œโ”€โ”€ input_pdf_folder/ +โ”‚ โ”‚ โ”œโ”€โ”€ pdf_file1/ +โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ page_1/ +โ”‚ โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ 1_plain text_3_result.json +โ”‚ โ”‚ โ”‚ โ”‚ โ””โ”€โ”€ ... +โ”‚ โ”‚ โ”‚ โ””โ”€โ”€ ... +โ”‚ โ”‚ โ””โ”€โ”€ ... +``` +--- + +## ํŒŒ์ดํ”„๋ผ์ธ ์„ค๋ช… + +1. PDF ํŒŒ์ผ์„ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ (PDF -> ์ด๋ฏธ์ง€) +2. ์ด๋ฏธ์ง€๋ฅผ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋กœ ์ถ”์ถœ (DocLayout-YOLO) (์ด๋ฏธ์ง€ -> ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค) +3. OCR ์ฒ˜๋ฆฌ (Clova OCR/Upstage Parser API) (๋ฐ”์šด๋”ฉ ๋ฐ•์Šค -> OCR ๊ฒฐ๊ณผ(json)) +4. OCR ๊ฒฐ๊ณผ๋ฅผ ์ •์ œ (json -> json) +4.1. description ๋‹ฌ๊ธฐ +4.2 ํ…Œ์ด๋ธ” description์€ LLM์—๊ฒŒ query +4.2.1 ํ…Œ์ด๋ธ”์€ csv๋„ ์ €์žฅ (json -> csv) +5. ์ •์ œ๋œ ๊ฒฐ๊ณผ๋ฅผ ์ทจํ•ฉ (json -> json) +6. ์ทจํ•ฉ๋œ ๊ฒฐ๊ณผ๋ฅผ vector DB์— ์ €์žฅ (json -> vector DB) + diff --git a/PDF_OCR/__init__.py b/PDF_OCR/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/PDF_OCR/config.py b/PDF_OCR/config.py new file mode 100644 index 0000000..623ef33 --- /dev/null +++ b/PDF_OCR/config.py @@ -0,0 +1,59 @@ +import os +from pathlib import Path + +# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ • +PROJECT_ROOT = Path(__file__).parent # .parent ์ œ๊ฑฐํ•˜์—ฌ PDF_OCR ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ๋ฃจํŠธ๋กœ ์„ค์ • + +# ๊ธฐ๋ณธ ์„ค์ •๊ฐ’๋“ค +DEFAULT_CONFIG = { + # ๋ชจ๋ธ ๊ด€๋ จ ์„ค์ • + "MODEL": { + "path": os.path.expanduser( + "~/.cache/huggingface/hub/models--juliozhao--DocLayout-YOLO-DocStructBench/snapshots/8c3299a30b8ff29a1503c4431b035b93220f7b11/doclayout_yolo_docstructbench_imgsz1024.pt" + ), + # "path": "doclayout_yolo_docstructbench_imgsz1024.pt", # ๊ฐ„๋‹จํ•œ ๊ธฐ๋ณธ ๊ฒฝ๋กœ + "imgsz": 1024, + "line_width": 5, + "font_size": 20, + "conf": 0.2, + "threshold": 0.05, + }, + # ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ • + "DIRS": { + "input_dir": str(PROJECT_ROOT / "pdf"), # PDF ํŒŒ์ผ์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ + "output_dir": str(PROJECT_ROOT / "output"), # ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋ฌผ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ + "database_dir": str(PROJECT_ROOT / "database"), # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ + "ocr_output_dir": str(PROJECT_ROOT / "ocr_results"), # OCR ๊ฒฐ๊ณผ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ + }, + # ํŒŒ์ผ๋ช… ์„ค์ • + "FILES": { + "database": "database.csv", + }, +} + + +def get_config(custom_config=None): + """ + ๊ธฐ๋ณธ ์„ค์ •๊ฐ’๊ณผ ์‚ฌ์šฉ์ž ์ •์˜ ์„ค์ •๊ฐ’์„ ๋ณ‘ํ•ฉํ•˜์—ฌ ๋ฐ˜ํ™˜ + + Args: + custom_config (dict, optional): ์‚ฌ์šฉ์ž ์ •์˜ ์„ค์ •๊ฐ’ + + Returns: + dict: ์ตœ์ข… ์„ค์ •๊ฐ’ + """ + config = DEFAULT_CONFIG.copy() + + if custom_config: + # ์ค‘์ฒฉ๋œ ๋”•์…”๋„ˆ๋ฆฌ ์—…๋ฐ์ดํŠธ + for key, value in custom_config.items(): + if isinstance(value, dict) and key in config: + config[key].update(value) + else: + config[key] = value + + # ๋””๋ ‰ํ† ๋ฆฌ๋“ค ์ƒ์„ฑ + for dir_path in config["DIRS"].values(): + os.makedirs(dir_path, exist_ok=True) + + return config diff --git a/PDF_OCR/data_postprocess.py b/PDF_OCR/data_postprocess.py new file mode 100644 index 0000000..f5dc096 --- /dev/null +++ b/PDF_OCR/data_postprocess.py @@ -0,0 +1,362 @@ +import json +import os +import sys +import time +import warnings + +import pandas as pd +import requests +from bs4 import BeautifulSoup +from dotenv import load_dotenv + +load_dotenv() +warnings.filterwarnings("ignore") + +""" + {//text 3 + "title":"24๋…„ ์˜์—…์ด์ต", + "description":"{์›๋ฌธ}", + "category":"text", + "company":"naver", + "securities":"hana", + "page":"1", + "date":"24.10.17", + "path":"/cation/naver/kybo/1017/1/1_plain text_3.png" + } + api๋ฅผ ์ด์šฉํ•ด html์„ queryํ•ด html์— ๋Œ€ํ•œ ์„ค๋ช…์„ description์— ๋„ฃ์–ด์ค€๋‹ค. + ์ด๋ ‡๊ฒŒ ๋งŒ๋“ค์–ด์ง„ ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋‘ ๋ชจ์•„์„œ ํ•˜๋‚˜์˜ ํŒŒ์ผ๋กœ ์ €์žฅํ•œ๋‹ค. + +""" + + +class MakeData: + def __init__(self): + self.base_folder = "ocr_results" + self.output_folder = "ocr_results" + self.error_cnt = 0 + # ๊ฒฐ๊ณผ ์ €์žฅํ•  ํด๋” ์ƒ์„ฑ + self.existing_data = self.load_existing_data() + self.failed_logs = self.load_failed_logs() + + if not os.path.exists(self.output_folder): + os.makedirs(self.output_folder) + + def load_existing_data(self): + try: + with open("new_data/All_data/table_data.json", "r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return [] + + def load_failed_logs(self): + try: + with open("fail_logs.json", "r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return [] + + def process_folders(self): + data = self.existing_data + try: + # ์ฒซ ๋ฒˆ์งธ ์ฒ˜๋ฆฌ ์‹œ๋„ + self._process_all_folders(data) + + # rate limit ์˜ค๋ฅ˜๊ฐ€ ์žˆ๋Š” ์ผ€์ด์Šค ์žฌ์ฒ˜๋ฆฌ + retry_count = 0 + while retry_count < 3: + rate_limit_files = [] + for log in self.failed_logs: + if log.get("status_code") == "42901": # rate limit ์˜ค๋ฅ˜ + rate_limit_files.append(log["file_path"]) + + if not rate_limit_files: + break + + print(f"\n์žฌ์‹œ๋„ {retry_count + 1}: Rate limit ์˜ค๋ฅ˜ ํŒŒ์ผ {len(rate_limit_files)}๊ฐœ ์žฌ์ฒ˜๋ฆฌ ์‹œ์ž‘") + time.sleep(60) # rate limit ํ•ด์ œ๋ฅผ ์œ„ํ•ด 1๋ถ„ ๋Œ€๊ธฐ + + for file_path in rate_limit_files: + description = self.process_table_json_files(file_path) + if description: # ์„ฑ๊ณต์ ์œผ๋กœ ์ฒ˜๋ฆฌ๋œ ๊ฒฝ์šฐ + # ์„ฑ๊ณตํ•œ ํŒŒ์ผ์˜ ๋ฐ์ดํ„ฐ ์ถ”๊ฐ€ + path_parts = file_path.split(os.sep) + company = path_parts[1] + broker = path_parts[2] + page = path_parts[3] + broker_date = broker.split("_")[-1] + broker_name = broker_date.split("(")[0] + broker_date = broker_date.split("(")[1].replace(")", "") + + data.append( + { + "title": "", + "description": description, + "category": "table", + "company": company, + "securities": broker_name, + "page": page, + "date": broker_date, + "path": file_path, + } + ) + # ์ฒ˜๋ฆฌ๋œ ํŒŒ์ผ ๋กœ๊ทธ ํŒŒ์ผ ์ œ๊ฑฐ + + # + retry_count += 1 + + # ์ตœ์ข… ๋ฐ์ดํ„ฐ ์ €์žฅ + with open("new_data/All_data/table_data.json", "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + except Exception as e: + print(f"์ „์ฒด ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}") + self.save_failed_log("process_folders", str(e)) + + def _process_all_folders(self, data): + # ๊ธฐ์กด์˜ process_folders ๋กœ์ง์„ ์—ฌ๊ธฐ๋กœ ์ด๋™ + for company in os.listdir(self.base_folder): + company_path = os.path.join(self.base_folder, company) + if not os.path.isdir(company_path): + continue + + # ํšŒ์‚ฌ๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ + company_output = os.path.join(self.output_folder, company) + if not os.path.exists(company_output): + os.makedirs(company_output) + + # ์ฆ๊ถŒ์‚ฌ๋ณ„ ํด๋” ์ˆœํšŒ + for broker in os.listdir(company_path): + broker_path = os.path.join(company_path, broker) + if not os.path.isdir(broker_path): + continue + print(f"์ฆ๊ถŒ์‚ฌ๋ณ„ ํด๋” ์ˆœํšŒ: {broker}") + # ์ฆ๊ถŒ์‚ฌ๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ + broker_output = os.path.join(company_output, broker) + if not os.path.exists(broker_output): + os.makedirs(broker_output) + + # ํŽ˜์ด์ง€๋ณ„ ํด๋” ์ˆœํšŒ + for page in os.listdir(broker_path): + page_path = os.path.join(broker_path, page) + if not os.path.isdir(page_path): + continue + # ํŽ˜์ด์ง€๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ ํ—ท๊ฐˆ๋ ค์ฃฝ๊ฒ ๋„ค + page_output = os.path.join(broker_output, page) + if not os.path.exists(page_output): + os.makedirs(page_output) + + # html ํŒŒ์ผ ์ฒ˜๋ฆฌ + + for file in os.listdir(page_path): + if not file.lower().endswith((".json")): + continue + if not "table" in file: + continue + description = self.process_table_json_files(os.path.join(page_path, file)) + broker_date = broker.split("_")[-1] + broker_name = broker_date.split("(")[0] + broker_date = broker_date.split("(")[1].replace(")", "") + data_category = file.split("_")[1] + data.append( + { + "title": "", + "description": description, + "category": "table", + "company": company, + "securities": broker_name, + "page": page, + "date": broker_date, + "path": f"./ocr_results/{company}/{page}/{file}", + } + ) + + def process_table_json_files(self, input_path): + + try: + with open(input_path, "r", encoding="utf-8-sig") as f: + json_data = json.load(f) + html = json_data["content"]["html"] + + # api๋ฅผ ์ด์šฉํ•ด html์„ queryํ•ด html์— ๋Œ€ํ•œ ์„ค๋ช…์„ description์— ๋„ฃ์–ด์ค€๋‹ค. + # ์ด๋ ‡๊ฒŒ ๋งŒ๋“ค์–ด์ง„ ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋‘ ๋ชจ์•„์„œ ํ•˜๋‚˜์˜ ํŒŒ์ผ๋กœ ์ €์žฅํ•œ๋‹ค. + + api_url = "https://clovastudio.stream.ntruss.com/testapp/v1/chat-completions/HCX-003" + studio_key = os.getenv("clova_studio_api_key") + request_id = os.getenv("clova_request_id") + headers = { + "Authorization": "Bearer " + studio_key, + "X-NCP-CLOVASTUDIO-REQUEST-ID": request_id, + "Content-Type": "application/json; charset=utf-8", + } + # print(f"์ฒ˜๋ฆฌ ์™„๋ฃŒ: {output_base} : {file}") + preset_text = [ + { + "role": "system", + "content": "์ฃผ์–ด์ง„ html์€ table์„ html๋กœ ํ‘œํ˜„ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ•ด๋‹น ํ‘œ์—์„œ ์ˆ˜์น˜๋ฅผ ์ œ์™ธํ•œ ๋ชจ๋“  ํ•ญ๋ชฉ์˜ ์ •๋ณด๋ฅผ ๋ฌธ์žฅ์œผ๋กœ ์š”์•ฝํ•ด์„œ ์•Œ๋ ค์ฃผ์„ธ์š”. ์„ธ๋ถ€ํ•ญ๋ชฉ์˜ ์ •๋ณด๋„ ํฌํ•จํ•ด์ฃผ์„ธ์š”\n์˜ˆ์‹œ: ํ•ด๋‹น ํ‘œ๋Š” 2022A๋ถ€ํ„ฐ 2026F๊นŒ์ง€์˜ ๋งค์ถœ์•ก, ๋งค์ถœ์›๊ฐ€, ๋งค์ถœ์ด์ด์ต, ํŒ๋งค๋น„์™€๊ด€๋ฆฌ๋น„, ์˜์—…์ด์ต, ...(์ „๋ถ€๋‹ค) ์žฌ๋ฌด์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.", + }, + {"role": "user", "content": html}, + {"role": "assistant", "content": ""}, + ] + request_data = { + "messages": preset_text, + "topP": 0.8, + "topK": 0, + "maxTokens": 400, + "temperature": 0.5, + "repeatPenalty": 5.0, + "stopBefore": [], + "includeAiFilters": True, + "seed": 0, + } + # Query Per Minute 60ํšŒ ์ดํ•˜๋กœ ๊ณ ์ • + time.sleep(2) + response = requests.post(api_url, headers=headers, json=request_data) + response_json = response.json() + if response_json["status"]["code"] != "20000": + error_message = response_json["status"]["message"] + print(f"FAILED : {input_path} - {error_message}") + self.save_failed_log(input_path, error_message, response_json["status"]["code"]) + return "" + else: + respon_msg = response_json["status"]["code"] + print(f"{input_path} SUCCESS : {respon_msg} ") + return response_json["result"]["message"]["content"] + + except Exception as e: + + print(f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}") + self.save_failed_log(input_path, str(e)) + return "" + + def save_failed_log(self, file_path, error_message, status_code=None): + log_entry = { + "file_path": file_path, + "error_message": error_message, + "status_code": status_code, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + } + self.failed_logs.append(log_entry) + with open("fail_logs.json", "w", encoding="utf-8") as f: + json.dump(self.failed_logs, f, ensure_ascii=False, indent=2) + + +class TextDataPostprocess: + def __init__(self): + self.base_folder = "ocr_results" + self.output_folder = "ocr_results" + self.existing_data = self.load_existing_data() + self.failed_logs = self.load_failed_logs() + + if not os.path.exists(self.output_folder): + os.makedirs(self.output_folder) + + def load_existing_data(self): + try: + with open("new_data/All_data/text_data.json", "r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return [] + + def load_failed_logs(self): + try: + with open("fail_logs.json", "r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return [] + + def process_folders(self): + data = [] + for company in os.listdir(self.base_folder): + company_path = os.path.join(self.base_folder, company) + if not os.path.isdir(company_path): + continue + + # ํšŒ์‚ฌ๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ + company_output = os.path.join(self.output_folder, company) + if not os.path.exists(company_output): + os.makedirs(company_output) + + # # ์ฆ๊ถŒ์‚ฌ๋ณ„ ํด๋” ์ˆœํšŒ + for broker in os.listdir(company_path): + broker_path = os.path.join(company_path, broker) + if not os.path.isdir(broker_path): + continue + + # # ์ฆ๊ถŒ์‚ฌ๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ + broker_output = os.path.join(company_output, broker) + if not os.path.exists(broker_output): + os.makedirs(broker_output) + + # ํŽ˜์ด์ง€๋ณ„ ํด๋” ์ˆœํšŒ + for page in os.listdir(broker_path): + page_path = os.path.join(broker_path, page) + if not os.path.isdir(page_path): + continue + + # ํŽ˜์ด์ง€๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ ํ—ท๊ฐˆ๋ ค์ฃฝ๊ฒ ๋„ค + page_output = os.path.join(broker_output, page) + + if not os.path.exists(page_output): + os.makedirs(page_output) + + for file in os.listdir(page_path): + + if not file.lower().endswith((".json")): + continue + if not "text" in file: + continue + print(f"text ์ฒ˜๋ฆฌ ์ค‘: {file}") + description = self.process_text_json_files(os.path.join(page_path, file)) + broker_date = broker.split("_")[-1] + broker_name = broker_date.split("(")[0] + broker_date = broker_date.split("(")[1].replace(")", "") + data_category = file.split("_")[1] + print(description) + data.append( + { + "title": "", + "description": description, + "category": "text", + "company": company, + "securities": "All_data", + "page": page, + "date": "All_data", + "path": f"./ocr_results/{company}/{page}/{file}", + } + ) + with open("new_data/All_data/text_data.json", "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def process_text_json_files(self, input_path): + try: + with open(input_path, "r", encoding="utf-8") as f: + json_data = json.load(f) + + # images ๋ฐฐ์—ด์˜ ๊ฐ ์ด๋ฏธ์ง€์—์„œ fields ๋ฐฐ์—ด์„ ์ˆœํšŒํ•˜๋ฉฐ inferText ์ถ”์ถœ + all_text = [] + for image in json_data.get("images", []): + for field in image.get("fields", []): + if "inferText" in field: + all_text.append(field["inferText"]) + + # ๋ชจ๋“  ํ…์ŠคํŠธ๋ฅผ ๊ณต๋ฐฑ์œผ๋กœ ์—ฐ๊ฒฐ + return " ".join(all_text) + + except Exception as e: + print(f"ํ…์ŠคํŠธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}") + self.save_failed_log(input_path, str(e)) + return "" + + +def main(): + + # processor = MakeData() + # processor.process_folders() + + processor2 = TextDataPostprocess() + processor2.process_folders() + sys.exit() + + +if __name__ == "__main__": + main() diff --git a/PDF_OCR/ocr_api.py b/PDF_OCR/ocr_api.py new file mode 100644 index 0000000..12fec81 --- /dev/null +++ b/PDF_OCR/ocr_api.py @@ -0,0 +1,60 @@ +import json +import os +import time +import uuid + +import pandas as pd +import requests +from dotenv import load_dotenv + +load_dotenv() + + +def process_image_ocr(image_file, is_table=False): + """ + ์ด๋ฏธ์ง€ ํŒŒ์ผ์— ๋Œ€ํ•ด OCR์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜ + + Args: + image_file (str): OCR์„ ์ˆ˜ํ–‰ํ•  ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ + + Returns: + dict: OCR ๊ฒฐ๊ณผ + """ + api_url = os.getenv("clova_api_url") + secret_key = os.getenv("clova_secret_key") + """ + naver clova ocr api ์‚ฌ์šฉ + version : model version + requestId : ์š”์ฒญ ๊ณ ์œ  ์‹๋ณ„์ž + timestamp : ์š”์ฒญ ์‹œ๊ฐ„ + enableTableDetection: ํ…Œ์ด๋ธ” ์—ฌ๋ถ€ + """ + request_json = { + "images": [{"format": "png", "name": "demo"}], + "requestId": str(uuid.uuid4()), + "version": "V2", + "timestamp": int(round(time.time() * 1000)), + "enableTableDetection": is_table, + } + + payload = {"message": json.dumps(request_json).encode("UTF-8")} + + with open(image_file, "rb") as f: + files = [("file", f)] + headers = {"X-OCR-SECRET": secret_key} + response = requests.request("POST", api_url, headers=headers, data=payload, files=files) + + return response.json() + + +def upstage_ocr(image_file): + api_url = os.getenv("up_stage_url") + secret_key = os.getenv("up_stage_api_key") + + with open(image_file, "rb") as f: + files = {"document": open(image_file, "rb")} + data = {"ocr": "force", "base64_encoding": "['table']", "model": "document-parse"} + headers = {"Authorization": f"Bearer {secret_key}"} + response = requests.request("POST", api_url, headers=headers, files=files, data=data) + + return response.json() diff --git a/PDF_OCR/ocr_processor.py b/PDF_OCR/ocr_processor.py new file mode 100644 index 0000000..497ffeb --- /dev/null +++ b/PDF_OCR/ocr_processor.py @@ -0,0 +1,104 @@ +import json +import os + +import pandas as pd +from ocr_api import process_image_ocr, upstage_ocr +from table_converter import json_to_table + + +class OCRProcessor: + def __init__(self, base_folder="pdf", output_folder="./ocr_results"): + self.base_folder = base_folder + self.output_folder = output_folder + + # ๊ฒฐ๊ณผ ์ €์žฅํ•  ํด๋” ์ƒ์„ฑ + if not os.path.exists(self.output_folder): + os.makedirs(self.output_folder) + + def process_folders(self): + # PDF ํŒŒ์ผ๋ช…์œผ๋กœ ์ƒ์„ฑ๋œ ํด๋” ์ˆœํšŒ + for pdf_folder in os.listdir(self.base_folder): + pdf_path = os.path.join(self.base_folder, pdf_folder) + if not os.path.isdir(pdf_path): + continue + + # PDF๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ + pdf_output = os.path.join(self.output_folder, pdf_folder) + if not os.path.exists(pdf_output): + os.makedirs(pdf_output) + + # images ํด๋” ๊ฒฝ๋กœ + images_path = os.path.join(pdf_path, "images") + if not os.path.exists(images_path): + continue + + # ํŽ˜์ด์ง€๋ณ„ ํด๋” ์ˆœํšŒ + for page in os.listdir(images_path): + page_path = os.path.join(images_path, page) + if not os.path.isdir(page_path): + continue + + # split_images ํด๋” ๊ฒฝ๋กœ + split_images_path = os.path.join(page_path, "split_images") + if not os.path.exists(split_images_path): + continue + + # ํŽ˜์ด์ง€๋ณ„ ๊ฒฐ๊ณผ ํด๋” ์ƒ์„ฑ + page_output = os.path.join(pdf_output, page) + if not os.path.exists(page_output): + os.makedirs(page_output) + + # ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ฒ˜๋ฆฌ + self.process_image_files(split_images_path, page_output) + + def process_image_files(self, input_path, output_path): + for file in os.listdir(input_path): + # plain text๋‚˜ table์ด ํฌํ•จ๋œ ํŒŒ์ผ๋งŒ ์ฒ˜๋ฆฌ + if not ("plain text" in file.lower() or "table" in file.lower()): + continue + # ํ…Œ์ด๋ธ” ํŒŒ์ผ๋งŒ ์ฒ˜๋ฆฌ + if "table" in file.lower(): + if "caption" in file.lower() or "footnote" in file.lower() or "caption" in file.lower(): + continue + if not file.lower().endswith((".png", ".jpg", ".jpeg")): + continue + + input_file = os.path.join(input_path, file) + output_base = os.path.join(output_path, os.path.splitext(file)[0]) + + try: + # ๊ฒฐ๊ณผ๊ฐ€ ํ…Œ์ด๋ธ”์ธ ๊ฒฝ์šฐ + if "table" in file.lower(): + result = upstage_ocr(input_file) + # JSON ๊ฒฐ๊ณผ ์ €์žฅ + json_path = f"{output_base}_result.json" + with open(json_path, "w", encoding="utf-8") as f: + json.dump(result, f, indent=2, ensure_ascii=False) + + # ํ…Œ์ด๋ธ” ๋ฐ์ดํ„ฐ ์ถ”์ถœ ๋ฐ CSV ์ €์žฅ + try: + table_df = json_to_table(result) + table_df.to_csv(f"{output_base}.csv", encoding="utf-8-sig") + except Exception as e: + print(f"ํ…Œ์ด๋ธ” ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ ({file}): {str(e)}") + + # ์ผ๋ฐ˜ ํ…์ŠคํŠธ์ธ ๊ฒฝ์šฐ + else: + result = process_image_ocr(input_file, is_table=False) + # JSON ๊ฒฐ๊ณผ๋งŒ ์ €์žฅ + with open(f"{output_base}_result.json", "w", encoding="utf-8") as f: + json.dump(result, f, indent=2, ensure_ascii=False) + + print(f"์ฒ˜๋ฆฌ ์™„๋ฃŒ: {file}") + + except Exception as e: + print(f"ํŒŒ์ผ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ ({file}): {str(e)}") + + +def main(): + processor = OCRProcessor() + processor.process_folders() + + +if __name__ == "__main__": + main() diff --git a/PDF_OCR/pdf_parser.py b/PDF_OCR/pdf_parser.py new file mode 100644 index 0000000..707085b --- /dev/null +++ b/PDF_OCR/pdf_parser.py @@ -0,0 +1,659 @@ +from typing import Any, Callable, Dict, List, Tuple + +import argparse +import os +import re +import shutil +import sys +from collections import defaultdict +from functools import cmp_to_key +from pathlib import Path + +import cv2 +import huggingface_hub +import numpy as np +import pandas as pd +import torch +from config import get_config +from doclayout_yolo import YOLOv10 +from huggingface_hub import hf_hub_download # ์ƒ๋‹จ์— import ์ถ”๊ฐ€ +from pdf2image import convert_from_path +from tqdm import tqdm + + +def pdf_to_image(pdf_path: str, save_path: str, db_path: str, verbose: bool = False) -> None: + """ + ์ฃผ์–ด์ง„ PDF ํŒŒ์ผ์„ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , PDF ํŒŒ์ผ์„ ์ง€์ •๋œ ๋””๋ ‰ํ† ๋ฆฌ๋กœ ์ด๋™ํ•˜๋ฉฐ, ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. + ๋˜ํ•œ ๋ณ€ํ™˜ํ•œ ์ •๋ณด๋ฅผ `database.csv` ํŒŒ์ผ์— ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค. + + Args: + pdf_path (str): ๋ณ€ํ™˜ํ•  PDF ํŒŒ์ผ์˜ ๊ฒฝ๋กœ. + save_path (str): ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€์™€ PDF ํŒŒ์ผ์„ ์ €์žฅํ•  ํด๋” ๊ฒฝ๋กœ. + db_path (str): ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๊ฒฝ๋กœ + verbose (bool, optional): ์ด๋ฏธ์ง€ ์ €์žฅ ์ง„ํ–‰ ์ƒํ™ฉ์„ ์ถœ๋ ฅํ• ์ง€ ์—ฌ๋ถ€ (๊ธฐ๋ณธ๊ฐ’์€ False). + + Returns: + None: ํ•จ์ˆ˜๋Š” ๋ฐ˜ํ™˜๊ฐ’์ด ์—†์Šต๋‹ˆ๋‹ค. + """ + + # ์ข…๋ชฉ ์ด๋ฆ„ + company_name = os.path.basename(save_path) + + # PDF ํŒŒ์ผ ์ด๋ฆ„์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํด๋” ์ด๋ฆ„ ์ƒ์„ฑ (ํ™•์žฅ์ž ์ œ์™ธ) + file_name = os.path.splitext(os.path.basename(pdf_path))[0] + + # ํด๋” ๊ฒฝ๋กœ ์ƒ์„ฑ + output_dir = os.path.join(save_path, file_name) # ํ˜„์žฌ ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ ๋‚ด์— ์ƒ์„ฑ + os.makedirs(output_dir, exist_ok=True) + + # PDF ํŒŒ์ผ ์ด๋™ + new_pdf_path = os.path.join(output_dir, os.path.basename(pdf_path)) + shutil.move(pdf_path, new_pdf_path) + + # images ์ €์žฅ ํด๋” ์ƒ์„ฑ + output_dir = os.path.join(output_dir, "images") # ํ˜„์žฌ ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ ๋‚ด์— ์ƒ์„ฑ + os.makedirs(output_dir, exist_ok=True) + + # PDF๋ฅผ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ + images = convert_from_path(new_pdf_path, dpi=300) + + # PDF ํŽ˜์ด์ง€ ์ˆ˜ + num_pages = len(images) + + # ๊ฐ ํŽ˜์ด์ง€๋ฅผ ์ด๋ฏธ์ง€๋กœ ์ €์žฅ + for page_num, image in enumerate(images, start=1): + # ์ด๋ฏธ์ง€ ํŒŒ์ผ๋ช… ์„ค์ • + output_image_path = os.path.join(output_dir, f"page_{page_num}.png") + + # ์ด๋ฏธ์ง€ ์ €์žฅ + image.save(output_image_path, "PNG") + if verbose: + print(f"Page {page_num} saved as {output_image_path}") + + # ํŒŒ์ผ์— ๋Œ€ํ•œ ๋ฉ”ํƒ€ ๋ฐ์ดํ„ฐ ๊ธฐ๋ก + new_data = pd.DataFrame( + { + "company_name": [company_name] * num_pages, + "file_name": [file_name] * num_pages, + "page": [i for i in range(1, num_pages + 1)], + } + ) + + # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์—…๋ฐ์ดํŠธ + if os.path.exists(db_path): + database = pd.read_csv(db_path, encoding="utf-8") + else: + database = pd.DataFrame(columns=["company_name", "file_name", "page"]) + + # concat์œผ๋กœ ๋‘ DataFrame์„ ๋ณ‘ํ•ฉ + database = pd.concat([database, new_data], ignore_index=True) + + # 'page' ์—ด์„ ์ •์ˆ˜ํ˜•์œผ๋กœ ๋ณ€ํ™˜ + database["page"] = database["page"].astype("int") + + # company_name -> file_name -> page ์ˆœ์œผ๋กœ ์˜ค๋ฆ„์ฐจ์ˆœ ์ •๋ ฌ + database = database.sort_values(by=["company_name", "file_name", "page"], ascending=[True, True, True]) + + # database csv๋กœ ์ €์žฅ + database.to_csv(db_path, index=False, encoding="utf-8") + + +def multi_pdf_to_image(root_dir: str, db_path: str) -> None: + """ + ์ฃผ์–ด์ง„ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ ๋‚ด ๋ชจ๋“  ํ•˜์œ„ ๋””๋ ‰ํ† ๋ฆฌ์—์„œ PDF ํŒŒ์ผ์„ ์ฐพ์•„ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. + ๊ฐ PDF ํŒŒ์ผ์€ `pdf_to_image` ํ•จ์ˆ˜๋กœ ์ฒ˜๋ฆฌ๋˜์–ด ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜๋ฉ๋‹ˆ๋‹ค. + + Args: + root_dir (str): PDF ํŒŒ์ผ์ด ์ €์žฅ๋œ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ. + db_path (str): ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๊ฒฝ๋กœ + + Returns: + None: ํ•จ์ˆ˜๋Š” ๋ฐ˜ํ™˜๊ฐ’์ด ์—†์Šต๋‹ˆ๋‹ค. + """ + + # ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ ๋‚ด ๋ชจ๋“  ํ•˜์œ„ ๋””๋ ‰ํ† ๋ฆฌ์™€ ํŒŒ์ผ์„ ์ˆœํšŒ + for dirpath, _, filenames in os.walk(root_dir): + for filename in filenames: + # PDF ํŒŒ์ผ๋งŒ ์ฒ˜๋ฆฌ + if filename.lower().endswith(".pdf"): + pdf_path = os.path.join(dirpath, filename) + print(f"Converting {pdf_path} to images...") + + # ๋™์ผํ•œ ๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ์œ ์ง€ํ•˜๋ฉฐ ์ด๋ฏธ์ง€ ์ €์žฅ + pdf_to_image(pdf_path, dirpath, db_path=db_path, verbose=False) + + +def sort_bounding_boxes(output_data, image_width): + def get_columns(data, image_width, threshold_x=0.085, threshold_diff=1, threshold_column=0.1): + """ + Group bounding boxes into columns based on their x_min values. + """ + # ๋ฐ์ดํ„ฐ๋ฅผ ์ •๋ ฌ + x_mins = np.array([item["coordinates"][0] for item in data]) + sorted_x = np.sort(x_mins) + + # ๊ทธ๋ฃน์„ ์ €์žฅํ•  ๋ฆฌ์ŠคํŠธ + grouped = [] + + # ์ฒซ ๋ฒˆ์งธ ๊ฐ’์„ ์‹œ์ž‘์œผ๋กœ ๊ทธ๋ฃน ์ดˆ๊ธฐํ™” + current_group = [sorted_x[0]] + + # ์ •๋ ฌ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์ˆœํšŒ + for i in range(1, len(sorted_x)): + if abs(sorted_x[i] - current_group[-1]) <= image_width * threshold_x: + # threshold ์ด๋‚ด์˜ ๊ฐ’์€ ๊ฐ™์€ ๊ทธ๋ฃน์œผ๋กœ ์ถ”๊ฐ€ + current_group.append(sorted_x[i]) + else: + # ๊ทธ๋ฃน์„ ์ €์žฅํ•˜๊ณ  ์ƒˆ ๊ทธ๋ฃน ์‹œ์ž‘ + grouped.append(current_group) + current_group = [sorted_x[i]] + + # ๋งˆ์ง€๋ง‰ ๊ทธ๋ฃน ์ถ”๊ฐ€ + grouped.append(current_group) + + grouped_count = list(map(len, grouped)) + # 1. grouped_count์˜ ์˜ค๋ฆ„์ฐจ์ˆœ ์ •๋ ฌ (์›๋ž˜ ์ธ๋ฑ์Šค ์ถ”์ ) + sorted_indices = np.argsort(grouped_count) # ์ •๋ ฌ๋œ ์ธ๋ฑ์Šค + sorted_grouped_count = [grouped_count[i] for i in sorted_indices] # ์ •๋ ฌ๋œ grouped_count + + # 2. diff ๊ณ„์‚ฐ + diffs = np.diff(sorted_grouped_count) + + # 3. diff๊ฐ€ ํŠน์ • ์ž„๊ณ„๊ฐ’ ์ด์ƒ์œผ๋กœ ์ฆ๊ฐ€ํ•œ ์ง€์  ์ฐพ๊ธฐ + sudden_increase_indices = np.where(diffs >= threshold_diff)[0] + 1 # +1์€ diff์˜ ๊ฒฐ๊ณผ๊ฐ€ n-1 ๊ธธ์ด์ด๊ธฐ ๋•Œ๋ฌธ + + if len(sudden_increase_indices) != 0: + # 4. ๊ฐ‘์ž‘์Šค๋Ÿฌ์šด ๋ณ€ํ™” ์ดํ›„์˜ ์›๋ž˜ ์ธ๋ฑ์Šค ์ฐพ๊ธฐ + original_indices = sorted_indices[sudden_increase_indices[0] :] + mode_components_list = [grouped[i] for i in original_indices] + x_column_boundary = [min(mode_components) for mode_components in mode_components_list] + x_column_boundary.sort() + column_bounds = [(0, x_column_boundary[0])] + for i in range(len(x_column_boundary) - 1): + column_bounds.append((x_column_boundary[i], x_column_boundary[i + 1])) + column_bounds.append((x_column_boundary[-1], float("inf"))) + else: # ๋‹ค๋‹จ์€ ๋‚˜๋ˆ„์–ด์ ธ ์žˆ๋Š”๋ฐ ๋‹ค๋‹จ ์ž์ฒด๊ฐ€ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ํ•˜๋‚˜๋กœ ํฌ๊ฒŒ ์ด๋ฃจ์–ด์ ธ์žˆ์œผ๋ฉด + # ์ตœ๋นˆ๊ฐ’์ด 1๋กœ ๋™๋ฅ ์ผ ๊ฒฝ์šฐ x_min ์ขŒํ‘œ ์‚ฌ์ด ๊ฐ„๊ฒฉ์„ ๋ถ„์„ํ•ด์„œ ์ขŒํ‘œ ์‚ฌ์ด ๊ฐ„๊ฒฉ์ด ๊ฐ‘์ž๊ธฐ ์ปค์ง€๋Š” ๊ณณ์„ ๋‹ค๋‹จ์œผ๋กœ ์ธ์‹ํ•˜๊ฒŒ ํ•œ๋‹ค. + gaps = np.diff(sorted_x) + column_threshold = threshold_column * (sorted_x[-1] - sorted_x[0]) + column_indices = np.where(gaps > column_threshold)[0] + + columns = [] + start = 0 + for idx in column_indices: + columns.append(sorted_x[start : idx + 1]) + start = idx + 1 + columns.append(sorted_x[start:]) + + column_bounds = [[col.min(), col.max()] for col in map(np.array, columns)] + column_bounds.insert(0, (0, column_bounds[0][0])) + for i in range(1, len(column_bounds) - 1): + column_bounds[i][1] = column_bounds[i + 1][0] + column_bounds.append((column_bounds[-1][1], float("inf"))) + return column_bounds + + def assign_column(box, column_bounds): + """Assign a bounding box to its column.""" + x_min = box["coordinates"][0] # bounding box์˜ x_min ๊ฐ’์„ ๊ฐ€์ ธ์˜ด + for idx, (col_min, col_max) in enumerate(column_bounds): # ๊ฐ ์ปฌ๋Ÿผ์˜ ๊ฒฝ๊ณ„ ํ™•์ธ + if col_min <= x_min < col_max: # x_min์ด ์ปฌ๋Ÿผ ๊ฒฝ๊ณ„ ์•ˆ์— ์žˆ์œผ๋ฉด + return idx # ํ•ด๋‹น ์ปฌ๋Ÿผ์˜ ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜ํ™˜ + return len(column_bounds) # ์ปฌ๋Ÿผ ๊ฒฝ๊ณ„์— ์†ํ•˜์ง€ ์•Š์œผ๋ฉด ๋งˆ์ง€๋ง‰ ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜ํ™˜ + + def fuzzy_comparator(box1, box2): + # ๋‘ ๋ฐ•์Šค์˜ x_min, y_min ์ขŒํ‘œ ์ถ”์ถœ + x1, y1, _, _ = box1["coordinates"] + x2, y2, _, _ = box2["coordinates"] + + y_threshold = 32 + + # y์ขŒํ‘œ๊ฐ€ ๋น„์Šทํ•˜๋ฉด x์ขŒํ‘œ ๊ธฐ์ค€์œผ๋กœ ๋น„๊ต + if abs(y1 - y2) <= y_threshold: + return x1 - x2 + # ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด y์ขŒํ‘œ ๊ธฐ์ค€์œผ๋กœ ๋น„๊ต + return y1 - y2 + + def sort_within_column(boxes): + """Sort boxes within a column by y_min, then x_min.""" + return sorted(boxes, key=cmp_to_key(fuzzy_comparator)) + # return sorted(boxes, key=lambda b: (b['coordinates'][1], b['coordinates'][0])) + + # Step 1: Detect columns based on x_min values + column_bounds = get_columns(output_data, image_width) + if not column_bounds: + tolerance = 0.05 + sorted_boxes = sorted( + output_data, key=lambda b: ((b["coordinates"][1] // tolerance) * tolerance, b["coordinates"][0]) + ) + return sorted_boxes + else: + column_data = defaultdict(list) + + for box in output_data: + column_idx = assign_column(box, column_bounds) + column_data[column_idx].append(box) + + # Step 2: Sort columns based on width (wide to narrow or left to right if similar) + sorted_columns = sorted( + column_data.items(), + key=lambda c: ( + -(max(box["coordinates"][2] for box in c[1]) - min(box["coordinates"][0] for box in c[1])), + c[0], + ), + ) + + # Step 3: Sort boxes within each column + sorted_boxes = [] + for _, boxes in sorted_columns: + sorted_boxes.extend(sort_within_column(boxes)) + + return sorted_boxes + + +def extract_and_save_bounding_boxes( + image_path, + database_path, + model_path="~/.cache/huggingface/hub/models--juliozhao--DocLayout-YOLO-DocStructBench/snapshots/8c3299a30b8ff29a1503c4431b035b93220f7b11/doclayout_yolo_docstructbench_imgsz1024.pt", + res_path="outputs", + imgsz=1024, + line_width=5, + font_size=20, + conf=0.2, + split_images_foler_name="split_images", + threshold=0.05, + verbose=False, +): + # Automatically select device + device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + + try: + model = YOLOv10(model_path) + + except Exception as e: + print(f"์ง€์ •๋œ ๊ฒฝ๋กœ์—์„œ ๋ชจ๋ธ์„ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {e}") + print("Hugging Face์—์„œ ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค...") + try: + model_path = hf_hub_download( + repo_id="juliozhao/DocLayout-YOLO-DocStructBench", filename="doclayout_yolo_docstructbench_imgsz1024.pt" + ) + model = YOLOv10(model_path) + print(f"๋ชจ๋ธ์„ ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œํ–ˆ์Šต๋‹ˆ๋‹ค: {model_path}") + except Exception as e: + raise Exception(f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์‹คํŒจ: {e}") + + det_res = model.predict( + image_path, + imgsz=imgsz, + conf=conf, + device=device, + ) + annotated_frame = det_res[0].plot(pil=True, line_width=line_width, font_size=font_size) + if not os.path.exists(res_path): + os.makedirs(res_path) + output_path = os.path.join(res_path, image_path.split("/")[-1].replace(".png", "_annotated.png")) + cv2.imwrite(output_path, annotated_frame) + print(f'The result was saved to "{output_path}"') + + # ํด๋ž˜์Šค ID์™€ ์ด๋ฆ„ ๋งคํ•‘ + CLASS_LABELS = { + 0: "title", + 1: "plain text", + 2: "abandon", + 3: "figure", + 4: "figure_caption", + 5: "table", + 6: "table_caption", + 7: "table_footnote", + 8: "isolate_formula", + 9: "formula_caption", + } + + image = cv2.imread(image_path) + + # ๊ฒฐ๊ณผ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ + output_dir = os.path.join(res_path, f"{split_images_foler_name}") + print(f'Split images were saved to "{output_dir}"') + os.makedirs(output_dir, exist_ok=True) + + # ํด๋ž˜์Šค๋ณ„ ๊ณ ์œ  ์ธ๋ฑ์Šค ๊ด€๋ฆฌ + class_indices = defaultdict(int) # ๊ฐ ํด๋ž˜์Šค๋ณ„ ์ €์žฅ ์ธ๋ฑ์Šค + + output_data = [] + unique_boxes = {} # ์ค‘๋ณต๋œ ๋ฐ•์Šค๋ฅผ ๋ฐฉ์ง€ํ•˜๊ณ  ์ตœ๊ณ  ํ™•๋ฅ ๋กœ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•œ ๋”•์…”๋„ˆ๋ฆฌ + + for box in det_res[0].boxes.data: + # Bounding Box ์ •๋ณด ์ถ”์ถœ + x_min, y_min, x_max, y_max = map(int, box[:4].cpu().numpy()) # ์ขŒํ‘œ + confidence = box[4].cpu().numpy() # ์‹ ๋ขฐ๋„ ์ ์ˆ˜ + class_id = int(box[5].cpu().numpy()) # ํด๋ž˜์Šค ID + class_name = CLASS_LABELS.get(class_id, "Unknown") # ํด๋ž˜์Šค ์ด๋ฆ„ ๋งคํ•‘ + + # ์ขŒํ‘œ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์ค‘๋ณต ์ฒดํฌ ๋ฐ ์ตœ๊ณ  ํ™•๋ฅ  ์œ ์ง€ + box_tuple = (x_min, y_min, x_max, y_max) + + # ์ค‘๋ณต ๋ฐ•์Šค๋ฅผ ์ฒดํฌ + overlap_found = False + for existing_key, existing_box in list(unique_boxes.items()): + existing_coordinates = existing_box["coordinates"] + + x_min1, y_min1, x_max1, y_max1 = x_min, y_min, x_max, y_max + x_min2, y_min2, x_max2, y_max2 = existing_coordinates + + # ๊ต์ง‘ํ•ฉ ์˜์—ญ์˜ ์ขŒํ‘œ ๊ณ„์‚ฐ + x_min_inter = max(x_min1, x_min2) + y_min_inter = max(y_min1, y_min2) + x_max_inter = min(x_max1, x_max2) + y_max_inter = min(y_max1, y_max2) + + # ๊ต์ง‘ํ•ฉ ๋ฉด์  + if x_max_inter - x_min_inter > 0 and y_max_inter - y_min_inter > 0: + intersection_area = (x_max_inter - x_min_inter) * (y_max_inter - y_min_inter) + else: + intersection_area = 0 + + # ๋‘ ๋ฐ•์Šค์˜ ๋ฉด์  ๊ณ„์‚ฐ + area1 = (x_max1 - x_min1) * (y_max1 - y_min1) + area2 = (x_max2 - x_min2) * (y_max2 - y_min2) + + if area1 - intersection_area < threshold * area1 and area2 - intersection_area < threshold * area2: + # ๋‘ ๋ฐ•์Šค๊ฐ€ ๊ฑฐ์˜ ์ผ์น˜ํ•˜๋ฉด, ํ™•๋ฅ ์ด ๋” ๋†’์€ ๋ฐ•์Šค๋กœ ๊ต์ฒด + if confidence > existing_box["confidence"]: + del unique_boxes[existing_key] + if box_tuple not in unique_boxes.keys(): + unique_boxes[box_tuple] = { + "class_name": class_name, + "confidence": confidence, + "coordinates": [x_min, y_min, x_max, y_max], + } + overlap_found = True + elif area1 < area2 and area1 - intersection_area < threshold * area1: + # ํ˜„์žฌ ๋ฐ•์Šค๊ฐ€ ๋” ์ž‘์€ ๊ฒฝ์šฐ, ๊ธฐ์กด ๋ฐ•์Šค๋ฅผ ์ œ๊ฑฐ + del unique_boxes[existing_key] + unique_boxes[box_tuple] = { + "class_name": class_name, + "confidence": confidence, + "coordinates": [x_min, y_min, x_max, y_max], + } + overlap_found = True + elif area2 < area1 and area2 - intersection_area < threshold * area2: + # ๊ธฐ์กด ๋ฐ•์Šค๊ฐ€ ๋” ์ž‘์€ ๊ฒฝ์šฐ, ํ˜„์žฌ ๋ฐ•์Šค๋ฅผ ์ถ”๊ฐ€ํ•˜์ง€ ์•Š์Œ + overlap_found = True + + # ์ค‘๋ณต์ด ์—†์œผ๋ฉด ์ƒˆ๋กœ์šด ๋ฐ•์Šค๋ฅผ ์ถ”๊ฐ€ + if not overlap_found: + unique_boxes[box_tuple] = { + "class_name": class_name, + "confidence": confidence, + "coordinates": [x_min, y_min, x_max, y_max], + } + + print("num_split_images: {num_split_images}".format(num_split_images=len(unique_boxes))) + + # ๊ฒฐ๊ณผ ์ €์žฅ ๋ฐ ์ด๋ฏธ์ง€ ์ž๋ฅด๊ธฐ + for _, box_info in unique_boxes.items(): + x_min, y_min, x_max, y_max = box_info["coordinates"] + class_name = box_info["class_name"] + confidence = box_info["confidence"] + + # ํด๋ž˜์Šค๋ณ„ ๊ณ ์œ  ์ธ๋ฑ์Šค ์ถ”๊ฐ€ + class_index = class_indices[class_name] + 1 + class_indices[class_name] += 1 + + # ์ •๋ณด ์ €์žฅ + output_data.append( + { + "box_id": class_index, + "class_name": class_name, + "confidence": float(confidence), + "coordinates": [x_min, y_min, x_max, y_max], + } + ) + + # ๋ฉ”ํƒ€ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ + dir_path = os.path.dirname(image_path) + path_parts = dir_path.split("/") + company_name = path_parts[-3] + file_name = path_parts[-2] + page = os.path.splitext(os.path.basename(image_path))[0] + page = int(page.split("_")[-1]) + + # ouput_data๋ฅผ ๋‹ค๋‹จ์„ ๋”ฐ๋ผ ์œ„์—์„œ ์•„๋ž˜๋กœ ์ฝ๊ณ  ๋‹ค๋ฅธ ๋‹ค๋‹จ์„ ์œ„์—์„œ ์•„๋ž˜๋กœ ์ฝ๋Š” ์ˆœ์„œ๋กœ ์ •๋ ฌ + output_data = sort_bounding_boxes(output_data, image.shape[1]) + + # ์ €์žฅ๋œ ๋ฐ์ดํ„ฐ ํ™•์ธ + if verbose: + for data in output_data: + print(data) + + # ํŒŒ์ผ์— ๋Œ€ํ•œ ๋ฉ”ํƒ€ ๋ฐ์ดํ„ฐ ๊ธฐ๋ก + num_page_components = len(unique_boxes) + new_data = pd.DataFrame( + { + "company_name": [company_name] * num_page_components, + "file_name": [file_name] * num_page_components, + "page": [page] * num_page_components, + "component_index": [i for i in range(1, len(output_data) + 1)], + "component_type": [component["class_name"] for component in output_data], + "component_type_sub_index": [component["box_id"] for component in output_data], + "coordinates-x_min,y_min,x_max,y_max": [ + component["coordinates"] for component in output_data + ], # left, top, right, bottom + "component_type_confidence": [component["confidence"] for component in output_data], + } + ) + + # ๊ฐ component_type์— ๋Œ€ํ•ด ๋ณ„๋„๋กœ 'box_id' ๋งค๊ธฐ๊ธฐ + new_data["component_type_sub_index"] = new_data.groupby("component_type").cumcount() + 1 + new_data["component_index"] = range(1, len(new_data) + 1) + + for _, row in new_data.iterrows(): + # ๋ฐ•์Šค ์˜์—ญ ์ž˜๋ผ๋‚ด๊ธฐ + x_min, y_min, x_max, y_max = row["coordinates-x_min,y_min,x_max,y_max"] + cropped_image = image[y_min:y_max, x_min:x_max] + + # ์ž˜๋ผ๋‚ธ ์ด๋ฏธ์ง€ ์ €์žฅ + cropped_image_path = os.path.join( + output_dir, f"{row['component_index']}_{row['component_type']}_{row['component_type_sub_index']}.png" + ) + cv2.imwrite(cropped_image_path, cropped_image) + database = pd.read_csv(database_path, encoding="utf-8") + + # ์กฐ๊ฑด์— ๋งž๋Š” ํ–‰ ์ธ๋ฑ์Šค๋ฅผ ์ฐพ๊ธฐ + matching_indices = database.loc[ + (database["company_name"] == company_name) & (database["file_name"] == file_name) & (database["page"] == page) + ].index + matching_indices = matching_indices[0] + + # ๊ธฐ์กด DataFrame์—์„œ ํ˜„์žฌ ์ž…๋ ฅ ์ด๋ฏธ์ง€์˜ company_name, file_name, page์— ๋Œ€์‘ํ•˜๋Š” ํ–‰์„ ์‚ญ์ œํ•˜๊ณ  new_data๋ฅผ ์‚ฝ์ž…ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ + # ๋ฌธ์„œ ํŽ˜์ด์ง€ ์ด๋ฏธ์ง€๊ฐ€ ์—ฌ๋Ÿฌ components๋กœ ๋‚˜๋ˆ„์–ด์กŒ์œผ๋ฏ€๋กœ components์— ๋Œ€์‘ํ•˜๋Š” ์—ฌ๋Ÿฌ ํ–‰์œผ๋กœ ๊ธฐ์กด ํ•˜๋‚˜์˜ ํ–‰์„ ๊ต์ฒด + database = pd.concat( + [database.iloc[:matching_indices], new_data, database.iloc[matching_indices + 1 :]] + ).reset_index(drop=True) + + # database csv๋กœ ์ €์žฅ + database.to_csv(database_path, index=False, encoding="utf-8") + + print(f"{company_name}|{file_name}|{page} conversion completed.\n") + + return det_res, output_data + + +def multi_extract_and_save_bounding_boxes( + root_dir: str, + extract_and_save_bounding_boxes: Callable[ + [str, str, str, str, int, int, int, float, str, float, bool], Tuple[Dict, List] + ], + **kwargs: Any, +) -> None: + """ + ๋ฃจํŠธ ํด๋” ๋‚ด์—์„œ ํŠน์ • ํ˜•์‹์˜ ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜๊ณ , + ๊ฒฐ๊ณผ๋ฅผ ์ด๋ฏธ์ง€ ์ด๋ฆ„(ํ™•์žฅ์ž ์ œ๊ฑฐ)๊ณผ ๋™์ผํ•œ ํ•˜์œ„ ํด๋”์— ์ €์žฅํ•˜๋Š” ํ•จ์ˆ˜. + + ์ด ํ•จ์ˆ˜๋Š” ํŒŒ์ผ๋ช…์ด "page_์ˆซ์ž" ํ˜•์‹์ธ ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ์‹๋ณ„ํ•œ ํ›„, + ์ฃผ์–ด์ง„ extract_and_save_bounding_boxes๋ฅผ ์‚ฌ์šฉํ•ด ๊ฐ ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. + + Args: + root_dir (str): ์ด๋ฏธ์ง€ ํŒŒ์ผ์ด ์ €์žฅ๋œ ๋ฃจํŠธ ํด๋” ๊ฒฝ๋กœ. + extract_and_save_bounding_boxes (Callable[..., None]): + ๋‹จ์ผ ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜. ๋‹ค์Œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค: + - image_path (str): ์ฒ˜๋ฆฌํ•  ์ด๋ฏธ์ง€ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ. + - res_path (str): ์ฒ˜๋ฆฌ๋œ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  ํด๋” ๊ฒฝ๋กœ. + - ์ถ”๊ฐ€์ ์ธ ํ‚ค์›Œ๋“œ ์ธ์ž (**kwargs). + **kwargs (Any): extract_and_save_bounding_boxes์— ์ „๋‹ฌ๋  ์ถ”๊ฐ€ ๋งค๊ฐœ๋ณ€์ˆ˜. + + Returns: + None + """ + # ์ด๋ฏธ์ง€ ํ™•์žฅ์ž ์ •์˜ + valid_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".tiff") + + # ์ •๊ทœ์‹ ํŒจํ„ด: ํŒŒ์ผ๋ช…์ด page_์ˆซ์ž ํ˜•์‹์ธ์ง€ ํ™•์ธ + page_pattern = re.compile(r"^page_\d+$") + + # ๋ฃจํŠธ ํด๋”์—์„œ ํŒŒ์ผ ๊ฒ€์ƒ‰ + all_image_paths = [ + os.path.join(dp, f) + for dp, dn, filenames in os.walk(root_dir) + for f in filenames + if f.lower().endswith(valid_extensions) and page_pattern.match(os.path.splitext(f)[0]) + ] + + for image_path in tqdm(all_image_paths, desc="Processing Images", unit="image"): + # ํ˜„์žฌ ์ด๋ฏธ์ง€ ํŒŒ์ผ์ด ์œ„์น˜ํ•œ ํด๋” ๊ฒฝ๋กœ + current_folder = os.path.dirname(image_path) + + # ํŒŒ์ผ ์ด๋ฆ„์—์„œ ํ™•์žฅ์ž๋ฅผ ์ œ๊ฑฐํ•ด ์ถœ๋ ฅ ํด๋”๋ช… ์ƒ์„ฑ + image_name = os.path.splitext(os.path.basename(image_path))[0] # image_name(ํ™•์žฅ์ž ์ œ๊ฑฐ ํŒŒ์ผ ์ด๋ฆ„) ์˜ˆ์‹œ: page_1 + output_folder = os.path.join(current_folder, image_name) # ์ด๋ฆ„์ด ํŒŒ์ผ ์ด๋ฆ„์ด๋ž‘ ๊ฐ™์€ ํด๋” ๊ฒฝ๋กœ + + # ์ถœ๋ ฅ ํด๋”๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•˜๋ฉด ์‚ญ์ œ + if os.path.exists(output_folder): + print("The output folder already exists. It will be deleted and recreated.") + shutil.rmtree(output_folder) + + # ์ถœ๋ ฅ ํด๋” ์ƒ์„ฑ + os.makedirs(output_folder, exist_ok=True) + + # extract_and_save_bounding_boxes ํ˜ธ์ถœ + try: + extract_and_save_bounding_boxes(image_path=image_path, res_path=output_folder, **kwargs) + print("Save completed") + except Exception as e: + print(f"An error occurred while processing {image_path}: {e}") + + print("All images have been processed successfully.") + + +def pdf_parsing_pipeline(config=None) -> None: + + # ์„ค์ • + cfg = get_config(config) + + # ๊ฒฝ๋กœ ์„ค์ • + root_dir = cfg["DIRS"]["input_dir"] + db_path = os.path.join(cfg["DIRS"]["database_dir"], cfg["FILES"]["database"]) + ocr_output_dir = cfg["DIRS"]["ocr_output_dir"] + + # PDF -> ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ + multi_pdf_to_image(root_dir=root_dir, db_path=db_path) + + # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋ฐฑ์—… + shutil.copy(db_path, db_path.replace(".csv", "_temp.csv")) + + # ์ด๋ฏธ์ง€ -> ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ์ถ”์ถœ + multi_extract_and_save_bounding_boxes( + root_dir=root_dir, + extract_and_save_bounding_boxes=extract_and_save_bounding_boxes, + database_path=db_path, + model_path=cfg["MODEL"]["path"], + imgsz=cfg["MODEL"]["imgsz"], + line_width=cfg["MODEL"]["line_width"], + font_size=cfg["MODEL"]["font_size"], + split_images_foler_name="split_images", + conf=cfg["MODEL"]["conf"], + threshold=cfg["MODEL"]["threshold"], + verbose=False, + ) + + # OCR ์ฒ˜๋ฆฌ + from ocr_processor import OCRProcessor + + processor = OCRProcessor(base_folder=root_dir, output_folder=ocr_output_dir) + processor.process_folders() + + print("์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์ฒ˜๋ฆฌ๊ฐ€ ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + + +def parse_args(): + """ + ์ปค๋งจ๋“œ ๋ผ์ธ ์ธ์ž๋ฅผ ํŒŒ์‹ฑํ•ฉ๋‹ˆ๋‹ค. + """ + parser = argparse.ArgumentParser(description="PDF ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜์—ฌ OCR์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํŒŒ์ดํ”„๋ผ์ธ") + + parser.add_argument("--input", "-i", type=str, help="์ž…๋ ฅ PDF ํŒŒ์ผ ๋˜๋Š” PDF ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ") + + parser.add_argument( + "--output-dir", "-o", type=str, default=None, help="๊ฒฐ๊ณผ๋ฌผ์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ (๊ธฐ๋ณธ๊ฐ’: ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ)" + ) + + parser.add_argument("--recursive", "-r", action="store_true", help="๋””๋ ‰ํ† ๋ฆฌ ์ž…๋ ฅ์‹œ ํ•˜์œ„ ๋””๋ ‰ํ† ๋ฆฌ๋„ ์ฒ˜๋ฆฌ") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + # ์ž…๋ ฅ ๊ฒฝ๋กœ ์ฒ˜๋ฆฌ + input_path = Path(args.input).resolve() if args.input else None + if not input_path or not input_path.exists(): + raise ValueError(f"์ž…๋ ฅ ๊ฒฝ๋กœ๊ฐ€ ์œ ํšจํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค: {args.input}") + + # ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ • + if args.output_dir: + output_base = Path(args.output_dir).resolve() + else: + output_base = Path(__file__).parent + + # ์„ค์ • ๊ตฌ์„ฑ + custom_config = { + "DIRS": { + "input_dir": str(input_path.parent if input_path.is_file() else input_path), + "output_dir": str(output_base / "output"), + "database_dir": str(output_base / "database"), + "ocr_output_dir": str(output_base / "ocr_results"), + } + } + + # PDF ํŒŒ์ผ ์ฒ˜๋ฆฌ + if input_path.is_file() and input_path.suffix.lower() == ".pdf": + # ๋‹จ์ผ PDF ํŒŒ์ผ ์ฒ˜๋ฆฌ + if not input_path.parent.samefile(Path(custom_config["DIRS"]["input_dir"])): + # ์ž…๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ๋กœ PDF ํŒŒ์ผ ๋ณต์‚ฌ + os.makedirs(custom_config["DIRS"]["input_dir"], exist_ok=True) + shutil.copy2(input_path, Path(custom_config["DIRS"]["input_dir"]) / input_path.name) + + elif input_path.is_dir(): + # ์ž…๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ๊ฐ€ ์ฒ˜๋ฆฌ ๋””๋ ‰ํ† ๋ฆฌ์™€ ๋‹ค๋ฅธ ๊ฒฝ์šฐ์—๋งŒ ํŒŒ์ผ ๋ณต์‚ฌ + if not input_path.samefile(Path(custom_config["DIRS"]["input_dir"])): + # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ + os.makedirs(custom_config["DIRS"]["input_dir"], exist_ok=True) + + # PDF ํŒŒ์ผ ๋ณต์‚ฌ + if args.recursive: + # ์žฌ๊ท€์ ์œผ๋กœ ๋ชจ๋“  PDF ํŒŒ์ผ ๋ณต์‚ฌ + for pdf_file in input_path.rglob("*.pdf"): + relative_path = pdf_file.relative_to(input_path) + target_path = Path(custom_config["DIRS"]["input_dir"]) / relative_path + os.makedirs(target_path.parent, exist_ok=True) + shutil.copy2(pdf_file, target_path) + else: + # ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ์˜ PDF ํŒŒ์ผ๋งŒ ๋ณต์‚ฌ + for pdf_file in input_path.glob("*.pdf"): + target_path = Path(custom_config["DIRS"]["input_dir"]) / pdf_file.name + os.makedirs(target_path.parent, exist_ok=True) + shutil.copy2(pdf_file, target_path) + + # ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ + pdf_parsing_pipeline(custom_config) + sys.exit(0) diff --git a/PDF_OCR/requirements.txt b/PDF_OCR/requirements.txt new file mode 100644 index 0000000..c0ae835 --- /dev/null +++ b/PDF_OCR/requirements.txt @@ -0,0 +1,4 @@ +torch==2.5.1 +torchvision==0.20.0 +doclayout-yolo==0.0.2 +pdf2image==1.16.1 diff --git a/PDF_OCR/table_converter.py b/PDF_OCR/table_converter.py new file mode 100644 index 0000000..20b0e71 --- /dev/null +++ b/PDF_OCR/table_converter.py @@ -0,0 +1,97 @@ +from typing import Dict, List, Union + +import json +import os +import warnings +from pathlib import Path + +import pandas as pd +from bs4 import BeautifulSoup + +warnings.filterwarnings("ignore") + + +def json_to_table(json_data: Union[str, Dict]) -> pd.DataFrame: + + # JSON ๋ฐ์ดํ„ฐ ๋กœ๋“œ + if isinstance(json_data, str): + with open(json_data, "r", encoding="utf-8") as f: + data = json.load(f) + else: + data = json_data + + try: + html = data["content"]["html"] + + # beautifulsoup๋กœ html ํŒŒ์‹ฑ + soup = BeautifulSoup(html, "html.parser") + + # html์—์„œ ํ…Œ์ด๋ธ” ์ถ”์ถœ + df = pd.read_html(str(soup))[0] + + # csv ์ €์žฅ + return df + # print(f"์ฒ˜๋ฆฌ ์™„๋ฃŒ: {output_base} : {file}") + + except Exception as e: + print(f"ํ…Œ์ด๋ธ” ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}") + + +def convert_json_to_csv( + input_path: Union[str, Path], output_path: Union[str, Path] = None, recursive: bool = False +) -> None: + """ + Args: + input_path (Union[str, Path]): JSON ํŒŒ์ผ ๋˜๋Š” ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ + output_path (Union[str, Path], optional): ์ถœ๋ ฅ ๊ฒฝ๋กœ. + ์ง€์ •ํ•˜์ง€ ์•Š์œผ๋ฉด ์ž…๋ ฅ ํŒŒ์ผ๊ณผ ๋™์ผํ•œ ์œ„์น˜์— ์ €์žฅ + recursive (bool, optional): ๋””๋ ‰ํ† ๋ฆฌ ์ฒ˜๋ฆฌ์‹œ ํ•˜์œ„ ๋””๋ ‰ํ† ๋ฆฌ๋„ ์ฒ˜๋ฆฌํ• ์ง€ ์—ฌ๋ถ€ + """ + input_path = Path(input_path) + + if output_path: + output_path = Path(output_path) + if not output_path.exists(): + output_path.mkdir(parents=True) + + def process_file(json_path: Path) -> None: + try: + # JSON ํŒŒ์ผ์ด ํ…Œ์ด๋ธ” ๊ฒฐ๊ณผ๋ฅผ ํฌํ•จํ•˜๋Š”์ง€ ํ™•์ธ + if not json_path.stem.endswith("_result"): + return + + # ์ถœ๋ ฅ ๊ฒฝ๋กœ ์„ค์ • + if output_path: + csv_path = output_path / f"{json_path.stem.replace('_result', '')}.csv" + else: + csv_path = json_path.parent / f"{json_path.stem.replace('_result', '')}.csv" + + # ๋ณ€ํ™˜ ๋ฐ ์ €์žฅ + table_df = json_to_table(str(json_path)) + table_df.to_csv(csv_path, encoding="utf-8-sig") + print(f"๋ณ€ํ™˜ ์™„๋ฃŒ: {csv_path}") + + except Exception as e: + print(f"ํŒŒ์ผ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ ({json_path.name}): {str(e)}") + + # ๋‹จ์ผ ํŒŒ์ผ ์ฒ˜๋ฆฌ + if input_path.is_file(): + process_file(input_path) + return + + # ๋””๋ ‰ํ† ๋ฆฌ ์ฒ˜๋ฆฌ + if recursive: + json_files = input_path.rglob("*.json") + else: + json_files = input_path.glob("*.json") + + for json_file in json_files: + process_file(json_file) + + +def main(): + convert_json_to_csv("../../PDF_OCR/new_data/All_data/table.json") + + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md new file mode 100644 index 0000000..0e75f9b --- /dev/null +++ b/README.md @@ -0,0 +1,231 @@ +# Level 4. ์ฆ๊ถŒ์‚ฌ ์ž๋ฃŒ ๊ธฐ๋ฐ˜ ์ฃผ์‹ LLM ์„œ๋น„์Šค + +# **ํ”„๋กœ์ ํŠธ ๊ฐœ์š”** + +### **ํ”„๋กœ์ ํŠธ ์ฃผ์ œ** + +1. ์ฃผ์ œ + - ์ฆ๊ถŒ์‚ฌ ์ž๋ฃŒ ๊ธฐ๋ฐ˜ ์ฃผ์‹ LLM ์„œ๋น„์Šค ๊ฐœ๋ฐœ +2. ์š”๊ตฌ์‚ฌํ•ญ + - PDF ๋ฌธ์„œ๋กœ๋ถ€ํ„ฐ ํ…์ŠคํŠธ, ๊ทธ๋ž˜ํ”„ ๋“ฑ ์ •๋ณด์˜ ์ถ”์ถœ + - ๋ฐ์ดํ„ฐ ๋ ˆํฌ์ง€ํ† ๋ฆฌ ๊ตฌ์ถ•(GraphDB, VectorDB ๋“ฑ) + - ์ฟผ๋ฆฌ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ์ ํ•ฉํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ฐพ์•„๋‚ด๋Š” RAG ์‹œ์Šคํ…œ ๊ตฌํ˜„ + - ํ”„๋กฌํ”„ํŠธ ๊ฐœ๋ฐœ + - ๋‹ต๋ณ€ ์ƒ์„ฑ + - Q&A ๊ธฐ๋Šฅ: ์ •๋Ÿ‰ํ‰๊ฐ€ ๋ชฉ์  + - REST API ๋กœ ๊ตฌํ˜„ + - Input: query(์งˆ์˜) + - Output: context(์ฐธ์กฐํ…์ŠคํŠธ), answer(๋‹ต๋ณ€) + +### **๋ฐ์ดํ„ฐ์…‹** + +1. ์ œ๊ณต๋œ ๋ฐ์ดํ„ฐ + - ์ฆ๊ถŒ์‚ฌ ์ž๋ฃŒ ํŒŒ์ผ(PDF) 100๊ฐœ + +### **ํ‰๊ฐ€ ๋ฐฉ๋ฒ•** + +1. ์ •๋Ÿ‰ํ‰๊ฐ€ 50% + - ํ…Œ์ŠคํŠธ์…‹ ์งˆ์˜์— ๋Œ€ํ•œ ๋‹ต๋ณ€ ์„ฑ๋Šฅ โ€“ ์ง€ํ‘œ G-Eval +2. ์ •์„ฑํ‰๊ฐ€ 50% + - ์„œ๋น„์Šค์˜ ์ฐฝ์˜์„ฑ, ์œ ์šฉ์„ฑ, ๊ฐœ๋ฐœ ์™„์„ฑ๋„, ์†Œ์Šค์ฝ”๋“œ ํ’ˆ์งˆ, ๋ฌธ์„œํ™” ์ˆ˜์ค€ + +
+ +# **๐Ÿ‘จ๐Ÿปโ€๐Ÿ’ปย ํŒ€์› ์†Œ๊ฐœ ๋ฐ ์—ญํ• ** + +
+ +| ์ด๋ฆ„ | ํ”„๋กœํ•„ | ์—ญํ•  | +| :--------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------------: | +| ๊ถŒ๊ธฐํƒœ [](https://github.com/starlike6617) | | API ์„ค๊ณ„ ๋ฐ ๊ฐœ๋ฐœ, RESTful API ๊ตฌํ˜„, OCR ๋ฐ์ดํ„ฐ ํ›„์ฒ˜๋ฆฌ | +| ๊ถŒ์œ ์ง„ [](https://github.com/0618yujin) | | ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ์ œ์ž‘, OCR ๋ฐ์ดํ„ฐ ํ›„์ฒ˜๋ฆฌ, Web Design ๋ฐ ๋ฐœํ‘œ ์ž๋ฃŒ | +| ๋ฐ•๋ฌด์žฌ [](https://github.com/Mujae) | | RAG ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌํ˜„, ํ‰๊ฐ€ ์ฝ”๋“œ ๊ตฌํ˜„ ๋ฐ ์‹คํ—˜, ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ์„ ๋ณ„ | +| ๋ฐ•์ •๋ฏธ [](https://github.com/imJeongmi) | | ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ์ œ์ž‘, OCR ๋ฐ์ดํ„ฐ ํ›„์ฒ˜๋ฆฌ, Front-end | +| ์ด์šฉ์ค€ [](https://github.com/elwhyjay) | | PM, ๋ฆฌํŒฉํ† ๋ง ๋ฐ ๊ธฐํƒ€ ๊ตฌํ˜„, ์•„ํ‚คํ…์ณ ์„ค๊ณ„ ๋ฐ ์„œ๋น™ | +| ์ •์›์‹ [](https://github.com/wonsjeong) | | DocLayout ๋ชจ๋“ˆ ๊ตฌํ˜„, Embedding Model, Fine Tuning, ๋ฐœํ‘œ | + +
+
+ +# ํ”„๋กœ์ ํŠธ ์ˆ˜ํ–‰ ๋ฐฉ๋ฒ• + +## 1. PDF OCR + +๐Ÿ“‘ **[PDF OCR ์ƒ์„ธ ์„ค๋ช… ๋ณด๊ธฐ](PDF_OCR/README.MD)** + +![pdf-ocr_flowchart](images/pdf-ocr_flowchart.png) +### 1.1 ์‹คํ–‰ +```bash +python pdf_parser.py -i "./pdf/input_pdf_folder" +python data_postprocessor.py +``` + +## 2. RAG + +๐Ÿ“‘ **[RAG ์ƒ์„ธ ์„ค๋ช… ๋ณด๊ธฐ](app/RAG/README.md)** + +### 2.1 ์‹คํ–‰ + +```bash +cd app/RAG + +# retrieval ํ‰๊ฐ€ +python main.py mode=retrieve + +# generator ํ‰๊ฐ€ +python main.py mode=generate + +# vectordb ์ƒ์„ฑ ๋ฐ ์—…๋ฐ์ดํŠธ +python main.py mode=update_vectordb +``` + +### 2.2 ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ๊ตฌ์ถ• + +- ๋ชฉ์  + - Retriever์˜ Top-K Accuracy ํ‰๊ฐ€ ๋ฐ Retriever, Generator์˜ G-Eval ํ‰๊ฐ€ ์ˆ˜ํ–‰ +- ๋ฐฉ๋ฒ• + - **์งˆ๋ฌธ ์ƒ์„ฑ**: GPT๋ฅผ ํ™œ์šฉํ•˜์—ฌ PDF์—์„œ ๊ฐ ์ข…๋ชฉ์˜ ์ฆ๊ถŒ์‚ฌ๋งˆ๋‹ค text ๊ธฐ๋ฐ˜ ์งˆ๋ฌธ 10๊ฐœ์”ฉ ์ƒ์„ฑ + - **Query ์ •์ œ**: ๊ฐ ์ข…๋ชฉ๋ณ„๋กœ 100๊ฐœ์˜ Query๋ฅผ ์ƒ์„ฑํ•œ ํ›„, ์ค‘๋ณต์„ ์ œ๊ฑฐํ•˜์—ฌ ์ตœ์ข… Query ์„ ์ • + - **๋‹ต๋ณ€ ์ถ”์ถœ**: ์ •์ œ๋œ Query๋ฅผ ๊ฐ ์ฆ๊ถŒ์‚ฌ ๋ฆฌํฌํŠธ์— ์ ์šฉํ•˜์—ฌ answers ๋„์ถœ + - **Ground Truth ๊ฐ•ํ™”**: ์ข…๋ชฉ๋ณ„๋กœ ๋‹ค์–‘ํ•œ ์ฆ๊ถŒ์‚ฌ(5~6๊ฐœ)๋ฅผ ์„ ์ •ํ•˜์—ฌ Ground Truth์˜ ํ’ˆ์งˆ ํ–ฅ์ƒ + - **ํ‘œ&๊ทธ๋ฆผ ์งˆ๋ฌธ ์ถ”๊ฐ€**: ํ‘œ์™€ ๊ทธ๋ฆผ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ์งˆ๋ฌธ์„ 10๊ฐœ ์ถ”๊ฐ€ ์ƒ์„ฑ +- ํ™œ์šฉ + - Retrieval Top-K Accuracy์—๋Š” ์ „์ฒด 1,843๊ฐœ ํ™œ์šฉ + - G-Eval ํ‰๊ฐ€๋Š” 1,843๊ฐœ ์ค‘ 75๊ฐœ ์ƒ˜ํ”Œ ์‚ฌ์šฉ + +### 2.3 Embedding Model ํ‰๊ฐ€ + +| | Top_1 | Top_5 | Top_10 | Top_20 | Top_30 | Top_50 | +| --- | --- | --- | --- | --- | --- | --- | +| TF-IDF | 9.80 | 22.55 | 37.52 | 59.89 | 72.64 | 90.94 | +| BM25 | 12.20 | 28.84 | 42.33 | 63.59 | 79.85 | 96.12 | +| klue/roberta-large | 2.40 | 11.46 | 20.89 | 38.26 | 59.15 | 86.88 | +| klue/bert base | 5.73 | 17.38 | 30.50 | 49.35 | 66.73 | 87.62 | +| multilingual-e5-large-instruct | 11.09 | 29.94 | 44.92 | 66.17 | 80.41 | 94.82 | +| nlpai-lab/KoE5 | 15.16 | 38.26 | 53.42 | 71.72 | 81.52 | 93.53 | +| BAAI/bge-m3 | 15.34 | 41.22 | 56.38 | 73.94 | 84.84 | 96.30 | +| nlpai-lab/KURE-v1 | 16.64 | 42.41 | 58.41 | 76.53 | 85.03 | 95.38 | + +nlpai-lab์˜ KoE5์™€ KURE-v1์ด ์šฐ์ˆ˜ํ•œ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค. ์‹ค์ œ ๋ฌธ์„œ๋ฅผ ๊ฒ€ํ† ํ•œ ๊ฒฐ๊ณผ ํŠน์ • Query์— ๋Œ€ํ•œ ๊ฒ€์ƒ‰ ์„ฑ๋Šฅ์ด ๋” ๋›ฐ์–ด๋‚œ KoE5๋ฅผ ์ตœ์ข… ๋ชจ๋ธ๋กœ ์„ ํƒํ•˜์˜€๋‹ค. +### 2.4 Embedding Model Fine-Tuning + +- Fine-tuning ๋ฐ์ดํ„ฐ: [virattt/financial-qa-10K](https://huggingface.co/datasets/virattt/financial-qa-10K)๋ฅผ ๋ฒˆ์—ญํ•œ ๋ฐ์ดํ„ฐ +- Query Encoder์™€ Passage Encoder๋ฅผ ๋‚˜๋ˆ„์–ด Hard Negative ์—†์ด In-Batch Negatives ๋ฐฉ์‹์œผ๋กœ Multiple Negatives Ranking Loss์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต +- ๊ฒฐ๊ณผ(Top-K Accuracy) + + + | | KoE5 | Fine-Tuned Model | + | --- | --- | --- | + | Top-1 | 15.16 | 18.11 | + | Top-5 | 38.26 | 43.07 | + | Top-10 | 53.42 | 58.78 | + | Top-20 | 71.72 | 75.60 | + | Top-30 | 81.52 | 85.40 | + | Top-50 | 93.53 | 95.93 | + +### 2.5 Vector Store + +- ChromaDB: Metadata๋ฅผ ์ €์žฅํ•˜์—ฌ Filtering ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•˜๊ณ , ํšŒ์‚ฌ๋ณ„ ๊ฒ€์ƒ‰์ด ๊ฐ€๋Šฅํ•ด ์ •๋ณด์˜ ์ •ํ™•์„ฑ์„ ๋†’์ผ ์ˆ˜ ์žˆ๋‹ค. ๋˜ํ•œ, ์„œ๋ฒ„ ์‹คํ–‰ ์ค‘์—๋„ DB๋ฅผ ์—…๋ฐ์ดํŠธํ•  ์ˆ˜ ์žˆ์–ด ์œ ์—ฐ์„ฑ์ด ๋›ฐ์–ด๋‚˜ ์ด๋Ÿฌํ•œ ์  ๋•Œ๋ฌธ์— ์„ ํƒํ–ˆ๋‹ค. + +### 2.6 Reranker + +- Cross Encoder๋กœ ๋ฌธ์„œ์™€ ์งˆ์˜์˜ ์œ ์‚ฌ๋„๋ฅผ ์ธก์ •ํ•˜์—ฌ ๋ฌธ์„œ๋ฅผ ์žฌ์ •๋ ฌ +- ์‹คํ—˜ + +|| Top_1 | Top_5 | Top_10 | Top_20 | Top_30 | Top_50 | +| --- | --- | --- | --- | --- | --- | --- | +| nlpai-lab/KoE5 | 15.16 | 38.26 | 53.42 | 71.72 | 81.52 | 93.53 | +| nlpai-lab/KoE5 + BAAI/bge-reranker-v2-m3 | 19.78 | 43.25 | 61.55 | 77.08 | 85.58 | 95.75 | +| nlpai-lab/KoE5 + Dongjin-kr/ko-reranker | 20.15 | 45.47 | 61.37 | 78.00 | 87.25 | 96.49 | +- Reranker๋ฅผ ์‚ฌ์šฉํ•œ ํ›„ Accuracy๊ฐ€ ์ „๋ฐ˜์ ์œผ๋กœ ์•ฝ 5% ์ด์ƒ ์ฆ๊ฐ€ํ•˜์˜€๊ณ  ๊ทธ ์ค‘ ์„ฑ๋Šฅ์ด ๋” ์ข‹์€ Dongjin-kr/ko-reranker๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค. + +### 2.7 Generator + +- ํ”„๋กฌํ”„ํŠธ ์—”์ง€๋‹ˆ์–ด๋ง +- ์ฟผ๋ฆฌ ๋ฆฌ๋ผ์ดํŒ… + - 2๊ฐœ ์ด์ƒ์˜ ํšŒ์‚ฌ ์ •๋ณด๊ฐ€ ํ•„์š”ํ•˜๊ฑฐ๋‚˜ ์งˆ๋ฌธ์ด ๋ถ€์ ์ ˆํ•œ ๊ฒฝ์šฐ ๋ฆฌ๋ผ์ดํŒ…์„ ํ†ตํ•ด ๊ฒ€์ƒ‰ ์„ฑ๋Šฅ ํ–ฅ์ƒ + +### 2.8 Evaluation + +- G-Eval(Retrieval, Generator) + - Top-K Accuracy, BLEU ๋“ฑ์€ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์ œ๋Œ€๋กœ ๋œ ํ‰๊ฐ€๊ฐ€ ๋ถˆ๊ฐ€๋Šฅํ•˜๊ณ , ์‚ฌ๋žŒ์ด ์ผ์ผ์ด ๋ฐ์ดํ„ฐ๋ฅผ ์ฑ„์ ํ•  ์ˆ˜ ์—†์–ด์„œ LLM-as-a-Judge ๋ฐฉ์‹์œผ๋กœ G-Eval์„ ์„ ํƒํ•˜์˜€๋‹ค. + - ๋น ๋ฅธ ๊ตฌํ˜„๊ณผ ์›ํ™œํ•œ ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด DeepEval Open Source๋ฅผ ํ™œ์šฉ + - Retrieval G-Eval ๊ฒฐ๊ณผ + +| Retrieval (top5) | ์œ ์‚ฌ์„ฑ | ํ•„์ˆ˜ ์ •๋ณด ํฌํ•จ ์—ฌ๋ถ€ | ์งˆ๋ฌธ ์ถฉ์กฑ๋„ | ๊ด€๋ จ์„ฑ | ๊ฐ„๊ฒฐ์„ฑ | total | +|---------------------------------|--------|------------------|------------|--------|--------|--------| +| BAAI/bge-m3 | 2.52 | 3 | 2.34 | 1.92 | 1 | 10.81 | +| nlpai-lab/KoE5 | 2.62 | 3 | 2.36 | 1.98 | 0.99 | 10.98 | +| fine-tuned/nlpai-lab/KoE5 | 2.68 | 2.87 | 2.41 | 1.8 | 1.3 | 11.08 | + + - Generator G-Eval ๊ฒฐ๊ณผ + +| Generation | ๊ด€๋ จ์„ฑ | ์‚ฌ์‹ค์  ์ •ํ™•์„ฑ | ํ•„์ˆ˜ ์ •๋ณด ํฌํ•จ ์—ฌ๋ถ€ | ๋ช…ํ™•์„ฑ๊ณผ ๊ฐ„๊ฒฐ์„ฑ | ๋…ผ๋ฆฌ์  ๊ตฌ์กฐ | ๊ณผํ•˜์ง€์•Š์€ ์„ธ๋ถ€์ •๋ณด | ์ ์ ˆํ•œ ์ถœ์ฒ˜ ํ‘œ์‹œ | ํ˜•์‹ ์ ์ ˆ์„ฑ | ์ถ”๊ฐ€์  ํ†ตํ•ฉ | total | +|------------|------------|------------|------------|------------|------------|------------|------------|------------|------------|--------| +| Top-5 | 2.7 | 2.7 | 2.8 | 2.4 | 1.6 |1.7 | 1.2 | 0.4 | 0.6 | 16.2 | +| **Top-7** | **3.1** | **3.0** | **3.0** | **2.9** | 1.6 |2.0 | 1.3 | 0.4 | 0.7 | 18.3 | +| Top-10 | 3.0 | 2.9 | 2.8 | 2.6 | **1.7** |1.7 | 1.1 | 0.4 | 0.7 | 17.0 | + + +## 3. API + +๐Ÿ“‘ **[API ์ƒ์„ธ ์„ค๋ช… ๋ณด๊ธฐ](app/README.md)** + +REST API ๊ฐœ๋ฐœ (ํŒŒ์ด์ฌ API, Query API) + +### 3.1 ์‹คํ–‰ + +```bash +cd app +uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +### 3.2 Endpoint + +- query +- documents +- chatting + +## 4. FE + +### 4.1 ์‹คํ–‰ + +```bash +cd FE +npm install +npm run dev +``` + + + + + +### 4.2 ๊ธฐ๋Šฅ + +- AI ๋ชจ๋ธ ์„ ํƒ(GPT-4o, GPT-4o-mini, Clova X) +- ์ฒจ๋ถ€ํ•œ PDF ๋ฌธ์„œ๋ฅผ ๋ฒกํ„ฐ DBํ™”ํ•˜์—ฌ ํšจ์œจ์ ์ธ ๊ฒ€์ƒ‰ ์ง€์› +- ์ด์ „ context๋ฅผ ์œ ์ง€ํ•œ ์‹ค์‹œ๊ฐ„ ๋Œ€ํ™” +- ์œ„์ ฏ: ์ฝ”์Šคํ”ผ ์ง€์ˆ˜, ์‹ค์‹œ๊ฐ„ ํ™˜์œจ, ์ตœ์‹  ๊ฒฝ์ œ ๋‰ด์Šค, ์ข…๋ชฉ ๊ด€๋ จ ์ •๋ณด, ์ข…๋ชฉ๋ณ„ ์ตœ์‹  ๋‰ด์Šค + +
+ +# ๊ฒฐ๊ณผ + +### ์‚ฌ์šฉ ๊ธฐ์ˆ  + +- **OCR**: DocLayout-Yolo, Clova OCR, Upstage Parser API +- **VectorDB**: ChromaDB +- **Retriever**: Langchain +- **Generator**: Langchain, LLM-based Answering Model (gpt-4o, Clova X) +- **Evaluation**: G-Eval, Top-K Accuracy +- **API server**: Fastapi +- **Web Front-end**: React.js, Tailwind CSS + +### ํŒ€์›Œํฌ & ํ˜‘์—… ๊ฒฝํ—˜ + +- ํ˜‘์—… ๋„๊ตฌ : Github issue์™€ discussion์œผ๋กœ task ํ• ๋‹น ๋ฐ ํ† ์˜ ๐Ÿค +- Commit ๊ด€๋ฆฌ : Github commit message template์œผ๋กœ ์ผ๊ด€์„ฑ ์œ ์ง€, ํ˜‘์—… ํšจ์œจ ์ฆ๋Œ€ ๐Ÿ“š + +### ํ”„๋กœ์ ํŠธ ์ง„ํ–‰ ๋ฐฉ์‹ + +- ํ”„๋กœ์ ํŠธ ๊ด€๋ฆฌ : Notion์— ์™„๋ฃŒ๋œ ์ผ ๊ณต์œ , Zoom meeting์„ ํ†ตํ•ด ์ง„ํ–‰ ์ƒํ™ฉ ํ† ์˜ diff --git a/app/RAG/README.md b/app/RAG/README.md new file mode 100644 index 0000000..dfa7e64 --- /dev/null +++ b/app/RAG/README.md @@ -0,0 +1,64 @@ +# RAG(์‹คํ—˜) ์‚ฌ์šฉ ๊ฐ€์ด๋“œ + +## 1. ํ™˜๊ฒฝ ์„ค์ • + +### 1.1 ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ • +`.env` ํŒŒ์ผ์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค: +```bash +env๋กœ openai API_Key, naverClova API_key ์„ค์ • +``` + +# 2. ์‹คํ—˜์„ธํŒ…(config ์ˆ˜์ •) + +#### ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค๋ช… +- `passage_embedding_model_name` : vectorDB ๋งŒ๋“ค ๋•Œ ์‚ฌ์šฉํ•˜๋Š” embedding_model +- `query_embedding_model_name` : question embeddingํ•˜๋Š” ๋ชจ๋ธ +- `llm_model_name` : ์ƒ์„ฑํ˜• ๋ชจ๋ธ +- `chat_template` : ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ + + +## 3. ์‹คํ—˜ + +### 3.1 retrieve (ํ‰๊ฐ€) +```bash +cd RAG +python main.py mode=retrieve +``` + +### 3.2 generate (ํ‰๊ฐ€) +```bash +cd RAG +python main.py mode=generate +``` + +### 3.3 update_vectordb +```bash +cd RAG +python main.py mode=update_vectordb +``` + +## 4. ์‘๋‹ต ํ˜•์‹ + +## 4.1 retriever G-eval(5๊ฐ€์ง€ criteria, ์ด์  20) +```json +{ + "question": "question", + "docs": "retrieved_docs", + "ground_truth": "ground_truth", + "criteria1": "criteria1_score", + "criteria2": "criteria2_score", + "final_score": "total_score" +} +``` + +## 4.2 generator G-eval(9๊ฐ€์ง€ criteria, ์ด์  30) +```json +{ + "question": "question", + "generated_answer": "generated_answer", + "ground_truth": "ground_truth", + "criteria1": "criteria1_score", + "criteria2": "criteria2_score", + "final_score": "total_score" +} +``` \ No newline at end of file diff --git a/app/RAG/__init__.py b/app/RAG/__init__.py new file mode 100644 index 0000000..6046b5d --- /dev/null +++ b/app/RAG/__init__.py @@ -0,0 +1,20 @@ +""" +RAG (Retrieval Augmented Generation) ํŒจํ‚ค์ง€ +""" + +import sys +from pathlib import Path + +# ํŒจํ‚ค์ง€ ๋ฃจํŠธ ๊ฒฝ๋กœ +ROOT_PATH = Path(__file__).parent.absolute() + +# Python ๊ฒฝ๋กœ์— RAG ๋ฃจํŠธ ์ถ”๊ฐ€ +if str(ROOT_PATH) not in sys.path: + sys.path.append(str(ROOT_PATH)) + +# ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ +DATA_PATH = ROOT_PATH / "data" +VECTOR_STORE_PATH = ROOT_PATH / "vector_store" + +# ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ +CONFIG_PATH = ROOT_PATH / "configs" diff --git a/app/RAG/configs/config.yaml b/app/RAG/configs/config.yaml new file mode 100644 index 0000000..43cd188 --- /dev/null +++ b/app/RAG/configs/config.yaml @@ -0,0 +1,54 @@ +defaults: + - inference: generator + - ret_eval: ret_test + - ret_finetune: ret_finetune + - _self_ + +datapath: ../data/ +vector_store_path: "../vector_db" +chunk_size: 300 +chunk_overlap: 15 +dataset: ODQA +openai_key: '' +seed: 42 + +# Retrieval ์„ค์ • +retrieval: + top_k: 5 + model_name: "nlpai-lab/KoE5" + rerank: True + reranker_model_name: "BAAI/bge-reranker-v2-m3" + use_mmr: true # MMR ์‚ฌ์šฉ ์—ฌ๋ถ€ + lambda_mult: 0.5 # MMR ๋‹ค์–‘์„ฑ ๊ฐ€์ค‘์น˜ + batch_size: 32 # ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ ํฌ๊ธฐ + timeout: 30 # ๊ฒ€์ƒ‰ ํƒ€์ž„์•„์›ƒ + cache_size: 1000 # ์บ์‹œ ํฌ๊ธฐ + parallel_workers: 4 # ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ์›Œ์ปค ์ˆ˜ + +eval_data_path: "data/ephemeral/data/LabQ/selected_eval.csv" +retriever_type: "dense" +embedding_model_source: "huggingface" +passage_embedding_model_name: "nlpai-lab/KoE5" +query_embedding_model_name: "nlpai-lab/KoE5" +# query_embedding_model_name: "RAG/retrieval/embedding_model/query_encoder" +# passage_embedding_model_name: "RAG/retrieval/embedding_model/passage_encoder" +llm_model_source: "openai" +llm_model_name: "gpt-4o-mini" +chat_template: | + ์ฃผ์–ด์ง„ ๋ฌธ์„œ๋“ค์„ ๋ฐ”ํƒ•์œผ๋กœ ์งˆ๋ฌธ์— ๋‹ตํ•ด์ฃผ์„ธ์š”. + ์ฃผ์–ด์ง„ ๋ฌธ์„œ์ค‘ table์ด ์žˆ์„ ๊ฒฝ์šฐ ์ด๋ฅผ ํ•ด์„ํ•ด์„œ ๋Œ€๋‹ตํ•ด์ฃผ์„ธ์š”. + ์—ฌ๋Ÿฌ ๋ฌธ์„œ๋“ค์—์„œ ์ •๋‹ต์„ ์ฐพ์„ ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ ์ถœ์ฒ˜๋ฅผ ํฌํ•จํ•ด ์—ฌ๋Ÿฌ ์ถœ์ฒ˜์—์„œ ๋‹ต์„ ์ฐพ์•„ ์ถœ๋ ฅํ•ด์ฃผ์„ธ์š”. + ๋งŒ์•ฝ ์ฃผ์–ด์ง„ ๋ฌธ์„œ๋“ค ์ „์ฒด์—์„œ ๋‹ต์„ ์ฐพ์„์ˆ˜ ์—†๋Š”๊ฒฝ์šฐ ๋‹ตํ•  ์ˆ˜ ์—†๋‹ค๊ณ  ๋Œ€๋‹ตํ•ด์ฃผ์„ธ์š”. + {docs} + +chatting_template: | + ๋‹น์‹ ์€ ๊ธˆ์œต chatbot์ž…๋‹ˆ๋‹ค. + ์ฃผ์–ด์ง„ ๋ฌธ์„œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์งˆ๋ฌธ์— ๋‹ตํ•ด์ฃผ์„ธ์š”. + ์ฃผ์–ด์ง„ ๋ฌธ์„œ๊ฐ€ table์ผ ๊ฒฝ์šฐ ์ด๋ฅผ ํ•ด์„ํ•ด์„œ ๋Œ€๋‹ตํ•ด์ฃผ์„ธ์š”. + ์—ฌ๋Ÿฌ ๋ฌธ์„œ์—์„œ ์ •๋‹ต์ด ๋‚˜์˜ค๋Š” ๊ฒฝ์šฐ ์ถœ์ฒ˜๋ฅผ ํฌํ•จํ•ด ์—ฌ๋Ÿฌ ์ถœ์ฒ˜์—์„œ ๋‹ต์„ ์ฐพ์•„ ์ถœ๋ ฅํ•ด์ฃผ์„ธ์š”. + ๋งŒ์•ฝ ์ฃผ์–ด์ง„ ๋ฌธ์„œ์—์„œ ๋‹ต์„ ์ฐพ์„์ˆ˜ ์—†๋Š”๊ฒฝ์šฐ ๋‹ตํ•  ์ˆ˜ ์—†๋‹ค๊ณ  ๋Œ€๋‹ตํ•ด์ฃผ์„ธ์š”. + ๋˜ ์ด์ „ ๋Œ€ํ™” ๊ธฐ๋ก์„ ์ฐธ๊ณ ํ•ด์„œ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”. + ์ฃผ์–ด์ง„ ๋ฌธ์„œ: {docs} +batch_size: 16 +g_eval: True +mode: retrieve \ No newline at end of file diff --git a/app/RAG/configs/inference/generator.yaml b/app/RAG/configs/inference/generator.yaml new file mode 100644 index 0000000..33e2374 --- /dev/null +++ b/app/RAG/configs/inference/generator.yaml @@ -0,0 +1,10 @@ +defaults: [] +seed: 42 +llm_model_source: "openai" +llm_model_name: "gpt-3.5-turbo" +retriever_type: "dense" +embedding_model_source: "huggingface" +embedding_model_name: "BAAI/bge-m3" +chat_template: | + Answer the user's question using only the provided information below: + {docs} diff --git a/app/RAG/configs/ret_eval/ret_test.yaml b/app/RAG/configs/ret_eval/ret_test.yaml new file mode 100644 index 0000000..6fb8576 --- /dev/null +++ b/app/RAG/configs/ret_eval/ret_test.yaml @@ -0,0 +1,5 @@ +defaults: [] +seed: 42 +retriever_type: "dense" +embedding_model_source: "huggingface" +embedding_model_name: "nlpai-lab/KURE-v1" diff --git a/app/RAG/configs/ret_finetune/ret_finetune.yaml b/app/RAG/configs/ret_finetune/ret_finetune.yaml new file mode 100644 index 0000000..926b31d --- /dev/null +++ b/app/RAG/configs/ret_finetune/ret_finetune.yaml @@ -0,0 +1,35 @@ +model: + model_name_or_path: 'klue/roberta-large' + dense_model_name_or_path: 'klue/roberta-large' + config_name: null + tokenizer_name: null + + +data: + data_path: '/data/ephemeral/data' + overwrite_cache: false + preprocessing_num_workers: null + eval_retrieval: true + top_k_retrieval: 10 + use_faiss: false + num_neg: 2 + +train: + output_dir: './models/train_dataset' + do_train: true + do_eval: true + overwrite_output_dir: true + report_to: 'wandb' + per_device_train_batch_size: 3 + per_device_eval_batch_size: 2 + logging_strategy: 'steps' + logging_steps: 50 + evaluation_strategy: 'epoch' + save_strategy: 'epoch' + save_total_limit: 2 + num_train_epochs: 1 + warmup_steps: 300 + seed: 42 + dataloader_num_workers: 4 + logging_first_step: true + \ No newline at end of file diff --git a/app/RAG/data/__init__.py b/app/RAG/data/__init__.py new file mode 100644 index 0000000..8be8768 --- /dev/null +++ b/app/RAG/data/__init__.py @@ -0,0 +1,22 @@ +def get_docs(cfg): + if cfg.dataset == "ODQA": + import json + + from langchain.docstore.document import Document + + with open(cfg.datapath, "r", encoding="utf-8") as f: + data = json.load(f) + + documents = [ + Document( + page_content=content.get("text", ""), + metadata={ + "title": content.get("title", "No Title"), + "document_id": doc_id, + "source": content.get("corpus_source", "Unknown Source"), + }, + ) + for doc_id, content in data.items() + ] + return documents + # elif dataset=="ours": ์ถ”ํ›„ ์ž‘์„ฑ diff --git a/app/RAG/generator/ClovaStudioExcecutor.py b/app/RAG/generator/ClovaStudioExcecutor.py new file mode 100644 index 0000000..6e36ffa --- /dev/null +++ b/app/RAG/generator/ClovaStudioExcecutor.py @@ -0,0 +1,45 @@ +import http +import json +import os +from http import HTTPStatus + +import backoff +import dotenv + +dotenv.load_dotenv() + + +class RateLimitException(Exception): + pass + + +class ClovaStudioExecutor: + def __init__(self, host="https://clovastudio.stream.ntruss.com/serviceapp/v1/chat-completions/HCX-003"): + self.host = host + self.api_key = os.getenv("NCP_CLOVASTUDIO_API_KEY") + self.request_id = os.getenv("NCP_CLOVASTUDIO_REQUEST_ID") + + def _send_request(self, completion_request, endpoint): + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + "X-NCP-CLOVASTUDIO-REQUEST-ID": self.request_id, + "Accept": "text/event-stream", + } + conn = http.client.HTTPSConnection(self.host) + conn.request("POST", endpoint, json.dumps(completion_request), headers) + response = conn.getresponse() + status = response.status + result = json.loads(response.read().decode(encoding="utf-8")) + conn.close() + return result, status + + @backoff.on_exception(backoff.expo, RateLimitException, max_tries=5, max_time=120, base=10) + def execute(self, completion_request, endpoint): + res, status = self._send_request(completion_request, endpoint) + if status == HTTPStatus.OK: + return res, status + elif status == HTTPStatus.TOO_MANY_REQUESTS: + raise RateLimitException + else: + raise Exception(f"API Error: {res}, {status}") diff --git a/app/RAG/generator/__init__.py b/app/RAG/generator/__init__.py new file mode 100644 index 0000000..5efde68 --- /dev/null +++ b/app/RAG/generator/__init__.py @@ -0,0 +1,33 @@ +from typing import Optional + +import os + +import dotenv + +dotenv.load_dotenv() + + +def get_llm_api(cfg, temperature: Optional[int] = 0.5): + if cfg.llm_model_source == "openai": + from langchain.chat_models import ChatOpenAI + + return ChatOpenAI( + model=cfg.llm_model_name, + api_key=os.getenv("OPENAI_API_KEY"), + temperature=temperature, + ) # temperature=0.5, max_tokens=1024 + + elif cfg.llm_model_source == "naver": + from langchain_community.chat_models import ChatClovaX + + from .ClovaStudioExcecutor import ClovaStudioExecutor + + os.environ["NCP_CLOVASTUDIO_API_KEY"] = os.getenv("NCP_CLOVASTUDIO_API_KEY") + os.environ["NCP_CLOVASTUDIO_REQUEST_ID"] = os.getenv("NCP_CLOVASTUDIO_REQUEST_ID") + os.environ["NCP_APIGW_API_KEY"] = os.getenv("NCP_APIGW_API_KEY") + return ChatClovaX( + model="HCX-003", + ) + + elif cfg.llm_model_source == "huggingface": + return diff --git a/app/RAG/main.py b/app/RAG/main.py new file mode 100644 index 0000000..be22946 --- /dev/null +++ b/app/RAG/main.py @@ -0,0 +1,67 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +import hydra +from dotenv import load_dotenv +from omegaconf import DictConfig + +# from source.finetune_ret import finetune +# from source.generate import generate +# from source.retrieve import retrieve +from utils.vector_store import VectorStore + + +@hydra.main(version_base="1.3", config_path="configs", config_name="config") +def main(cfg: DictConfig): + print(sys.path) + load_dotenv() + cfg.openai_key = os.getenv("OPENAI_API_KEY") + + if cfg.mode == "retrieve": + print("test retrieval") + # retrieve(cfg) + + elif cfg.mode == "generate": + print("test inference") + # generate(cfg) + + elif cfg.mode == "update_vectordb": + print("๋ฒกํ„ฐ DB ์—…๋ฐ์ดํŠธ ์‹œ์ž‘") + + # ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ • + vector_db_dir = "vector_db" + old_data_dir = "old_data" + + if not os.path.exists(vector_db_dir): + os.makedirs(vector_db_dir) + if not os.path.exists(old_data_dir): + os.makedirs(old_data_dir) + + # JSON ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ • + text_json_path = "../../PDF_OCR/new_data/All_data/data.json" + table_json_path = "../../PDF_OCR/new_data/All_data/data.json" + + # ํŒŒ์ผ์ด ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธ + if not (os.path.exists(text_json_path) and os.path.exists(table_json_path)): + print("์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค.") + return + + try: + # ๋ฒกํ„ฐ ์Šคํ† ์–ด ์ดˆ๊ธฐํ™” ๋ฐ ์—…๋ฐ์ดํŠธ + vector_store = VectorStore(cfg=cfg, persist_directory=vector_db_dir) + #vector_store.update_company_vector_stores(text_json_path, table_json_path) + vector_store.update_all_vector_stores(text_json_path, table_json_path) + # ์ฒ˜๋ฆฌ๋œ ํŒŒ์ผ ์ด๋™ + vector_store.move_to_old_data( + [text_json_path, table_json_path], old_data_dir="../../PDF_OCR/old_data", user_name="All_data" + ) + print("๋ฒกํ„ฐ DB ์—…๋ฐ์ดํŠธ ์™„๋ฃŒ") + + except Exception as e: + print(f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}") + + +if __name__ == "__main__": + main() diff --git a/app/RAG/requirements.txt b/app/RAG/requirements.txt new file mode 100644 index 0000000..d3d07f9 --- /dev/null +++ b/app/RAG/requirements.txt @@ -0,0 +1,15 @@ +langchain==0.3.14 +langchain-community==0.3.14 +langchain-huggingface==0.1.2 +langchain-openai==0.3.2 +langchain-text-splitters==0.3.5 +langchain-unstructured==0.1.6 +omegaconf==2.3.0 +openai==1.59.9 +rank-bm25==0.2.2 +hydra-core==1.3.2 +datasets==3.2.0 +deepeval==2.2.6 +isort==5.13.2 +black==24.8.0 +flake8==7.1.1 \ No newline at end of file diff --git a/app/RAG/retrieval/__init__.py b/app/RAG/retrieval/__init__.py new file mode 100644 index 0000000..1351b94 --- /dev/null +++ b/app/RAG/retrieval/__init__.py @@ -0,0 +1,17 @@ +from .bm25_retrieval import BM25Retrieval +from .chroma_retrieval import ChromaRetrieval +from .dense_retrieval import DenseRetrieval +from .ensemble_retrieval import EnsembleRetrieval + +__all__ = ["DenseRetrieval", "BM25Retrieval", "EnsembleRetrieval"] + + +def get_retriever(cfg): + if cfg.retriever_type == "dense": + return DenseRetrieval(cfg).retriever + elif cfg.retriever_type == "bm25": + return BM25Retrieval(cfg).retriever + elif cfg.retriever_type == "ensemble": + return EnsembleRetrieval(retrievers=[DenseRetrieval(cfg), BM25Retrieval(cfg)], weights=[0.7, 0.3]).retriever + else: + raise ValueError(f"Unknown retriever type: {cfg.retriever_type}") diff --git a/app/RAG/retrieval/base.py b/app/RAG/retrieval/base.py new file mode 100644 index 0000000..e201755 --- /dev/null +++ b/app/RAG/retrieval/base.py @@ -0,0 +1,11 @@ +from typing import List + +from abc import ABC, abstractmethod + +from langchain.docstore.document import Document + + +class BaseRetriever(ABC): + @abstractmethod + def get_relevant_documents(self, query: str, k: int = 50) -> List[Document]: + pass diff --git a/app/RAG/retrieval/bm25_retrieval.py b/app/RAG/retrieval/bm25_retrieval.py new file mode 100644 index 0000000..698c428 --- /dev/null +++ b/app/RAG/retrieval/bm25_retrieval.py @@ -0,0 +1,38 @@ +from typing import List + +import os +import pickle + +import numpy as np +from data import get_docs +from langchain.docstore.document import Document +from rank_bm25 import BM25Okapi +from retrieval.base import BaseRetriever +from retrieval.reranking import get_reranker_model + + +class BM25Retrieval(BaseRetriever): + def __init__(self, cfg): + self.pickle_path = cfg.vector_store_path + self.documents = get_docs(cfg) + self.bm25_index = self._load_or_create_bm25() + + def _load_or_create_bm25(self): + if os.path.isfile(self.pickle_path): + with open(self.pickle_path, "rb") as f: + bm25_index = pickle.load(f) + else: + doc_texts = [doc.page_content for doc in self.documents] + tokenized_docs = [text.split() for text in doc_texts] + bm25_index = BM25Okapi(tokenized_docs) + with open(self.pickle_path, "wb") as f: + pickle.dump(bm25_index, f) + return bm25_index + + def get_relevant_documents(self, query: str, k: int = 5) -> List[Document]: + tokenized_query = query.split() + scores = self.bm25_index.get_scores(tokenized_query) + + top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k] + + return [self.documents[i] for i in top_indices] diff --git a/app/RAG/retrieval/chroma_retrieval.py b/app/RAG/retrieval/chroma_retrieval.py new file mode 100644 index 0000000..f104796 --- /dev/null +++ b/app/RAG/retrieval/chroma_retrieval.py @@ -0,0 +1,181 @@ +from typing import List, Optional + +import os +import time +from concurrent.futures import ThreadPoolExecutor +from functools import lru_cache + +import numpy as np +from langchain.docstore.document import Document +from langchain.embeddings import HuggingFaceEmbeddings +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain.vectorstores import Chroma +from langchain_community.cross_encoders import HuggingFaceCrossEncoder +from loguru import logger +from retrieval.base import BaseRetriever +from utils.query_rewriter import QueryRewriter + + +class ChromaRetrieval(BaseRetriever): + def __init__(self, cfg): + self.base_path = "./RAG/vector_db" + self.embedding_model = HuggingFaceEmbeddings( + model_name=cfg.query_embedding_model_name, + model_kwargs={"device": "cuda"}, + encode_kwargs={"normalize_embeddings": True, "batch_size": 32}, # ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ ํฌ๊ธฐ ์„ค์ • + ) + self.query_rewriter = QueryRewriter() + self.db_cache = {} + self.k = cfg.retrieval.get("top_k", 5) + self.use_mmr = cfg.retrieval.get("use_mmr", True) # MMR ์‚ฌ์šฉ ์—ฌ๋ถ€ + self.lambda_mult = cfg.retrieval.get("lambda_mult", 0.5) # MMR ๋‹ค์–‘์„ฑ ๊ฐ€์ค‘์น˜ + self.executor = ThreadPoolExecutor(max_workers=4) # ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์Šค๋ ˆ๋“œ ํ’€ + self.reranker = HuggingFaceCrossEncoder(model_name=cfg.retrieval.reranker_model_name) + self.compressor = CrossEncoderReranker(model=self.reranker, top_n=15) + + @lru_cache(maxsize=1000) + def _get_db(self, company: Optional[str] = None) -> Chroma: + """ํŠน์ • ํšŒ์‚ฌ ๋˜๋Š” ์ „์ฒด ๋ฐ์ดํ„ฐ์˜ ChromaDB ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜ (์บ์‹ฑ ์ ์šฉ)""" + db_path = os.path.join(self.base_path, company if company else "All_data") + + if db_path not in self.db_cache: + self.db_cache[db_path] = Chroma(persist_directory=db_path, embedding_function=self.embedding_model) + + return self.db_cache[db_path] + + def _search_with_mmr(self, db: Chroma, query: str, k: int, company: str) -> List[Document]: + """MMR์„ ์‚ฌ์šฉํ•œ ๋‹ค์–‘์„ฑ ์žˆ๋Š” ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰""" + if company and company.lower() != "none": + return db.max_marginal_relevance_search(query, k=k, filter={"company": company}) + else: + return db.max_marginal_relevance_search(query, k=k) + + def _search_with_similarity(self, db: Chroma, query: str, k: int, company: str) -> List[Document]: + """ + ๊ธฐ๋ณธ ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰ + company๊ฐ€ none์ด๋ฉด ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๊ฒ€์ƒ‰ + """ + logger.info(f"Performing similarity search with query: {query}, company: {company}") + if company and company.lower() != "none": + logger.info(f"Applying company filter: {company}") + return db.similarity_search(query, k=k, filter={"company": company}) + logger.info("No company filter applied, searching entire database") + return db.similarity_search(query, k=k) + + def get_relevant_documents_without_query_rewritten(self, query: str, k: int = None) -> List[Document]: + start_time = time.time() + all_docs = [] + query_text, company = self.query_rewriter.extract_company(query) + if not isinstance(query_text, list): + query_text = [query_text] + + db = self._get_db("All_data") + retriever = db.as_retriever() + + compression_retriever = ContextualCompressionRetriever( + base_compressor=self.compressor, base_retriever=retriever + ) + + def search_for_query(q, company): + if company: + return compression_retriever.get_relevant_documents(q, k=k, filter={"company": company}) + else: + return compression_retriever.get_relevant_documents(q, k=k) + + futures = [self.executor.submit(search_for_query, q, company) for q in query_text] + for future in futures: + all_docs.extend(future.result()) + logger.info(f"Retrieval processed without query rewritten in {time.time() - start_time:.2f} seconds") + return all_docs + + def get_relevant_documents_with_query_rewritten(self, query: str, k: int = None) -> List[Document]: + if k is None: + k = self.k + + all_docs = [] + + # ์ฟผ๋ฆฌ ๋ฆฌ๋ผ์ดํ„ฐ๋ฅผ ํ†ตํ•ด ์ฟผ๋ฆฌ ์ˆ˜์ • + rewritten_query = self.query_rewriter.rewrite_query(query) + print(rewritten_query) + # OUTPUT: ๋ถ€๋ถ„ ์ถ”์ถœ + clean_query = rewritten_query.split("OUTPUT:")[-1].strip() + start_time = time.time() + # None์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ ์šฐ์„ ์ „์ฒด์—์„œ ๊ฒ€์ƒ‰. + if clean_query == "None": + retrieval_time = time.time() - start_time + ret = self._search_with_similarity(self._get_db("All_data"), query, k, None) + logger.info(f"Retrieval processed in {retrieval_time:.2f} seconds") + return ret + + # ์—ฌ๋Ÿฌ ํšŒ์‚ฌ์— ๋Œ€ํ•œ ์ฟผ๋ฆฌ์ธ ๊ฒฝ์šฐ ํŒŒ์ดํ”„(|)๋กœ ๋ถ„๋ฆฌ + queries = clean_query.split("|") + logger.info(f"Parsed queries: {queries}") + + if len(queries) == 1: + # ๋‹จ์ผ ์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ + if queries[0].strip() == "None": + retrieval_time = time.time() - start_time + ret = self._search_with_similarity(self._get_db("All_data"), query, k, None) + logger.info(f"Retrieval processed in {retrieval_time:.2f} seconds") + return ret + + query_text, company = self.query_rewriter.extract_company(queries[0]) + if not isinstance(query_text, list): + query_text = [query_text] + + db = self._get_db("All_data") + retriever = db.as_retriever() + + compression_retriever = ContextualCompressionRetriever( + base_compressor=self.compressor, base_retriever=retriever + ) + + def search_for_query(q, company): + if company: + logger.info(f"Applying company filter: {company}") + return compression_retriever.get_relevant_documents(q, k=k, filter={"company": company}) + else: + return compression_retriever.get_relevant_documents(q, k=k) + + futures = [self.executor.submit(search_for_query, q, company) for q in query_text] + for future in futures: + all_docs.extend(future.result()) + else: + # ์—ฌ๋Ÿฌ ์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ + db = self._get_db("All_data") + retriever = db.as_retriever() + compression_retriever = ContextualCompressionRetriever( + base_compressor=self.compressor, base_retriever=retriever + ) + + def search_for_query(q, k_per_query, company): + if company: + logger.info(f"Applying company filter: {company}") + return compression_retriever.get_relevant_documents(q, k=k_per_query, filter={"company": company}) + else: + return compression_retriever.get_relevant_documents(q, k=k_per_query) + + for query_part in queries: + if query_part.strip() == "None": + continue + + query_text, company = self.query_rewriter.extract_company(query_part) + if not isinstance(query_text, list): + query_text = [query_text] + + k_per_query = max(1, k // len(queries)) + future = self.executor.submit(search_for_query, query_text[0], k_per_query, company) + # future = self.executor.submit(search_func, db, query_text[0], k_per_query, company) + all_docs.extend(future.result()) + + processing_time = time.time() - start_time + logger.info(f"Retrieval processed in {processing_time:.2f} seconds") + return all_docs + + def get_relevant_documents(self, query: str, k: int = None) -> List[Document]: + """ + BaseRetriever์˜ ์ถ”์ƒ ๋ฉ”์„œ๋“œ ๊ตฌํ˜„ + ๊ธฐ๋ณธ์ ์œผ๋กœ query rewritten์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰ + """ + return self.get_relevant_documents_with_query_rewritten(query, k) diff --git a/app/RAG/retrieval/dense_retrieval.py b/app/RAG/retrieval/dense_retrieval.py new file mode 100644 index 0000000..d064d94 --- /dev/null +++ b/app/RAG/retrieval/dense_retrieval.py @@ -0,0 +1,43 @@ +# retrieval/dense_retriever.py +from typing import List + +import os + +from data import get_docs +from langchain.docstore.document import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import FAISS +from retrieval.base import BaseRetriever +from retrieval.embedding_model import get_embedding_model +from retrieval.reranking import get_reranker_model + + +class DenseRetrieval(BaseRetriever): + def __init__(self, cfg): + self.vector_store_path = cfg.vector_store_path + self.chunk_size = cfg.chunk_size + self.documents = get_docs(cfg) + self.chuck_overlap = cfg.chunk_overlap + self.embedding_model = get_embedding_model(cfg) + self.vector_store = self._load_or_create_vector_store() + self.retriever = self.vector_store.as_retriever() + if cfg.rerank: + self.retriever = get_reranker_model(cfg, self.retriever) + + def _load_or_create_vector_store(self) -> FAISS: + if os.path.exists(self.vector_store_path): + vector_store = FAISS.load_local( + self.vector_store_path, self.embedding_model, allow_dangerous_deserialization=True + ) + return vector_store + else: + text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chuck_overlap) + split_docs = text_splitter.split_documents(self.documents) + # cossine similarity + vector_store = FAISS.from_documents(split_docs, self.embedding_model, metric="cosine") + os.makedirs(self.vector_store_path, exist_ok=True) + vector_store.save_local(self.vector_store_path) + return vector_store + + def get_relevant_documents(self, query: str, k: int = 5) -> List[Document]: + return self.retriever.get_relevant_documents(query, k=k) diff --git a/app/RAG/retrieval/embedding_model/__init__.py b/app/RAG/retrieval/embedding_model/__init__.py new file mode 100644 index 0000000..b91e8af --- /dev/null +++ b/app/RAG/retrieval/embedding_model/__init__.py @@ -0,0 +1,17 @@ +def get_embedding_model(cfg): + if cfg.embedding_model_source == "huggingface": + from langchain_community.embeddings import HuggingFaceEmbeddings + + embedding_model = HuggingFaceEmbeddings( + model_name=cfg.embedding_model_name, + model_kwargs={"device": "cuda", "trust_remote_code": True}, + encode_kwargs={"batch_size": cfg.batch_size}, # sentence_transformer ๊ธฐ์ค€ 32์ด๊ฐ€ ๊ธฐ๋ณธ๊ฐ’ + ) + return embedding_model + + elif cfg.embedding_model_sourcee == "openai": + from langchain.embeddings import OpenAIEmbeddings + + return OpenAIEmbeddings(openai_api_key=cfg.openai_key, model=cfg.embedding_model_name) + + # elif cfg.embedding_model_source=="naver": diff --git a/app/RAG/retrieval/embedding_model/fine_tuning.py b/app/RAG/retrieval/embedding_model/fine_tuning.py new file mode 100644 index 0000000..dbfc1e7 --- /dev/null +++ b/app/RAG/retrieval/embedding_model/fine_tuning.py @@ -0,0 +1,239 @@ +from typing import Dict, List, Tuple + +import gc +import os + +import pandas as pd +import torch +import torch.nn.functional as F +import wandb +from sentence_transformers import SentenceTransformer +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader +from tqdm import tqdm +from torch.optim import AdamW +from transformers import get_scheduler + +# ํ™˜๊ฒฝ ์„ค์ • +WANDB_PROJECT = "retriever_embedding_model_fine-tuning" +MODEL_NAME = "nlpai-lab/KoE5" +EPOCHS = 10 +LR = 0.00007180661859592403 # USE_SWEEP=True๋ฉด ๋ฌด์‹œ +WARMUP_RATIO = 0.1 +BATCH_SIZE = 24 +ACCUMULATION_STEPS = 64 # USE_SWEEP=True๋ฉด ๋ฌด์‹œ +TEMPERATURE = 0.04184381288580703 # USE_SWEEP=True๋ฉด ๋ฌด์‹œ +SAVE_INTERVAL = 3 # ๋ช‡ epoch๋งˆ๋‹ค ์ €์žฅํ• ์ง€ (Sweep ์‚ฌ์šฉ ์•ˆํ•  ๋•Œ๋งŒ) +EARLY_STOPPING_PATIENCE = 3 +USE_SWEEP = False # Sweep ์‚ฌ์šฉ ์—ฌ๋ถ€ ์„ค์ • +COUNT = 30 +CSV_PATH = "/data/ephemeral/home/data/fine-tuning_data.csv" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +if USE_SWEEP: + # WandB Sweep ์„ค์ • + sweep_config = { + "method": "bayes", + "metric": {"name": "eval_loss", "goal": "minimize"}, + "parameters": { + "learning_rate": {"distribution": "log_uniform_values", "min": 1e-7, "max": 5e-4}, + "temperature": {"distribution": "uniform", "min": 0.01, "max": 0.1999}, + "accumulation_steps": {"values": [4, 8, 16, 32, 64]}, + }, + } + sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT) + +df = pd.read_csv(CSV_PATH) + +# ๋ฐ์ดํ„ฐ Train / Eval ๋ถ„ํ•  +train_samples, eval_samples = train_test_split(list(zip(df["question"], df["context"])), test_size=0.2, random_state=42) + + +# DataLoader ์„ค์ • +def dual_collate(batch: List[Tuple[str, str]]) -> Dict[str, List[str]]: + """ + ๋ฐ์ดํ„ฐ Collate ํ•จ์ˆ˜: ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ๋ฅผ Query์™€ Passage๋กœ ๋ถ„๋ฆฌํ•˜์—ฌ ๋ฐ˜ํ™˜ + + Args: + batch (List[Tuple[str, str]]): (query, passage) ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ + + Returns: + Dict[str, List[str]]: {'queries': Query ๋ฆฌ์ŠคํŠธ, 'passages': Passage ๋ฆฌ์ŠคํŠธ} + """ + queries, passages = zip(*batch) + return {"queries": [q for q in queries], "passages": [p for p in passages]} + + +train_dataloader = DataLoader(train_samples, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dual_collate) +eval_dataloader = DataLoader(eval_samples, batch_size=BATCH_SIZE, shuffle=False, collate_fn=dual_collate) + + +def encode_with_grad(model: SentenceTransformer, texts: List[str]) -> torch.Tensor: + """ + ์ฃผ์–ด์ง„ ๋ฌธ์žฅ์„ SentenceTransformer ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž„๋ฒ ๋”ฉ ๋ณ€ํ™˜ + + Args: + model (SentenceTransformer): ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•  ๋ชจ๋ธ + texts (List[str]): ์ž…๋ ฅ ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ + + Returns: + torch.Tensor: ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ + """ + features = model.tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=120) + features = {k: v.to(device) for k, v in features.items()} # GPU๋กœ ์ด๋™ + outputs = model.forward(features) + return outputs["sentence_embedding"] + + +def multiple_negatives_ranking_loss( + query_embeds: torch.Tensor, passage_embeds: torch.Tensor, temperature: float +) -> torch.Tensor: + """ + Multiple Negatives Ranking Loss ๊ณ„์‚ฐ + + Args: + query_embeds (torch.Tensor): Query ์ž„๋ฒ ๋”ฉ ํ…์„œ + passage_embeds (torch.Tensor): Passage ์ž„๋ฒ ๋”ฉ ํ…์„œ + temperature (float): Softmax ์Šค์ผ€์ผ๋ง ๊ฐ’ + + Returns: + torch.Tensor: ๊ณ„์‚ฐ๋œ Loss ๊ฐ’ + """ + scores = torch.matmul(query_embeds, passage_embeds.T) / temperature + log_probs = F.log_softmax(scores, dim=1) + labels = torch.arange(scores.shape[0]).to(device) + return F.nll_loss(log_probs, labels) + + +def evaluate( + query_encoder: SentenceTransformer, passage_encoder: SentenceTransformer, dataloader: DataLoader, temperature: float +) -> float: + """ + ๋ชจ๋ธ ํ‰๊ฐ€ ํ•จ์ˆ˜ + + Args: + query_encoder (SentenceTransformer): Query ์ธ์ฝ”๋” ๋ชจ๋ธ + passage_encoder (SentenceTransformer): Passage ์ธ์ฝ”๋” ๋ชจ๋ธ + dataloader (DataLoader): ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ ๋กœ๋” + temperature (float): Loss ๊ณ„์‚ฐ์— ์‚ฌ์šฉํ•  ์˜จ๋„ ๊ฐ’ + + Returns: + float: ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์˜ ํ‰๊ท  Loss ๊ฐ’ + """ + query_encoder.eval() + passage_encoder.eval() + total_loss = 0.0 + + with torch.no_grad(): + for batch in tqdm(dataloader, desc="Evaluating"): + query_embeds = encode_with_grad(query_encoder, batch["queries"]) + passage_embeds = encode_with_grad(passage_encoder, batch["passages"]) + loss = multiple_negatives_ranking_loss(query_embeds, passage_embeds, temperature) + total_loss += loss.item() + + avg_loss = total_loss / len(dataloader) + return avg_loss + + +def train() -> None: + """ + ๋ชจ๋ธ ํ•™์Šต ํ•จ์ˆ˜: WandB Sweep ๋˜๋Š” ์ˆ˜๋™ ์„ค์ •์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•™์Šต ์ˆ˜ํ–‰ + """ + lr = LR + tempurature = TEMPERATURE + accumulation_steps = ACCUMULATION_STEPS + + wandb.init( + project=WANDB_PROJECT, + config={ + "batch_size": BATCH_SIZE, + "epochs": EPOCHS, + "learning_rate": LR, + "temperature": TEMPERATURE, + "accumulation_steps": ACCUMULATION_STEPS, + }, + ) + + if USE_SWEEP: + print(f"ํ˜„์žฌ Sweep ์‹คํ–‰: {wandb.run.name} (ID: {wandb.run.id})") + sweep_run_number = len(list(wandb.Api().runs(WANDB_PROJECT))) + print(f"ํ˜„์žฌ Sweep ์‹คํ–‰ ํšŸ์ˆ˜: {sweep_run_number}") + lr = wandb.config.learning_rate + tempurature = wandb.config.temperature + accumulation_steps = wandb.config.accumulation_steps + + query_encoder = SentenceTransformer(MODEL_NAME).to(device) + passage_encoder = SentenceTransformer(MODEL_NAME).to(device) + optimizer = AdamW(list(query_encoder.parameters()) + list(passage_encoder.parameters()), lr=lr) + total_steps = len(train_dataloader) * EPOCHS // accumulation_steps + warmup_steps = int(WARMUP_RATIO * total_steps) + lr_scheduler = get_scheduler( + "linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps + ) + + best_eval_loss = float("inf") + patience_counter = 0 + early_stopping_patience = EARLY_STOPPING_PATIENCE + + for epoch in range(1, EPOCHS + 1): + total_loss = 0.0 + query_encoder.train() + passage_encoder.train() + optimizer.zero_grad() + + loop = tqdm(train_dataloader, desc=f"Epoch {epoch}") + + for step, batch in enumerate(loop): + query_embeds = encode_with_grad(query_encoder, batch["queries"]) + passage_embeds = encode_with_grad(passage_encoder, batch["passages"]) + loss = multiple_negatives_ranking_loss(query_embeds, passage_embeds, tempurature) + loss = loss / accumulation_steps + loss.backward() + + if (step + 1) % accumulation_steps == 0 or (step + 1 == len(train_dataloader)): + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + torch.cuda.empty_cache() + total_loss += loss.item() * accumulation_steps + loop.set_postfix(loss=loss.item() * accumulation_steps) # ์ถœ๋ ฅ์€ ์›๋ž˜ loss ๊ฐ’ + wandb.log({"step_loss": loss.item() * accumulation_steps, "learning_rate": lr_scheduler.get_last_lr()[0]}) + + avg_train_loss = total_loss / len(train_dataloader) + eval_loss = evaluate(query_encoder, passage_encoder, eval_dataloader, tempurature) + print(f"Epoch {epoch} | Train_Loss: {avg_train_loss:.4f} | Eval_Loss: {eval_loss:.4f}") + wandb.log({"epoch": epoch, "train_loss": avg_train_loss, "eval_loss": eval_loss}) + + if not USE_SWEEP and (epoch % SAVE_INTERVAL == 0 or epoch == EPOCHS): + save_dir = f"epoch{epoch}" + os.makedirs(save_dir, exist_ok=True) + query_encoder.save(f"{save_dir}/query_encoder") + passage_encoder.save(f"{save_dir}/passage_encoder") + print(f"๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ: {save_dir}/") + + if eval_loss < best_eval_loss: + best_eval_loss = eval_loss + patience_counter = 0 + else: + patience_counter += 1 + + if patience_counter >= early_stopping_patience: + print(f"Early stopping at epoch {epoch}") + break + + # ๋ชจ๋ธ ์‚ญ์ œ + del query_encoder + del passage_encoder + + # Python ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜ ์‹คํ–‰ + gc.collect() + + # CUDA ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + +if USE_SWEEP: + wandb.agent(sweep_id, train, count=COUNT) +else: + train() diff --git a/app/RAG/retrieval/ensemble_retrieval.py b/app/RAG/retrieval/ensemble_retrieval.py new file mode 100644 index 0000000..2e6fe87 --- /dev/null +++ b/app/RAG/retrieval/ensemble_retrieval.py @@ -0,0 +1,23 @@ +from typing import List + +import numpy as np +from langchain.docstore.document import Document +from retrieval.base import BaseRetriever + + +class EnsembleRetrieval(BaseRetriever): + def __init__(self, retrievers: List[BaseRetriever], weights: List[float] = None): + self.retrievers = retrievers + if weights is None: + self.weights = [1.0] * len(self.retrievers) + else: + self.weights = weights + + def get_relevant_documents(self, query: str, k: int = 5) -> List[Document]: + all_docs = [] + for retriever in self.retrievers: + docs = retriever.get_relevant_documents(query, k) + all_docs.extend(docs) + + unique_docs = list({doc.page_content: doc for doc in all_docs}.values()) + return unique_docs[:k] diff --git a/app/RAG/retrieval/reranking.py b/app/RAG/retrieval/reranking.py new file mode 100644 index 0000000..aabe878 --- /dev/null +++ b/app/RAG/retrieval/reranking.py @@ -0,0 +1,10 @@ +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain_community.cross_encoders import HuggingFaceCrossEncoder + + +def get_reranker_model(cfg, retriever): + model = HuggingFaceCrossEncoder(model_name=cfg.reranker_model_name) + compressor = CrossEncoderReranker(model=model, top_n=10) + compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) + return compression_retriever diff --git a/app/RAG/source/__init__.py b/app/RAG/source/__init__.py new file mode 100644 index 0000000..1b112a7 --- /dev/null +++ b/app/RAG/source/__init__.py @@ -0,0 +1,4 @@ +from source.generate import generate +from source.retrieve import retrieve + +__all__ = ["generate", "retrieve"] diff --git a/app/RAG/source/generate.py b/app/RAG/source/generate.py new file mode 100644 index 0000000..7207a49 --- /dev/null +++ b/app/RAG/source/generate.py @@ -0,0 +1,54 @@ +import pandas as pd +import tqdm +from generator import get_llm_api +from langchain.prompts import ChatPromptTemplate +from omegaconf import DictConfig +from openai import OpenAI +from retrieval import ChromaRetrieval +from utils.generator_evaluate import evaluate_batch +from utils.set_seed import set_seed + +client = OpenAI() + + +async def generate(cfg: DictConfig): + set_seed(cfg.seed) + all_results = [] + + # data + data = pd.read_csv("eval_data_path") + + # retrieval = get_retriever(cfg) + # retrieval - ChromaRetrieval ์‚ฌ์šฉ + retriever = ChromaRetrieval(cfg) + + # llm + system_template = cfg.chat_template + model = get_llm_api(cfg) + + data = pd.read_csv(cfg.eval_data_path) + + all_results = [] + for _, row in tqdm.tqdm(data.iterrows(), desc="Processing Queries"): + # dataset validation ์ˆ˜์ •ํ•„์š” + query_result = {"query": row["question"]} + + docs = retriever.get_relevant_documents(row["question"]) + query_result["retrieved_docs"] = docs + + tem = ChatPromptTemplate.from_messages([("system", system_template), ("user", row["question"])]) + + s = "" + for i in range(len(docs)): + s += docs[i].page_content + + prompt = tem.invoke({"docs": s}) + + answer = model.invoke(prompt) + + query_result["answer"] = answer + query_result["ground_truth"] = row["llm_text"] + + all_results.append(query_result) + + await evaluate_batch(all_results) diff --git a/app/RAG/source/retrieve.py b/app/RAG/source/retrieve.py new file mode 100644 index 0000000..c69e79f --- /dev/null +++ b/app/RAG/source/retrieve.py @@ -0,0 +1,19 @@ +# from langchain.smith import LangSmithSession +from omegaconf import DictConfig +from retrieval import get_retriever +from utils.ret_evaluate import ret_evaluate_acc, ret_evaluate_geval +from utils.set_seed import set_seed + + +def retrieve(cfg: DictConfig): + set_seed(cfg.seed) + + retriever = get_retriever(cfg) + + # dataset ํ™•์ •๋˜๋ฉด llm๊นŒ์ง€ ์—ฐ๊ฒฐ + ์‹คํ—˜ + # if cfg.mode == "inference": return retriever.get_relevant_documents(query, cfg.k) + + if cfg.g_eval: + ret_evaluate_geval(retriever, cfg) + else: + ret_evaluate_acc(retriever) diff --git a/app/RAG/test_retrieval.py b/app/RAG/test_retrieval.py new file mode 100644 index 0000000..57f4843 --- /dev/null +++ b/app/RAG/test_retrieval.py @@ -0,0 +1,55 @@ +from pathlib import Path + +import hydra +from loguru import logger +from omegaconf import DictConfig +from retrieval import BM25Retrieval, DenseRetrieval, EnsembleRetrieval + + +def test_retrievers(cfg: DictConfig): + """ + ๊ฐ retriever๋ฅผ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค. + """ + # Dense Retrieval ํ…Œ์ŠคํŠธ + logger.info("Testing Dense Retrieval...") + dense_retriever = DenseRetrieval(cfg) + dense_results = dense_retriever.get_relevant_documents(query="ํ…Œ์ŠคํŠธ ์ฟผ๋ฆฌ์ž…๋‹ˆ๋‹ค.", k=3) + logger.info(f"Dense Retrieval Results: {len(dense_results)} documents found") + for doc in dense_results: + logger.info(f"Score: {getattr(doc, 'score', 'N/A')}") + logger.info(f"Content: {doc.page_content[:100]}...") + logger.info("---") + + # BM25 Retrieval ํ…Œ์ŠคํŠธ + logger.info("\nTesting BM25 Retrieval...") + bm25_retriever = BM25Retrieval(cfg) + bm25_results = bm25_retriever.get_relevant_documents(query="ํ…Œ์ŠคํŠธ ์ฟผ๋ฆฌ์ž…๋‹ˆ๋‹ค.", k=3) + logger.info(f"BM25 Results: {len(bm25_results)} documents found") + for doc in bm25_results: + logger.info(f"Content: {doc.page_content[:100]}...") + logger.info("---") + + # Ensemble Retrieval ํ…Œ์ŠคํŠธ + logger.info("\nTesting Ensemble Retrieval...") + ensemble_retriever = EnsembleRetrieval(retrievers=[dense_retriever, bm25_retriever], weights=[0.7, 0.3]) + ensemble_results = ensemble_retriever.get_relevant_documents(query="ํ…Œ์ŠคํŠธ ์ฟผ๋ฆฌ์ž…๋‹ˆ๋‹ค.", k=3) + logger.info(f"Ensemble Results: {len(ensemble_results)} documents found") + for doc in ensemble_results: + logger.info(f"Content: {doc.page_content[:100]}...") + logger.info("---") + + +@hydra.main(version_base=None, config_path="configs", config_name="config") +def main(cfg: DictConfig): + """ + ๋ฉ”์ธ ํ•จ์ˆ˜ + """ + try: + test_retrievers(cfg) + except Exception as e: + logger.error(f"Error during testing: {str(e)}") + raise + + +if __name__ == "__main__": + main() diff --git a/app/RAG/utils/generator_evaluate.py b/app/RAG/utils/generator_evaluate.py new file mode 100644 index 0000000..4b56d0a --- /dev/null +++ b/app/RAG/utils/generator_evaluate.py @@ -0,0 +1,195 @@ +import asyncio +import json + +from deepeval.metrics import GEval +from deepeval.test_case import LLMTestCase, LLMTestCaseParams +from langsmith import Client, traceable + +client = Client() +model = "gpt-4o-mini" +Generation_criteria = [ + { + "name": "Relevance", + "description": "Is the final answer clearly relevant to the question and reflective of the userโ€™s intent?", + "weight": 5, + }, + { + "name": "FactualCorrectness", + "description": "Is the answer factually correct and free from unsupported or inaccurate information?", + "weight": 5, + }, + { + "name": "Completeness", + "description": ( + "Does the answer include all essential points " "required by the question and the ground_truth_answer?" + ), + "weight": 5, + }, + { + "name": "ClarityConciseness", + "description": "Is the answer clear and concise, avoiding unnecessary repetition or ambiguity?", + "weight": 5, + }, + { + "name": "LogicalStructure", + "description": "Is the answer logically structured, consistent with the context, and free of contradictions?", + "weight": 3, + }, + { + "name": "DetailwithoutExcessiveness", + "description": "Does the answer provide sufficient detail for the question without being excessive?", + "weight": 3, + }, + { + "name": "ProperCitation", + "description": ( + "Does the answer provide proper citations or " + "indications of the source when claims or data are referenced?" + ), + "weight": 2, + }, + { + "name": "Formatting", + "description": "Is the answer presented in a suitable format (list, table, short text, etc.) for the question?", + "weight": 1, + }, + { + "name": "ExtraInsights", + "description": ( + "Does the answer offer any helpful extra insights or context " + "that enrich the userโ€™s understanding (without deviating from factual correctness)?" + ), + "weight": 1, + }, +] +metric1 = GEval( + name=Generation_criteria[0]["name"], + criteria=Generation_criteria[0]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric2 = GEval( + name=Generation_criteria[1]["name"], + criteria=Generation_criteria[1]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric3 = GEval( + name=Generation_criteria[2]["name"], + criteria=Generation_criteria[2]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric4 = GEval( + name=Generation_criteria[3]["name"], + criteria=Generation_criteria[3]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric5 = GEval( + name=Generation_criteria[4]["name"], + criteria=Generation_criteria[4]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric6 = GEval( + name=Generation_criteria[5]["name"], + criteria=Generation_criteria[5]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric7 = GEval( + name=Generation_criteria[6]["name"], + criteria=Generation_criteria[6]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric8 = GEval( + name=Generation_criteria[7]["name"], + criteria=Generation_criteria[7]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) +metric9 = GEval( + name=Generation_criteria[8]["name"], + criteria=Generation_criteria[8]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, +) + + +async def get_metric_evaluations(test_case: LLMTestCaseParams) -> list: + return await asyncio.gather( + metric1.a_measure(test_case), + metric2.a_measure(test_case), + metric3.a_measure(test_case), + metric4.a_measure(test_case), + metric5.a_measure(test_case), + metric6.a_measure(test_case), + metric7.a_measure(test_case), + metric8.a_measure(test_case), + metric9.a_measure(test_case), + ) + + +async def evaluate_single_sample(question: str, answer: str, ground_truth: str) -> dict: + test_case = LLMTestCase(input=question, actual_output=answer, expected_output=ground_truth) + + eval_result = await get_metric_evaluations(test_case) + evaluation_result = { + "question": question, + "answer": answer, + "ground_truth": ground_truth, + } + + # deepeval์—์„œ๋Š” 0~10์—์„œ score ๋งค๊ธด ํ›„, 0~1๋กœ ๋งคํ•‘ํ•จ + final_score = 0 + for i in range(len(eval_result)): + final_score += eval_result[i][0] + evaluation_result[Generation_criteria[i]["name"]] = ( + str(round(eval_result[i] * Generation_criteria[i]["weight"], 1)) + "์  " + ) + + evaluation_result["final_score"] = final_score + """ + client.log( + run_type="evaluation", + name="G-Eval Batch Evaluation", + inputs={ + "question": evaluation_result["question"], + "answer": evaluation_result["answer"], + "ground_truth": evaluation_result["ground_truth"], + }, + outputs={ + "final_score": result["final_score"], + #criteria reasons + }, + ) + """ + + return evaluation_result + + +@traceable(run_type="G-eval") +async def evaluate_batch(samples: list) -> list: + results = [] + for item in samples: + res = await evaluate_single_sample( + question=item["question"], answer=item["answer"], ground_truth=item["ground_truth"] + ) + # asyncio.run(log_to_langsmith(res)) + results.append(res) + + with open("evaluation_results.json", "w", encoding="utf-8") as f: # ์ถ”ํ›„ ๊ฒฝ๋กœ ์ˆ˜์ • + json.dump(results, f, indent=2, ensure_ascii=False) + + return results diff --git a/app/RAG/utils/query_rewriter.py b/app/RAG/utils/query_rewriter.py new file mode 100644 index 0000000..e5044a8 --- /dev/null +++ b/app/RAG/utils/query_rewriter.py @@ -0,0 +1,123 @@ +from typing import Any, Dict, List, Optional, Tuple + +import os +import re +import time +import warnings +from pathlib import Path + +import hydra +import numpy as np +from generator import get_llm_api +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.chat_models import ChatClovaX +from langchain_core.documents import Document +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate, PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from loguru import logger +from rapidfuzz import process +from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity +from transformers import pipeline + +warnings.filterwarnings("ignore") + +""" +๊ธฐ๋ณธ์ ์œผ๋กœ gpt-4o-mini ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ฟผ๋ฆฌ๋ฅผ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค. +""" +query_rewriting_prompt = """ +๋‹น์‹ ์€ ์ฟผ๋ฆฌ๋ฅผ ์žฌ์ž‘์„ฑ ํ•ด์ฃผ๋Š” ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. +๋‹ค์Œ ์ฟผ๋ฆฌ๋ฅผ ์ˆ˜์ •ํ•˜์—ฌ ๋” ์ •ํ™•ํ•œ ๊ฒ€์ƒ‰์„ ์œ„ํ•ด ์กฐ๊ธˆ ๋” ๊ตฌ์ฒด์ ์œผ๋กœ ์ž‘์„ฑํ•˜๊ฑฐ๋‚˜ ๋ถ„๋ฆฌํ•ด ์ฃผ์„ธ์š”. +๋งŒ์•ฝ ์ฟผ๋ฆฌ์— ๋ฆฌ์ŠคํŠธ์— ์žˆ๋Š” ํšŒ์‚ฌ๋ช…๊ณผ ๊ฐ™์€ ํšŒ์‚ฌ๋ช…์ด ์žˆ์œผ๋ฉด ๋ฆฌ์ŠคํŠธ์˜ ์ด๋ฆ„์œผ๋กœ ์ˆ˜์ •ํ•ด์ฃผ์„ธ์š”. +๋งŒ์•ฝ ์—†๋‹ค๋ฉด None ์„ ๋ฐ˜ํ™˜ํ•ด์ฃผ์„ธ์š”. + +List: {list} + +์˜ˆ์‹œ) +INPUT: kakaobank ์ฃผ๊ฐ€ ์˜ˆ์ธก +OUTPUT: ์นด์นด์˜ค๋ฑ…ํฌ ์ฃผ๊ฐ€ ์˜ˆ์ธก + +INPUT: ์นด์นด์˜ค๋ฑ…ํฌ์™€ ๋„ค์ด๋ฒ„ ์ฃผ๊ฐ€ ์˜ˆ์ธก +OUTPUT: ์นด์นด์˜ค๋ฑ…ํฌ ์ฃผ๊ฐ€ ์˜ˆ์ธก|๋„ค์ด๋ฒ„ ์ฃผ๊ฐ€ ์˜ˆ์ธก + +INPUT: "์—†๋Š”ํšŒ์‚ฌ๋ช…"๊ณผ ๋„ค์ด๋ฒ„ ์ฃผ๊ฐ€ ์˜ˆ์ธก +OUTPUT: None|๋„ค์ด๋ฒ„ ์ฃผ๊ฐ€ ์˜ˆ์ธก + +INPUT: "์—†๋Š”ํšŒ์‚ฌ๋ช…"์˜ ์‹œ๊ฐ€์ด์•ก์€? +OUTPUT: None + +""" + +project_root = Path(__file__).parent.parent + + +class QueryRewriter: + def __init__(self): + + self.company_names = os.listdir(project_root / "vector_db") + self.parser = StrOutputParser() + self._load_config() + self.model = get_llm_api(self.cfg, temperature=0.4) + + def _load_config(self): + """Hydra ์„ค์ • ๋กœ๋“œ""" + + # ์ƒ๋Œ€ ๊ฒฝ๋กœ๋กœ config_path ์„ค์ • + with hydra.initialize(version_base=None, config_path="../configs"): + cfg = hydra.compose(config_name="config") + self.cfg = cfg + + def extract_company(self, query: str) -> Tuple[str, Optional[str]]: + """ + Args: + query: ์›๋ณธ ์ฟผ๋ฆฌ ๋ฌธ์ž์—ด + + Returns: + Tuple[str, Optional[str]]: (์ˆ˜์ •๋œ ์ฟผ๋ฆฌ, ํšŒ์‚ฌ๋ช…) ๋˜๋Š” (์›๋ณธ ์ฟผ๋ฆฌ, None) + """ + # query ๋Œ€๋ฌธ์ž๋กœ ๋ณ€๊ฒฝ + query = query.upper() + # ํšŒ์‚ฌ๋ช… ์ถ”์ถœ + print("Company names: ", self.company_names) + for company in self.company_names: + if company in query: + # ํšŒ์‚ฌ๋ช…์„ ์ฟผ๋ฆฌ์—์„œ ์ œ๊ฑฐํ•˜๊ณ  ๊ณต๋ฐฑ ์ •๋ฆฌ + cleaned_query = re.sub(company, "", query).strip() + print("Company extracted: ", company) + return cleaned_query, company + # fuzzy ํšŒ์‚ฌ๋ช… ์ถ”์ถœ + matches = process.extract(query, self.company_names, limit=1) + if matches and matches[0][1] >= 80: # 80% ์ด์ƒ์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ฐ€์ง„ ๊ฒฝ์šฐ์—๋งŒ ๋งค์นญ + company = matches[0][0] + # ํšŒ์‚ฌ๋ช…์„ ์ฟผ๋ฆฌ์—์„œ ์ œ๊ฑฐํ•˜๊ณ  ๊ณต๋ฐฑ ์ •๋ฆฌ + cleaned_query = re.sub(company, "", query).strip() + print("Company extracted: ", company) + return cleaned_query, company + # ner ํšŒ์‚ฌ๋ช… ์ถ”์ถœํ›„ ์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ํšŒ์‚ฌ๋ช… ์ถ”์ถœ + return query, None + + def rewrite_query(self, query: str) -> str: + """ + ์ฟผ๋ฆฌ๋ฅผ ์ˆ˜์ •ํ•˜์—ฌ ๋” ์ •ํ™•ํ•œ ๊ฒ€์ƒ‰์„ ์œ„ํ•ด ์กฐ๊ธˆ ๋” ๊ตฌ์ฒด์ ์œผ๋กœ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. + """ + start_time = time.time() + # prompt = PromptTemplate(template=query_rewriting_prompt, input_variables=["query", "list"]) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", query_rewriting_prompt), + ("user", "{query}"), + ] + ) + + chain = prompt | self.model | self.parser + # ํšŒ์‚ฌ๋ช… ๋ฆฌ์ŠคํŠธ๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ + company_list_str = ", ".join(f'"{company}"' for company in self.company_names) + list_str = f"[{company_list_str}]" + # ๋”•์…”๋„ˆ๋ฆฌ๋กœ ์ž…๋ ฅ๊ฐ’ ์ „๋‹ฌ + result = chain.invoke({"query": query, "list": list_str}) + # ๊ฒฐ๊ณผ๊ฐ€ ๋ฆฌ์ŠคํŠธ์ธ ๊ฒฝ์šฐ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ + processing_time = time.time() - start_time + logger.info(f"Rewrite Query processed in {processing_time:.2f} seconds") + return result diff --git a/app/RAG/utils/ret_evaluate.py b/app/RAG/utils/ret_evaluate.py new file mode 100644 index 0000000..796fb2c --- /dev/null +++ b/app/RAG/utils/ret_evaluate.py @@ -0,0 +1,186 @@ +import asyncio +import json + +import pandas as pd +from datasets import concatenate_datasets, load_from_disk +from deepeval.metrics import GEval +from deepeval.test_case import LLMTestCase, LLMTestCaseParams +from langsmith import traceable +from tqdm import tqdm + + +def ret_evaluate_acc(retriever): + dataset_dict = load_from_disk("/data/ephemeral/data/train_dataset") + dataset1 = dataset_dict["train"].select(range(1000)) + dataset2 = dataset_dict["validation"] + dataset_combined = concatenate_datasets([dataset1, dataset2]) + + top1_count = 0 + top10_count = 0 + top20_count = 0 + top30_count = 0 + top40_count = 0 + top50_count = 0 + + for i in tqdm(range(len(dataset_combined)), desc="retrieval eval"): + question = dataset_combined[i]["question"] + original_id = dataset_combined[i]["document_id"] + + topk_passages = retriever.get_relevant_documents(question, k=50) + + retrieved_id = [int(doc.metadata["document_id"]) for doc in topk_passages] + + if original_id == retrieved_id[0]: + top1_count += 1 + if original_id in retrieved_id[:10]: + top10_count += 1 + if original_id in retrieved_id[:20]: + top20_count += 1 + if original_id in retrieved_id[:30]: + top30_count += 1 + if original_id in retrieved_id[:40]: + top40_count += 1 + if original_id in retrieved_id[:50]: + top50_count += 1 + + print(f"Top 1 Score: {top1_count / (i+1) * 100:.2f}%") + print(f"Top 10 Score: {top10_count / (i+1) * 100:.2f}%") + print(f"Top 20 Score: {top20_count / (i+1) * 100:.2f}%") + print(f"Top 30 Score: {top30_count / (i+1) * 100:.2f}%") + print(f"Top 40 Score: {top40_count / (i+1) * 100:.2f}%") + print(f"Top 50 Score: {top50_count / (i+1) * 100:.2f}%") + + +def ret_evaluate_geval(retriever, cfg): + model = "gpt-4o-mini" + Generation_criteria = [ + { + "name": "Similarity", + "description": "Do any of the retrieved contexts show strong similarity to the Ground Truth?", + "weight": 5, + }, + { + "name": "Essentiality", + "description": ( + "Do the retrieved contexts collectively capture " "essential information from the Ground Truth?" + ), + "weight": 5, + }, + { + "name": "Coverage", + "description": "Do the retrieved contexts sufficiently address the userโ€™s question?", + "weight": 4, + }, + { + "name": "Relevance", + "description": "Are all retrieved contexts relevant to the Ground Truth or the userโ€™s query?", + "weight": 3, + }, + { + "name": "Conciseness", + "description": ( + "Does the combined length and number of retrieved contexts remain " + "reasonable without overwhelming the user with excessive or irrelevant details?" + ), + "weight": 3, + }, + ] + + metric1 = GEval( + name=Generation_criteria[0]["name"], + criteria=Generation_criteria[0]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, + ) + metric2 = GEval( + name=Generation_criteria[1]["name"], + criteria=Generation_criteria[1]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, + ) + metric3 = GEval( + name=Generation_criteria[2]["name"], + criteria=Generation_criteria[2]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, + ) + metric4 = GEval( + name=Generation_criteria[3]["name"], + criteria=Generation_criteria[3]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, + ) + metric5 = GEval( + name=Generation_criteria[4]["name"], + criteria=Generation_criteria[4]["description"], + evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], + model=model, + threshold=0.0, + ) + + async def get_metric_evaluations(test_case: LLMTestCaseParams) -> list: + return await asyncio.gather( + metric1.a_measure(test_case), # ๋น„๋™๊ธฐ ์ง€์› + metric2.a_measure(test_case), + metric3.a_measure(test_case), + metric4.a_measure(test_case), + metric5.a_measure(test_case), + ) + + async def evaluate_single_sample(question: str, docs: list, ground_truth: list) -> dict: + actual_output = ", ".join([f"๋ฌธ์„œ{i+1}: {doc}" for i, doc in enumerate(docs)]) + test_case = LLMTestCase(input=question, actual_output=actual_output, expected_output=ground_truth) + + eval_result = await get_metric_evaluations(test_case) + evaluation_result = { + "question": question, + "docs": docs, + "ground_truth": ground_truth, + } + + final_score = 0 + for i in range(len(eval_result)): + final_score += eval_result[ + i + ] # evaluate์œผ๋กœ ํ‰๊ฐ€ํ•˜๋ฉด ์ ์ˆ˜์— ๋Œ€ํ•œ reason๋„ ๋ฐ˜ํ™˜ํ•˜๋Š”๋ฐ ๊ทธ๋Ÿผ eval_step์„ ์ž…๋ ฅํ•ด์ค˜์•ผํ•จ + evaluation_result[Generation_criteria[i]["name"]] = ( + str(round(eval_result[i] * Generation_criteria[i]["weight"], 1)) + "์  " + ) + + evaluation_result["final_score"] = final_score + + return evaluation_result + + @traceable(run_type="G-eval") + async def evaluate_batch(samples: list) -> list: + results = [] + for item in samples: + res = await evaluate_single_sample( + question=item["question"], answer=item["docs"], ground_truth=item["ground_truth"] + ) + results.append(res) + + with open("ret_evaluation_results.json", "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + return results + + data = pd.read_csv(cfg.eval_data_path) + + samples = [] + + for _, row in data.iterrows(): + sample = { + "question": row["question"], + "docs": [], + "ground_truth": row["answer"], + } + sample["docs"] = retriever.get_relevant_documents(row["question"], k=cfg.tok_k) + + samples.append(sample) + + await evaluate_batch(samples) diff --git a/app/RAG/utils/set_seed.py b/app/RAG/utils/set_seed.py new file mode 100644 index 0000000..9633811 --- /dev/null +++ b/app/RAG/utils/set_seed.py @@ -0,0 +1,14 @@ +import random + +import numpy as np +import torch + + +def set_seed(random_seed): + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) # if use multi-GPU + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(random_seed) + random.seed(random_seed) diff --git a/app/RAG/utils/vector_store.py b/app/RAG/utils/vector_store.py new file mode 100644 index 0000000..ca56d52 --- /dev/null +++ b/app/RAG/utils/vector_store.py @@ -0,0 +1,199 @@ +from typing import Dict, List + +import json +import os +import shutil +import warnings + +from langchain.schema import Document +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import Chroma +from omegaconf import DictConfig +from tqdm import tqdm + +warnings.filterwarnings("ignore") + + +class VectorStore: + def __init__(self, cfg: DictConfig, persist_directory: str = "vector_db"): + """ + ๋ฒกํ„ฐ ์Šคํ† ์–ด ์ดˆ๊ธฐํ™” + Args: + cfg (DictConfig): ์„ค์ • ํŒŒ์ผ + persist_directory (str): ๋ฒกํ„ฐ DB๋ฅผ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ + """ + self.persist_directory = persist_directory + self.embeddings = HuggingFaceEmbeddings( + model_name=cfg.passage_embedding_model_name, + model_kwargs={"device": "cuda"}, + encode_kwargs={"normalize_embeddings": True}, + ) + + def load_json_data(self, json_path: str) -> List[Dict]: + """JSON ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.""" + with open(json_path, "r", encoding="utf-8") as f: + return json.load(f) + + def create_documents(self, data: List[Dict]) -> List[Document]: + """ + ๋ฐ์ดํ„ฐ๋ฅผ Document ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + """ + documents = [] + for item in data: + # ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ + metadata = { + "company": item["company"], + "securities": item["securities"], + "category": item["category"], + "page": item["page"], + "date": item["date"], + "path": item["path"], + } + if isinstance(item["page"], int): + page_info = "page_" + str(item["page"]) + else: + page_info = item["page"] + if item["category"] == "figure" and item["title"] != None: + doc = Document( + page_content="<" + + item["company"] + + ">" + + item["title"] + + " " + + item["description"] + + "< ์ถœ์ฒ˜ : " + + item["securities"] + + " " + + page_info + + ">" + + "<๊ธฐ์ค€๋‚ ์งœ : " + + item["date"] + + ">", + metadata=metadata, + ) + else: + # Document ๊ฐ์ฒด ์ƒ์„ฑ + doc = Document( + page_content="<" + + item["company"] + + ">" + + item["description"] + + "< ์ถœ์ฒ˜ : " + + item["securities"] + + " " + + page_info + + ">" + + "<๊ธฐ์ค€๋‚ ์งœ : " + + item["date"] + + ">", + metadata=metadata, + ) + documents.append(doc) + return documents + + def update_company_vector_stores(self, text_json_path: str, table_json_path: str): + """ + ํšŒ์‚ฌ๋ณ„๋กœ ๋ฒกํ„ฐ DB๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค. + """ + # JSON ๋ฐ์ดํ„ฐ ๋กœ๋“œ + text_data = self.load_json_data(text_json_path) + table_data = self.load_json_data(table_json_path) + + # ๋ชจ๋“  ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ + all_data = text_data + table_data + + # ํšŒ์‚ฌ๋ณ„๋กœ ๋ฐ์ดํ„ฐ ๊ทธ๋ฃนํ™” + company_data = {} + for item in all_data: + company = item["company"] + if company not in company_data: + company_data[company] = [] + company_data[company].append(item) + + # ํšŒ์‚ฌ๋ณ„๋กœ ๋ฒกํ„ฐ DB ์—…๋ฐ์ดํŠธ + for company, data in tqdm(company_data.items(), desc="ํšŒ์‚ฌ๋ณ„ ๋ฒกํ„ฐ DB ์—…๋ฐ์ดํŠธ ์ค‘"): + company_persist_dir = os.path.join(self.persist_directory, company) + + # Document ๊ฐ์ฒด ์ƒ์„ฑ + documents = self.create_documents(data) + + # ๊ธฐ์กด ๋ฒกํ„ฐ DB๊ฐ€ ์žˆ์œผ๋ฉด ์ถ”๊ฐ€, ์—†์œผ๋ฉด ์ƒˆ๋กœ ์ƒ์„ฑ + if os.path.exists(company_persist_dir): + vectorstore = Chroma(persist_directory=company_persist_dir, embedding_function=self.embeddings) + vectorstore.add_documents(documents) + else: + vectorstore = Chroma.from_documents( + documents=documents, embedding=self.embeddings, persist_directory=company_persist_dir + ) + + vectorstore.persist() + print(f"{company} ๋ฒกํ„ฐ DB ์—…๋ฐ์ดํŠธ ์™„๋ฃŒ: {len(documents)}๊ฐœ ๋ฌธ์„œ ์ถ”๊ฐ€") + + def update_user_vector_stores(self, user_json_path: str, user_name: str): + """ + ์œ ์ €๋ณ„๋กœ ๋ฒกํ„ฐ DB๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค. + """ + + text_data = self.load_json_data(os.path.join(user_json_path, "text.json")) + table_data = self.load_json_data(os.path.join(user_json_path, "table.json")) + user_data = text_data + table_data + + documents = self.create_documents(user_data) + user_persist_dir = os.path.join(self.persist_directory, user_name) + if os.path.exists(user_persist_dir): + vectorstore = Chroma(persist_directory=user_persist_dir, embedding_function=self.embeddings) + vectorstore.add_documents(documents) + else: + vectorstore = Chroma.from_documents( + documents=documents, embedding=self.embeddings, persist_directory=user_persist_dir + ) + + def update_all_vector_stores(self, text_json_path: str, table_json_path: str): + """ + ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ํ†ตํ•ฉํ•˜์—ฌ ๋ฒกํ„ฐ DB๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค. + """ + # JSON ๋ฐ์ดํ„ฐ ๋กœ๋“œ + if text_json_path == table_json_path: + all_data = self.load_json_data(text_json_path) + else: + text_data = self.load_json_data(text_json_path) + table_data = self.load_json_data(table_json_path) + + # ๋ชจ๋“  ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ + all_data = text_data + table_data + documents = self.create_documents(all_data) + # ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ํ†ตํ•ฉํ•œ ๋ฒกํ„ฐ DB ์—…๋ฐ์ดํŠธ + company_persist_dir = os.path.join(self.persist_directory, "All_data") + if os.path.exists(company_persist_dir): + vectorstore = Chroma(persist_directory=company_persist_dir, embedding_function=self.embeddings) + vectorstore.add_documents(documents) + else: + vectorstore = Chroma.from_documents( + documents=documents, embedding=self.embeddings, persist_directory=company_persist_dir + ) + + vectorstore.persist() + print(f"All_data ๋ฒกํ„ฐ DB ์—…๋ฐ์ดํŠธ ์™„๋ฃŒ: {len(documents)}๊ฐœ ๋ฌธ์„œ ์ถ”๊ฐ€") + + def load_company_vectorstore(self, company: str) -> Chroma: + """ + ํŠน์ • ํšŒ์‚ฌ์˜ ๋ฒกํ„ฐ DB๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. + """ + company_persist_dir = os.path.join(self.persist_directory, company) + if not os.path.exists(company_persist_dir): + raise ValueError(f"Vector store for company {company} does not exist") + + return Chroma(persist_directory=company_persist_dir, embedding_function=self.embeddings) + + @staticmethod + def move_to_old_data(json_paths: List[str], old_data_dir: str = "old_data", user_name: str = "All_data"): + """์ฒ˜๋ฆฌ๋œ JSON ํŒŒ์ผ์„ old_data ๋””๋ ‰ํ† ๋ฆฌ๋กœ ์ด๋™ํ•ฉ๋‹ˆ๋‹ค.""" + if not os.path.exists(os.path.join(old_data_dir, user_name)): + os.makedirs(os.path.join(old_data_dir, user_name)) + + for json_path in json_paths: + if os.path.exists(json_path): + filename = os.path.basename(json_path) + target_path = os.path.join(old_data_dir, user_name, filename) + shutil.move(json_path, target_path) + print(f"ํŒŒ์ผ ์ด๋™ ์™„๋ฃŒ: {json_path} -> {target_path}") diff --git a/app/README.md b/app/README.md new file mode 100644 index 0000000..351e737 --- /dev/null +++ b/app/README.md @@ -0,0 +1,118 @@ +# RAG API ์„œ๋ฒ„ ์‚ฌ์šฉ ๊ฐ€์ด๋“œ + +## 1. ํ™˜๊ฒฝ ์„ค์ • + +### 1.1 ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ • +`.env` ํŒŒ์ผ์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค: +```bash + +``` + +### 1.2 ์˜์กด์„ฑ ์„ค์น˜ +```bash +pip install -r app/requirements.txt +``` + +## 2. ์„œ๋ฒ„ ์‹คํ–‰ + +### 2.1 ๊ฐœ๋ฐœ ๋ชจ๋“œ +```bash +cd app +uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +### 2.2 API ์„œ๋ฒ„ ๋ชจ๋“œ +```bash +cd app +gunicorn main:app -w 2 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:30685 +``` + +## 3. API ์—”๋“œํฌ์ธํŠธ + +### 3.1 ์งˆ๋ฌธํ•˜๊ธฐ (POST `/api/v1/query`) + +#### ์š”์ฒญ ํ˜•์‹ +```json +{ + "query": "์งˆ๋ฌธ ๋‚ด์šฉ", + "max_tokens": 256, + "temperature": 0.7 +} +``` + +#### ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค๋ช… +- `query` (ํ•„์ˆ˜): ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ +- `max_tokens` (์„ ํƒ, ๊ธฐ๋ณธ๊ฐ’: 256): ์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜ +- `temperature` (์„ ํƒ, ๊ธฐ๋ณธ๊ฐ’: 0.7): ์ƒ์„ฑ ํ…์ŠคํŠธ์˜ ๋‹ค์–‘์„ฑ (0.0 ~ 1.0) + +#### ์‘๋‹ต ํ˜•์‹ +```json +{ + "answer": "์ƒ์„ฑ๋œ ๋‹ต๋ณ€", + "retrieved_documents": [ + { + "content": "๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋‚ด์šฉ", + "score": 0.95, + "source": "๋ฌธ์„œ ์ถœ์ฒ˜", + "company": "์นด์นด์˜ค๋ฑ…ํฌ" + } + ], + "processing_time": 1.23, + "company": "์นด์นด์˜ค๋ฑ…ํฌ" +} +``` + +### 3.2 API ํ˜ธ์ถœ ์˜ˆ์‹œ + +#### cURL + +## query +```bash +curl -u test@email.com:1234 \ + -X POST "http://0.0.0.0:8000/api/v1/query/" \ + -H "Content-Type: application/json" \ + -d '{ + "query": "NAVER์˜ 2024๋…„ 3๋ถ„๊ธฐ ์„œ์น˜ํ”Œ๋žซํผ ๋งค์ถœ์€ ์–ผ๋งˆ์ด๋ฉฐ, ์ „๋…„ ๋™๊ธฐ ๋Œ€๋น„ ๋ช‡ % ์ฆ๊ฐ€ํ–ˆ๋‚˜์š”?", + "max_tokens": 1000, + "temperature": 0.7 + }' +``` + + +#### Python +```python +import requests + +url = "http://localhost:8000/api/v1/query/" +auth = ("test@email.com", "1234") +data = { + "query": "์งˆ๋ฌธ ๋‚ด์šฉ", + "max_tokens": 1000, + "temperature": 0.7 +} + +response = requests.post(url, json=data, auth=auth) +result = response.json() + +``` + +## 4. ๋ชจ๋‹ˆํ„ฐ๋ง + +### 4.1 ๋ฉ”ํŠธ๋ฆญ์Šค +- Prometheus ๋ฉ”ํŠธ๋ฆญ์Šค: `http://localhost:8000/metrics` + +### 4.2 ๋กœ๊ทธ +- ๋กœ๊ทธ ํŒŒ์ผ ์œ„์น˜: `app/logs/app.log` +- ๋กœ๊ทธ ๋ ˆ๋ฒจ ์„ค์ •: `.env` ํŒŒ์ผ์˜ `LOG_LEVEL` ๋ณ€์ˆ˜๋กœ ์กฐ์ • + +## 5. ๋ฌธ์ œ ํ•ด๊ฒฐ + +### 5.1 ์ผ๋ฐ˜์ ์ธ ์˜ค๋ฅ˜ +- 500 ์—๋Ÿฌ: ์„œ๋ฒ„ ๋‚ด๋ถ€ ์˜ค๋ฅ˜, ๋กœ๊ทธ ํŒŒ์ผ ํ™•์ธ +- 404 ์—๋Ÿฌ: ์ž˜๋ชป๋œ ์—”๋“œํฌ์ธํŠธ ์ ‘๊ทผ +- 422 ์—๋Ÿฌ: ์ž˜๋ชป๋œ ์š”์ฒญ ํ˜•์‹ + +### 5.2 ๋กœ๊ทธ ํ™•์ธ +```bash +tail -f app/logs/app.log +``` \ No newline at end of file diff --git a/app/api/v1/endpoints/chatting.py b/app/api/v1/endpoints/chatting.py new file mode 100644 index 0000000..a7fc4ad --- /dev/null +++ b/app/api/v1/endpoints/chatting.py @@ -0,0 +1,34 @@ +from uuid import uuid4 + +from fastapi import APIRouter, HTTPException +from loguru import logger +from schemas.rag import ChatRequest, ChatResponse +from services.rag_service import RAGService + +router = APIRouter() +rag_service = RAGService() + + +@router.post("", response_model=ChatResponse) +async def chatting(request: ChatRequest): + try: + # ์„ธ์…˜ ID๊ฐ€ ์—†์œผ๋ฉด ์ƒˆ๋กœ ์ƒ์„ฑ + session_id = request.session_id or str(uuid4()) + + # ์ฑ„ํŒ… ์ฒ˜๋ฆฌ + answer, retrieval_results, processing_time, company, current_chat_history = await rag_service.process_chat( + session_id=session_id, query=request.query, llm_model=request.llm_model, chat_history=request.chat_history + ) + + return ChatResponse( + session_id=session_id, + answer=answer, + company=company, + retrieved_documents=retrieval_results, + processing_time=processing_time, + chat_history=current_chat_history, + ) + + except Exception as e: + logger.error(f"Error processing chat: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}") diff --git a/app/api/v1/endpoints/documents.py b/app/api/v1/endpoints/documents.py new file mode 100644 index 0000000..a4b5a1d --- /dev/null +++ b/app/api/v1/endpoints/documents.py @@ -0,0 +1,75 @@ +from typing import List, Optional + +import os +import shutil +import time +from datetime import datetime + +from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile +from loguru import logger +from schemas.rag import DocumentResponse +from services.pdf_service import PDFService + +router = APIRouter() +pdf_service = PDFService() + +# PDF ์ €์žฅ ๊ฒฝ๋กœ ์„ค์ • +UPLOAD_DIR = "../PDF_OCR/pdf" +os.makedirs(UPLOAD_DIR, exist_ok=True) + + +def process_pdf_background(file_path: str): + """๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ PDF๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜""" + try: + pdf_service.process_pdf(file_path) + except Exception as e: + logger.error(f"Background PDF processing error: {str(e)}") + + +@router.post("/upload", response_model=DocumentResponse) +async def upload_document( + background_tasks: BackgroundTasks, file: UploadFile = File(...), company: Optional[str] = Form(None) +): + try: + # ํšŒ์‚ฌ๋ณ„ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ + save_dir = os.path.join(UPLOAD_DIR) + os.makedirs(save_dir, exist_ok=True) + + # ํŒŒ์ผ ์ €์žฅ + filename = file.filename + file_path = os.path.join(save_dir, filename) + + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + # PDF ์ฒ˜๋ฆฌ๋ฅผ ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ž‘์—…์œผ๋กœ ์‹คํ–‰ + background_tasks.add_task(process_pdf_background, file_path) + + return DocumentResponse( + message="Document uploaded and processing started", filename=filename, company=company, status="processing" + ) + + except Exception as e: + logger.error(f"Error uploading document: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}") + + +@router.get("/", response_model=List[DocumentResponse]) +async def list_documents(): + try: + documents = [] + # ๋””๋ ‰ํ† ๋ฆฌ ์ˆœํšŒํ•˜์—ฌ ๋ฌธ์„œ ๋ชฉ๋ก ์ƒ์„ฑ + for company in os.listdir(UPLOAD_DIR): + company_dir = os.path.join(UPLOAD_DIR, company) + if os.path.isdir(company_dir): + for filename in os.listdir(company_dir): + if filename.lower().endswith(".pdf"): + documents.append( + DocumentResponse( + message="Document found", filename=filename, company=company, status="completed" + ) + ) + return documents + except Exception as e: + logger.error(f"Error listing documents: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error listing documents: {str(e)}") diff --git a/app/api/v1/endpoints/query.py b/app/api/v1/endpoints/query.py new file mode 100644 index 0000000..f7132a6 --- /dev/null +++ b/app/api/v1/endpoints/query.py @@ -0,0 +1,32 @@ +import json + +from core.auth import verify_credentials +from fastapi import APIRouter, Depends, HTTPException +from loguru import logger +from schemas.rag import QueryRequest, QueryResponse +from services.rag_service import RAGService + +router = APIRouter() +rag_service = RAGService() + + +@router.post("", response_model=QueryResponse) +async def query(request: QueryRequest): # , username: str = Depends(verify_credentials)): + try: + # logger.info(f"Received query request from {username}: {request.query}") + logger.info(f"Received query request: {request.query}") + answer, retrieved_docs, processing_time, company = await rag_service.process_query(request) + + response = QueryResponse( + answer=answer, context=retrieved_docs, processing_time=processing_time, company=company + ) + logger.info(f"Query response: {response.answer}") + # log ์— ์ €์žฅ + with open("result_log.txt", "a", encoding="utf-8") as f: + f.write(f"{request.query}\n{response}\n--------------------------------\n") + logger.info(f"Query processed successfully in {processing_time:.2f} seconds") + return response + + except Exception as e: + logger.error(f"Error processing query: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") diff --git a/app/api/v1/router.py b/app/api/v1/router.py new file mode 100644 index 0000000..002568b --- /dev/null +++ b/app/api/v1/router.py @@ -0,0 +1,9 @@ +from api.v1.endpoints import chatting, documents, query +from fastapi import APIRouter + +router = APIRouter() + +# ๊ฐ ์—”๋“œํฌ์ธํŠธ ๋ผ์šฐํ„ฐ ๋“ฑ๋ก +router.include_router(chatting.router, prefix="/chatting", tags=["chatting"]) +router.include_router(query.router, prefix="/query", tags=["query"]) +router.include_router(documents.router, prefix="/documents", tags=["documents"]) diff --git a/app/core/auth.py b/app/core/auth.py new file mode 100644 index 0000000..76f92e4 --- /dev/null +++ b/app/core/auth.py @@ -0,0 +1,24 @@ +import secrets + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials + +security = HTTPBasic() + + +def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)): + # ์ž„์‹œ. ์‚ฌ์šฉ์ž db ์—ฐ๋™ ํ›„ ์ˆ˜์ • ํ•„์š” + correct_username = "test@email.com" + correct_password = "1234" + + is_correct_username = secrets.compare_digest(credentials.username, correct_username) + is_correct_password = secrets.compare_digest(credentials.password, correct_password) + + if not (is_correct_username and is_correct_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect email or password", + headers={"WWW-Authenticate": "Basic"}, + ) + + return credentials.username diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..082317b --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,29 @@ +from typing import List + +import os +from pathlib import Path + +from dotenv import load_dotenv +from pydantic_settings import BaseSettings + +load_dotenv() + + +class Settings(BaseSettings): + API_V1_STR: str = "/api/v1" + PROJECT_NAME: str = "RAG API Server" + + # CORS + BACKEND_CORS_ORIGINS: List[str] = ["*"] + + # OpenAI + OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") + + # RAG ์„ค์ • + RAG_CONFIG_PATH: str = os.getenv("RAG_CONFIG_PATH", str(Path(__file__).parent.parent / "RAG/configs/config.yaml")) + + class Config: + case_sensitive = True + + +settings = Settings() diff --git a/app/core/logging.py b/app/core/logging.py new file mode 100644 index 0000000..ffda4a2 --- /dev/null +++ b/app/core/logging.py @@ -0,0 +1,52 @@ +import logging +import sys +from pathlib import Path + +from loguru import logger + + +def setup_logging(): + # ๋กœ๊ทธ ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ • + log_path = Path("logs") + log_path.mkdir(parents=True, exist_ok=True) + + # ๋กœ๊ฑฐ ์„ค์ • + config = { + "handlers": [ + { + "sink": sys.stdout, + "format": "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", + }, + { + "sink": log_path / "app.log", + "rotation": "500 MB", + "retention": "10 days", + "format": "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", + }, + ], + } + + # ๊ธฐ์กด ๋กœ๊ฑฐ ์ œ๊ฑฐ + logger.remove() + + # ์ƒˆ ์„ค์ • ์ ์šฉ + for handler in config["handlers"]: + logger.add(**handler) + + # FastAPI ๋กœ๊ฑฐ์™€ ํ†ตํ•ฉ + logging.getLogger("uvicorn.access").handlers = [InterceptHandler()] + + +class InterceptHandler(logging.Handler): + def emit(self, record): + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..43205f5 --- /dev/null +++ b/app/main.py @@ -0,0 +1,67 @@ +import os + +import uvicorn +from api.v1.endpoints import documents +from api.v1.router import router as api_v1_router +from core.config import settings +from core.logging import setup_logging +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from prometheus_fastapi_instrumentator import Instrumentator +from starlette.responses import FileResponse +from starlette.staticfiles import StaticFiles + +app = FastAPI( + title="RAG API Server", + description="RAG(Retrieval Augmented Generation) API Server", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", +) + +# CORS ์„ค์ • +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# ๋ฉ”ํŠธ๋ฆญ์Šค ์„ค์ • +Instrumentator().instrument(app).expose(app) + + +dist_dir = "./dist" +app.mount("/assets", StaticFiles(directory=os.path.join(dist_dir, "assets")), name="assets") + + +@app.get("/") +def serve_index(): + return FileResponse(os.path.join(dist_dir, "index.html")) + + +# ๋กœ๊น… ์„ค์ • +setup_logging() + +# API ๋ผ์šฐํ„ฐ ๋“ฑ๋ก +app.include_router(api_v1_router, prefix="/api/v1") + + +# ํ—ฌ์Šค ์ฒดํฌ ์—”๋“œํฌ์ธํŠธ +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +if __name__ == "__main__": + # uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, workers=4) + uvicorn.run( + "main:app", + host="0.0.0.0", + port=8000, + reload=True, + workers=1, + timeout_keep_alive=300, # ์—ฐ๊ฒฐ ์œ ์ง€ ํƒ€์ž„์•„์›ƒ + timeout=600, # ์›Œ์ปค ํƒ€์ž„์•„์›ƒ + ) diff --git a/app/requirements.txt b/app/requirements.txt new file mode 100644 index 0000000..91a86c8 --- /dev/null +++ b/app/requirements.txt @@ -0,0 +1,12 @@ +fastapi==0.104.1 +uvicorn==0.27.0 +pydantic==2.7.4 +python-dotenv==1.0.0 +loguru==0.7.2 +openai==1.59.9 +langchain==0.3.14 +langchain-community==0.3.14 +langchain-openai==0.3.2 +prometheus-client==0.19.0 +prometheus-fastapi-instrumentator==6.1.0 +gunicorn==21.2.0 \ No newline at end of file diff --git a/app/schemas/rag.py b/app/schemas/rag.py new file mode 100644 index 0000000..6be01f2 --- /dev/null +++ b/app/schemas/rag.py @@ -0,0 +1,61 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + + +class ChatMessage(BaseModel): + role: str = Field(..., description="๋ฉ”์‹œ์ง€ ์ž‘์„ฑ์ž ์—ญํ•  (user ๋˜๋Š” assistant)") + content: str = Field(..., description="๋ฉ”์‹œ์ง€ ๋‚ด์šฉ") + + +class QueryRequest(BaseModel): + query: str = Field(..., description="์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ") + llm_model: Optional[str] = Field(default="GPT-4o", description="์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•  LLM") + max_tokens: Optional[int] = Field(default=1000, description="์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜") + temperature: Optional[float] = Field(default=0.7, description="์ƒ์„ฑ ํ…์ŠคํŠธ์˜ ๋‹ค์–‘์„ฑ (0.0 ~ 1.0)") + company: Optional[str] = None + + +class RetrievalResult(BaseModel): + content: str = Field(..., description="๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋‚ด์šฉ") + metadata: dict + score: float = Field(..., description="๊ฒ€์ƒ‰ ์ ์ˆ˜") + company: str = Field(..., description="๋ฌธ์„œ ์†Œ์† ๊ธฐ์—…") + source: str = Field(..., description="๋ฌธ์„œ ์ถœ์ฒ˜") + + +class QueryResponse(BaseModel): + answer: str = Field(..., description="์ƒ์„ฑ๋œ ๋‹ต๋ณ€") + context: List[RetrievalResult] = Field(..., description="๊ฒ€์ƒ‰๋œ ๊ด€๋ จ ๋ฌธ์„œ๋“ค") + processing_time: float = Field(..., description="์ฒ˜๋ฆฌ ์‹œ๊ฐ„ (์ดˆ)") + company: Optional[str] = None + + +class ChatRequest(BaseModel): + session_id: Optional[str] = None + query: str = Field(..., description="์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ") + llm_model: Optional[str] = Field(default="GPT-4o-mini", description="์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•  LLM") + max_tokens: Optional[int] = Field(default=1000, description="์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜") + temperature: Optional[float] = Field(default=0.7, description="์ƒ์„ฑ ํ…์ŠคํŠธ์˜ ๋‹ค์–‘์„ฑ (0.0 ~ 1.0)") + company: Optional[str] = None + chat_history: Optional[List[ChatMessage]] = Field(default=None, description="์ด์ „ ๋Œ€ํ™” ๊ธฐ๋ก") + + +class ChatResponse(BaseModel): + session_id: str = Field(..., description="์ฑ„ํŒ… ์„ธ์…˜ ID") + answer: str = Field(..., description="์ƒ์„ฑ๋œ ๋‹ต๋ณ€") + retrieved_documents: List[RetrievalResult] = Field(..., description="๊ฒ€์ƒ‰๋œ ๊ด€๋ จ ๋ฌธ์„œ๋“ค") + processing_time: float = Field(..., description="์ฒ˜๋ฆฌ ์‹œ๊ฐ„ (์ดˆ)") + company: Optional[str] = None + chat_history: List[ChatMessage] = Field(..., description="ํ˜„์žฌ๊นŒ์ง€์˜ ์ „์ฒด ๋Œ€ํ™” ๊ธฐ๋ก") + + +class DocumentResponse(BaseModel): + message: str + filename: str + company: Optional[str] = None + upload_time: Optional[str] = None + + +class CompanyResponse(BaseModel): + company: str = Field(..., description="๊ธฐ์—… ์ด๋ฆ„") diff --git a/app/services/pdf_service.py b/app/services/pdf_service.py new file mode 100644 index 0000000..91c67b2 --- /dev/null +++ b/app/services/pdf_service.py @@ -0,0 +1,160 @@ +from typing import Optional + +import asyncio +import os +import shutil +import subprocess +import sys +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ Python ๊ฒฝ๋กœ์— ์ถ”๊ฐ€ +project_root = Path(__file__).parent.parent.parent +sys.path.append(str(project_root)) + +from omegaconf import OmegaConf + +# RAG ๋ชจ๋“ˆ import +from app.RAG.utils.vector_store import VectorStore + + +class PDFService: + def __init__(self): + # ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ • (app์˜ ์ƒ์œ„ ๋””๋ ‰ํ† ๋ฆฌ) + self.base_dir = Path(__file__).parent.parent.parent + self.pdf_ocr_dir = self.base_dir / "PDF_OCR" + self.upload_dir = self.pdf_ocr_dir / "pdf" + self.vector_db_dir = self.base_dir / "app/RAG/vector_db" + self.executor = ThreadPoolExecutor(max_workers=1) + + # ํ•„์š”ํ•œ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ + self._create_directories() + + # ๊ธฐ๋ณธ ์„ค์ • ๋กœ๋“œ + self.config = { + "DIRS": { + "input_dir": str(self.upload_dir), + "output_dir": str(self.base_dir / "PDF_OCR/output"), + "database_dir": str(self.base_dir / "PDF_OCR/database"), + "ocr_output_dir": str(self.pdf_ocr_dir / "ocr_results"), + }, + "MODEL": { + "path": "~/.cache/huggingface/hub/models--juliozhao--DocLayout-YOLO-DocStructBench/snapshots/8c3299a30b8ff29a1503c4431b035b93220f7b11/doclayout_yolo_docstructbench_imgsz1024.pt", + "imgsz": 1024, + "line_width": 5, + "font_size": 20, + "conf": 0.2, + "threshold": 0.05, + }, + } + + def _create_directories(self): + """ํ•„์š”ํ•œ ๋””๋ ‰ํ† ๋ฆฌ๋“ค์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.""" + directories = [ + self.upload_dir, + self.pdf_ocr_dir / "ocr_results", + self.pdf_ocr_dir / "new_data", + self.vector_db_dir, + self.base_dir / "PDF_OCR/output", + self.base_dir / "PDF_OCR/database", + ] + + for directory in directories: + directory.mkdir(parents=True, exist_ok=True) + + async def process_pdf_async(self, pdf_path: str, company: str) -> bool: + """ + PDF๋ฅผ ๋น„๋™๊ธฐ์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๊ณ  Vector DB์— ์ €์žฅํ•˜๋Š” ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, self.process_pdf, pdf_path, company) + + def process_pdf(self, pdf_path: str) -> bool: + """ + PDF๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  Vector DB์— ์ €์žฅํ•˜๋Š” ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. + + Args: + pdf_path (str): ์—…๋กœ๋“œ๋œ PDF ํŒŒ์ผ ๊ฒฝ๋กœ + company (str): ํšŒ์‚ฌ๋ช… + + Returns: + bool: ์ฒ˜๋ฆฌ ์„ฑ๊ณต ์—ฌ๋ถ€ + """ + try: + # 2. PDF_OCR ๋””๋ ‰ํ† ๋ฆฌ๋กœ ์ด๋™ํ•˜์—ฌ ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ + current_dir = os.getcwd() + os.chdir(str(self.pdf_ocr_dir)) + + # PDF ํŒŒ์‹ฑ ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ + print("PDF ํŒŒ์‹ฑ ์‹œ์ž‘...") + result = subprocess.run( + ["python", "pdf_parser.py", "-i", "./pdf"], + check=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + print(result.stdout) + if result.stderr: + print("์˜ค๋ฅ˜:", result.stderr) + print("PDF ํŒŒ์‹ฑ ์™„๋ฃŒ") + + print("Postprocessing ์‹œ์ž‘...") + result = subprocess.run( + ["python", "data_postprocess.py"], check=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + print(result.stdout) + if result.stderr: + print("์˜ค๋ฅ˜:", result.stderr) + print("Postprocessing ์™„๋ฃŒ") + + # ์›๋ž˜ ๋””๋ ‰ํ† ๋ฆฌ๋กœ ๋ณต๊ท€ + os.chdir(current_dir) + + # 3. Vector DB ์ €์žฅ + print("Vector DB ์ €์žฅ ์ค‘...") + vector_store = VectorStore( + OmegaConf.create({"passage_embedding_model_name": "nlpai-lab/KoE5"}), str(self.vector_db_dir) + ) + + # ํšŒ์‚ฌ๋ณ„ ๋ฐ ์ „์ฒด Vector DB ์—…๋ฐ์ดํŠธ + new_data_dir = self.pdf_ocr_dir / "new_data" + vector_store.update_company_vector_stores( + str(new_data_dir / "All_data/text_data.json"), str(new_data_dir / "All_data/table_data.json") + ) + print("Vector DB ์ €์žฅ ์™„๋ฃŒ") + + return True + + except subprocess.CalledProcessError as e: + print(f"PDF ์ฒ˜๋ฆฌ ์ค‘ ๋ช…๋ น์–ด ์‹คํ–‰ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}") + print(f"์˜ค๋ฅ˜ ์ถœ๋ ฅ: {e.stderr}") + return False + except Exception as e: + print(f"PDF ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}") + return False + finally: + # ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์›๋ž˜๋Œ€๋กœ ๋ณต์› + if "current_dir" in locals(): + os.chdir(current_dir) + + def clean_up(self): + """์ž„์‹œ ํŒŒ์ผ๋“ค์„ ์ •๋ฆฌํ•ฉ๋‹ˆ๋‹ค.""" + try: + # PDF_OCR ๋””๋ ‰ํ† ๋ฆฌ๋กœ ์ด๋™ + current_dir = os.getcwd() + os.chdir(str(self.pdf_ocr_dir)) + + # ์ž„์‹œ ํŒŒ์ผ๋“ค ์ •๋ฆฌ + if self.upload_dir.exists(): + shutil.rmtree(self.upload_dir) + + # ํ•„์š”ํ•œ ๋””๋ ‰ํ† ๋ฆฌ ์žฌ์ƒ์„ฑ + self._create_directories() + + except Exception as e: + print(f"์ •๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}") + finally: + # ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์›๋ž˜๋Œ€๋กœ ๋ณต์› + if "current_dir" in locals(): + os.chdir(current_dir) diff --git a/app/services/rag_service.py b/app/services/rag_service.py new file mode 100644 index 0000000..efb3035 --- /dev/null +++ b/app/services/rag_service.py @@ -0,0 +1,312 @@ +from typing import Dict, List, Optional, Tuple + +import json +import os +import sys +import time +import warnings +from functools import lru_cache +from io import StringIO +from pathlib import Path + +import aiofiles +import hydra +import pandas as pd +from core.config import settings +from langchain.prompts import ChatPromptTemplate +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.messages import AIMessage, HumanMessage +from loguru import logger +from omegaconf import DictConfig +from RAG.generator import get_llm_api +from schemas.rag import QueryRequest, RetrievalResult + +warnings.filterwarnings("ignore") +# RAG ๋ชจ๋“ˆ import๋ฅผ ์œ„ํ•œ ๊ฒฝ๋กœ ์„ค์ • +project_root = Path(__file__).parent.parent +rag_path = project_root / "RAG" +sys.path.append(str(rag_path)) + +# RAG ๋ชจ๋“ˆ import +from RAG.retrieval import ChromaRetrieval + +# from RAG.source.generate import generate + +# ๋ฉ”ํŠธ๋ฆญ ์ •์˜ + + +class RAGService: + def __init__(self): + """RAG ์„œ๋น„์Šค ์ดˆ๊ธฐํ™”""" + self._load_config() + self._init_retrievers() + self._init_generator() + self._init_cache() + self._init_chat_histories() + + def _load_config(self): + """Hydra ์„ค์ • ๋กœ๋“œ""" + # ํ˜„์žฌ ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์ €์žฅ + original_cwd = os.getcwd() + + try: + os.chdir(str(project_root)) + + # ์ƒ๋Œ€ ๊ฒฝ๋กœ๋กœ config_path ์„ค์ • + with hydra.initialize(version_base=None, config_path="../RAG/configs"): + cfg = hydra.compose(config_name="config") + self.cfg = cfg + finally: + # ์›๋ž˜ ๋””๋ ‰ํ† ๋ฆฌ๋กœ ๋ณต๊ท€ + os.chdir(original_cwd) + + def _init_retrievers(self): + """๊ฒ€์ƒ‰ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”""" + try: + + self.ensemble_retriever = ChromaRetrieval(self.cfg) + logger.info("Successfully initialized all retrievers") + + except Exception as e: + logger.error(f"Error initializing retrievers: {str(e)}") + raise + + def _init_generator(self): + """์ƒ์„ฑ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”""" + try: + # self.generator = generate + logger.info("Successfully initialized generator") + except Exception as e: + logger.error(f"Error initializing generator: {str(e)}") + raise + + def _init_cache(self): + """์บ์‹œ ์ดˆ๊ธฐํ™”""" + self.query_cache = {} + + def _init_chat_histories(self): + """์ฑ„ํŒ… ๊ธฐ๋ก ์ดˆ๊ธฐํ™”""" + self.chat_histories: Dict[str, ChatMessageHistory] = {} + + @lru_cache(maxsize=1000) + def _get_cached_retrieval_with_query_rewritten(self, query: str) -> List[RetrievalResult]: + """๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์บ์‹ฑ""" + return self.ensemble_retriever.get_relevant_documents_with_query_rewritten(query=query, k=20) + + @lru_cache(maxsize=1000) + def _get_cached_retrieval_without_query_rewritten(self, query: str) -> List[RetrievalResult]: + """๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์บ์‹ฑ""" + return self.ensemble_retriever.get_relevant_documents_without_query_rewritten(query=query, k=20) + + async def _retrieve_documents(self, query: str, is_rewritten: bool = True) -> Tuple[str, List[RetrievalResult]]: + """๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋กœ์ง""" + if is_rewritten: + retrieved_docs = self._get_cached_retrieval_with_query_rewritten(query) + else: + retrieved_docs = self._get_cached_retrieval_without_query_rewritten(query) + docs_text = "" + retrieval_results = [] + + async def process_doc(doc): + retrieval_results.append( + RetrievalResult( + content=doc.page_content, + metadata=doc.metadata, + score=float(doc.metadata.get("score", 1.0)), + company=doc.metadata.get("company", "unknown"), + source=f"{doc.metadata.get('company', 'unknown')}_{doc.metadata.get('securities', 'unknown')}_{doc.metadata.get('date', 'unknown')}_page{doc.metadata.get('page', 'unknown')}_{doc.metadata.get('category', 'unknown')}", + ) + ) + + if doc.metadata.get("category") == "table": + try: + doc_path = self._fix_path(doc.metadata.get("path")) + doc_path = "../PDF_OCR/processed_ocr_results" + doc_path + table_path = doc_path.replace(".json", ".csv") + + if os.path.exists(table_path): + async with aiofiles.open(table_path, mode="r") as f: + content = await f.read() + df = pd.read_csv(StringIO(content)) + return f"{doc.metadata.get('company')} ํ…Œ์ด๋ธ” ๋ฐ์ดํ„ฐ :\n {df.to_string(index=False)}\n" + except Exception as e: + logger.error(f"Error processing table document: {str(e)}") + return "ํ…Œ์ด๋ธ” ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค." + return doc.page_content + + from asyncio import gather + + if len(retrieved_docs) > 7: + processed_contents = await gather(*[process_doc(doc) for doc in retrieved_docs[:7]]) + else: + processed_contents = await gather(*[process_doc(doc) for doc in retrieved_docs]) + docs_text = "\n".join(processed_contents) + + return docs_text, retrieval_results + + async def _generate_response(self, query: str, docs_text: str, llm_model: Optional[str] = None) -> str: + """LLM ์‘๋‹ต ์ƒ์„ฑ ๋กœ์ง""" + if llm_model == "GPT-4o-mini": + self.cfg.llm_model_name = "gpt-4o-mini" + self.cfg.llm_model_source = "openai" + llm = get_llm_api(self.cfg) + elif llm_model == "GPT-4o" or llm_model == None: + self.cfg.llm_model_name = "gpt-4o" + self.cfg.llm_model_source = "openai" + llm = get_llm_api(self.cfg) + elif llm_model == "CLOVA X": + self.cfg.llm_model_source = "naver" + llm = get_llm_api(self.cfg) + else: + raise ValueError(f"Invalid LLM model: {llm_model}") + + prompt_template = ChatPromptTemplate.from_messages( + [("system", self.cfg.chat_template), ("user", f"์งˆ๋ฌธ: {query}")] + ) + prompt = prompt_template.invoke({"docs": docs_text}) + start_time = time.time() + answer = llm.invoke(prompt) + # LLM response time log + logger.info(f"LLM response time: {time.time() - start_time:.2f} seconds") + return answer.content + + async def process_query(self, request: QueryRequest) -> Tuple[str, List[RetrievalResult], float, str]: + """์ผ๋ฐ˜ ์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ""" + start_time = time.time() + try: + docs_text, retrieval_results = await self._retrieve_documents(request.query, False) + + if not retrieval_results: + logger.warning("No retrieval results found") + company = "unknown" + else: + # docs ์—ํฌํ•จ๋œ company ์ค‘ ๊ฐ€์žฅ ๋งŽ์€ ํšŒ์‚ฌ + company_counts = {} + for result in retrieval_results: + company = result.company + if company in company_counts: + company_counts[company] += 1 + else: + company_counts[company] = 1 + company = max(company_counts, key=company_counts.get) + + answer_text = await self._generate_response(request.query, docs_text, request.llm_model) + + processing_time = time.time() - start_time + logger.info(f"Query processed in {processing_time:.2f} seconds") + + return answer_text, retrieval_results, processing_time, company + + except Exception as e: + logger.error(f"Error processing query: {str(e)}", exc_info=True) + raise + finally: + processing_time = time.time() - start_time + + async def process_chat( + self, session_id: str, query: str, llm_model: str, chat_history: Optional[List[dict]] = None + ) -> Tuple[str, List[RetrievalResult], float, str, List[dict]]: + """์ฑ„ํŒ… ์ฒ˜๋ฆฌ""" + # user query caching + if session_id not in self.query_cache: + self.query_cache[session_id] = [] + + self.query_cache[session_id].append(query) + if len(self.query_cache[session_id]) > 2: + self.query_cache[session_id].pop(0) + + # ์„ธ์…˜ ๊ธฐ๋ก ์ดˆ๊ธฐํ™” ๋˜๋Š” ๊ฐ€์ ธ์˜ค๊ธฐ + if session_id not in self.chat_histories: + self.chat_histories[session_id] = ChatMessageHistory() + # ์ด์ „ ๋Œ€ํ™” ๊ธฐ๋ก์ด ์žˆ๋‹ค๋ฉด ๋ณต์› + if chat_history: + for msg in chat_history: + if isinstance(msg, dict): + role = msg.get("role") + content = msg.get("content") + else: + role = msg.role + content = msg.content + + if role == "user": + self.chat_histories[session_id].add_user_message(content) + elif role == "assistant": + self.chat_histories[session_id].add_ai_message(content) + + chat_history = self.chat_histories[session_id] + + # ์ƒˆ ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€ + chat_history.add_user_message(query) + + try: + # ๋ฌธ์„œ ๊ฒ€์ƒ‰ + # ์ตœ๊ทผ ๋‘๊ฐœ์˜ ์งˆ๋ฌธ์„ ํ•ฉ์นœ ๋ฌธ์žฅ์„ ๊ฒ€์ƒ‰ + previous_user_query = " ".join(self.query_cache[session_id]) + docs_text, retrieval_results = await self._retrieve_documents(previous_user_query + "\n" + query, True) + + if not retrieval_results: + company = "unknown" + else: + company_counts = {} + for result in retrieval_results: + company = result.company + if company in company_counts: + company_counts[company] += 1 + else: + company_counts[company] = 1 + company = max(company_counts, key=company_counts.get) + + # ์ฑ„ํŒ… ์ปจํ…์ŠคํŠธ๋ฅผ ํฌํ•จํ•œ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ + chat_context = "\n".join( + [ + f"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}" + for msg in chat_history.messages[-4:] # ์ตœ๊ทผ 4๊ฐœ ๋ฉ”์‹œ์ง€๋งŒ ์‚ฌ์šฉ + ] + ) + + # ์‘๋‹ต ์ƒ์„ฑ + prompt_template = ChatPromptTemplate.from_messages( + [ + ("system", self.cfg.chatting_template), + ("system", "์ด์ „ ๋Œ€ํ™” ๊ธฐ๋ก:\n{chat_context}"), + ("user", f"์งˆ๋ฌธ: {query}"), + ] + ) + + if llm_model == "GPT-4o" or llm_model == "GPT-4o-mini": + self.cfg.llm_model_name = "gpt-4o-mini" + self.cfg.llm_model_source = "openai" + llm = get_llm_api(self.cfg) + elif llm_model == "CLOVA X": + self.cfg.llm_model_source = "naver" + llm = get_llm_api(self.cfg) + else: + raise ValueError(f"Invalid LLM model: {llm_model}") + + prompt = prompt_template.invoke({"docs": docs_text, "chat_context": chat_context}) + + answer = llm.invoke(prompt) + answer_text = answer.content + + # ์‘๋‹ต ์ €์žฅ + chat_history.add_ai_message(answer_text) + + # ํ˜„์žฌ ๋Œ€ํ™” ๊ธฐ๋ก์„ ChatMessage ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ + current_chat_history = [ + {"role": "user" if isinstance(msg, HumanMessage) else "assistant", "content": msg.content} + for msg in chat_history.messages + ] + + processing_time = time.time() + + return answer_text, retrieval_results, processing_time, company, current_chat_history + + except Exception as e: + logger.error(f"Error processing chat: {str(e)}", exc_info=True) + raise + + def _fix_path(self, path: str) -> str: + path = path.replace("page_page_", "page_") + if path.endswith(".json.json"): + path = path[:-5] + return path diff --git a/images/demo.mov b/images/demo.mov new file mode 100644 index 0000000..7c649d4 Binary files /dev/null and b/images/demo.mov differ diff --git a/images/github-mark.png b/images/github-mark.png new file mode 100644 index 0000000..6cb3b70 Binary files /dev/null and b/images/github-mark.png differ diff --git a/images/pdf-ocr_flowchart.png b/images/pdf-ocr_flowchart.png new file mode 100644 index 0000000..48249aa Binary files /dev/null and b/images/pdf-ocr_flowchart.png differ diff --git a/images/profile-1.jpeg b/images/profile-1.jpeg new file mode 100644 index 0000000..862116b Binary files /dev/null and b/images/profile-1.jpeg differ diff --git a/images/profile-2.jpeg b/images/profile-2.jpeg new file mode 100644 index 0000000..6eec317 Binary files /dev/null and b/images/profile-2.jpeg differ diff --git a/images/profile-3.jpeg b/images/profile-3.jpeg new file mode 100644 index 0000000..62b3e01 Binary files /dev/null and b/images/profile-3.jpeg differ diff --git a/images/profile-4.jpeg b/images/profile-4.jpeg new file mode 100644 index 0000000..1e9571d Binary files /dev/null and b/images/profile-4.jpeg differ diff --git a/images/profile-5.jpeg b/images/profile-5.jpeg new file mode 100644 index 0000000..bb6ab3c Binary files /dev/null and b/images/profile-5.jpeg differ diff --git a/images/profile-6.jpeg b/images/profile-6.jpeg new file mode 100644 index 0000000..06c72b5 Binary files /dev/null and b/images/profile-6.jpeg differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..33e7633 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[tool.black] +line-length = 120 +target-version = ['py36', 'py37', 'py38','py39','py310'] +include = '\.py$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | env + | venv +)/ +''' +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true +line_length = 120 +known_typing = ["typing", "types", "typing_extensions", "mypy", "mypy_extensions"] +known_third_party = ["wandb"] +sections = ["FUTURE", "TYPING", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] +profile = "black" \ No newline at end of file