feat: add tool aes

This commit is contained in:
zhaoyupeng
2025-07-16 18:51:52 +08:00
parent 868b959c6f
commit 104745e260
2 changed files with 120 additions and 0 deletions

95
tool/aes.go Normal file
View File

@ -0,0 +1,95 @@
package tool
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"io"
)
// AES加密CBC模式PKCS7填充
func AesEncrypt(data []byte, key []byte) (string, error) {
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
// 添加PKCS7填充
data = pkcs7Pad(data, block.BlockSize())
// 创建存储密文的buffer前aes.BlockSize字节存储IV
ciphertext := make([]byte, aes.BlockSize+len(data))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}
// CBC加密
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext[aes.BlockSize:], data)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// AES解密CBC模式PKCS7填充
func AesDecrypt(encrypted string, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return nil, err
}
if len(ciphertext) < aes.BlockSize {
return nil, errors.New("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
if len(ciphertext)%aes.BlockSize != 0 {
return nil, errors.New("ciphertext is not a multiple of the block size")
}
// CBC解密
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(ciphertext, ciphertext)
// 去除PKCS7填充
return pkcs7Unpad(ciphertext)
}
// PKCS7填充
func pkcs7Pad(data []byte, blockSize int) []byte {
padding := blockSize - (len(data) % blockSize)
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padText...)
}
// PKCS7去除填充
func pkcs7Unpad(data []byte) ([]byte, error) {
length := len(data)
if length == 0 {
return nil, errors.New("empty input")
}
padding := int(data[length-1])
if padding > length || padding == 0 {
return nil, errors.New("invalid padding")
}
// 验证填充字节是否正确
for i := length - padding; i < length; i++ {
if int(data[i]) != padding {
return nil, errors.New("invalid padding")
}
}
return data[:length-padding], nil
}

25
tool/aes_test.go Normal file
View File

@ -0,0 +1,25 @@
package tool
import (
"os"
"testing"
)
func TestAes(t *testing.T) {
key := os.Getenv("AES_KEY")
name := "admin"
res, err := AesEncrypt([]byte(name), []byte(key))
if err != nil {
t.Fatal(err)
}
t.Logf("res = %s", string(res))
raw, err := AesDecrypt(res, []byte(key))
if err != nil {
t.Fatal(err)
}
t.Logf("raw = %s", string(raw))
}