196 lines
4.9 KiB
Go
196 lines
4.9 KiB
Go
package downloader
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"gitea.loveuer.com/yizhisec/pkg3/logger"
|
|
)
|
|
|
|
// Options defines options for downloading files
|
|
type Options struct {
|
|
// InsecureSkipVerify skips TLS certificate verification
|
|
InsecureSkipVerify bool
|
|
// HTTPClient allows providing a custom HTTP client
|
|
HTTPClient *http.Client
|
|
// OnProgress is called during download with bytes downloaded and total size
|
|
OnProgress func(downloaded, total int64)
|
|
// CreateDirs automatically creates parent directories if they don't exist
|
|
CreateDirs bool
|
|
// Overwrite allows overwriting existing files
|
|
Overwrite bool
|
|
FileMode os.FileMode
|
|
}
|
|
|
|
// Option is a functional option for configuring the downloader
|
|
type Option func(*Options)
|
|
|
|
// WithInsecureSkipVerify skips TLS certificate verification
|
|
func WithInsecureSkipVerify() Option {
|
|
return func(o *Options) {
|
|
o.InsecureSkipVerify = true
|
|
}
|
|
}
|
|
|
|
// WithHTTPClient sets a custom HTTP client
|
|
func WithHTTPClient(client *http.Client) Option {
|
|
return func(o *Options) {
|
|
o.HTTPClient = client
|
|
}
|
|
}
|
|
|
|
// WithProgress sets a progress callback
|
|
func WithProgress(callback func(downloaded, total int64)) Option {
|
|
return func(o *Options) {
|
|
o.OnProgress = callback
|
|
}
|
|
}
|
|
|
|
// WithoutCreateDirs disables automatic creation of parent directories
|
|
func WithoutCreateDirs() Option {
|
|
return func(o *Options) {
|
|
o.CreateDirs = false
|
|
}
|
|
}
|
|
|
|
// WithoutOverwrite prevents overwriting existing files
|
|
func WithoutOverwrite() Option {
|
|
return func(o *Options) {
|
|
o.Overwrite = false
|
|
}
|
|
}
|
|
|
|
func WithFileMode(mode os.FileMode) Option {
|
|
return func(o *Options) {
|
|
o.FileMode = mode
|
|
}
|
|
}
|
|
|
|
// Download downloads a file from URL to the specified destination
|
|
func Download(ctx context.Context, url, dest string, opts ...Option) error {
|
|
options := &Options{
|
|
InsecureSkipVerify: false,
|
|
CreateDirs: true,
|
|
Overwrite: true,
|
|
FileMode: 0644,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(options)
|
|
}
|
|
|
|
logger.Debug("开始下载文件: %s -> %s", url, dest)
|
|
|
|
// Check if file exists and overwrite is disabled
|
|
if !options.Overwrite {
|
|
if _, err := os.Stat(dest); err == nil {
|
|
logger.Debug("文件已存在且不允许覆盖: %s", dest)
|
|
return fmt.Errorf("file already exists: %s", dest)
|
|
}
|
|
}
|
|
|
|
// Create parent directories if needed
|
|
if options.CreateDirs {
|
|
dir := filepath.Dir(dest)
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
logger.Debug("创建目录失败 %s: %v", dir, err)
|
|
return fmt.Errorf("failed to create directory: %w", err)
|
|
}
|
|
}
|
|
|
|
// Create HTTP client
|
|
client := options.HTTPClient
|
|
if client == nil {
|
|
client = &http.Client{}
|
|
if options.InsecureSkipVerify {
|
|
logger.Debug("TLS 证书验证已禁用")
|
|
client.Transport = &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create HTTP request
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
logger.Debug("创建请求失败: %v", err)
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
// Execute request
|
|
logger.Debug("发起 HTTP 请求: %s", url)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
logger.Debug("下载失败: %v", err)
|
|
return fmt.Errorf("failed to download: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
logger.Debug("HTTP 状态码异常: %d", resp.StatusCode)
|
|
return fmt.Errorf("bad status: %s", resp.Status)
|
|
}
|
|
|
|
// Get content length for progress reporting
|
|
contentLength := resp.ContentLength
|
|
logger.Debug("文件大小: %d bytes", contentLength)
|
|
|
|
outFile, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, options.FileMode)
|
|
if err != nil {
|
|
logger.Debug("创建文件失败 %s: %v", dest, err)
|
|
return fmt.Errorf("failed to create file: %w", err)
|
|
}
|
|
defer outFile.Close()
|
|
|
|
// Copy content with optional progress reporting
|
|
var written int64
|
|
if options.OnProgress != nil && contentLength > 0 {
|
|
// Use progress reader
|
|
reader := &progressReader{
|
|
reader: resp.Body,
|
|
callback: options.OnProgress,
|
|
total: contentLength,
|
|
}
|
|
written, err = io.Copy(outFile, reader)
|
|
} else {
|
|
written, err = io.Copy(outFile, resp.Body)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Debug("写入文件失败 %s: %v", dest, err)
|
|
return fmt.Errorf("failed to write file: %w", err)
|
|
}
|
|
|
|
if options.FileMode != 0 {
|
|
if err := os.Chmod(dest, options.FileMode); err != nil {
|
|
logger.Debug("设置文件权限失败 %s: %v", dest, err)
|
|
return fmt.Errorf("failed to set file mode: %w", err)
|
|
}
|
|
}
|
|
|
|
logger.Debug("文件下载成功: %s (%d bytes)", dest, written)
|
|
return nil
|
|
}
|
|
|
|
// progressReader wraps an io.Reader to report progress
|
|
type progressReader struct {
|
|
reader io.Reader
|
|
callback func(downloaded, total int64)
|
|
total int64
|
|
downloaded int64
|
|
}
|
|
|
|
func (pr *progressReader) Read(p []byte) (int, error) {
|
|
n, err := pr.reader.Read(p)
|
|
pr.downloaded += int64(n)
|
|
if pr.callback != nil {
|
|
pr.callback(pr.downloaded, pr.total)
|
|
}
|
|
return n, err
|
|
}
|