187 lines
4.5 KiB
Go
187 lines
4.5 KiB
Go
package tool
|
||
|
||
import (
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/sha256"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"encoding/base64"
|
||
"encoding/pem"
|
||
"errors"
|
||
"fmt"
|
||
)
|
||
|
||
// 通用RSA加密方法(支持选择填充方式)
|
||
func rsaEncrypt(data []byte, publicKey *rsa.PublicKey, useOAEP bool) ([]byte, error) {
|
||
if len(data) == 0 {
|
||
return nil, errors.New("input data is empty")
|
||
}
|
||
|
||
if useOAEP {
|
||
// 使用OAEP填充(更安全)
|
||
hash := sha256.New()
|
||
return rsa.EncryptOAEP(hash, rand.Reader, publicKey, data, nil)
|
||
} else {
|
||
// 使用PKCS1v15填充(兼容旧系统)
|
||
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, data)
|
||
}
|
||
}
|
||
|
||
// 方法1:通过tls.Config加密
|
||
func RSAEncrypt(data []byte, cfg *tls.Config) (string, error) {
|
||
// 验证参数
|
||
if len(data) == 0 {
|
||
return "", errors.New("input data is empty")
|
||
}
|
||
if cfg == nil {
|
||
return "", errors.New("TLS config is nil")
|
||
}
|
||
|
||
// 获取第一个证书
|
||
if len(cfg.Certificates) == 0 {
|
||
return "", errors.New("no certificates found in TLS config")
|
||
}
|
||
cert := cfg.Certificates[0]
|
||
|
||
// 获取RSA公钥
|
||
rsaPublicKey, ok := cert.Leaf.PublicKey.(*rsa.PublicKey)
|
||
if !ok {
|
||
return "", errors.New("certificate does not contain RSA public key")
|
||
}
|
||
|
||
// 加密(无分块,适用于小数据)
|
||
encrypted, err := rsaEncrypt(data, rsaPublicKey, true)
|
||
if err != nil {
|
||
return "", fmt.Errorf("encryption failed: %w", err)
|
||
}
|
||
|
||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||
}
|
||
|
||
// 方法2:通过证书PEM加密(支持分块)
|
||
func RSAEncryptByCert(data, cert []byte) (string, error) {
|
||
// 解析PEM证书
|
||
block, _ := pem.Decode(cert)
|
||
if block == nil || block.Type != "CERTIFICATE" {
|
||
return "", errors.New("failed to parse certificate PEM")
|
||
}
|
||
|
||
// 解析X.509证书
|
||
crt, err := x509.ParseCertificate(block.Bytes)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to parse certificate: %w", err)
|
||
}
|
||
|
||
// 获取RSA公钥
|
||
publicKey, ok := crt.PublicKey.(*rsa.PublicKey)
|
||
if !ok {
|
||
return "", errors.New("certificate does not contain RSA public key")
|
||
}
|
||
|
||
// 计算最大分块大小
|
||
keySize := publicKey.Size()
|
||
maxChunk := 0
|
||
|
||
maxChunk = keySize - 42 // OAEP填充
|
||
|
||
// 不需要分块的情况
|
||
if len(data) <= maxChunk {
|
||
encrypted, err := rsaEncrypt(data, publicKey, true)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||
}
|
||
|
||
// 需要分块的情况
|
||
var encrypted []byte
|
||
for i := 0; i < len(data); i += maxChunk {
|
||
end := i + maxChunk
|
||
if end > len(data) {
|
||
end = len(data)
|
||
}
|
||
|
||
chunk, err := rsaEncrypt(data[i:end], publicKey, true)
|
||
if err != nil {
|
||
return "", fmt.Errorf("RSA encryption failed: %w", err)
|
||
}
|
||
encrypted = append(encrypted, chunk...)
|
||
}
|
||
|
||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||
}
|
||
|
||
func rsaDecryptBlock(block []byte, privateKey *rsa.PrivateKey, useOAEP bool) ([]byte, error) {
|
||
if useOAEP {
|
||
// Use OAEP with SHA-256
|
||
return rsa.DecryptOAEP(
|
||
sha256.New(),
|
||
rand.Reader,
|
||
privateKey,
|
||
block,
|
||
nil,
|
||
)
|
||
} else {
|
||
// Use PKCS1v15 for backward compatibility
|
||
return rsa.DecryptPKCS1v15(
|
||
rand.Reader,
|
||
privateKey,
|
||
block,
|
||
)
|
||
}
|
||
}
|
||
|
||
func RSADecrypt(encryptedBase64 string, cfg *tls.Config) ([]byte, error) {
|
||
// Validate inputs
|
||
if encryptedBase64 == "" {
|
||
return nil, errors.New("encrypted data is empty")
|
||
}
|
||
if cfg == nil {
|
||
return nil, errors.New("TLS config is nil")
|
||
}
|
||
if len(cfg.Certificates) == 0 {
|
||
return nil, errors.New("no certificates found in TLS config")
|
||
}
|
||
|
||
// Decode base64 string
|
||
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("base64 decoding failed: %w", err)
|
||
}
|
||
|
||
// Get private key from first certificate
|
||
privateKey, ok := cfg.Certificates[0].PrivateKey.(*rsa.PrivateKey)
|
||
if !ok {
|
||
return nil, errors.New("failed to get RSA private key from TLS config")
|
||
}
|
||
|
||
// Determine block size for decryption
|
||
blockSize := privateKey.Size()
|
||
if len(encryptedData)%blockSize != 0 {
|
||
return nil, fmt.Errorf("invalid encrypted data size. Expected multiple of %d, got %d",
|
||
blockSize, len(encryptedData))
|
||
}
|
||
|
||
// Handle single block decryption
|
||
if len(encryptedData) == blockSize {
|
||
return rsaDecryptBlock(encryptedData, privateKey, true)
|
||
}
|
||
|
||
// Handle multi-block decryption
|
||
var decrypted []byte
|
||
for i := 0; i < len(encryptedData); i += blockSize {
|
||
end := i + blockSize
|
||
block := encryptedData[i:end]
|
||
|
||
decryptedBlock, err := rsaDecryptBlock(block, privateKey, true)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("block decryption failed at offset %d: %w", i, err)
|
||
}
|
||
|
||
decrypted = append(decrypted, decryptedBlock...)
|
||
}
|
||
|
||
return decrypted, nil
|
||
}
|