mirror of
1
Fork 0

[feature] `oob` oauth token support (#889)

* move helpful advice into oauth server

* rewrite HandleAuthorizeRequest to allow oob
This commit is contained in:
tobi 2022-10-08 13:49:56 +02:00 committed by GitHub
parent 5cf0f9950a
commit 3bb45b7179
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 252 additions and 42 deletions

View File

@ -23,6 +23,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router" "github.com/superseriousbusiness/gotosocial/internal/router"
@ -92,5 +93,7 @@ func (m *Module) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, OauthAuthorizePath, m.AuthorizePOSTHandler) s.AttachHandler(http.MethodPost, OauthAuthorizePath, m.AuthorizePOSTHandler)
s.AttachHandler(http.MethodGet, CallbackPath, m.CallbackGETHandler) s.AttachHandler(http.MethodGet, CallbackPath, m.CallbackGETHandler)
s.AttachHandler(http.MethodGet, oauth.OOBTokenPath, m.OobHandler)
return nil return nil
} }

View File

@ -33,12 +33,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
// helpfulAdvice is a handy hint to users;
// particularly important during the login flow
var helpfulAdvice = "If you arrived at this error during a login/oauth flow, please try clearing your session cookies and logging in again; if problems persist, make sure you're using the correct credentials"
// AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize // AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize
// The idea here is to present an oauth authorize page to the user, with a button // The idea here is to present an oauth authorize page to the user, with a button
// that they have to click to accept. // that they have to click to accept.
@ -57,7 +54,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
form := &model.OAuthAuthorize{} form := &model.OAuthAuthorize{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -76,7 +73,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if !ok || clientID == "" { if !ok || clientID == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionClientID) err := fmt.Errorf("key %s was not found in session", sessionClientID)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -86,9 +83,9 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
@ -100,9 +97,9 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
safe := fmt.Sprintf("user with id %s could not be retrieved", userID) safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
@ -114,9 +111,9 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
@ -131,7 +128,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if !ok || redirect == "" { if !ok || redirect == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionRedirectURI) err := fmt.Errorf("key %s was not found in session", sessionRedirectURI)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -139,7 +136,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if !ok || scope == "" { if !ok || scope == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionScope) err := fmt.Errorf("key %s was not found in session", sessionScope)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -208,7 +205,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
} }
if len(errs) != 0 { if len(errs) != 0 {
errs = append(errs, helpfulAdvice) errs = append(errs, oauth.HelpfulAdvice)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during AuthorizePOSTHandler"), errs...), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during AuthorizePOSTHandler"), errs...), m.processor.InstanceGet)
return return
} }
@ -219,9 +216,9 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
safe := fmt.Sprintf("user with id %s could not be retrieved", userID) safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
@ -233,9 +230,9 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
@ -245,8 +242,10 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
return return
} }
// we're done with the session now, so just clear it out if redirectURI != oauth.OOBURI {
m.clearSession(s) // we're done with the session now, so just clear it out
m.clearSession(s)
}
// we have to set the values on the request form // we have to set the values on the request form
// so that they're picked up by the oauth server // so that they're picked up by the oauth server
@ -263,8 +262,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
c.Request.Form.Set("state", clientState) c.Request.Form.Set("state", clientState)
} }
if err := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); err != nil { if errWithCode := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); errWithCode != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
} }
} }
@ -273,22 +272,22 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
func saveAuthFormToSession(s sessions.Session, form *model.OAuthAuthorize) gtserror.WithCode { func saveAuthFormToSession(s sessions.Session, form *model.OAuthAuthorize) gtserror.WithCode {
if form == nil { if form == nil {
err := errors.New("OAuthAuthorize form was nil") err := errors.New("OAuthAuthorize form was nil")
return gtserror.NewErrorBadRequest(err, err.Error(), helpfulAdvice) return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
} }
if form.ResponseType == "" { if form.ResponseType == "" {
err := errors.New("field response_type was not set on OAuthAuthorize form") err := errors.New("field response_type was not set on OAuthAuthorize form")
return gtserror.NewErrorBadRequest(err, err.Error(), helpfulAdvice) return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
} }
if form.ClientID == "" { if form.ClientID == "" {
err := errors.New("field client_id was not set on OAuthAuthorize form") err := errors.New("field client_id was not set on OAuthAuthorize form")
return gtserror.NewErrorBadRequest(err, err.Error(), helpfulAdvice) return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
} }
if form.RedirectURI == "" { if form.RedirectURI == "" {
err := errors.New("field redirect_uri was not set on OAuthAuthorize form") err := errors.New("field redirect_uri was not set on OAuthAuthorize form")
return gtserror.NewErrorBadRequest(err, err.Error(), helpfulAdvice) return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
} }
// set default scope to read // set default scope to read
@ -307,7 +306,7 @@ func saveAuthFormToSession(s sessions.Session, form *model.OAuthAuthorize) gtser
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
err := fmt.Errorf("error saving form values onto session: %s", err) err := fmt.Errorf("error saving form values onto session: %s", err)
return gtserror.NewErrorInternalError(err, helpfulAdvice) return gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice)
} }
return nil return nil

View File

@ -34,6 +34,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/validate" "github.com/superseriousbusiness/gotosocial/internal/validate"
) )
@ -91,7 +92,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
if !ok || clientID == "" { if !ok || clientID == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionClientID) err := fmt.Errorf("key %s was not found in session", sessionClientID)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -101,9 +102,9 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, helpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return

View File

@ -0,0 +1,111 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
func (m *Module) OobHandler(c *gin.Context) {
host := config.GetHost()
instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), host)
if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return
}
instanceGet := func(ctx context.Context, domain string) (*model.Instance, gtserror.WithCode) { return instance, nil }
oobToken := c.Query("code")
if oobToken == "" {
err := errors.New("no 'code' query value provided in callback redirect")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice), instanceGet)
return
}
s := sessions.Default(c)
errs := []string{}
scope, ok := s.Get(sessionScope).(string)
if !ok {
errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope))
}
userID, ok := s.Get(sessionUserID).(string)
if !ok {
errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID))
}
if len(errs) != 0 {
errs = append(errs, oauth.HelpfulAdvice)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during OobHandler"), errs...), m.processor.InstanceGet)
return
}
user, err := m.db.GetUserByID(c.Request.Context(), userID)
if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
}
api.ErrorHandler(c, errWithCode, instanceGet)
return
}
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID)
var errWithCode gtserror.WithCode
if err == db.ErrNoEntries {
errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice)
} else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
}
api.ErrorHandler(c, errWithCode, instanceGet)
return
}
// we're done with the session now, so just clear it out
m.clearSession(s)
c.HTML(http.StatusOK, "oob.tmpl", gin.H{
"instance": instance,
"user": acct.Username,
"oobToken": oobToken,
"scope": scope,
})
}

View File

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -86,7 +87,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
form := &login{} form := &login{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -101,7 +102,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
s.Set(sessionUserID, userid) s.Set(sessionUserID, userid)
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
err := fmt.Errorf("error saving user id onto session: %s", err) err := fmt.Errorf("error saving user id onto session: %s", err)
api.ErrorHandler(c, gtserror.NewErrorInternalError(err, helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
} }
c.Redirect(http.StatusFound, OauthAuthorizePath) c.Redirect(http.StatusFound, OauthAuthorizePath)
@ -140,5 +141,5 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
// only a generic 'safe' error message to the user, to not give any info away. // only a generic 'safe' error message to the user, to not give any info away.
func incorrectPassword(err error) (string, gtserror.WithCode) { func incorrectPassword(err error) (string, gtserror.WithCode) {
safeErr := fmt.Errorf("password/email combination was incorrect") safeErr := fmt.Errorf("password/email combination was incorrect")
return "", gtserror.NewErrorUnauthorized(err, safeErr.Error(), helpfulAdvice) return "", gtserror.NewErrorUnauthorized(err, safeErr.Error(), oauth.HelpfulAdvice)
} }

View File

@ -22,6 +22,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -49,12 +50,19 @@ const (
// of a Client who has successfully passed Bearer token authorization. // of a Client who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application // The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
SessionAuthorizedApplication = "authorized_app" SessionAuthorizedApplication = "authorized_app"
// OOBURI is the out-of-band oauth token uri
OOBURI = "urn:ietf:wg:oauth:2.0:oob"
// OOBTokenPath is the path to redirect out-of-band token requests to.
OOBTokenPath = "/oob"
// HelpfulAdvice is a handy hint to users;
// particularly important during the login flow
HelpfulAdvice = "If you arrived at this error during a login/oauth flow, please try clearing your session cookies and logging in again; if problems persist, make sure you're using the correct credentials"
) )
// Server wraps some oauth2 server functions in an interface, exposing only what is needed // Server wraps some oauth2 server functions in an interface, exposing only what is needed
type Server interface { type Server interface {
HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
GenerateUserAccessToken(ctx context.Context, ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) GenerateUserAccessToken(ctx context.Context, ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error)
@ -123,13 +131,13 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
gt, tgr, err := s.server.ValidationTokenRequest(r) gt, tgr, err := s.server.ValidationTokenRequest(r)
if err != nil { if err != nil {
help := fmt.Sprintf("could not validate token request: %s", err) help := fmt.Sprintf("could not validate token request: %s", err)
return nil, gtserror.NewErrorBadRequest(err, help) return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
} }
ti, err := s.server.GetAccessToken(ctx, gt, tgr) ti, err := s.server.GetAccessToken(ctx, gt, tgr)
if err != nil { if err != nil {
help := fmt.Sprintf("could not get access token: %s", err) help := fmt.Sprintf("could not get access token: %s", err)
return nil, gtserror.NewErrorBadRequest(err, help) return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
} }
data := s.server.GetTokenData(ti) data := s.server.GetTokenData(ti)
@ -145,7 +153,7 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
} }
default: default:
err := errors.New("expires_in was set on token response, but was not an int64") err := errors.New("expires_in was set on token response, but was not an int64")
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
} }
} }
@ -155,9 +163,88 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
return data, nil return data, nil
} }
func (s *s) errorOrRedirect(err error, w http.ResponseWriter, req *server.AuthorizeRequest) gtserror.WithCode {
if req == nil {
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
data, _, _ := s.server.GetErrorData(err)
uri, err := s.server.GetRedirectURI(req, data)
if err != nil {
return gtserror.NewErrorInternalError(err, HelpfulAdvice)
}
w.Header().Set("Location", uri)
w.WriteHeader(http.StatusFound)
return nil
}
// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function // HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode {
return s.server.HandleAuthorizeRequest(w, r) ctx := r.Context()
req, err := s.server.ValidationAuthorizeRequest(r)
if err != nil {
return s.errorOrRedirect(err, w, req)
}
// user authorization
userID, err := s.server.UserAuthorizationHandler(w, r)
if err != nil {
return s.errorOrRedirect(err, w, req)
}
if userID == "" {
help := "userID was empty"
return gtserror.NewErrorUnauthorized(err, help, HelpfulAdvice)
}
req.UserID = userID
// specify the scope of authorization
if fn := s.server.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
return s.errorOrRedirect(err, w, req)
} else if scope != "" {
req.Scope = scope
}
}
// specify the expiration time of access token
if fn := s.server.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
return s.errorOrRedirect(err, w, req)
}
req.AccessTokenExp = exp
}
ti, err := s.server.GetAuthorizeToken(ctx, req)
if err != nil {
return s.errorOrRedirect(err, w, req)
}
// If the redirect URI is empty, the default domain provided by the client is used.
if req.RedirectURI == "" {
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
if err != nil {
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
req.RedirectURI = client.GetDomain()
}
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))
if err != nil {
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
if strings.Contains(uri, OOBURI) {
w.Header().Set("Location", strings.ReplaceAll(uri, OOBURI, OOBTokenPath))
} else {
w.Header().Set("Location", uri)
}
w.WriteHeader(http.StatusFound)
return nil
} }
// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function // ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function

View File

@ -24,7 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
func (p *processor) OAuthHandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { func (p *processor) OAuthHandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode {
// todo: some kind of metrics stuff here // todo: some kind of metrics stuff here
return p.oauthServer.HandleAuthorizeRequest(w, r) return p.oauthServer.HandleAuthorizeRequest(w, r)
} }

View File

@ -160,7 +160,7 @@ type Processor interface {
NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode
OAuthHandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) OAuthHandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
OAuthHandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error OAuthHandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
// SearchGet performs a search with the given params, resolving/dereferencing remotely as desired // SearchGet performs a search with the given params, resolving/dereferencing remotely as desired
SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode)

8
web/template/oob.tmpl Normal file
View File

@ -0,0 +1,8 @@
{{ template "header.tmpl" .}}
<main>
<h1>Hi {{ .user }}!</h1>
<p>Here's your out-of-band token with scope <em>{{.scope}}</em>:</p>
<p><code>{{ .oobToken }}</code><p>
<p>Use it wisely!</p>
</main>
{{ template "footer.tmpl" .}}