feat: add tool rsa

This commit is contained in:
zhaoyupeng
2025-07-17 16:17:36 +08:00
parent 104745e260
commit 2172a39a20
5 changed files with 375 additions and 2 deletions

4
.gitignore vendored
View File

@ -1,3 +1,5 @@
.idea
.vscode
.DS_Store
.DS_Store
xtest

View File

@ -8,7 +8,7 @@ import (
func TestAes(t *testing.T) {
key := os.Getenv("AES_KEY")
name := "admin"
name := "YizhiSEC@123"
res, err := AesEncrypt([]byte(name), []byte(key))
if err != nil {
t.Fatal(err)
@ -23,3 +23,14 @@ func TestAes(t *testing.T) {
t.Logf("raw = %s", string(raw))
}
func TestDecrypt(t *testing.T) {
key := os.Getenv("AES_KEY")
enc := "2hurNK+0+b9lEO2hNAkc+TzVx7KH7S0/mRt7mWBJiFA="
raw, err := AesDecrypt(enc, []byte(key))
if err != nil {
t.Fatal(err)
}
t.Logf("raw = %s", string(raw))
}

186
tool/rsa.go Normal file
View File

@ -0,0 +1,186 @@
package tool
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
)
// 通用RSA加密方法支持选择填充方式
func rsaEncrypt(data []byte, publicKey *rsa.PublicKey, useOAEP bool) ([]byte, error) {
if len(data) == 0 {
return nil, errors.New("input data is empty")
}
if useOAEP {
// 使用OAEP填充更安全
hash := sha256.New()
return rsa.EncryptOAEP(hash, rand.Reader, publicKey, data, nil)
} else {
// 使用PKCS1v15填充兼容旧系统
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, data)
}
}
// 方法1通过tls.Config加密
func RSAEncrypt(data []byte, cfg *tls.Config) (string, error) {
// 验证参数
if len(data) == 0 {
return "", errors.New("input data is empty")
}
if cfg == nil {
return "", errors.New("TLS config is nil")
}
// 获取第一个证书
if len(cfg.Certificates) == 0 {
return "", errors.New("no certificates found in TLS config")
}
cert := cfg.Certificates[0]
// 获取RSA公钥
rsaPublicKey, ok := cert.Leaf.PublicKey.(*rsa.PublicKey)
if !ok {
return "", errors.New("certificate does not contain RSA public key")
}
// 加密(无分块,适用于小数据)
encrypted, err := rsaEncrypt(data, rsaPublicKey, true)
if err != nil {
return "", fmt.Errorf("encryption failed: %w", err)
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
// 方法2通过证书PEM加密支持分块
func RSAEncryptByCert(data, cert []byte) (string, error) {
// 解析PEM证书
block, _ := pem.Decode(cert)
if block == nil || block.Type != "CERTIFICATE" {
return "", errors.New("failed to parse certificate PEM")
}
// 解析X.509证书
crt, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return "", fmt.Errorf("failed to parse certificate: %w", err)
}
// 获取RSA公钥
publicKey, ok := crt.PublicKey.(*rsa.PublicKey)
if !ok {
return "", errors.New("certificate does not contain RSA public key")
}
// 计算最大分块大小
keySize := publicKey.Size()
maxChunk := 0
maxChunk = keySize - 42 // OAEP填充
// 不需要分块的情况
if len(data) <= maxChunk {
encrypted, err := rsaEncrypt(data, publicKey, true)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
// 需要分块的情况
var encrypted []byte
for i := 0; i < len(data); i += maxChunk {
end := i + maxChunk
if end > len(data) {
end = len(data)
}
chunk, err := rsaEncrypt(data[i:end], publicKey, true)
if err != nil {
return "", fmt.Errorf("RSA encryption failed: %w", err)
}
encrypted = append(encrypted, chunk...)
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
func rsaDecryptBlock(block []byte, privateKey *rsa.PrivateKey, useOAEP bool) ([]byte, error) {
if useOAEP {
// Use OAEP with SHA-256
return rsa.DecryptOAEP(
sha256.New(),
rand.Reader,
privateKey,
block,
nil,
)
} else {
// Use PKCS1v15 for backward compatibility
return rsa.DecryptPKCS1v15(
rand.Reader,
privateKey,
block,
)
}
}
func RSADecrypt(encryptedBase64 string, cfg *tls.Config) ([]byte, error) {
// Validate inputs
if encryptedBase64 == "" {
return nil, errors.New("encrypted data is empty")
}
if cfg == nil {
return nil, errors.New("TLS config is nil")
}
if len(cfg.Certificates) == 0 {
return nil, errors.New("no certificates found in TLS config")
}
// Decode base64 string
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
if err != nil {
return nil, fmt.Errorf("base64 decoding failed: %w", err)
}
// Get private key from first certificate
privateKey, ok := cfg.Certificates[0].PrivateKey.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("failed to get RSA private key from TLS config")
}
// Determine block size for decryption
blockSize := privateKey.Size()
if len(encryptedData)%blockSize != 0 {
return nil, fmt.Errorf("invalid encrypted data size. Expected multiple of %d, got %d",
blockSize, len(encryptedData))
}
// Handle single block decryption
if len(encryptedData) == blockSize {
return rsaDecryptBlock(encryptedData, privateKey, true)
}
// Handle multi-block decryption
var decrypted []byte
for i := 0; i < len(encryptedData); i += blockSize {
end := i + blockSize
block := encryptedData[i:end]
decryptedBlock, err := rsaDecryptBlock(block, privateKey, true)
if err != nil {
return nil, fmt.Errorf("block decryption failed at offset %d: %w", i, err)
}
decrypted = append(decrypted, decryptedBlock...)
}
return decrypted, nil
}

43
tool/rsa_test.go Normal file
View File

@ -0,0 +1,43 @@
package tool
import "testing"
func TestRSA(t *testing.T) {
var (
crt, key string
)
cfg, err := LoadTLSConfig(StringToBytes(crt), StringToBytes(key))
if err != nil {
t.Fatal(err)
}
raw := []byte("admin")
enc, err := RSAEncrypt(raw, cfg)
if err != nil {
t.Fatal(err)
}
t.Logf("Encrypted data: %s", enc)
org, err := RSADecrypt(enc, cfg)
if err != nil {
t.Fatal(err)
}
t.Logf("Decrypted data: %s", string(org))
enc2, err := RSAEncryptByCert(raw, []byte(crt))
if err != nil {
t.Fatal(err)
}
org2, err := RSADecrypt(enc2, cfg)
if err != nil {
t.Fatal(err)
}
if string(org) != string(org2) {
t.Fatalf("Original and decrypted data don't match, org1 = %s, org2 = %s", org, org2)
}
}

131
tool/tls.go Normal file
View File

@ -0,0 +1,131 @@
package tool
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
)
func GenerateTlsConfig() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(99, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
// create our private and public key
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
// create the CA
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}
// pem encode
caPEM := new(bytes.Buffer)
pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})
caPrivKeyPEM := new(bytes.Buffer)
pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})
// set up our server certificate
cert := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}
certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
if err != nil {
return nil, nil, err
}
serverTLSConf = &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
certpool := x509.NewCertPool()
certpool.AppendCertsFromPEM(caPEM.Bytes())
clientTLSConf = &tls.Config{
RootCAs: certpool,
}
return
}
// LoadTLSConfig 从字节数据加载TLS配置
func LoadTLSConfig(certPEM, keyPEM []byte, caPEMs ...[]byte) (*tls.Config, error) {
// 加载客户端证书密钥对
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, err
}
config := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, // 设置最低TLS版本
}
// 加载CA证书如果有
if len(caPEMs) > 0 {
pool := x509.NewCertPool()
for _, caPEM := range caPEMs {
if !pool.AppendCertsFromPEM(caPEM) {
return nil, x509.SystemRootsError{}
}
}
config.RootCAs = pool
}
return config, nil
}