package auth import ( "errors" "github.com/loveuer/nf" "github.com/loveuer/nf/nft/log" "github.com/loveuer/nf/nft/resp" "net/http" "time" "uauth/internal/store/cache" "uauth/internal/tool" "uauth/model" ) type Config struct { IgnoreFn func(c *nf.Ctx) bool TokenFn func(c *nf.Ctx) (string, bool) GetUserFn func(c *nf.Ctx, token string) (*model.User, error) NextOnError bool } var ( defaultIgnoreFn = func(c *nf.Ctx) bool { return false } defaultTokenFn = func(c *nf.Ctx) (string, bool) { var token string if token = c.Request.Header.Get("Authorization"); token != "" { return token, true } if token = c.Query("access_token"); token != "" { return token, true } if token = c.Cookies("access_token"); token != "" { return token, true } return "", false } defaultGetUserFn = func(c *nf.Ctx, token string) (*model.User, error) { var ( err error op = new(model.User) key = cache.Prefix + "token:" + token ) if err = cache.Client.GetExScan(tool.Timeout(3), key, 24*time.Hour).Scan(op); err != nil { if errors.Is(err, cache.ErrorKeyNotFound) { return nil, err } log.Error("[M] cache client get user by token key = %s, err = %s", key, err.Error()) return nil, errors.New("Internal Server Error") } return op, nil } ) func New(cfgs ...*Config) nf.HandlerFunc { var cfg = &Config{} if len(cfgs) > 0 && cfgs[0] != nil { cfg = cfgs[0] } if cfg.IgnoreFn == nil { cfg.IgnoreFn = defaultIgnoreFn } if cfg.TokenFn == nil { cfg.TokenFn = defaultTokenFn } if cfg.GetUserFn == nil { cfg.GetUserFn = defaultGetUserFn } return func(c *nf.Ctx) error { if cfg.IgnoreFn(c) { return c.Next() } token, ok := cfg.TokenFn(c) if !ok { if cfg.NextOnError { return c.Next() } return resp.Resp401(c, nil, "请登录") } op, err := cfg.GetUserFn(c, token) if err != nil { if cfg.NextOnError { return c.Next() } if errors.Is(err, cache.ErrorKeyNotFound) { return c.Status(http.StatusUnauthorized).JSON(map[string]any{ "status": 500, "msg": "用户认证信息不存在或已过期, 请重新登录", }) } return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") } c.Locals("user", op) return c.Next() } }