diff --git a/pkg/files/files.go b/pkg/files/files.go index 383c2ccd4..9b777f2dd 100644 --- a/pkg/files/files.go +++ b/pkg/files/files.go @@ -85,12 +85,12 @@ func (f *File) LoadFileMetaByID() (err error) { } // Create creates a new file from an FileHeader -func Create(f io.Reader, realname string, realsize uint64, a web.Auth) (file *File, err error) { +func Create(f io.ReadSeeker, realname string, realsize uint64, a web.Auth) (file *File, err error) { return CreateWithMime(f, realname, realsize, a, "") } // CreateWithMime creates a new file from an FileHeader and sets its mime type -func CreateWithMime(f io.Reader, realname string, realsize uint64, a web.Auth, mime string) (file *File, err error) { +func CreateWithMime(f io.ReadSeeker, realname string, realsize uint64, a web.Auth, mime string) (file *File, err error) { s := db.NewSession() defer s.Close() @@ -102,7 +102,7 @@ func CreateWithMime(f io.Reader, realname string, realsize uint64, a web.Auth, m return } -func CreateWithMimeAndSession(s *xorm.Session, f io.Reader, realname string, realsize uint64, a web.Auth, mime string, checkFileSizeLimit bool) (file *File, err error) { +func CreateWithMimeAndSession(s *xorm.Session, f io.ReadSeeker, realname string, realsize uint64, a web.Auth, mime string, checkFileSizeLimit bool) (file *File, err error) { if realsize > config.GetMaxFileSizeInMBytes()*uint64(datasize.MB) && checkFileSizeLimit { return nil, ErrFileIsTooLarge{Size: realsize} } @@ -154,23 +154,25 @@ func (f *File) Delete(s *xorm.Session) (err error) { } // writeToStorage writes content to the given path, handling both local and S3 backends -func writeToStorage(path string, content io.Reader, size uint64) error { +func writeToStorage(path string, content io.ReadSeeker, size uint64) error { if s3Client == nil { return afs.WriteReader(path, content) } - body, contentLength, cleanup, err := prepareS3UploadBody(content, size) + contentLength, err := contentLengthFromReadSeeker(content, size) if err != nil { - return err + return fmt.Errorf("failed to determine S3 upload content length: %w", err) } - if cleanup != nil { - defer cleanup() + + _, err = content.Seek(0, io.SeekStart) + if err != nil { + return fmt.Errorf("failed to seek S3 upload body to start: %w", err) } _, err = s3Client.PutObject(context.Background(), &s3.PutObjectInput{ Bucket: aws.String(s3Bucket), Key: aws.String(path), - Body: body, + Body: content, ContentLength: aws.Int64(contentLength), }) if err != nil { @@ -180,7 +182,7 @@ func writeToStorage(path string, content io.Reader, size uint64) error { } // Save saves a file to storage -func (f *File) Save(fcontent io.Reader) error { +func (f *File) Save(fcontent io.ReadSeeker) error { err := writeToStorage(f.getAbsoluteFilePath(), fcontent, f.Size) if err != nil { return fmt.Errorf("failed to save file: %w", err) @@ -188,54 +190,6 @@ func (f *File) Save(fcontent io.Reader) error { return keyvalue.IncrBy(metrics.FilesCountKey, 1) } -func prepareS3UploadBody(fcontent io.Reader, expectedSize uint64) (body io.ReadSeeker, contentLength int64, cleanup func(), err error) { - if seeker, ok := fcontent.(io.ReadSeeker); ok { - contentLength, err = contentLengthFromReadSeeker(seeker, expectedSize) - if err != nil { - return nil, 0, nil, fmt.Errorf("failed to determine S3 upload content length: %w", err) - } - - _, err = seeker.Seek(0, io.SeekStart) - if err != nil { - return nil, 0, nil, fmt.Errorf("failed to seek S3 upload body to start: %w", err) - } - - return seeker, contentLength, nil, nil - } - - tempFile, tempPath, err := createS3TempFile() - if err != nil { - return nil, 0, nil, fmt.Errorf("failed to create temp file for S3 upload: %w", err) - } - - cleanup = func() { - _ = tempFile.Close() - _ = os.Remove(tempPath) - } - - written, err := io.Copy(tempFile, fcontent) - if err != nil { - cleanup() - return nil, 0, nil, fmt.Errorf("failed to buffer S3 upload to temp file: %w", err) - } - - if expectedSize > 0 { - if expectedSize > uint64(math.MaxInt64) { - log.Warningf("File size mismatch for S3 upload: expected size %d bytes does not fit into int64", expectedSize) - } else if written != int64(expectedSize) { - log.Warningf("File size mismatch for S3 upload: expected %d bytes but buffered %d bytes", expectedSize, written) - } - } - - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - cleanup() - return nil, 0, nil, fmt.Errorf("failed to seek temp file for S3 upload: %w", err) - } - - return tempFile, written, cleanup, nil -} - func contentLengthFromReadSeeker(seeker io.ReadSeeker, expectedSize uint64) (int64, error) { currentOffset, err := seeker.Seek(0, io.SeekCurrent) if err != nil { @@ -258,30 +212,3 @@ func contentLengthFromReadSeeker(seeker io.ReadSeeker, expectedSize uint64) (int return endOffset, nil } - -func createS3TempFile() (file *os.File, path string, err error) { - dir := config.FilesS3TempDir.GetString() - - tryCreate := func(tempDir string) (*os.File, error) { - return os.CreateTemp(tempDir, "vikunja-s3-upload-*") - } - - if dir != "" { - file, err = tryCreate(dir) - if err == nil { - return file, file.Name(), nil - } - } - - file, err = tryCreate("") - if err == nil { - return file, file.Name(), nil - } - - file, err = tryCreate(".") - if err != nil { - return nil, "", err - } - - return file, file.Name(), nil -}