package downloader import ( "context" "crypto/tls" "fmt" "io" "net/http" "os" "path/filepath" "yizhisec.com/hsv2/forge/pkg/logger" ) // Options defines options for downloading files type Options struct { // InsecureSkipVerify skips TLS certificate verification InsecureSkipVerify bool // HTTPClient allows providing a custom HTTP client HTTPClient *http.Client // OnProgress is called during download with bytes downloaded and total size OnProgress func(downloaded, total int64) // CreateDirs automatically creates parent directories if they don't exist CreateDirs bool // Overwrite allows overwriting existing files Overwrite bool FileMode os.FileMode } // Option is a functional option for configuring the downloader type Option func(*Options) // WithInsecureSkipVerify skips TLS certificate verification func WithInsecureSkipVerify() Option { return func(o *Options) { o.InsecureSkipVerify = true } } // WithHTTPClient sets a custom HTTP client func WithHTTPClient(client *http.Client) Option { return func(o *Options) { o.HTTPClient = client } } // WithProgress sets a progress callback func WithProgress(callback func(downloaded, total int64)) Option { return func(o *Options) { o.OnProgress = callback } } // WithoutCreateDirs disables automatic creation of parent directories func WithoutCreateDirs() Option { return func(o *Options) { o.CreateDirs = false } } // WithoutOverwrite prevents overwriting existing files func WithoutOverwrite() Option { return func(o *Options) { o.Overwrite = false } } func WithFileMode(mode os.FileMode) Option { return func(o *Options) { o.FileMode = mode } } // Download downloads a file from URL to the specified destination func Download(ctx context.Context, url, dest string, opts ...Option) error { options := &Options{ InsecureSkipVerify: false, CreateDirs: true, Overwrite: true, FileMode: 0644, } for _, opt := range opts { opt(options) } logger.Debug("开始下载文件: %s -> %s", url, dest) // Check if file exists and overwrite is disabled if !options.Overwrite { if _, err := os.Stat(dest); err == nil { logger.Debug("文件已存在且不允许覆盖: %s", dest) return fmt.Errorf("file already exists: %s", dest) } } // Create parent directories if needed if options.CreateDirs { dir := filepath.Dir(dest) if err := os.MkdirAll(dir, 0755); err != nil { logger.Debug("创建目录失败 %s: %v", dir, err) return fmt.Errorf("failed to create directory: %w", err) } } // Create HTTP client client := options.HTTPClient if client == nil { client = &http.Client{} if options.InsecureSkipVerify { logger.Debug("TLS 证书验证已禁用") client.Transport = &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } } } // Create HTTP request req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { logger.Debug("创建请求失败: %v", err) return fmt.Errorf("failed to create request: %w", err) } // Execute request logger.Debug("发起 HTTP 请求: %s", url) resp, err := client.Do(req) if err != nil { logger.Debug("下载失败: %v", err) return fmt.Errorf("failed to download: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.Debug("HTTP 状态码异常: %d", resp.StatusCode) return fmt.Errorf("bad status: %s", resp.Status) } // Get content length for progress reporting contentLength := resp.ContentLength logger.Debug("文件大小: %d bytes", contentLength) outFile, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, options.FileMode) if err != nil { logger.Debug("创建文件失败 %s: %v", dest, err) return fmt.Errorf("failed to create file: %w", err) } defer outFile.Close() // Copy content with optional progress reporting var written int64 if options.OnProgress != nil && contentLength > 0 { // Use progress reader reader := &progressReader{ reader: resp.Body, callback: options.OnProgress, total: contentLength, } written, err = io.Copy(outFile, reader) } else { written, err = io.Copy(outFile, resp.Body) } if err != nil { logger.Debug("写入文件失败 %s: %v", dest, err) return fmt.Errorf("failed to write file: %w", err) } if options.FileMode != 0 { if err := os.Chmod(dest, options.FileMode); err != nil { logger.Debug("设置文件权限失败 %s: %v", dest, err) return fmt.Errorf("failed to set file mode: %w", err) } } logger.Debug("文件下载成功: %s (%d bytes)", dest, written) return nil } // progressReader wraps an io.Reader to report progress type progressReader struct { reader io.Reader callback func(downloaded, total int64) total int64 downloaded int64 } func (pr *progressReader) Read(p []byte) (int, error) { n, err := pr.reader.Read(p) pr.downloaded += int64(n) if pr.callback != nil { pr.callback(pr.downloaded, pr.total) } return n, err }