package file_manager import ( "bytes" "context" "crypto/md5" "crypto/rand" "encoding/hex" "fmt" "io" "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" ) const ( s3MetadataKeyFilename = "x-amz-meta-filename" s3MetadataKeySize = "x-amz-meta-size" s3MetadataKeySha256 = "x-amz-meta-sha256" s3MetadataKeyCreateTime = "x-amz-meta-create-time" s3MetadataKeyComplete = "x-amz-meta-complete" s3MetadataKeyCode = "x-amz-meta-code" ) type S3FileManager struct { client *s3.Client bucket string timeout time.Duration expire time.Duration ctx context.Context cancel context.CancelFunc mu sync.RWMutex } func NewS3FileManager(opts *option) *S3FileManager { cfg := opts.s3 awsConfig := []func(*config.LoadOptions) error{} if cfg.region != "" { awsConfig = append(awsConfig, config.WithRegion(cfg.region)) } if cfg.endpoint != "" { awsConfig = append(awsConfig, config.WithEndpointResolverWithOptions( aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { return aws.Endpoint{ PartitionID: "aws", URL: cfg.endpoint, SigningRegion: region, }, nil }), )) } if cfg.accessKey != "" && cfg.secretKey != "" { awsConfig = append(awsConfig, config.WithCredentialsProvider( credentials.NewStaticCredentialsProvider(cfg.accessKey, cfg.secretKey, ""), )) } loadedConfig, err := config.LoadDefaultConfig(context.Background(), awsConfig...) if err != nil { panic(fmt.Sprintf("failed to load aws config: %v", err)) } client := s3.NewFromConfig(loadedConfig, func(o *s3.Options) { o.UsePathStyle = cfg.usePathStyle }) ctx, cancel := context.WithCancel(context.Background()) return &S3FileManager{ client: client, bucket: cfg.bucket, timeout: opts.timeout, expire: opts.expire, ctx: ctx, cancel: cancel, } } func (s *S3FileManager) startCleaner() { ticker := time.NewTicker(time.Minute) defer ticker.Stop() for { select { case <-s.ctx.Done(): return case <-ticker.C: s.cleanup() } } } func (s *S3FileManager) cleanup() { now := time.Now() paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{ Bucket: aws.String(s.bucket), }) for paginator.HasMorePages() { output, err := paginator.NextPage(context.Background()) if err != nil { return } for _, obj := range output.Contents { code := aws.ToString(obj.Key) headOutput, err := s.client.HeadObject(context.Background(), &s3.HeadObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), }) if err != nil { continue } createTimeStr := "" if headOutput.Metadata != nil { if val, ok := headOutput.Metadata[s3MetadataKeyCreateTime]; ok { createTimeStr = val } } if createTimeStr == "" { continue } createTime, err := time.Parse(time.RFC3339, createTimeStr) if err != nil { continue } completeStr := "" if headOutput.Metadata != nil { if val, ok := headOutput.Metadata[s3MetadataKeyComplete]; ok { completeStr = val } } complete := completeStr == "true" var expireTime time.Duration if complete { expireTime = s.expire } else { expireTime = s.timeout } if now.Sub(createTime) > expireTime { s.deleteObject(code) } } } } func (s *S3FileManager) deleteObject(code string) { s.client.DeleteObject(context.Background(), &s3.DeleteObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), }) } func (s *S3FileManager) generateRandomCode(length int) (string, error) { b := make([]byte, length/2) if _, err := rand.Read(b); err != nil { return "", err } return hex.EncodeToString(b), nil } func (s *S3FileManager) CloseManager() { s.cancel() } func (s *S3FileManager) Create(ctx context.Context, filename string, size int64, sha256 string) (*CreateResult, error) { code, err := s.generateRandomCode(16) if err != nil { return nil, err } _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), Metadata: map[string]string{ s3MetadataKeyFilename: filename, s3MetadataKeySize: fmt.Sprintf("%d", size), s3MetadataKeySha256: sha256, s3MetadataKeyCreateTime: time.Now().Format(time.RFC3339), s3MetadataKeyComplete: "false", s3MetadataKeyCode: code, }, }) if err != nil { return nil, err } return &CreateResult{ Code: code, SHA256: sha256, }, nil } func (s *S3FileManager) Upload(ctx context.Context, code string, start int64, end int64, reader io.Reader) (int64, int64, error) { headOutput, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), }) if err != nil { return 0, 0, fmt.Errorf("file not found: %w", err) } sizeStr := "" if headOutput.Metadata != nil { if val, ok := headOutput.Metadata[s3MetadataKeySize]; ok { sizeStr = val } } var expectedSize int64 fmt.Sscanf(sizeStr, "%d", &expectedSize) data, err := io.ReadAll(reader) if err != nil { return 0, 0, err } sha256Str := "" if headOutput.Metadata != nil { if val, ok := headOutput.Metadata[s3MetadataKeySha256]; ok { sha256Str = val } } if sha256Str != "" && int64(len(data)) == expectedSize { calculatedHash := md5Hash(data) if calculatedHash != sha256Str { s.deleteObject(code) return 0, 0, NewHashMismatchError(sha256Str, calculatedHash) } } metadata := map[string]string{ s3MetadataKeyFilename: "", s3MetadataKeySize: "", s3MetadataKeySha256: "", s3MetadataKeyCreateTime: "", s3MetadataKeyComplete: "true", s3MetadataKeyCode: "", } if headOutput.Metadata != nil { if val, ok := headOutput.Metadata[s3MetadataKeyFilename]; ok { metadata[s3MetadataKeyFilename] = val } if val, ok := headOutput.Metadata[s3MetadataKeySize]; ok { metadata[s3MetadataKeySize] = val } if val, ok := headOutput.Metadata[s3MetadataKeySha256]; ok { metadata[s3MetadataKeySha256] = val } if val, ok := headOutput.Metadata[s3MetadataKeyCreateTime]; ok { metadata[s3MetadataKeyCreateTime] = val } if val, ok := headOutput.Metadata[s3MetadataKeyCode]; ok { metadata[s3MetadataKeyCode] = val } } if sha256Str == "" && int64(len(data)) == expectedSize { metadata[s3MetadataKeySha256] = md5Hash(data) } _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), Body: bytes.NewReader(data), Metadata: metadata, }) if err != nil { return 0, 0, err } return int64(len(data)), int64(len(data)), nil } func md5Hash(data []byte) string { h := md5.New() h.Write(data) return hex.EncodeToString(h.Sum(nil)) } func (s *S3FileManager) Get(ctx context.Context, code string) ([]byte, error) { output, err := s.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), }) if err != nil { return nil, fmt.Errorf("file not found: %w", err) } defer output.Body.Close() sha256Str := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeySha256]; ok { sha256Str = val } } completeStr := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeyComplete]; ok { completeStr = val } } if completeStr != "true" { sizeStr := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeySize]; ok { sizeStr = val } } var expectedSize int64 fmt.Sscanf(sizeStr, "%d", &expectedSize) currentSize := aws.ToInt64(output.ContentLength) if currentSize != expectedSize { return nil, NewFileNotReadyError(code, currentSize, expectedSize) } } if sha256Str != "" { data, err := io.ReadAll(output.Body) if err != nil { return nil, err } calculatedHash := md5Hash(data) if calculatedHash != sha256Str { s.deleteObject(code) return nil, NewHashMismatchError(sha256Str, calculatedHash) } return data, nil } return io.ReadAll(output.Body) } func (s *S3FileManager) GetInfo(ctx context.Context, code string) (*FileInfo, error) { output, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), }) if err != nil { return nil, fmt.Errorf("file not found: %w", err) } sizeStr := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeySize]; ok { sizeStr = val } } var size int64 fmt.Sscanf(sizeStr, "%d", &size) createTimeStr := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeyCreateTime]; ok { createTimeStr = val } } var createTime time.Time if createTimeStr != "" { createTime, _ = time.Parse(time.RFC3339, createTimeStr) } completeStr := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeyComplete]; ok { completeStr = val } } filename := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeyFilename]; ok { filename = val } } sha256Val := "" if output.Metadata != nil { if val, ok := output.Metadata[s3MetadataKeySha256]; ok { sha256Val = val } } return &FileInfo{ Filename: filename, Size: size, SHA256: sha256Val, Path: code, CreateTime: createTime, Complete: completeStr == "true", }, nil } func (s *S3FileManager) Delete(ctx context.Context, code string) error { _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(code), }) return err } func (s *S3FileManager) Close(code string) error { return nil }