diff --git a/.gitignore b/.gitignore index ebf869b..8abc79a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .idea .vscode -.DS_Store \ No newline at end of file +.DS_Store + +xtest \ No newline at end of file diff --git a/tool/aes_test.go b/tool/aes_test.go index 9ada9dd..096f904 100644 --- a/tool/aes_test.go +++ b/tool/aes_test.go @@ -8,7 +8,7 @@ import ( func TestAes(t *testing.T) { key := os.Getenv("AES_KEY") - name := "admin" + name := "YizhiSEC@123" res, err := AesEncrypt([]byte(name), []byte(key)) if err != nil { t.Fatal(err) @@ -23,3 +23,14 @@ func TestAes(t *testing.T) { t.Logf("raw = %s", string(raw)) } + +func TestDecrypt(t *testing.T) { + key := os.Getenv("AES_KEY") + enc := "2hurNK+0+b9lEO2hNAkc+TzVx7KH7S0/mRt7mWBJiFA=" + + raw, err := AesDecrypt(enc, []byte(key)) + if err != nil { + t.Fatal(err) + } + t.Logf("raw = %s", string(raw)) +} diff --git a/tool/rsa.go b/tool/rsa.go new file mode 100644 index 0000000..e5fd223 --- /dev/null +++ b/tool/rsa.go @@ -0,0 +1,186 @@ +package tool + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "errors" + "fmt" +) + +// 通用RSA加密方法(支持选择填充方式) +func rsaEncrypt(data []byte, publicKey *rsa.PublicKey, useOAEP bool) ([]byte, error) { + if len(data) == 0 { + return nil, errors.New("input data is empty") + } + + if useOAEP { + // 使用OAEP填充(更安全) + hash := sha256.New() + return rsa.EncryptOAEP(hash, rand.Reader, publicKey, data, nil) + } else { + // 使用PKCS1v15填充(兼容旧系统) + return rsa.EncryptPKCS1v15(rand.Reader, publicKey, data) + } +} + +// 方法1:通过tls.Config加密 +func RSAEncrypt(data []byte, cfg *tls.Config) (string, error) { + // 验证参数 + if len(data) == 0 { + return "", errors.New("input data is empty") + } + if cfg == nil { + return "", errors.New("TLS config is nil") + } + + // 获取第一个证书 + if len(cfg.Certificates) == 0 { + return "", errors.New("no certificates found in TLS config") + } + cert := cfg.Certificates[0] + + // 获取RSA公钥 + rsaPublicKey, ok := cert.Leaf.PublicKey.(*rsa.PublicKey) + if !ok { + return "", errors.New("certificate does not contain RSA public key") + } + + // 加密(无分块,适用于小数据) + encrypted, err := rsaEncrypt(data, rsaPublicKey, true) + if err != nil { + return "", fmt.Errorf("encryption failed: %w", err) + } + + return base64.StdEncoding.EncodeToString(encrypted), nil +} + +// 方法2:通过证书PEM加密(支持分块) +func RSAEncryptByCert(data, cert []byte) (string, error) { + // 解析PEM证书 + block, _ := pem.Decode(cert) + if block == nil || block.Type != "CERTIFICATE" { + return "", errors.New("failed to parse certificate PEM") + } + + // 解析X.509证书 + crt, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return "", fmt.Errorf("failed to parse certificate: %w", err) + } + + // 获取RSA公钥 + publicKey, ok := crt.PublicKey.(*rsa.PublicKey) + if !ok { + return "", errors.New("certificate does not contain RSA public key") + } + + // 计算最大分块大小 + keySize := publicKey.Size() + maxChunk := 0 + + maxChunk = keySize - 42 // OAEP填充 + + // 不需要分块的情况 + if len(data) <= maxChunk { + encrypted, err := rsaEncrypt(data, publicKey, true) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(encrypted), nil + } + + // 需要分块的情况 + var encrypted []byte + for i := 0; i < len(data); i += maxChunk { + end := i + maxChunk + if end > len(data) { + end = len(data) + } + + chunk, err := rsaEncrypt(data[i:end], publicKey, true) + if err != nil { + return "", fmt.Errorf("RSA encryption failed: %w", err) + } + encrypted = append(encrypted, chunk...) + } + + return base64.StdEncoding.EncodeToString(encrypted), nil +} + +func rsaDecryptBlock(block []byte, privateKey *rsa.PrivateKey, useOAEP bool) ([]byte, error) { + if useOAEP { + // Use OAEP with SHA-256 + return rsa.DecryptOAEP( + sha256.New(), + rand.Reader, + privateKey, + block, + nil, + ) + } else { + // Use PKCS1v15 for backward compatibility + return rsa.DecryptPKCS1v15( + rand.Reader, + privateKey, + block, + ) + } +} + +func RSADecrypt(encryptedBase64 string, cfg *tls.Config) ([]byte, error) { + // Validate inputs + if encryptedBase64 == "" { + return nil, errors.New("encrypted data is empty") + } + if cfg == nil { + return nil, errors.New("TLS config is nil") + } + if len(cfg.Certificates) == 0 { + return nil, errors.New("no certificates found in TLS config") + } + + // Decode base64 string + encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64) + if err != nil { + return nil, fmt.Errorf("base64 decoding failed: %w", err) + } + + // Get private key from first certificate + privateKey, ok := cfg.Certificates[0].PrivateKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("failed to get RSA private key from TLS config") + } + + // Determine block size for decryption + blockSize := privateKey.Size() + if len(encryptedData)%blockSize != 0 { + return nil, fmt.Errorf("invalid encrypted data size. Expected multiple of %d, got %d", + blockSize, len(encryptedData)) + } + + // Handle single block decryption + if len(encryptedData) == blockSize { + return rsaDecryptBlock(encryptedData, privateKey, true) + } + + // Handle multi-block decryption + var decrypted []byte + for i := 0; i < len(encryptedData); i += blockSize { + end := i + blockSize + block := encryptedData[i:end] + + decryptedBlock, err := rsaDecryptBlock(block, privateKey, true) + if err != nil { + return nil, fmt.Errorf("block decryption failed at offset %d: %w", i, err) + } + + decrypted = append(decrypted, decryptedBlock...) + } + + return decrypted, nil +} diff --git a/tool/rsa_test.go b/tool/rsa_test.go new file mode 100644 index 0000000..2d59533 --- /dev/null +++ b/tool/rsa_test.go @@ -0,0 +1,43 @@ +package tool + +import "testing" + +func TestRSA(t *testing.T) { + var ( + crt, key string + ) + + cfg, err := LoadTLSConfig(StringToBytes(crt), StringToBytes(key)) + if err != nil { + t.Fatal(err) + } + + raw := []byte("admin") + enc, err := RSAEncrypt(raw, cfg) + if err != nil { + t.Fatal(err) + } + + t.Logf("Encrypted data: %s", enc) + + org, err := RSADecrypt(enc, cfg) + if err != nil { + t.Fatal(err) + } + + t.Logf("Decrypted data: %s", string(org)) + + enc2, err := RSAEncryptByCert(raw, []byte(crt)) + if err != nil { + t.Fatal(err) + } + + org2, err := RSADecrypt(enc2, cfg) + if err != nil { + t.Fatal(err) + } + + if string(org) != string(org2) { + t.Fatalf("Original and decrypted data don't match, org1 = %s, org2 = %s", org, org2) + } +} diff --git a/tool/tls.go b/tool/tls.go new file mode 100644 index 0000000..f2a5f35 --- /dev/null +++ b/tool/tls.go @@ -0,0 +1,131 @@ +package tool + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "time" +) + +func GenerateTlsConfig() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) { + ca := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{ + Organization: []string{"Company, INC."}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"San Francisco"}, + StreetAddress: []string{"Golden Gate Bridge"}, + PostalCode: []string{"94016"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(99, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + // create our private and public key + caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, err + } + // create the CA + caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) + if err != nil { + return nil, nil, err + } + // pem encode + caPEM := new(bytes.Buffer) + pem.Encode(caPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + }) + caPrivKeyPEM := new(bytes.Buffer) + pem.Encode(caPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey), + }) + // set up our server certificate + cert := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{ + Organization: []string{"Company, INC."}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"San Francisco"}, + StreetAddress: []string{"Golden Gate Bridge"}, + PostalCode: []string{"94016"}, + }, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, err + } + certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) + if err != nil { + return nil, nil, err + } + certPEM := new(bytes.Buffer) + pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + certPrivKeyPEM := new(bytes.Buffer) + pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), + }) + serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) + if err != nil { + return nil, nil, err + } + serverTLSConf = &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + certpool := x509.NewCertPool() + certpool.AppendCertsFromPEM(caPEM.Bytes()) + clientTLSConf = &tls.Config{ + RootCAs: certpool, + } + return +} + +// LoadTLSConfig 从字节数据加载TLS配置 +func LoadTLSConfig(certPEM, keyPEM []byte, caPEMs ...[]byte) (*tls.Config, error) { + // 加载客户端证书密钥对 + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, err + } + + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, // 设置最低TLS版本 + } + + // 加载CA证书(如果有) + if len(caPEMs) > 0 { + pool := x509.NewCertPool() + for _, caPEM := range caPEMs { + if !pool.AppendCertsFromPEM(caPEM) { + return nil, x509.SystemRootsError{} + } + } + config.RootCAs = pool + } + + return config, nil +}