Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions internal/recognition/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,11 @@ func (c *HuggingFaceClassifier) ClassifyFromBytes(imageData []byte) (*Classifica
startTime := time.Now()
log.Infof("Starting image classification using HuggingFace model %s", c.Model)

// Encode image to base64
encoded := base64.StdEncoding.EncodeToString(imageData)
log.Debugf("Encoded image size: %d bytes (base64: %d chars)", len(imageData), len(encoded))
log.Debugf("Image size: %d bytes", len(imageData))

// Detect image format for correct Content-Type header
contentType := detectImageFormat(imageData)
log.Debugf("Detected image format: %s", contentType)

// Build the API URL - HuggingFace Serverless Inference API
// Note: Some models may not be available on the free Serverless tier
Expand All @@ -410,7 +412,7 @@ func (c *HuggingFaceClassifier) ClassifyFromBytes(imageData []byte) (*Classifica
}

req.Header.Set("Authorization", "Bearer "+c.APIKey)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("Content-Type", contentType)

// Make API request
log.Infof("Sending classification request to HuggingFace API...")
Expand Down Expand Up @@ -566,13 +568,16 @@ func (c *HuggingFaceClassifier) detectNSFWHF(imageData []byte) (float64, string,
nsfwModel := "Falconsai/nsfw_image_detection"
apiURL := fmt.Sprintf("https://api-inference.huggingface.co/models/%s", nsfwModel)

// Detect image format for correct Content-Type header
contentType := detectImageFormat(imageData)

req, err := http.NewRequest("POST", apiURL, bytes.NewReader(imageData))
if err != nil {
return 0, "", fmt.Errorf("failed to create NSFW request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+c.APIKey)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("Content-Type", contentType)

client := &http.Client{Timeout: c.Timeout}
resp, err := client.Do(req)
Expand Down Expand Up @@ -627,6 +632,33 @@ func (c *HuggingFaceClassifier) detectNSFWHF(imageData []byte) (float64, string,

// Helper functions

// detectImageFormat detects the image format from the binary data
func detectImageFormat(data []byte) string {
if len(data) < 12 {
return "image/jpeg" // Default fallback
}

// Check magic bytes for different image formats
switch {
case data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF:
return "image/jpeg"
case data[0] == 0x89 && data[1] == 0x50 && data[2] == 0x4E && data[3] == 0x47:
return "image/png"
case data[0] == 0x47 && data[1] == 0x49 && data[2] == 0x46:
return "image/gif"
case data[0] == 0x52 && data[1] == 0x49 && data[2] == 0x46 && data[3] == 0x46 &&
data[8] == 0x57 && data[9] == 0x45 && data[10] == 0x42 && data[11] == 0x50:
return "image/webp"
case data[0] == 0x42 && data[1] == 0x4D:
return "image/bmp"
case (data[0] == 0x49 && data[1] == 0x49 && data[2] == 0x2A && data[3] == 0x00) ||
(data[0] == 0x4D && data[1] == 0x4D && data[2] == 0x00 && data[3] == 0x2A):
return "image/tiff"
default:
return "image/jpeg" // Default fallback
}
}

func isLikelyLabel(word string) bool {
// Common objects, simple heuristic
commonWords := map[string]bool{
Expand Down
Loading