package handler import ( "encoding/base64" "errors" "github.com/google/uuid" "github.com/loveuer/nf" "github.com/loveuer/nf/nft/log" "gorm.io/gorm" "net/http" "strings" "uauth/model" "uauth/pkg/cache" "uauth/pkg/store" "uauth/tool" ) func verifyClient(c *nf.Ctx) (*model.Client, error) { var ( err error basic string bs []byte strs []string client = new(model.Client) ) if basic = c.Get("Authorization"); basic == "" { return nil, errors.New("authorization header missing") } switch { case strings.HasPrefix(basic, "Basic "): basic = strings.TrimPrefix(basic, "Basic ") default: return nil, errors.New("authorization scheme not supported") } if bs, err = base64.StdEncoding.DecodeString(basic); err != nil { log.Warn("[Token] base64 decode failed, raw = %s, err = %s", basic, err.Error()) return nil, errors.New("invalid basic authorization") } if strs = strings.SplitN(string(bs), ":", 2); len(strs) != 2 { log.Warn("[Token] basic split err, decode = %s", string(bs)) return nil, errors.New("invalid basic authorization") } clientId, clientSecret := strs[0], strs[1] if err = store.Default.Session(tool.TimeoutCtx(c.Context(), 3)). Model(&model.Client{}). Where("client_id", clientId). Take(client). Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("client invalid") } log.Error("[Token] db take client by id = %s, err = %s", clientId, err.Error()) return nil, errors.New("unknown server error") } if client.ClientSecret != clientSecret { log.Warn("[Token] client_secret invalid, want = %s, got = %s", client.ClientSecret, clientSecret) return nil, errors.New("client secret invalid") } return client, nil } func HandleToken(c *nf.Ctx) error { var ( err error opId uint64 op = new(model.User) token string client *model.Client grantType = c.Form("grant_type") ) switch grantType { case "password": if client, err = verifyClient(c); err != nil { return c.Status(http.StatusBadRequest).SendString("Bad Request: " + err.Error()) } username := c.Form("username") password := c.Form("password") if err = store.Default.Session(tool.TimeoutCtx(c.Context(), 3)). Model(&model.User{}). Where("username = ?", username). Take(op).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid username or password") } log.Error("[Token] db take user by username = %s, err = %s", username, err.Error()) return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") } if !tool.ComparePassword(password, op.Password) { return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid username or password") } case "authorization_code": if client, err = verifyClient(c); err != nil { return c.Status(http.StatusBadRequest).SendString("Bad Request: " + err.Error()) } code := c.Form("code") if code == "" { return c.Status(http.StatusBadRequest).SendString("Bad Request: no code provided") } if err = cache.Client.GetScan(tool.Timeout(2), cache.Prefix+"auth_code:"+code).Scan(&opId); err != nil { if errors.Is(err, cache.ErrorKeyNotFound) { return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid code") } log.Error("[S] handleToken: get code from cache err = %s", err.Error()) return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") } op.Id = opId if err = store.Default.Session(tool.TimeoutCtx(c.Context(), 3)). Take(op).Error; err != nil { log.Error("[S] handleToken: get op by id err, id = %d, err = %s", opId, err.Error()) return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") } default: return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid grant type") } if token, err = op.JwtEncode(); err != nil { log.Error("[S] handleToken: encode token err, id = %d, err = %s", opId, err.Error()) return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") } refreshToken := uuid.New().String() _ = client return c.JSON(map[string]any{ "access_token": token, "refresh_token": refreshToken, "token_type": "Bearer", "expires_in": 24 * 3600, }) }