Skip to content
Open
Show file tree
Hide file tree
Changes from 93 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
d9082c0
Setting up GitHub Classroom Feedback
github-classroom[bot] Nov 8, 2024
d8f9da9
feat: Modularized Baseline
ocean010315 Nov 13, 2024
90ec293
docs: Update README.md
ocean010315 Nov 14, 2024
ed60b89
feat: Quantization
ocean010315 Nov 14, 2024
65a7728
fix: 8-bit quantization config
ocean010315 Nov 14, 2024
d9f5d41
feat: Add function to create uniform answer distribution
gwaksital Nov 15, 2024
3d5c35f
docs: Add template files
minjijeong98 Nov 15, 2024
1c0fe95
fix: Align modularized code with the original baseline
ocean010315 Nov 16, 2024
8b20078
fix: Add torch_dtype in config
ocean010315 Nov 16, 2024
18d2417
docs: Update README.md
ocean010315 Nov 16, 2024
887d585
Merge pull request #4 from boostcampaitech7/fix/baseline_template
jin-jae Nov 16, 2024
0201918
Merge pull request #3 from boostcampaitech7/docs
jin-jae Nov 16, 2024
873340b
Merge pull request #2 from boostcampaitech7/feat/uniform_answer_distr…
jin-jae Nov 16, 2024
079493a
fix: train.csv
ocean010315 Nov 16, 2024
9213a8b
Merge pull request #5 from boostcampaitech7/develop
jin-jae Nov 16, 2024
8875263
fix: src/utils.py path
ocean010315 Nov 16, 2024
dcc5c29
feat: Add wikipedia preprocessing function
minjijeong98 Nov 17, 2024
828b6c9
fix: Change context length filtering threshold (100 -> 50)
minjijeong98 Nov 17, 2024
2a09669
feat: Add crawling openstax textbook function
minjijeong98 Nov 18, 2024
bb77e2b
GENNLP-9: feat: Add EDA script for data analysis
gwaksital Nov 19, 2024
92914ed
feat: Add crawling code for korean history textbook
minjijeong98 Nov 19, 2024
089e304
feat: Add crawling code for korean history terms
minjijeong98 Nov 19, 2024
75cdd88
fix: Fix section name extraction error
minjijeong98 Nov 19, 2024
3388227
Merge pull request #7 from boostcampaitech7/GENNLP-9-Exploratory-Data…
jin-jae Nov 20, 2024
01c64de
GENNLP-29: refactor: Enable configurations changes via config.yaml an…
gwaksital Nov 21, 2024
a29ea56
Merge pull request #8 from boostcampaitech7/refactor/GENNLP-29-Unsloth
gwaksital Nov 21, 2024
f6978c9
feat: Add gpt_api_template, enables data augmentation and other tasks…
gwaksital Nov 21, 2024
e3c7669
refactor: Change prompt configuration
gwaksital Nov 21, 2024
2ce92fe
fix: Makr dir checkpoint folder before assigning config.yaml
gwaksital Nov 21, 2024
8456d99
add prompt_templates and modify config
ssunbear Nov 21, 2024
4ffa105
fix: Question 4,5 prompt style unified. This is final code refactor
gwaksital Nov 22, 2024
7820cde
Merge pull request #9 from boostcampaitech7/feat/GPT-API-template
ssunbear Nov 22, 2024
2ccaad7
feat: Add RAG inference code
minjijeong98 Nov 25, 2024
80d838d
feat: Add retrieval evaluation metric for RAG
minjijeong98 Nov 25, 2024
e45f891
fix: Fix testset path
minjijeong98 Nov 25, 2024
9b4bb2e
fix: Change data example for prompt checking
minjijeong98 Nov 25, 2024
fb3b207
Add rag.ipynb for experiment
gwaksital Nov 25, 2024
04a80b9
Download rag_w_eval.ipynb
gwaksital Nov 25, 2024
f690874
Move .ipynbs to notebooks/
gwaksital Nov 25, 2024
70ec368
feat: Add majority ensemble, weighted ensemble
ssunbear Nov 26, 2024
86e372f
GENNLP-32: refactor: Merge SOTA code and integrate FastLanguageModel …
gwaksital Nov 26, 2024
36cd222
GENNLP-32: feat: Integrate RAG functionality
gwaksital Nov 29, 2024
7d3c17c
GENNLP-32: chore: Relocate files
gwaksital Nov 29, 2024
a6b306e
GENNLP-32: chore: Relocate files
gwaksital Nov 29, 2024
cdd4bfd
GENNLP-32: chore: Annotate in English
gwaksital Nov 29, 2024
2b43e52
upload data_processing
wjddms4299 Nov 29, 2024
94c394f
Merge pull request #13 from boostcampaitech7/GENNLP-33-Ensemble
ssunbear Nov 29, 2024
3e7ee15
Merge pull request #14 from boostcampaitech7/feat/GENNLP-24-data-proc…
ssunbear Nov 29, 2024
83ff067
Merge pull request #15 from boostcampaitech7/develop
gwaksital Nov 29, 2024
0ad57b8
feat: demo with streamlit
ocean010315 Nov 29, 2024
26e3947
Merge pull request #17 from boostcampaitech7/feat/GENNLP-16-demo
ssunbear Nov 30, 2024
a8d6010
Revert "[Feat] Streamlit으로 만든 demo 페이지"
ssunbear Dec 1, 2024
06a62da
Merge pull request #18 from boostcampaitech7/revert-17-feat/GENNLP-16…
ssunbear Dec 2, 2024
d93e8fc
feat: demo page
ocean010315 Dec 2, 2024
a50a83f
Merge branch 'main' of https://github.com/boostcampaitech7/level2-nlp…
ocean010315 Dec 2, 2024
4b132af
Update README.md
gwaksital Dec 2, 2024
c5ec054
Update README.md
gwaksital Dec 2, 2024
b8193a2
Update README.md
gwaksital Dec 2, 2024
b8545b0
Update README.md
gwaksital Dec 2, 2024
c66ffc8
Update README.md
gwaksital Dec 2, 2024
9407a3e
Delete notebooks/gpt_api_template.ipynb
gwaksital Dec 2, 2024
dcf8c02
Upload Wrap Up Report
wjddms4299 Dec 2, 2024
7c9a9ef
Update README.md
ssunbear Dec 2, 2024
e5bc325
Upload 발표자료.pdf
wjddms4299 Dec 2, 2024
fc0d078
make a folder
ssunbear Dec 2, 2024
854022f
Add report and presentation works
ssunbear Dec 2, 2024
799dbd8
Delete assets/[5조]Lv2_수능문제풀이_프로젝트_발표자료.pdf
ssunbear Dec 2, 2024
4d7fd13
Delete NLP기초대회_NLP_팀 리포트(05조).pdf
ssunbear Dec 2, 2024
5e9fb1d
Upload image
wjddms4299 Dec 2, 2024
77e6515
Upload image
wjddms4299 Dec 2, 2024
e54cde7
Update README.md
ssunbear Dec 2, 2024
3132dba
Update README.md
gwaksital Dec 2, 2024
0f7ff07
Update README.md
gwaksital Dec 2, 2024
bf0dddc
Update README.md
gwaksital Dec 2, 2024
2ba9348
Update README.md
ssunbear Dec 2, 2024
ebeeb26
Update README.md
gwaksital Dec 2, 2024
492ea8a
Update README.md
gwaksital Dec 2, 2024
eeeb1eb
Update README.md
ssunbear Dec 2, 2024
54a338b
Update README.md
gwaksital Dec 2, 2024
a4d4ad2
Add Demo
gwaksital Dec 2, 2024
c25d0c2
Add timeline image for README
minjijeong98 Dec 2, 2024
f977d10
docs: Add timeline image to README
minjijeong98 Dec 2, 2024
ae8943e
Merge branch 'main' of https://github.com/boostcampaitech7/level2-nlp…
minjijeong98 Dec 3, 2024
2e68c60
feat: Add wikipedia preprocessing
minjijeong98 Dec 3, 2024
9064add
Merge branch 'main' of https://github.com/boostcampaitech7/level2-nlp…
minjijeong98 Dec 3, 2024
eda2f1b
Merge branch 'main' of https://github.com/boostcampaitech7/level2-nlp…
minjijeong98 Dec 3, 2024
6cfaa82
Merge branch 'feat/GENNLP-26-korean-history' of https://github.com/bo…
minjijeong98 Dec 3, 2024
29e354b
Merge branch 'feat/GENNLP-27-openstax' of https://github.com/boostcam…
minjijeong98 Dec 3, 2024
353a21a
feat: Combine preprocessing and crawling codes
minjijeong98 Dec 3, 2024
e4d4b58
Merge pull request #19 from boostcampaitech7/feat/preprocessing
minjijeong98 Dec 3, 2024
87ce95e
Delete duplicated file
minjijeong98 Dec 3, 2024
79dc136
Delete duplicated file
minjijeong98 Dec 3, 2024
f2cbc36
Add preprocessing code information in code structure
minjijeong98 Dec 3, 2024
225ee53
add: translation for openstax vectorstore
jin-jae Dec 4, 2024
503592c
Merge pull request #20 from boostcampaitech7/feat/GENNLP-27-openstax
jin-jae Dec 29, 2024
51111e6
fix: requirements dependency
jin-jae Jan 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
name: Bug report
about: Create a report to help us improve
title: "[BUG]"
labels: bug
assignees: ''

---

## Describe the bug
-

## To Reproduce
-

## Expected behavior
-

## Screenshots
-

## Additional context
-

## Possible Solution
-

## Your Environment
-
18 changes: 18 additions & 0 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
name: Feature request
about: Suggest an idea for this project
title: "[FEAT]"
labels: enhancement
assignees: ''

---

## Background
-

## Todo
- [ ] Todo 1
- [ ] Todo 2

## See also
- #
12 changes: 12 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Overview
-

## Change Log
-

## To Reviewer
-

## Issue Tags
- Closed | Fixed: #
- See also: #
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__pycache__
*.csv
wandb/*
!wandb/.gitkeep
checkpoints/*
!checkpoints/.gitkeep
!streamlit/assets/*.csv
189 changes: 189 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
<div align='center'>

# 🏆 LV.2 NLP 프로젝트 : 수능형 문제 풀이 모델 생성

</div>
<br><br>

## ✏️ 대회 소개

| 특징    | 설명 |
|:------:| --- |
| 대회 주제 | 네이버 부스트캠프 AI Tech 7기 NLP Track의 Level 2 도메인 기초 대회 '수능형 문제 풀이 모델 생성'입니다. |
| 대회 설명 | AI 모델로 한국어 수능 국어 및 사회 과목의 문제를 풀어 대형 언어 모델들을 능가하는 성능을 목표로 하는 대회입니다. |
| 데이터 구성 | 데이터는 수능 국어·사회와 유사한 문제를 기반으로 KMMLU(한국사), MMMLU(고교 역사·경제·정치 등), KLUE MRC(경제·국제·사회 등) 데이터를 포함합니다. |
| 평가 지표 | 모델이 맞춘 문제 수를 전체 문제 수로 나눈 정확도(Accuracy)를 기준으로 합니다.|
| 결과물 | [WrapUp Report](https://github.com/boostcampaitech7/level2-nlp-generationfornlp-nlp-05-lv3/blob/main/assets/NLP%E1%84%80%E1%85%B5%E1%84%8E%E1%85%A9%E1%84%83%E1%85%A2%E1%84%92%E1%85%AC_NLP_%E1%84%90%E1%85%B5%E1%86%B7%20%E1%84%85%E1%85%B5%E1%84%91%E1%85%A9%E1%84%90%E1%85%B3(05%E1%84%8C%E1%85%A9).pdf), [Presentation Material](https://github.com/boostcampaitech7/level2-nlp-generationfornlp-nlp-05-lv3/blob/main/assets/%5B5%E1%84%8C%E1%85%A9%5DLv2_%E1%84%89%E1%85%AE%E1%84%82%E1%85%B3%E1%86%BC%E1%84%86%E1%85%AE%E1%86%AB%E1%84%8C%E1%85%A6%E1%84%91%E1%85%AE%E1%86%AF%E1%84%8B%E1%85%B5_%E1%84%91%E1%85%B3%E1%84%85%E1%85%A9%E1%84%8C%E1%85%A6%E1%86%A8%E1%84%90%E1%85%B3_%E1%84%87%E1%85%A1%E1%86%AF%E1%84%91%E1%85%AD%E1%84%8C%E1%85%A1%E1%84%85%E1%85%AD.pdf) |

<br><br>

## 🎖️ Leader Board

프로젝트 결과 Public 리더보드 1등, Private 리더보드 1등을 기록하였습니다.

### 🥇 Public Leader Board (1위)

![image](https://github.com/user-attachments/assets/778831bc-2ed6-4090-a1a0-49ce38c71bc6)

### 🥇 Private Leader Board (1위)

![image](https://github.com/user-attachments/assets/8757896e-8e93-4bb2-9798-14bf764259ae)

<br><br>

## 👨‍💻 나야, 자, 연어팀 멤버
<div align='center'>

| 곽희준 [<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" width=20 style="vertical-align:middle;" />](https://github.com/gwaksital) | 김정은 [<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" width=20 style="vertical-align:middle;" />](https://github.com/wjddms4299) | 김진재 [<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" width=20 style="vertical-align:middle;" />](https://github.com/jin-jae) | 오수현 [<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" width=20 style="vertical-align:middle;" />](https://github.com/ocean010315) | 윤선웅 [<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" width=20 style="vertical-align:middle;" />](https://github.com/ssunbear) | 정민지 [<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" width=20 style="vertical-align:middle;" />](https://github.com/minjijeong98)
|:-:|:-:|:-:|:-:|:-:|:-:|
| ![곽희준](https://avatars.githubusercontent.com/u/80732503) | ![김정은](https://avatars.githubusercontent.com/u/121777522) | ![김진재](https://avatars.githubusercontent.com/u/97018331) | ![오수현](https://avatars.githubusercontent.com/u/91974779) | ![윤선웅](https://avatars.githubusercontent.com/u/117508164) | ![정민지](https://avatars.githubusercontent.com/u/162319450) |

</div>

<br><br>

## 👼 역할 분담

<div align='center'>

|팀원  | 역할 |
|:--------:| -------------- |
|곽희준| 데이터셋 레이블링, EDA, 외부 데이터셋 탐색, GPT를 통한 데이터셋 증강 실험, 코드 리팩토링, LLM 학습 방법 설계, Fine-Tuning, RAG 파이프라인 구축, 최종 코드 정리 |
|김정은| 데이터셋 레이블링, 모델 탐색 및 Fine-Tuning, 데이터셋 크롤링 및 전처리(공무원 기출, khan), 데이터셋 품질 테스트, RAG 시스템 구축 및 실험, 프롬프트 엔지니어링 |
|김진재| 초기 팀 환경 구축 및 대시보드 제작, 데이터셋 레이블링, 데이터 탐색, RAG 데이터 전처리(번역), Retrieval 구축 및 실험 (Sparse) |
|오수현| 초기 베이스라인 코드 구축, 데이터셋 레이블링, 데모 페이지 제작 |
|윤선웅| 데이터셋 레이블링, 모델 탐색 및 Fine-Tuning, Unsloth 세팅, 프롬프트 엔지니어링, 데이터셋 크롤링(공무원 기출, khan), 데이터셋 품질 테스트, LoRA 튜닝, 앙상블 |
|정민지| 데이터셋 레이블링, 벡터스토어 데이터 크롤링 및 전처리 (OpenStax, Wikipedia, 우리역사넷), Retrieval 성능 평가 데이터셋 및 지표 구성, RAG 시스템 구축 및 실험 (Chunk size, Dense Retrieval, Reranking) |

</div>

<br><br>

## 🏃 프로젝트 설명

### 🖥️ 프로젝트 개요

| 개요 | 설명 |
|:--------:| --- |
| 주제 | Generation for NLP - 수능형 문제 풀이 모델 생성 |
| 구조 | LLM Fine-Tuned Foundation Model + RAG |
| 평가 지표 | Accuracy = correct / total |
| 개발 환경 | `GPU` : Tesla V100 Server 4대, `IDE` : VsCode, Jupyter Notebook |
| 협업 환경 | Jira&Confluence(진행 상황 공유), Github(코드 및 데이터 공유), Zoom&Slack(실시간 소통) |

<br>

### 📅 프로젝트 타임라인

- 프로젝트는 2024-11-11 ~ 2024-11-28까지 진행되었습니다.

![image](./assets/timeline.png)

<br>

### 🕵️ 프로젝트 진행

- 프로젝트를 진행하며 단계별로 실험하여 적용한 내용들은 아래와 같습니다.

| 프로세스 | 설명 |
|:--------:| --- |
| 데이터 | EDA, Fine-Tuning 데이터셋 구성 (데이터 품질 개선, 데이터 증강), RAG 데이터 구성 (벡터 스토어 데이터 수집 및 전처리, chunking) |
| 모델링 | 모델 선정 및 튜닝, LoRA 튜닝, 프롬프트 튜닝 |
| RAG | Vector Store 구축, Retriever 평가용 데이터셋 구축, Retriever 파라미터 설정, RAFT(Retrieval Augmented Fine-Tuning) |
| 앙상블 | Weighted Voting Ensemble |

<br>

### 🤖 Ensemble

정제, 증강을 다양하게 적용한 데이터셋과 LoRA 튜닝을 통해 `itsmenlp/unsloth_qwen_2.5_32B_bnb_4bit_finetuned`로 추론한 output의 accuracy Top 5로 weighted voting ensemble을 진행한 결과, 최종 Public Accuracy **0.8341**을 달성했습니다.

<div align='center'>

| Output | Accuracy | Weight |
|:--------:| --- | --- |
| Top 5 | 0.8180 | 0.1 |
| Top 4 | 0.8180 | 0.1 |
| Top 3 | 0.8203 | 0.2 |
| Top 2 | 0.8272 | 0.2 |
| Top 1 | 0.8295 | 0.4 |

</div>

<br>

### 📃 KSAT Results

본 프로젝트에서 개발한 sLLM을 활용한 2025학년도 수능의 국어(화법과 작문), 한국사, 사회 탐구 영역 풀이 결과입니다.

![image](https://github.com/user-attachments/assets/ca280ffb-8598-4112-81f6-8f5fd04fb4dd)

<br><br>

## 🎥 2025학년도 수능 문제 풀이 데모 영상

https://github.com/user-attachments/assets/4448f058-6571-4037-9fb9-dfd8f86d5291

<br><br>

## 📁 프로젝트 구조

프로젝트 폴더 구조는 다음과 같습니다.

```
level2-nlp-generationfornlp-nlp-05-lv3/
├── checkpoints/ # 모델 체크포인트 저장 폴더
│ └── (experiment_name)/ # 실험 이름
│ ├── checkpoint-1111 # 모델 체크포인트
│ └── checkpoint-2222
├── config/
│ └── config.yaml # 설정 관리 파일
├── notebooks/
│ ├── eda.ipynb # EDA
│ ├── demo_data_preprocessing.ipynb # 데모 데이터 전처리 코드
│ └── ft_data_processing.ipynb # Fine-Tuning 데이터 전처리 코드
├── prompt/
│ ├── prompt_templates.yaml # 프롬프트 템플릿 관리 파일
├── src/
│ ├── dataset.py # 데이터 로드 및 전처리 관련 코드
│ ├── ensemble.py # 앙상블 기법 구현 코드
│ ├── model.py # 모델 정의 및 학습 관련 코드
│ ├── preprocessing.py # 벡터 스토어 구축용 데이터 수집 및 전처리 코드
│ ├── retrieval_dense.py # Dense Retrieval 구현 코드
│ ├── retrieval_sparse.py # Sparse Retrieval 구현 코드
│ └── utils.py # 보조 함수 및 유틸리티 코드
├── streamlit/ # Streamlit 관련 웹 애플리케이션 디렉토리
├── main.py # 프로젝트 실행의 메인 스크립트
├── .gitignore
├── README.md
└── requirements.txt

```

<br>

### 💾 프로젝트 설치 및 실행

- OS: Ubuntu-20.04.6 LTS
- Python: 3.11 이상
- 필수 라이브러리: `requirements.txt` 참고
- **GPU**: NVIDIA V100 32GB

```bash
git clone https://github.com/boostcampaitech7/level2-nlp-generationfornlp-nlp-05-lv3.git
pip install -r requirements.txt

python src/retrieval_dense.py # 혹은 python src/retrieval_sparse.py
python main.py --config {config_path} --mode {train/test}
```
**config.yaml**
[github](https://github.com/boostcampaitech7/level2-nlp-generationfornlp-nlp-05-lv3/blob/main/config/config.yaml)에서 확인

<br>

### 💾 Demo 실행

```bash
cd streamlit
streamlit run home.py
```
Binary file added assets/KSAT.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file added assets/Private Leader Board.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/Public Leader Board.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/demo_ksat.mp4
Binary file not shown.
Binary file added assets/timeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added checkpoints/.gitkeep
Empty file.
61 changes: 61 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
model:
experiment_name: &experiment_name "Experiment_Sample" # Define common experiment name
train:
train_model_name: "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" # Model name for training
train_csv_path: "data/rag_results/train_rag_rerank3_v2_list.csv" # Path to train CSV file
train_checkpoint_path: "checkpoints/{experiment_name}" # Path to save training checkpoints
test:
test_checkpoint_path: "checkpoints/{experiment_name}/checkpoint-298" # Path to inference checkpoint
test_csv_path: "data/rag_results/test_rag_rerank3_v2_list.csv" # Path to test CSV file
test_output_csv_path: "data/outputs/{experiment_name}.csv" # Path for leaderboard submission CSV file

max_seq_length: 4096 # Maximum sequence length for the model
prompt_name: "BASE_PROMPT" # Name of the prompt template in the prompt file
rag: True # Enable retrieval-augmented generation
uniform_answer_distribution: True # Ensure uniform answer distribution
train_valid_split: True # If True, split train and validation datasets (0.9/0.1)

seed: 3407 # Seed for reproducibility

FastLanguageModel:
# model_name -> Set to 'train_model_name'
# max_seq_length -> Set to 'max_seq_length'
# dtype -> Hardcoded to None
# load_in_4bit -> Hardcoded to True

peft:
# model_name -> Set to 'train_model_name'
r: 64
lora_alpha: 32
lora_dropout: 0
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",]
bias: "none"
use_gradient_checkpointing: "unsloth"
# random_state -> Set to 'seed'
use_rslora: True
# loftq_config -> Hardcoded to None

UnslothTrainingArguments:
# do_train -> Hardcoded to True
# do_eval -> Automatically set based on 'train_valid_split'
per_device_train_batch_size: 2
per_device_eval_batch_size: 2
gradient_accumulation_steps: 8
warmup_ratio: 0.1
num_train_epochs: 2
learning_rate: 5e-5
embedding_learning_rate: 1e-6
# fp16 -> Hardcoded to not is_bfloat16_supported()
# bf16 -> Hardcoded to is_bfloat16_supported()
# logging_steps -> Hardcoded to 1
optim: "adamw_8bit"
weight_decay: 0.01
lr_scheduler_type: "linear"
# seed -> Set to 'seed'
# max_seq_length -> Set to 'max_seq_length'
# output_dir -> Set to 'train_checkpoint_path'
save_strategy: "epoch"
# eval_strategy: "no" -> Automatically set based on 'do_eval'
save_total_limit: 2
save_only_model: True
# report_to -> Hardcoded to 'wandb'
58 changes: 58 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import yaml
import argparse
import shutil
import pandas as pd

from src.model import MyModel
from src.dataset import MyDataset
from src.utils import set_seed, reset_token, update_paths


if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", type=str, default="config/config.yaml")
parser.add_argument("--mode", "-m", type=str, default="train")
args = parser.parse_args()

# Load YAML configuration file
with open(args.config) as f:
config = yaml.full_load(f)

# Update paths based on the experiment name
config = update_paths(config)

# Set random seed for reproducibility
set_seed(config["seed"])

# Initialize dataset and model
dataset = MyDataset(config["model"])
model = MyModel(config, args.mode)

base_path = "../contest_baseline_code"

if args.mode == "train":
# Training mode
checkpoint_dir = config["model"]["train"]["train_checkpoint_path"]
os.makedirs(checkpoint_dir, exist_ok=True)

# Save configuration file in the checkpoint directory
shutil.copy(args.config, os.path.join(checkpoint_dir, "config.yaml"))

# Process training data and train the model
train_df = pd.read_csv(os.path.join(base_path, config["model"]["train"]["train_csv_path"]))
processed_train = dataset.process(train_df, "train")
model.train(processed_train)

# Reset tokenizer token configurations
reset_token(config["model"]["experiment_name"])

elif args.mode == "test":
# Testing mode
test_df = pd.read_csv(os.path.join(base_path, config["model"]["test"]["test_csv_path"]))
processed_test = dataset.process(test_df, "test")

# Run inference and save results
model.inference(processed_test, output_dir=os.path.join(base_path, config["model"]["test"]["test_output_csv_path"]),
)
Loading