diff --git a/internal/recognition/classifier.go b/internal/recognition/classifier.go index e448bf6..69ad659 100644 --- a/internal/recognition/classifier.go +++ b/internal/recognition/classifier.go @@ -395,9 +395,11 @@ func (c *HuggingFaceClassifier) ClassifyFromBytes(imageData []byte) (*Classifica startTime := time.Now() log.Infof("Starting image classification using HuggingFace model %s", c.Model) - // Encode image to base64 - encoded := base64.StdEncoding.EncodeToString(imageData) - log.Debugf("Encoded image size: %d bytes (base64: %d chars)", len(imageData), len(encoded)) + log.Debugf("Image size: %d bytes", len(imageData)) + + // Detect image format for correct Content-Type header + contentType := detectImageFormat(imageData) + log.Debugf("Detected image format: %s", contentType) // Build the API URL - HuggingFace Serverless Inference API // Note: Some models may not be available on the free Serverless tier @@ -410,7 +412,7 @@ func (c *HuggingFaceClassifier) ClassifyFromBytes(imageData []byte) (*Classifica } req.Header.Set("Authorization", "Bearer "+c.APIKey) - req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Type", contentType) // Make API request log.Infof("Sending classification request to HuggingFace API...") @@ -566,13 +568,16 @@ func (c *HuggingFaceClassifier) detectNSFWHF(imageData []byte) (float64, string, nsfwModel := "Falconsai/nsfw_image_detection" apiURL := fmt.Sprintf("https://api-inference.huggingface.co/models/%s", nsfwModel) + // Detect image format for correct Content-Type header + contentType := detectImageFormat(imageData) + req, err := http.NewRequest("POST", apiURL, bytes.NewReader(imageData)) if err != nil { return 0, "", fmt.Errorf("failed to create NSFW request: %w", err) } req.Header.Set("Authorization", "Bearer "+c.APIKey) - req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Type", contentType) client := &http.Client{Timeout: c.Timeout} resp, err := client.Do(req) @@ -627,6 +632,33 @@ func (c *HuggingFaceClassifier) detectNSFWHF(imageData []byte) (float64, string, // Helper functions +// detectImageFormat detects the image format from the binary data +func detectImageFormat(data []byte) string { + if len(data) < 12 { + return "image/jpeg" // Default fallback + } + + // Check magic bytes for different image formats + switch { + case data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF: + return "image/jpeg" + case data[0] == 0x89 && data[1] == 0x50 && data[2] == 0x4E && data[3] == 0x47: + return "image/png" + case data[0] == 0x47 && data[1] == 0x49 && data[2] == 0x46: + return "image/gif" + case data[0] == 0x52 && data[1] == 0x49 && data[2] == 0x46 && data[3] == 0x46 && + data[8] == 0x57 && data[9] == 0x45 && data[10] == 0x42 && data[11] == 0x50: + return "image/webp" + case data[0] == 0x42 && data[1] == 0x4D: + return "image/bmp" + case (data[0] == 0x49 && data[1] == 0x49 && data[2] == 0x2A && data[3] == 0x00) || + (data[0] == 0x4D && data[1] == 0x4D && data[2] == 0x00 && data[3] == 0x2A): + return "image/tiff" + default: + return "image/jpeg" // Default fallback + } +} + func isLikelyLabel(word string) bool { // Common objects, simple heuristic commonWords := map[string]bool{