package handler import ( _ "embed" "errors" "github.com/google/uuid" "github.com/loveuer/nf" "github.com/loveuer/nf/nft/resp" "gorm.io/gorm" "net/http" "time" "uauth/internal/store/cache" "uauth/internal/store/db" "uauth/model" ) var ( //go:embed serve_authorize.html pageAuthorize string ) func Authorize(c *nf.Ctx) error { type Req struct { ClientId string `query:"client_id"` ClientSecret string `query:"client_secret"` ResponseType string `query:"response_type"` RedirectURI string `query:"redirect_uri"` Scope string `query:"scope"` } var ( ok bool op *model.User req = new(Req) err error client = &model.Client{} authRecord = &model.AuthorizationRecord{} ) if err = c.QueryParser(req); err != nil { return resp.Resp400(c, err) } if req.ResponseType != "code" { return c.Status(http.StatusBadRequest).SendString("Invalid request") } if op, ok = c.Locals("user").(*model.User); !ok { return resp.Resp401(c, nil) } if err = db.Default.Session().Where("client_id", req.ClientId).Take(client).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return resp.Resp400(c, err, "无效的 client id") } } if err = db.Default.Session().Model(&model.AuthorizationRecord{}). Where("user_id", op.Id). Where("client_id", client.ClientId). Take(authRecord). Error; err != nil { // 用户第一次对该 client 进行授权 if errors.Is(err, gorm.ErrRecordNotFound) { state := uuid.New().String()[:8] if err = cache.Client.SetEx(c.Context(), cache.Prefix+"state_code:"+state, nil, 10*time.Minute); err != nil { return resp.Resp500(c, err.Error()) } return c.RenderHTML("authorize", pageAuthorize, map[string]any{ "user": map[string]any{ "username": op.Username, "avatar": "https://picsum.photos/200", }, "client_id": req.ClientId, "client_secret": req.ClientSecret, "redirect_uri": req.RedirectURI, "scope": req.Scope, "state": state, }) } return resp.Resp500(c, err) } // 当用户已经授权过时 // 生成授权码并缓存授权码 authorizationCode := uuid.New().String()[:8] if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil { return resp.Resp500(c, err) } // 重定向到回调 URL 并附带授权码 return c.Redirect(req.RedirectURI+"?code="+authorizationCode, http.StatusFound) }