diff --git a/tool/aes.go b/tool/aes.go new file mode 100644 index 0000000..b4bd4eb --- /dev/null +++ b/tool/aes.go @@ -0,0 +1,95 @@ +package tool + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "io" +) + +// AES加密(CBC模式,PKCS7填充) +func AesEncrypt(data []byte, key []byte) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + // 添加PKCS7填充 + data = pkcs7Pad(data, block.BlockSize()) + + // 创建存储密文的buffer,前aes.BlockSize字节存储IV + ciphertext := make([]byte, aes.BlockSize+len(data)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return "", err + } + + // CBC加密 + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(ciphertext[aes.BlockSize:], data) + + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// AES解密(CBC模式,PKCS7填充) +func AesDecrypt(encrypted string, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + ciphertext, err := base64.StdEncoding.DecodeString(encrypted) + if err != nil { + return nil, err + } + + if len(ciphertext) < aes.BlockSize { + return nil, errors.New("ciphertext too short") + } + + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + + if len(ciphertext)%aes.BlockSize != 0 { + return nil, errors.New("ciphertext is not a multiple of the block size") + } + + // CBC解密 + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(ciphertext, ciphertext) + + // 去除PKCS7填充 + return pkcs7Unpad(ciphertext) +} + +// PKCS7填充 +func pkcs7Pad(data []byte, blockSize int) []byte { + padding := blockSize - (len(data) % blockSize) + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + +// PKCS7去除填充 +func pkcs7Unpad(data []byte) ([]byte, error) { + length := len(data) + if length == 0 { + return nil, errors.New("empty input") + } + + padding := int(data[length-1]) + if padding > length || padding == 0 { + return nil, errors.New("invalid padding") + } + + // 验证填充字节是否正确 + for i := length - padding; i < length; i++ { + if int(data[i]) != padding { + return nil, errors.New("invalid padding") + } + } + + return data[:length-padding], nil +} diff --git a/tool/aes_test.go b/tool/aes_test.go new file mode 100644 index 0000000..9ada9dd --- /dev/null +++ b/tool/aes_test.go @@ -0,0 +1,25 @@ +package tool + +import ( + "os" + "testing" +) + +func TestAes(t *testing.T) { + key := os.Getenv("AES_KEY") + + name := "admin" + res, err := AesEncrypt([]byte(name), []byte(key)) + if err != nil { + t.Fatal(err) + } + + t.Logf("res = %s", string(res)) + + raw, err := AesDecrypt(res, []byte(key)) + if err != nil { + t.Fatal(err) + } + + t.Logf("raw = %s", string(raw)) +}