-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
314 lines (251 loc) · 12.1 KB
/
app.py
File metadata and controls
314 lines (251 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
from flask import Flask, request, jsonify
from flask_cors import CORS
import io
from PIL import Image
import torch
from watermark_anything.data.metrics import msg_predict_inference
import os, re
from torchvision import transforms
from datetime import datetime
import base64
import requests
import numpy as np, random
from notebooks.inference_utils import (
load_model_from_checkpoint,
create_random_mask,
unnormalize_img,
)
# PyTorch 스레드 수를 4로 강제 설정
torch.set_num_threads(4)
num_threads = torch.get_num_threads()
print(f"현재 PyTorch 기본 스레드 수: {num_threads}")
cpu_count = os.cpu_count()
print(f"CPU 코어 수: {cpu_count}")
# 정규화 파라미터 (ImageNet 기준)
image_mean = torch.tensor([0.485, 0.456, 0.406])
image_std = torch.tensor([0.229, 0.224, 0.225])
# 원본 크기 유지 + 정규화만 적용하는 transform
default_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=image_mean, std=image_std),
])
# SEED 고정
SEED = 42
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
np.random.seed(SEED)
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
# base64 변환 함수
def pil_to_base64(pil_img, fmt="PNG") -> str:
buf = io.BytesIO()
pil_img.save(buf, format=fmt)
buf.seek(0)
return base64.b64encode(buf.getvalue()).decode("utf-8")
# 이미지 전송 진행률 함수
spring_ip = os.getenv('SPRING_SERVER_IP')
SPRING_SERVER_URL = f'http://{spring_ip}:8080/progress'
def send_progress_to_spring(task_id, percent, login_id):
try:
payload = {
'taskId': task_id,
'progress': percent,
'loginId': login_id
}
headers = {
'Content-Type': 'application/json'
}
print(f"Flask에서 Spring으로 POST 요청 보내는 중: {payload}", flush=True)
requests.post(SPRING_SERVER_URL, json=payload, headers=headers, timeout=1)
except Exception as e:
print(f"[WARN] 진행률 전송 실패: {e}")
class ProgressSender:
def __init__(self, task_id, login_id):
self.task_id = task_id
self.login_id = login_id
def send(self, percent):
send_progress_to_spring(self.task_id, percent, self.login_id)
# 파일명(한글 포함)
def safe_filename(filename: str) -> str:
# 확장자 분리
name, ext = os.path.splitext(filename)
# 한글, 영문, 숫자, 일부 특수문자만 허용 → 나머지는 제거
name = re.sub(r'[^가-힣a-zA-Z0-9_\- ]', '', name)
# 공백을 _ 로 변환
name = name.replace(" ", "_")
return name + ext
# 비트를 문자열로 반환
def bits_to_message(bit_tensor, num_chars=4):
"""
32개의 예측 비트 텐서를 4글자 문자열로 변환합니다.
"""
# 텐서를 CPU로 이동하고 리스트로 변환 (0 또는 1)
bits = bit_tensor[0].cpu().tolist()[:32]
# 8비트씩 묶어 문자로 변환
message_chars = []
for i in range(0, len(bits), 8):
byte_bits = bits[i:i+8]
# 이진수 문자열로 변환
bit_string = "".join(map(str, [int(round(b)) for b in byte_bits]))
try:
# 이진수 문자열을 정수로, 다시 문자로 변환
char_code = int(bit_string, 2)
char = chr(char_code)
message_chars.append(char)
except ValueError:
# 변환 오류 시 안전을 위해 처리 (ex: 유효하지 않은 ASCII 코드)
message_chars.append('?')
# \x00 (null) 문자는 제거하여 원래 메시지만 반환
return "".join(message_chars).replace('\x00', '')
app = Flask(__name__)
CORS(app, origins="*")
# 모델 준비 (서버 시작 시 1회만)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = "checkpoints/wam_mit.pth"
json_path = "checkpoints/params.json"
wam = load_model_from_checkpoint(json_path, ckpt_path).to(device).eval()
# =========================================================
# 💡 디버그 코드 추가 지점: 가중치 로드 성공 여부 확인
# =========================================================
try:
# 모델의 첫 번째 컨볼루션 레이어의 가중치 값을 출력합니다.
# 모델 구조에 맞게 'encoder.conv1.weight'를 수정해야 할 수 있습니다.
first_layer_weights = wam.state_dict()['embedder.encoder.conv_in.weight']
print(f"[MODEL DEBUG] WAM Layer Shape: {first_layer_weights.shape}")
print(f"[MODEL DEBUG] WAM First 5 Weights: {first_layer_weights.flatten()[:5].tolist()}", flush=True)
except KeyError:
print("[MODEL DEBUG] WARNING: Cannot find 'embedder.encoder.conv_in.weight'. Check layer name.", flush=True)
except Exception as e:
# 이 로그가 찍힌다면, 가중치 로드 자체가 실패했을 가능성이 매우 높습니다.
print(f"[MODEL DEBUG] CRITICAL ERROR: Failed to read WAM state_dict: {e}", flush=True)
@app.route('/', methods=['GET'])
def home():
return "서버 구동 완료~"
# 워터마크 삽입
@app.route('/watermark-insert', methods=['POST'])
def watermarkInsert():
task_id = request.form.get('taskId') # taskId 받아오기
login_id = request.form.get('loginId') # loginId 받아오기
send_progress = ProgressSender(task_id, login_id)
# 1. 이미지와 메시지 받기
image_file = request.files.get('image')
message = request.form.get('message', 'ETNL')
assert len(message) <= 4, "메시지는 4자 이하만 가능"
if not image_file or not message:
return jsonify({"error": "image, message 둘 다 필요합니다."}), 400
# 작업 진행 상태 초기화
send_progress.send(0)
# 2. 이미지 로드 및 전처리
image = Image.open(image_file.stream).convert("RGB")
img_pt = default_transform(image).unsqueeze(0).to(device)
# 3. 메시지 전처리
wm_bits = ''.join(f"{ord(c):08b}" for c in message)
wm_bits = wm_bits.ljust(32, '0')[:32]
wm_msg = torch.tensor([[int(bit) for bit in wm_bits]], dtype=torch.float32).to(device)
# 진행 상태 25%로 업데이트
send_progress.send(25)
# 3. 워터마크 삽입
outputs = wam.embed(img_pt, wm_msg)
mask = create_random_mask(img_pt, num_masks=1, mask_percentage=0.5)
img_w = outputs['imgs_w'] * mask + img_pt * (1 - mask)
# 진행 상태 50%로 업데이트
send_progress.send(50)
# 4. 이미지 후처리
out_img = unnormalize_img(img_w).squeeze(0).detach().clamp_(0, 1) # 1. 정규화 해제 + 값 범위 제한 (0~1)
out_img_np = out_img.permute(1, 2, 0).cpu().numpy() # 2. CPU로 이동 후 numpy 변환 (HWC 형태)
out_img_np = (out_img_np * 255).round().astype('uint8') # 3. 0~255 범위로 변환 (소수점 처리 개선)
out_img_pil = Image.fromarray(out_img_np) # 4. PIL 이미지 생성
# 진행 상태 75%로 업데이트
send_progress.send(75)
# 파일명 처리
original_name = os.path.splitext(safe_filename(image_file.filename))[0] # example.jpg → example
ext = os.path.splitext(safe_filename(image_file.filename))[1] # 확장자 (jpg, png 등)
watermarked_name = f"{original_name}_deeptruth_watermark{ext}" # 파일명 (확장자 포함)
# 이미지 전송 완료
send_progress.send(100)
response = jsonify({
'image_base64': pil_to_base64(out_img_pil), # 삽입 이미지
'message': message, # 워터마크 메세지
'filename': watermarked_name, # 다운로드 시 사용 될 파일 이름
'taskId': task_id
})
return response
# 워터마크 탐지
@app.route('/watermark-detection', methods=['POST'])
def watermarkDetection():
try:
task_id = request.form.get('taskId') # taskId 받아오기
login_id = request.form.get('loginId') # loginId 받아오기
send_progress = ProgressSender(task_id, login_id)
# 1. 이미지 수신 및 기본 정보 추출
image_file = request.files.get('image')
message = request.form.get('message', '') # 삽입 당시 메시지 (db에서 가져오는 값)
if not image_file or not message:
return jsonify({"error": "image, message 둘 다 필요합니다."}), 400
# 작업 진행 상태 초기화
send_progress.send(0)
# 2. 이미지 전처리
image = Image.open(image_file.stream).convert("RGB")
img_pt = default_transform(image).unsqueeze(0).to(device)
# 진행 상태 25%로 업데이트
send_progress.send(25)
# 3. 워터마크 탐지 (모델 추론)
with torch.no_grad():
detect_outputs = wam.detect(img_pt)
preds = detect_outputs['preds'] # shape: [B, 1+nbits, H, W]
mask_preds = preds[:, 0:1, :, :] # 예측된 마스크
bit_preds = preds[:, 1:, :, :] # 예측된 메시지 비트
# 4. 예측된 비트로부터 메시지 추출
pred_message = msg_predict_inference(bit_preds, mask_preds)
pred_message_float = pred_message.float() # float32로
# 📌 예측된 메시지를 문자열로 변환
predicted_message_str = bits_to_message(pred_message)
# 📌 [ACCURACY DEBUG] 예측 메시지 로그
print(f"[ACCURACY DEBUG] 4. 예측 메시지 (pred_message) shape: {pred_message.shape}, device: {pred_message.device}")
print(f"[ACCURACY DEBUG] 예측 비트(첫 8개): {pred_message[0, :8].tolist()}")
print(f"[PREDICTION RESULT] 원본 메시지: {message}")
print(f"[PREDICTION RESULT] 예측된 메시지: '{predicted_message_str}'", flush=True) # 로그 추가
# 진행 상태 50%로 업데이트
send_progress.send(50)
# 5. 원본 메시지 텐서 변환
wm_bits = ''.join(f"{ord(c):08b}" for c in message.ljust(4, '\x00'))[:32]
wm_tensor = torch.tensor([int(b) for b in wm_bits], dtype=torch.float32).to(device)
# 📌 [ACCURACY DEBUG] 원본 메시지 로그
print(f"[ACCURACY DEBUG] 5. 원본 메시지: '{message}' -> 비트 문자열 길이: {len(wm_bits)}")
print(f"[ACCURACY DEBUG] 원본 비트(wm_tensor) shape: {wm_tensor.shape}, device: {wm_tensor.device}")
print(f"[ACCURACY DEBUG] 원본 비트(첫 8개): {wm_tensor[:8].tolist()}", flush=True)
# comparison_tensor = (pred_message == wm_tensor.unsqueeze(0)).float()
comparison_tensor = (pred_message_float == wm_tensor.unsqueeze(0)).float()
# 📌 [ACCURACY DEBUG] 비교 로그
num_correct_bits = comparison_tensor.sum().item()
print(f"[ACCURACY DEBUG] 일치하는 비트 수: {num_correct_bits} / 32", flush=True)
# 6. 비트 정확도 계산
# bit_acc = (pred_message == wm_tensor.unsqueeze(0)).float().mean().item()
bit_acc = (pred_message_float == wm_tensor.unsqueeze(0)).float().mean().item()
bit_acc_pct = round(bit_acc * 100, 1)
# 📌 [ACCURACY DEBUG] 최종 정확도 로그
print(f"[ACCURACY DEBUG] 최종 비트 정확도 (bit_acc): {bit_acc_pct}%", flush=True)
# 진행 상태 75%로 업데이트
send_progress.send(75)
# 10. 응답
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
original_name = os.path.splitext(safe_filename(image_file.filename))[0] # example.jpg → example
ext = os.path.splitext(safe_filename(image_file.filename))[1]
base_name = f"{original_name}{ext}"
# 이미지 전송 완료
send_progress.send(100)
# 기본 결과값 (정확도 90이상 시)
result = {
"basename": base_name,
"bit_accuracy": bit_acc_pct,
"detected_at": timestamp,
'taskId': task_id
}
# 정확도 < 90이면 삽입 이미지 포함
if result['bit_accuracy'] < 90:
result['image_base64'] = pil_to_base64(image) # 삽입 이미지
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)