feat: add tool rsa
This commit is contained in:
186
tool/rsa.go
Normal file
186
tool/rsa.go
Normal file
@ -0,0 +1,186 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// 通用RSA加密方法(支持选择填充方式)
|
||||
func rsaEncrypt(data []byte, publicKey *rsa.PublicKey, useOAEP bool) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("input data is empty")
|
||||
}
|
||||
|
||||
if useOAEP {
|
||||
// 使用OAEP填充(更安全)
|
||||
hash := sha256.New()
|
||||
return rsa.EncryptOAEP(hash, rand.Reader, publicKey, data, nil)
|
||||
} else {
|
||||
// 使用PKCS1v15填充(兼容旧系统)
|
||||
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, data)
|
||||
}
|
||||
}
|
||||
|
||||
// 方法1:通过tls.Config加密
|
||||
func RSAEncrypt(data []byte, cfg *tls.Config) (string, error) {
|
||||
// 验证参数
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("input data is empty")
|
||||
}
|
||||
if cfg == nil {
|
||||
return "", errors.New("TLS config is nil")
|
||||
}
|
||||
|
||||
// 获取第一个证书
|
||||
if len(cfg.Certificates) == 0 {
|
||||
return "", errors.New("no certificates found in TLS config")
|
||||
}
|
||||
cert := cfg.Certificates[0]
|
||||
|
||||
// 获取RSA公钥
|
||||
rsaPublicKey, ok := cert.Leaf.PublicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return "", errors.New("certificate does not contain RSA public key")
|
||||
}
|
||||
|
||||
// 加密(无分块,适用于小数据)
|
||||
encrypted, err := rsaEncrypt(data, rsaPublicKey, true)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encryption failed: %w", err)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
// 方法2:通过证书PEM加密(支持分块)
|
||||
func RSAEncryptByCert(data, cert []byte) (string, error) {
|
||||
// 解析PEM证书
|
||||
block, _ := pem.Decode(cert)
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return "", errors.New("failed to parse certificate PEM")
|
||||
}
|
||||
|
||||
// 解析X.509证书
|
||||
crt, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
// 获取RSA公钥
|
||||
publicKey, ok := crt.PublicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return "", errors.New("certificate does not contain RSA public key")
|
||||
}
|
||||
|
||||
// 计算最大分块大小
|
||||
keySize := publicKey.Size()
|
||||
maxChunk := 0
|
||||
|
||||
maxChunk = keySize - 42 // OAEP填充
|
||||
|
||||
// 不需要分块的情况
|
||||
if len(data) <= maxChunk {
|
||||
encrypted, err := rsaEncrypt(data, publicKey, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
// 需要分块的情况
|
||||
var encrypted []byte
|
||||
for i := 0; i < len(data); i += maxChunk {
|
||||
end := i + maxChunk
|
||||
if end > len(data) {
|
||||
end = len(data)
|
||||
}
|
||||
|
||||
chunk, err := rsaEncrypt(data[i:end], publicKey, true)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("RSA encryption failed: %w", err)
|
||||
}
|
||||
encrypted = append(encrypted, chunk...)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func rsaDecryptBlock(block []byte, privateKey *rsa.PrivateKey, useOAEP bool) ([]byte, error) {
|
||||
if useOAEP {
|
||||
// Use OAEP with SHA-256
|
||||
return rsa.DecryptOAEP(
|
||||
sha256.New(),
|
||||
rand.Reader,
|
||||
privateKey,
|
||||
block,
|
||||
nil,
|
||||
)
|
||||
} else {
|
||||
// Use PKCS1v15 for backward compatibility
|
||||
return rsa.DecryptPKCS1v15(
|
||||
rand.Reader,
|
||||
privateKey,
|
||||
block,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func RSADecrypt(encryptedBase64 string, cfg *tls.Config) ([]byte, error) {
|
||||
// Validate inputs
|
||||
if encryptedBase64 == "" {
|
||||
return nil, errors.New("encrypted data is empty")
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, errors.New("TLS config is nil")
|
||||
}
|
||||
if len(cfg.Certificates) == 0 {
|
||||
return nil, errors.New("no certificates found in TLS config")
|
||||
}
|
||||
|
||||
// Decode base64 string
|
||||
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decoding failed: %w", err)
|
||||
}
|
||||
|
||||
// Get private key from first certificate
|
||||
privateKey, ok := cfg.Certificates[0].PrivateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to get RSA private key from TLS config")
|
||||
}
|
||||
|
||||
// Determine block size for decryption
|
||||
blockSize := privateKey.Size()
|
||||
if len(encryptedData)%blockSize != 0 {
|
||||
return nil, fmt.Errorf("invalid encrypted data size. Expected multiple of %d, got %d",
|
||||
blockSize, len(encryptedData))
|
||||
}
|
||||
|
||||
// Handle single block decryption
|
||||
if len(encryptedData) == blockSize {
|
||||
return rsaDecryptBlock(encryptedData, privateKey, true)
|
||||
}
|
||||
|
||||
// Handle multi-block decryption
|
||||
var decrypted []byte
|
||||
for i := 0; i < len(encryptedData); i += blockSize {
|
||||
end := i + blockSize
|
||||
block := encryptedData[i:end]
|
||||
|
||||
decryptedBlock, err := rsaDecryptBlock(block, privateKey, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("block decryption failed at offset %d: %w", i, err)
|
||||
}
|
||||
|
||||
decrypted = append(decrypted, decryptedBlock...)
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
Reference in New Issue
Block a user