mirror of
1
Fork 0

Auth flow fixes (#82)

* preliminary fixes to broken auth flow

* fix some auth/cookie weirdness

* fmt
This commit is contained in:
Tobi Smethurst 2021-07-08 11:32:31 +02:00 committed by GitHub
parent c71e55ecc4
commit 5460271bb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 49 deletions

View File

@ -36,8 +36,26 @@ const (
OauthTokenPath = "/oauth/token" OauthTokenPath = "/oauth/token"
// OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user) // OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user)
OauthAuthorizePath = "/oauth/authorize" OauthAuthorizePath = "/oauth/authorize"
sessionUserID = "userid"
sessionClientID = "client_id"
sessionRedirectURI = "redirect_uri"
sessionForceLogin = "force_login"
sessionResponseType = "response_type"
sessionCode = "code"
sessionScope = "scope"
) )
var sessionKeys []string = []string{
sessionUserID,
sessionClientID,
sessionRedirectURI,
sessionForceLogin,
sessionResponseType,
sessionCode,
sessionScope,
}
// Module implements the ClientAPIModule interface for // Module implements the ClientAPIModule interface for
type Module struct { type Module struct {
config *config.Config config *config.Config

View File

@ -26,7 +26,6 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -38,25 +37,30 @@ import (
func (m *Module) AuthorizeGETHandler(c *gin.Context) { func (m *Module) AuthorizeGETHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizeGETHandler") l := m.log.WithField("func", "AuthorizeGETHandler")
s := sessions.Default(c) s := sessions.Default(c)
s.Options(sessions.Options{
MaxAge: 120, // give the user 2 minutes to sign in before expiring their session
})
// UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow // UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
userID, ok := s.Get("userid").(string) userID, ok := s.Get(sessionUserID).(string)
if !ok || userID == "" { if !ok || userID == "" {
l.Trace("userid was empty, parsing form then redirecting to sign in page") l.Trace("userid was empty, parsing form then redirecting to sign in page")
if err := parseAuthForm(c, l); err != nil { form := &model.OAuthAuthorize{}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) if err := c.Bind(form); err != nil {
} else { l.Debugf("invalid auth form: %s", err)
c.Redirect(http.StatusFound, AuthSignInPath) return
} }
l.Tracef("parsed auth form: %+v", form)
if err := extractAuthForm(s, form); err != nil {
l.Debugf(fmt.Sprintf("error parsing form at /oauth/authorize: %s", err))
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.Redirect(http.StatusFound, AuthSignInPath)
return return
} }
// We can use the client_id on the session to retrieve info about the app associated with the client_id // We can use the client_id on the session to retrieve info about the app associated with the client_id
clientID, ok := s.Get("client_id").(string) clientID, ok := s.Get(sessionClientID).(string)
if !ok || clientID == "" { if !ok || clientID == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
return return
@ -64,7 +68,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
app := &gtsmodel.Application{ app := &gtsmodel.Application{
ClientID: clientID, ClientID: clientID,
} }
if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: app.ClientID}}, app); err != nil { if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return return
} }
@ -88,12 +92,12 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
} }
// Finally we should also get the redirect and scope of this particular request, as stored in the session. // Finally we should also get the redirect and scope of this particular request, as stored in the session.
redirect, ok := s.Get("redirect_uri").(string) redirect, ok := s.Get(sessionRedirectURI).(string)
if !ok || redirect == "" { if !ok || redirect == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"})
return return
} }
scope, ok := s.Get("scope").(string) scope, ok := s.Get(sessionScope).(string)
if !ok || scope == "" { if !ok || scope == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"})
return return
@ -107,7 +111,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
"appname": app.Name, "appname": app.Name,
"appwebsite": app.Website, "appwebsite": app.Website,
"redirect": redirect, "redirect": redirect,
"scope": scope, sessionScope: scope,
"user": acct.Username, "user": acct.Username,
}) })
} }
@ -123,39 +127,47 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
// We need to retrieve the original form submitted to the authorizeGEThandler, and // We need to retrieve the original form submitted to the authorizeGEThandler, and
// recreate it on the request so that it can be used further by the oauth2 library. // recreate it on the request so that it can be used further by the oauth2 library.
// So first fetch all the values from the session. // So first fetch all the values from the session.
forceLogin, ok := s.Get("force_login").(string)
forceLogin, ok := s.Get(sessionForceLogin).(string)
if !ok { if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"}) c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
return return
} }
responseType, ok := s.Get("response_type").(string)
responseType, ok := s.Get(sessionResponseType).(string)
if !ok || responseType == "" { if !ok || responseType == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"}) c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
return return
} }
clientID, ok := s.Get("client_id").(string)
clientID, ok := s.Get(sessionClientID).(string)
if !ok || clientID == "" { if !ok || clientID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"}) c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
return return
} }
redirectURI, ok := s.Get("redirect_uri").(string)
redirectURI, ok := s.Get(sessionRedirectURI).(string)
if !ok || redirectURI == "" { if !ok || redirectURI == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"}) c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
return return
} }
scope, ok := s.Get("scope").(string)
scope, ok := s.Get(sessionScope).(string)
if !ok { if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"}) c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
return return
} }
userID, ok := s.Get("userid").(string)
userID, ok := s.Get(sessionUserID).(string)
if !ok { if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"}) c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"})
return return
} }
// we're done with the session so we can clear it now // we're done with the session so we can clear it now
s.Clear() for _, key := range sessionKeys {
s.Delete(key)
}
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -163,12 +175,12 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
// now set the values on the request // now set the values on the request
values := url.Values{} values := url.Values{}
values.Set("force_login", forceLogin) values.Set(sessionForceLogin, forceLogin)
values.Set("response_type", responseType) values.Set(sessionResponseType, responseType)
values.Set("client_id", clientID) values.Set(sessionClientID, clientID)
values.Set("redirect_uri", redirectURI) values.Set(sessionRedirectURI, redirectURI)
values.Set("scope", scope) values.Set(sessionScope, scope)
values.Set("userid", userID) values.Set(sessionUserID, userID)
c.Request.Form = values c.Request.Form = values
l.Tracef("values on request set to %+v", c.Request.Form) l.Tracef("values on request set to %+v", c.Request.Form)
@ -178,18 +190,9 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
} }
} }
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores // extractAuthForm checks the given OAuthAuthorize form, and stores
// the values in the form into the session. // the values in the form into the session.
func parseAuthForm(c *gin.Context, l *logrus.Entry) error { func extractAuthForm(s sessions.Session, form *model.OAuthAuthorize) error {
s := sessions.Default(c)
// first make sure they've filled out the authorize form with the required values
form := &model.OAuthAuthorize{}
if err := c.ShouldBind(form); err != nil {
return err
}
l.Tracef("parsed form: %+v", form)
// these fields are *required* so check 'em // these fields are *required* so check 'em
if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" { if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
return errors.New("missing one of: response_type, client_id or redirect_uri") return errors.New("missing one of: response_type, client_id or redirect_uri")
@ -201,10 +204,10 @@ func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
} }
// save these values from the form so we can use them elsewhere in the session // save these values from the form so we can use them elsewhere in the session
s.Set("force_login", form.ForceLogin) s.Set(sessionForceLogin, form.ForceLogin)
s.Set("response_type", form.ResponseType) s.Set(sessionResponseType, form.ResponseType)
s.Set("client_id", form.ClientID) s.Set(sessionClientID, form.ClientID)
s.Set("redirect_uri", form.RedirectURI) s.Set(sessionRedirectURI, form.RedirectURI)
s.Set("scope", form.Scope) s.Set(sessionScope, form.Scope)
return s.Save() return s.Save()
} }

View File

@ -62,7 +62,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
return return
} }
s.Set("userid", userid) s.Set(sessionUserID, userid)
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return

View File

@ -22,16 +22,16 @@ package model
// See here: https://docs.joinmastodon.org/methods/apps/oauth/ // See here: https://docs.joinmastodon.org/methods/apps/oauth/
type OAuthAuthorize struct { type OAuthAuthorize struct {
// Forces the user to re-login, which is necessary for authorizing with multiple accounts from the same instance. // Forces the user to re-login, which is necessary for authorizing with multiple accounts from the same instance.
ForceLogin string `form:"force_login,omitempty"` ForceLogin string `form:"force_login" json:"force_login"`
// Should be set equal to `code`. // Should be set equal to `code`.
ResponseType string `form:"response_type"` ResponseType string `form:"response_type" json:"response_type"`
// Client ID, obtained during app registration. // Client ID, obtained during app registration.
ClientID string `form:"client_id"` ClientID string `form:"client_id" json:"client_id"`
// Set a URI to redirect the user to. // Set a URI to redirect the user to.
// If this parameter is set to urn:ietf:wg:oauth:2.0:oob then the authorization code will be shown instead. // If this parameter is set to urn:ietf:wg:oauth:2.0:oob then the authorization code will be shown instead.
// Must match one of the redirect URIs declared during app registration. // Must match one of the redirect URIs declared during app registration.
RedirectURI string `form:"redirect_uri"` RedirectURI string `form:"redirect_uri" json:"redirect_uri"`
// List of requested OAuth scopes, separated by spaces (or by pluses, if using query parameters). // List of requested OAuth scopes, separated by spaces (or by pluses, if using query parameters).
// Must be a subset of scopes declared during app registration. If not provided, defaults to read. // Must be a subset of scopes declared during app registration. If not provided, defaults to read.
Scope string `form:"scope,omitempty"` Scope string `form:"scope" json:"scope"`
} }

View File

@ -22,6 +22,7 @@ import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt" "fmt"
"net/http"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore" "github.com/gin-contrib/sessions/memstore"
@ -63,6 +64,14 @@ func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
} }
store := memstore.NewStore(rs.Auth, rs.Crypt) store := memstore.NewStore(rs.Auth, rs.Crypt)
store.Options(sessions.Options{
Path: "/",
Domain: cfg.Host,
MaxAge: 120, // 2 minutes
Secure: true, // only use cookie over https
HttpOnly: true, // exclude javascript from inspecting cookie
SameSite: http.SameSiteStrictMode, // https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-cookie-same-site-00#section-4.1.1
})
sessionName := fmt.Sprintf("gotosocial-%s", cfg.Host) sessionName := fmt.Sprintf("gotosocial-%s", cfg.Host)
engine.Use(sessions.Sessions(sessionName, store)) engine.Use(sessions.Sessions(sessionName, store))
return nil return nil