Skip to content

Commit

Permalink
Fix s3 image loading memory leak (#910)
Browse files Browse the repository at this point in the history
  • Loading branch information
InfiniteStash authored Jan 12, 2025
1 parent 7a97232 commit c9fc7ec
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 25 deletions.
1 change: 1 addition & 0 deletions pkg/api/routes_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func (rs imageRoutes) image(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
defer reader.Close()

if databaseImage.Width == -1 {
w.Header().Add("Content-Type", "image/svg+xml")
Expand Down
12 changes: 3 additions & 9 deletions pkg/image/file.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package image

import (
"bytes"
"io"
"os"

Expand All @@ -12,7 +11,7 @@ import (

type FileBackend struct{}

func (s *FileBackend) WriteFile(file *bytes.Reader, image *models.Image) error {
func (s *FileBackend) WriteFile(file []byte, image *models.Image) error {
if err := config.ValidateImageLocation(); err != nil {
return err
}
Expand All @@ -26,14 +25,9 @@ func (s *FileBackend) WriteFile(file *bytes.Reader, image *models.Image) error {
return nil
}

buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(file); err != nil {
return err
}

// write the file
path := GetImagePath(fileDir, image.Checksum)
if err := os.WriteFile(path, buf.Bytes(), os.FileMode(0644)); err != nil {
if err := os.WriteFile(path, file, os.FileMode(0644)); err != nil {
_ = os.Remove(path)
return err
}
Expand All @@ -45,7 +39,7 @@ func (s *FileBackend) DestroyFile(image *models.Image) error {
return os.Remove(GetImagePath(config.GetImageLocation(), image.Checksum))
}

func (s *FileBackend) ReadFile(image models.Image) (io.Reader, error) {
func (s *FileBackend) ReadFile(image models.Image) (io.ReadCloser, error) {
fileDir := config.GetImageLocation()
path := GetImagePath(fileDir, image.Checksum)
return os.Open(path)
Expand Down
5 changes: 2 additions & 3 deletions pkg/image/image_backend.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package image

import (
"bytes"
"io"

"github.com/stashapp/stash-box/pkg/models"
)

type Backend interface {
WriteFile(file *bytes.Reader, image *models.Image) error
WriteFile(file []byte, image *models.Image) error
DestroyFile(image *models.Image) error
ReadFile(image models.Image) (io.Reader, error)
ReadFile(image models.Image) (io.ReadCloser, error)
}
2 changes: 1 addition & 1 deletion pkg/image/image_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type BackendService interface {
Destroy(input models.ImageDestroyInput) error
DestroyUnusedImages() error
DestroyUnusedImage(imageID uuid.UUID) error
Read(image models.Image) (io.Reader, error)
Read(image models.Image) (io.ReadCloser, error)
}

func GetService(repo models.ImageRepo) BackendService {
Expand Down
10 changes: 3 additions & 7 deletions pkg/image/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

type S3Backend struct{}

func (s *S3Backend) WriteFile(file *bytes.Reader, image *models.Image) error {
func (s *S3Backend) WriteFile(file []byte, image *models.Image) error {
s3config := config.GetS3Config()
headers := s3config.UploadHeaders

Expand All @@ -27,11 +27,7 @@ func (s *S3Backend) WriteFile(file *bytes.Reader, image *models.Image) error {
return fmt.Errorf("creating minio client: %w", err)
}

buf := new(bytes.Buffer)
if _, err = buf.ReadFrom(file); err != nil {
return fmt.Errorf("reading from file: %w", err)
}
if err := uploadS3File(*minioClient, buf.Bytes(), s3config.Bucket, image.ID.String(), headers); err != nil {
if err := uploadS3File(*minioClient, file, s3config.Bucket, image.ID.String(), headers); err != nil {
return fmt.Errorf("uploading to s3: %w", err)
}

Expand Down Expand Up @@ -84,7 +80,7 @@ func uploadS3File(client minio.Client, file []byte, bucket string, id string, he
return err
}

func (s *S3Backend) ReadFile(image models.Image) (io.Reader, error) {
func (s *S3Backend) ReadFile(image models.Image) (io.ReadCloser, error) {
ctx := context.TODO()

s3config := config.GetS3Config()
Expand Down
7 changes: 2 additions & 5 deletions pkg/image/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,7 @@ func (s *Service) Create(input models.ImageCreateInput) (*models.Image, error) {
return nil, err
}

if _, err = fileReader.Seek(0, 0); err != nil {
return nil, err
}
if err := s.Backend.WriteFile(fileReader, &newImage); err != nil {
if err := s.Backend.WriteFile(file, &newImage); err != nil {
return nil, err
}
} else if input.URL != nil {
Expand Down Expand Up @@ -169,6 +166,6 @@ func (s *Service) DestroyUnusedImage(imageID uuid.UUID) error {
return nil
}

func (s *Service) Read(image models.Image) (io.Reader, error) {
func (s *Service) Read(image models.Image) (io.ReadCloser, error) {
return s.Backend.ReadFile(image)
}

0 comments on commit c9fc7ec

Please sign in to comment.