Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
2172a39a20 | |||
104745e260 | |||
868b959c6f | |||
127c57dc3a | |||
d6b0b8ea36 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,3 +1,5 @@
|
|||||||
.idea
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
xtest
|
@ -23,7 +23,7 @@ type option struct {
|
|||||||
|
|
||||||
func WithName(name string) Option {
|
func WithName(name string) Option {
|
||||||
return func(o *option) {
|
return func(o *option) {
|
||||||
if name == "" {
|
if name != "" {
|
||||||
o.name = name
|
o.name = name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -31,7 +31,7 @@ func WithName(name string) Option {
|
|||||||
|
|
||||||
func WithVersion(version string) Option {
|
func WithVersion(version string) Option {
|
||||||
return func(o *option) {
|
return func(o *option) {
|
||||||
if version == "" {
|
if version != "" {
|
||||||
o.version = version
|
o.version = version
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -39,7 +39,7 @@ func WithVersion(version string) Option {
|
|||||||
|
|
||||||
func WithAddress(address string) Option {
|
func WithAddress(address string) Option {
|
||||||
return func(o *option) {
|
return func(o *option) {
|
||||||
if address == "" {
|
if address != "" {
|
||||||
o.address = address
|
o.address = address
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
1
database/cache/cache.go
vendored
1
database/cache/cache.go
vendored
@ -62,6 +62,7 @@ var (
|
|||||||
marshaler func(data any) ([]byte, error) = json.Marshal
|
marshaler func(data any) ([]byte, error) = json.Marshal
|
||||||
unmarshaler func(data []byte, model any) error = json.Unmarshal
|
unmarshaler func(data []byte, model any) error = json.Unmarshal
|
||||||
ErrorKeyNotFound = errors.New("key not found")
|
ErrorKeyNotFound = errors.New("key not found")
|
||||||
|
ErrorStoreFailed = errors.New("store failed")
|
||||||
Default Cache
|
Default Cache
|
||||||
)
|
)
|
||||||
|
|
||||||
|
155
database/cache/memory.go
vendored
Normal file
155
database/cache/memory.go
vendored
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dgraph-io/ristretto/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ Cache = (*_mem)(nil)
|
||||||
|
|
||||||
|
type _mem struct {
|
||||||
|
ctx context.Context
|
||||||
|
cache *ristretto.Cache[string, []byte]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemory(ctx context.Context, ins *ristretto.Cache[string, []byte]) Cache {
|
||||||
|
return &_mem{
|
||||||
|
ctx: ctx,
|
||||||
|
cache: ins,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *_mem) Client() any {
|
||||||
|
return m.cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *_mem) Close() {
|
||||||
|
c.cache.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *_mem) Del(ctx context.Context, keys ...string) error {
|
||||||
|
for _, key := range keys {
|
||||||
|
c.cache.Del(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *_mem) Get(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
val, ok := c.cache.Get(key)
|
||||||
|
if !ok {
|
||||||
|
return val, ErrorKeyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *_mem) GetDel(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
val, err := c.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cache.Del(key)
|
||||||
|
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *_mem) GetDelScan(ctx context.Context, key string) Scanner {
|
||||||
|
val, err := c.GetDel(ctx, key)
|
||||||
|
return newScan(val, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
|
||||||
|
val, err := c.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cache.SetWithTTL(key, val, 1, duration)
|
||||||
|
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
|
||||||
|
val, err := m.GetEx(ctx, key, duration)
|
||||||
|
return newScan(val, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *_mem) GetScan(ctx context.Context, key string) Scanner {
|
||||||
|
val, err := m.Get(ctx, key)
|
||||||
|
return newScan(val, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *_mem) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
|
||||||
|
vals := make([][]byte, 0, len(keys))
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
val, err := m.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrorKeyNotFound) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return vals, err
|
||||||
|
}
|
||||||
|
|
||||||
|
vals = append(vals, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vals) != len(keys) {
|
||||||
|
return vals, ErrorKeyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return vals, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *_mem) Set(ctx context.Context, key string, value any) error {
|
||||||
|
val, err := handleValue(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := m.cache.Set(key, val, 1); !ok {
|
||||||
|
return ErrorStoreFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
m.cache.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
|
||||||
|
val, err := handleValue(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := m.cache.SetWithTTL(key, val, 1, duration); !ok {
|
||||||
|
return ErrorStoreFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
m.cache.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *_mem) Sets(ctx context.Context, vm map[string]any) error {
|
||||||
|
for key, value := range vm {
|
||||||
|
val, err := handleValue(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := m.cache.Set(key, val, 1); !ok {
|
||||||
|
return ErrorStoreFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.cache.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
42
database/cache/new.go
vendored
42
database/cache/new.go
vendored
@ -4,35 +4,31 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gitea.loveuer.com/yizhisec/packages/tool"
|
"gitea.loveuer.com/yizhisec/packages/tool"
|
||||||
|
"github.com/dgraph-io/ristretto/v2"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
_ "github.com/go-redis/redis/v8"
|
_ "github.com/go-redis/redis/v8"
|
||||||
"net/url"
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func New(opts ...Option) (Cache, error) {
|
||||||
defaultRedis = "redis://127.0.0.1:6379"
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(opts ...OptionFn) (Cache, error) {
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
cfg = &config{
|
opt = &option{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
redis: &defaultRedis,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, fn := range opts {
|
||||||
opt(cfg)
|
fn(opt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.redis != nil {
|
if opt.redis != nil {
|
||||||
var (
|
var (
|
||||||
ins *url.URL
|
ins *url.URL
|
||||||
client *redis.Client
|
client *redis.Client
|
||||||
)
|
)
|
||||||
|
|
||||||
if ins, err = url.Parse(*cfg.redis); err != nil {
|
if ins, err = url.Parse(*opt.redis); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,17 +41,33 @@ func New(opts ...OptionFn) (Cache, error) {
|
|||||||
Password: password,
|
Password: password,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err = client.Ping(tool.CtxTimeout(cfg.ctx, 5)).Err(); err != nil {
|
if err = client.Ping(tool.TimeoutCtx(opt.ctx, 5)).Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return newRedis(cfg.ctx, client), nil
|
return newRedis(opt.ctx, client), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if opt.memory {
|
||||||
|
var (
|
||||||
|
ins *ristretto.Cache[string, []byte]
|
||||||
|
)
|
||||||
|
|
||||||
|
if ins, err = ristretto.NewCache(&ristretto.Config[string, []byte]{
|
||||||
|
NumCounters: 1e7,
|
||||||
|
MaxCost: 1 << 30,
|
||||||
|
BufferItems: 64,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newMemory(opt.ctx, ins), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("invalid cache config")
|
return nil, fmt.Errorf("invalid cache config")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Init(opts ...OptionFn) (err error) {
|
func Init(opts ...Option) (err error) {
|
||||||
Default, err = New(opts...)
|
Default, err = New(opts...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
29
database/cache/new_test.go
vendored
29
database/cache/new_test.go
vendored
@ -1,6 +1,7 @@
|
|||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,3 +72,31 @@ func TestNoAuth(t *testing.T) {
|
|||||||
// t.Fatal(err)
|
// t.Fatal(err)
|
||||||
//}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMemory(t *testing.T) {
|
||||||
|
c, err := New(WithMemory())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bs, err := c.Get(t.Context(), "haha")
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, ErrorKeyNotFound) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("haha = %s", string(bs))
|
||||||
|
|
||||||
|
if err = c.Set(t.Context(), "haha", "haha"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bs, err = c.Get(t.Context(), "haha"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("haha = %s", string(bs))
|
||||||
|
}
|
||||||
|
23
database/cache/option.go
vendored
23
database/cache/option.go
vendored
@ -5,23 +5,24 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type config struct {
|
type option struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
redis *string
|
redis *string
|
||||||
|
memory bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type OptionFn func(*config)
|
type Option func(*option)
|
||||||
|
|
||||||
func WithCtx(ctx context.Context) OptionFn {
|
func WithCtx(ctx context.Context) Option {
|
||||||
return func(c *config) {
|
return func(c *option) {
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
c.ctx = ctx
|
c.ctx = ctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithRedis(host string, port int, username, password string) OptionFn {
|
func WithRedis(host string, port int, username, password string) Option {
|
||||||
return func(c *config) {
|
return func(c *option) {
|
||||||
uri := fmt.Sprintf("redis://%s:%d", host, port)
|
uri := fmt.Sprintf("redis://%s:%d", host, port)
|
||||||
if username != "" || password != "" {
|
if username != "" || password != "" {
|
||||||
uri = fmt.Sprintf("redis://%s:%s@%s:%d", username, password, host, port)
|
uri = fmt.Sprintf("redis://%s:%s@%s:%d", username, password, host, port)
|
||||||
@ -30,3 +31,9 @@ func WithRedis(host string, port int, username, password string) OptionFn {
|
|||||||
c.redis = &uri
|
c.redis = &uri
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithMemory() Option {
|
||||||
|
return func(c *option) {
|
||||||
|
c.memory = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
7
go.mod
7
go.mod
@ -1,10 +1,11 @@
|
|||||||
module gitea.loveuer.com/yizhisec/packages
|
module gitea.loveuer.com/yizhisec/packages
|
||||||
|
|
||||||
go 1.23
|
go 1.23.0
|
||||||
|
|
||||||
toolchain go1.24.3
|
toolchain go1.24.3
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/dgraph-io/ristretto/v2 v2.2.0
|
||||||
github.com/fatih/color v1.18.0
|
github.com/fatih/color v1.18.0
|
||||||
github.com/gin-gonic/gin v1.10.1
|
github.com/gin-gonic/gin v1.10.1
|
||||||
github.com/glebarez/sqlite v1.11.0
|
github.com/glebarez/sqlite v1.11.0
|
||||||
@ -23,7 +24,7 @@ require (
|
|||||||
filippo.io/edwards25519 v1.1.0 // indirect
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
github.com/bytedance/sonic v1.11.6 // indirect
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.1 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
@ -60,7 +61,7 @@ require (
|
|||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/net v0.27.0 // indirect
|
golang.org/x/net v0.27.0 // indirect
|
||||||
golang.org/x/sync v0.11.0 // indirect
|
golang.org/x/sync v0.11.0 // indirect
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.22.0 // indirect
|
golang.org/x/text v0.22.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
11
go.sum
11
go.sum
@ -6,8 +6,9 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc
|
|||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||||
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
|
|
||||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||||
@ -19,6 +20,10 @@ github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7Do
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM=
|
||||||
|
github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI=
|
||||||
|
github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38=
|
||||||
|
github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
@ -321,8 +326,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
|
95
tool/aes.go
Normal file
95
tool/aes.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package tool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AES加密(CBC模式,PKCS7填充)
|
||||||
|
func AesEncrypt(data []byte, key []byte) (string, error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加PKCS7填充
|
||||||
|
data = pkcs7Pad(data, block.BlockSize())
|
||||||
|
|
||||||
|
// 创建存储密文的buffer,前aes.BlockSize字节存储IV
|
||||||
|
ciphertext := make([]byte, aes.BlockSize+len(data))
|
||||||
|
iv := ciphertext[:aes.BlockSize]
|
||||||
|
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CBC加密
|
||||||
|
mode := cipher.NewCBCEncrypter(block, iv)
|
||||||
|
mode.CryptBlocks(ciphertext[aes.BlockSize:], data)
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AES解密(CBC模式,PKCS7填充)
|
||||||
|
func AesDecrypt(encrypted string, key []byte) ([]byte, error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ciphertext) < aes.BlockSize {
|
||||||
|
return nil, errors.New("ciphertext too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
iv := ciphertext[:aes.BlockSize]
|
||||||
|
ciphertext = ciphertext[aes.BlockSize:]
|
||||||
|
|
||||||
|
if len(ciphertext)%aes.BlockSize != 0 {
|
||||||
|
return nil, errors.New("ciphertext is not a multiple of the block size")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CBC解密
|
||||||
|
mode := cipher.NewCBCDecrypter(block, iv)
|
||||||
|
mode.CryptBlocks(ciphertext, ciphertext)
|
||||||
|
|
||||||
|
// 去除PKCS7填充
|
||||||
|
return pkcs7Unpad(ciphertext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCS7填充
|
||||||
|
func pkcs7Pad(data []byte, blockSize int) []byte {
|
||||||
|
padding := blockSize - (len(data) % blockSize)
|
||||||
|
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||||
|
return append(data, padText...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCS7去除填充
|
||||||
|
func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||||
|
length := len(data)
|
||||||
|
if length == 0 {
|
||||||
|
return nil, errors.New("empty input")
|
||||||
|
}
|
||||||
|
|
||||||
|
padding := int(data[length-1])
|
||||||
|
if padding > length || padding == 0 {
|
||||||
|
return nil, errors.New("invalid padding")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证填充字节是否正确
|
||||||
|
for i := length - padding; i < length; i++ {
|
||||||
|
if int(data[i]) != padding {
|
||||||
|
return nil, errors.New("invalid padding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data[:length-padding], nil
|
||||||
|
}
|
36
tool/aes_test.go
Normal file
36
tool/aes_test.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package tool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAes(t *testing.T) {
|
||||||
|
key := os.Getenv("AES_KEY")
|
||||||
|
|
||||||
|
name := "YizhiSEC@123"
|
||||||
|
res, err := AesEncrypt([]byte(name), []byte(key))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("res = %s", string(res))
|
||||||
|
|
||||||
|
raw, err := AesDecrypt(res, []byte(key))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("raw = %s", string(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecrypt(t *testing.T) {
|
||||||
|
key := os.Getenv("AES_KEY")
|
||||||
|
enc := "2hurNK+0+b9lEO2hNAkc+TzVx7KH7S0/mRt7mWBJiFA="
|
||||||
|
|
||||||
|
raw, err := AesDecrypt(enc, []byte(key))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("raw = %s", string(raw))
|
||||||
|
}
|
@ -23,7 +23,7 @@ func Timeout(seconds ...int) (ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func CtxTimeout(ctx context.Context, seconds ...int) context.Context {
|
func TimeoutCtx(ctx context.Context, seconds ...int) context.Context {
|
||||||
var (
|
var (
|
||||||
duration time.Duration
|
duration time.Duration
|
||||||
)
|
)
|
||||||
|
186
tool/rsa.go
Normal file
186
tool/rsa.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
package tool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 通用RSA加密方法(支持选择填充方式)
|
||||||
|
func rsaEncrypt(data []byte, publicKey *rsa.PublicKey, useOAEP bool) ([]byte, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, errors.New("input data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if useOAEP {
|
||||||
|
// 使用OAEP填充(更安全)
|
||||||
|
hash := sha256.New()
|
||||||
|
return rsa.EncryptOAEP(hash, rand.Reader, publicKey, data, nil)
|
||||||
|
} else {
|
||||||
|
// 使用PKCS1v15填充(兼容旧系统)
|
||||||
|
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方法1:通过tls.Config加密
|
||||||
|
func RSAEncrypt(data []byte, cfg *tls.Config) (string, error) {
|
||||||
|
// 验证参数
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "", errors.New("input data is empty")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return "", errors.New("TLS config is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取第一个证书
|
||||||
|
if len(cfg.Certificates) == 0 {
|
||||||
|
return "", errors.New("no certificates found in TLS config")
|
||||||
|
}
|
||||||
|
cert := cfg.Certificates[0]
|
||||||
|
|
||||||
|
// 获取RSA公钥
|
||||||
|
rsaPublicKey, ok := cert.Leaf.PublicKey.(*rsa.PublicKey)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("certificate does not contain RSA public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加密(无分块,适用于小数据)
|
||||||
|
encrypted, err := rsaEncrypt(data, rsaPublicKey, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("encryption failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方法2:通过证书PEM加密(支持分块)
|
||||||
|
func RSAEncryptByCert(data, cert []byte) (string, error) {
|
||||||
|
// 解析PEM证书
|
||||||
|
block, _ := pem.Decode(cert)
|
||||||
|
if block == nil || block.Type != "CERTIFICATE" {
|
||||||
|
return "", errors.New("failed to parse certificate PEM")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析X.509证书
|
||||||
|
crt, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取RSA公钥
|
||||||
|
publicKey, ok := crt.PublicKey.(*rsa.PublicKey)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("certificate does not contain RSA public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算最大分块大小
|
||||||
|
keySize := publicKey.Size()
|
||||||
|
maxChunk := 0
|
||||||
|
|
||||||
|
maxChunk = keySize - 42 // OAEP填充
|
||||||
|
|
||||||
|
// 不需要分块的情况
|
||||||
|
if len(data) <= maxChunk {
|
||||||
|
encrypted, err := rsaEncrypt(data, publicKey, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 需要分块的情况
|
||||||
|
var encrypted []byte
|
||||||
|
for i := 0; i < len(data); i += maxChunk {
|
||||||
|
end := i + maxChunk
|
||||||
|
if end > len(data) {
|
||||||
|
end = len(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk, err := rsaEncrypt(data[i:end], publicKey, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("RSA encryption failed: %w", err)
|
||||||
|
}
|
||||||
|
encrypted = append(encrypted, chunk...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rsaDecryptBlock(block []byte, privateKey *rsa.PrivateKey, useOAEP bool) ([]byte, error) {
|
||||||
|
if useOAEP {
|
||||||
|
// Use OAEP with SHA-256
|
||||||
|
return rsa.DecryptOAEP(
|
||||||
|
sha256.New(),
|
||||||
|
rand.Reader,
|
||||||
|
privateKey,
|
||||||
|
block,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Use PKCS1v15 for backward compatibility
|
||||||
|
return rsa.DecryptPKCS1v15(
|
||||||
|
rand.Reader,
|
||||||
|
privateKey,
|
||||||
|
block,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RSADecrypt(encryptedBase64 string, cfg *tls.Config) ([]byte, error) {
|
||||||
|
// Validate inputs
|
||||||
|
if encryptedBase64 == "" {
|
||||||
|
return nil, errors.New("encrypted data is empty")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, errors.New("TLS config is nil")
|
||||||
|
}
|
||||||
|
if len(cfg.Certificates) == 0 {
|
||||||
|
return nil, errors.New("no certificates found in TLS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode base64 string
|
||||||
|
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("base64 decoding failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get private key from first certificate
|
||||||
|
privateKey, ok := cfg.Certificates[0].PrivateKey.(*rsa.PrivateKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("failed to get RSA private key from TLS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine block size for decryption
|
||||||
|
blockSize := privateKey.Size()
|
||||||
|
if len(encryptedData)%blockSize != 0 {
|
||||||
|
return nil, fmt.Errorf("invalid encrypted data size. Expected multiple of %d, got %d",
|
||||||
|
blockSize, len(encryptedData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle single block decryption
|
||||||
|
if len(encryptedData) == blockSize {
|
||||||
|
return rsaDecryptBlock(encryptedData, privateKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multi-block decryption
|
||||||
|
var decrypted []byte
|
||||||
|
for i := 0; i < len(encryptedData); i += blockSize {
|
||||||
|
end := i + blockSize
|
||||||
|
block := encryptedData[i:end]
|
||||||
|
|
||||||
|
decryptedBlock, err := rsaDecryptBlock(block, privateKey, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("block decryption failed at offset %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted = append(decrypted, decryptedBlock...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decrypted, nil
|
||||||
|
}
|
43
tool/rsa_test.go
Normal file
43
tool/rsa_test.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package tool
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestRSA(t *testing.T) {
|
||||||
|
var (
|
||||||
|
crt, key string
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg, err := LoadTLSConfig(StringToBytes(crt), StringToBytes(key))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := []byte("admin")
|
||||||
|
enc, err := RSAEncrypt(raw, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Encrypted data: %s", enc)
|
||||||
|
|
||||||
|
org, err := RSADecrypt(enc, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Decrypted data: %s", string(org))
|
||||||
|
|
||||||
|
enc2, err := RSAEncryptByCert(raw, []byte(crt))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
org2, err := RSADecrypt(enc2, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(org) != string(org2) {
|
||||||
|
t.Fatalf("Original and decrypted data don't match, org1 = %s, org2 = %s", org, org2)
|
||||||
|
}
|
||||||
|
}
|
131
tool/tls.go
Normal file
131
tool/tls.go
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
package tool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateTlsConfig() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
|
||||||
|
ca := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2019),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Company, INC."},
|
||||||
|
Country: []string{"US"},
|
||||||
|
Province: []string{""},
|
||||||
|
Locality: []string{"San Francisco"},
|
||||||
|
StreetAddress: []string{"Golden Gate Bridge"},
|
||||||
|
PostalCode: []string{"94016"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(99, 0, 0),
|
||||||
|
IsCA: true,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
// create our private and public key
|
||||||
|
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
// create the CA
|
||||||
|
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
// pem encode
|
||||||
|
caPEM := new(bytes.Buffer)
|
||||||
|
pem.Encode(caPEM, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: caBytes,
|
||||||
|
})
|
||||||
|
caPrivKeyPEM := new(bytes.Buffer)
|
||||||
|
pem.Encode(caPrivKeyPEM, &pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
|
||||||
|
})
|
||||||
|
// set up our server certificate
|
||||||
|
cert := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2019),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Company, INC."},
|
||||||
|
Country: []string{"US"},
|
||||||
|
Province: []string{""},
|
||||||
|
Locality: []string{"San Francisco"},
|
||||||
|
StreetAddress: []string{"Golden Gate Bridge"},
|
||||||
|
PostalCode: []string{"94016"},
|
||||||
|
},
|
||||||
|
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||||
|
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
}
|
||||||
|
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
certPEM := new(bytes.Buffer)
|
||||||
|
pem.Encode(certPEM, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: certBytes,
|
||||||
|
})
|
||||||
|
certPrivKeyPEM := new(bytes.Buffer)
|
||||||
|
pem.Encode(certPrivKeyPEM, &pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
||||||
|
})
|
||||||
|
serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
serverTLSConf = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}
|
||||||
|
certpool := x509.NewCertPool()
|
||||||
|
certpool.AppendCertsFromPEM(caPEM.Bytes())
|
||||||
|
clientTLSConf = &tls.Config{
|
||||||
|
RootCAs: certpool,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadTLSConfig 从字节数据加载TLS配置
|
||||||
|
func LoadTLSConfig(certPEM, keyPEM []byte, caPEMs ...[]byte) (*tls.Config, error) {
|
||||||
|
// 加载客户端证书密钥对
|
||||||
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
MinVersion: tls.VersionTLS12, // 设置最低TLS版本
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载CA证书(如果有)
|
||||||
|
if len(caPEMs) > 0 {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
for _, caPEM := range caPEMs {
|
||||||
|
if !pool.AppendCertsFromPEM(caPEM) {
|
||||||
|
return nil, x509.SystemRootsError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.RootCAs = pool
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
Reference in New Issue
Block a user