Files
packages/tool/rsa.go
2025-07-17 16:17:36 +08:00

187 lines
4.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tool
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
)
// 通用RSA加密方法支持选择填充方式
func rsaEncrypt(data []byte, publicKey *rsa.PublicKey, useOAEP bool) ([]byte, error) {
if len(data) == 0 {
return nil, errors.New("input data is empty")
}
if useOAEP {
// 使用OAEP填充更安全
hash := sha256.New()
return rsa.EncryptOAEP(hash, rand.Reader, publicKey, data, nil)
} else {
// 使用PKCS1v15填充兼容旧系统
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, data)
}
}
// 方法1通过tls.Config加密
func RSAEncrypt(data []byte, cfg *tls.Config) (string, error) {
// 验证参数
if len(data) == 0 {
return "", errors.New("input data is empty")
}
if cfg == nil {
return "", errors.New("TLS config is nil")
}
// 获取第一个证书
if len(cfg.Certificates) == 0 {
return "", errors.New("no certificates found in TLS config")
}
cert := cfg.Certificates[0]
// 获取RSA公钥
rsaPublicKey, ok := cert.Leaf.PublicKey.(*rsa.PublicKey)
if !ok {
return "", errors.New("certificate does not contain RSA public key")
}
// 加密(无分块,适用于小数据)
encrypted, err := rsaEncrypt(data, rsaPublicKey, true)
if err != nil {
return "", fmt.Errorf("encryption failed: %w", err)
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
// 方法2通过证书PEM加密支持分块
func RSAEncryptByCert(data, cert []byte) (string, error) {
// 解析PEM证书
block, _ := pem.Decode(cert)
if block == nil || block.Type != "CERTIFICATE" {
return "", errors.New("failed to parse certificate PEM")
}
// 解析X.509证书
crt, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return "", fmt.Errorf("failed to parse certificate: %w", err)
}
// 获取RSA公钥
publicKey, ok := crt.PublicKey.(*rsa.PublicKey)
if !ok {
return "", errors.New("certificate does not contain RSA public key")
}
// 计算最大分块大小
keySize := publicKey.Size()
maxChunk := 0
maxChunk = keySize - 42 // OAEP填充
// 不需要分块的情况
if len(data) <= maxChunk {
encrypted, err := rsaEncrypt(data, publicKey, true)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
// 需要分块的情况
var encrypted []byte
for i := 0; i < len(data); i += maxChunk {
end := i + maxChunk
if end > len(data) {
end = len(data)
}
chunk, err := rsaEncrypt(data[i:end], publicKey, true)
if err != nil {
return "", fmt.Errorf("RSA encryption failed: %w", err)
}
encrypted = append(encrypted, chunk...)
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
func rsaDecryptBlock(block []byte, privateKey *rsa.PrivateKey, useOAEP bool) ([]byte, error) {
if useOAEP {
// Use OAEP with SHA-256
return rsa.DecryptOAEP(
sha256.New(),
rand.Reader,
privateKey,
block,
nil,
)
} else {
// Use PKCS1v15 for backward compatibility
return rsa.DecryptPKCS1v15(
rand.Reader,
privateKey,
block,
)
}
}
func RSADecrypt(encryptedBase64 string, cfg *tls.Config) ([]byte, error) {
// Validate inputs
if encryptedBase64 == "" {
return nil, errors.New("encrypted data is empty")
}
if cfg == nil {
return nil, errors.New("TLS config is nil")
}
if len(cfg.Certificates) == 0 {
return nil, errors.New("no certificates found in TLS config")
}
// Decode base64 string
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
if err != nil {
return nil, fmt.Errorf("base64 decoding failed: %w", err)
}
// Get private key from first certificate
privateKey, ok := cfg.Certificates[0].PrivateKey.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("failed to get RSA private key from TLS config")
}
// Determine block size for decryption
blockSize := privateKey.Size()
if len(encryptedData)%blockSize != 0 {
return nil, fmt.Errorf("invalid encrypted data size. Expected multiple of %d, got %d",
blockSize, len(encryptedData))
}
// Handle single block decryption
if len(encryptedData) == blockSize {
return rsaDecryptBlock(encryptedData, privateKey, true)
}
// Handle multi-block decryption
var decrypted []byte
for i := 0; i < len(encryptedData); i += blockSize {
end := i + blockSize
block := encryptedData[i:end]
decryptedBlock, err := rsaDecryptBlock(block, privateKey, true)
if err != nil {
return nil, fmt.Errorf("block decryption failed at offset %d: %w", i, err)
}
decrypted = append(decrypted, decryptedBlock...)
}
return decrypted, nil
}