From f4a0c0ef31b25baa721edb5666867ed38a2e4a51 Mon Sep 17 00:00:00 2001 From: Marc Date: Sun, 2 Mar 2025 15:21:09 +0000 Subject: [PATCH] feat(auth): sso fallback mapping (#3068) Reviewed-on: https://kolaente.dev/vikunja/vikunja/pulls/3068 Reviewed-by: konrad Co-authored-by: Marc Co-committed-by: Marc --- config-raw.json | 10 + pkg/models/error.go | 28 +++ pkg/modules/auth/openid/openid.go | 283 +++++++++++++++---------- pkg/modules/auth/openid/openid_test.go | 89 +++++++- pkg/modules/auth/openid/providers.go | 34 ++- pkg/web/web.go | 5 + 6 files changed, 323 insertions(+), 126 deletions(-) diff --git a/config-raw.json b/config-raw.json index fe94ba034..b90212ebf 100644 --- a/config-raw.json +++ b/config-raw.json @@ -699,6 +699,16 @@ "key": "scope", "default_value": "openid email profile", "comment": "The scope necessary to use oidc.\nIf you want to use the Feature to create and assign to Vikunja teams via oidc, you have to add the custom \"vikunja_scope\" and check [openid.md](https://vikunja.io/docs/openid/).\ne.g. scope: openid email profile vikunja_scope" + }, + { + "key": "usernamefallback", + "default_value": "false", + "comment": "This option allows to look for a local account where the OIDC Issuer match the Vikunja local username. Allowed value is either `true` or `false`. That option can be combined with `emailfallback`.\nUse with caution, this can allow the 3rd party provider to connect to *any* local account and therefore potential account hijaking." + }, + { + "key": "emailfallback", + "default_value": "false", + "comment": "This option allows to look for a local account where the OIDC user's email match the Vikunja local email. Allowed value is either `true` or `false`. That option can be combined with `usernamefallback`.\nUse with caution, this can allow the 3rd party provider to connect to *any* local account and therefore potential account hijaking." } ] } diff --git a/pkg/models/error.go b/pkg/models/error.go index f17199391..8ac854e23 100644 --- a/pkg/models/error.go +++ b/pkg/models/error.go @@ -1967,3 +1967,31 @@ func (err *ErrInvalidAPITokenPermission) HTTPError() web.HTTPError { Message: fmt.Sprintf("The permission %s of group %s is invalid.", err.Permission, err.Group), } } + +// OIDC errors +const ErrCodeOpenIDError = 15001 + +type ErrOpenIDBadRequest struct { + Message string +} + +func (err *ErrOpenIDBadRequest) Error() string { + return err.Message +} + +func (err ErrOpenIDBadRequest) HTTPError() web.HTTPError { + return web.HTTPError{ + HTTPCode: http.StatusBadRequest, + Code: ErrCodeOpenIDError, + Message: err.Message, + } +} + +type ErrOpenIDBadRequestWithDetails struct { + Message string + Details interface{} +} + +func (err *ErrOpenIDBadRequestWithDetails) Error() string { + return err.Message +} diff --git a/pkg/modules/auth/openid/openid.go b/pkg/modules/auth/openid/openid.go index a3219e02a..0c24ee582 100644 --- a/pkg/modules/auth/openid/openid.go +++ b/pkg/modules/auth/openid/openid.go @@ -48,16 +48,18 @@ type Callback struct { // Provider is the structure of an OpenID Connect provider type Provider struct { - Name string `json:"name"` - Key string `json:"key"` - OriginalAuthURL string `json:"-"` - AuthURL string `json:"auth_url"` - LogoutURL string `json:"logout_url"` - ClientID string `json:"client_id"` - Scope string `json:"scope"` - ClientSecret string `json:"-"` - openIDProvider *oidc.Provider - Oauth2Config *oauth2.Config `json:"-"` + Name string `json:"name"` + Key string `json:"key"` + OriginalAuthURL string `json:"-"` + AuthURL string `json:"auth_url"` + LogoutURL string `json:"logout_url"` + ClientID string `json:"client_id"` + Scope string `json:"scope"` + EmailFallback bool `json:"email_fallback"` + UsernameFallback bool `json:"username_fallback"` + ClientSecret string `json:"-"` + openIDProvider *oidc.Provider + Oauth2Config *oauth2.Config `json:"-"` } type claims struct { Email string `json:"email"` @@ -110,112 +112,29 @@ func (p *Provider) Issuer() (issuerURL string, err error) { // @Failure 500 {object} models.Message "Internal error" // @Router /auth/openid/{provider}/callback [post] func HandleCallback(c echo.Context) error { - cb := &Callback{} - if err := c.Bind(cb); err != nil { - return c.JSON(http.StatusBadRequest, models.Message{Message: "Bad data"}) - } - // Check if the provider exists - providerKey := c.Param("provider") - provider, err := GetProvider(providerKey) + provider, oauthToken, idToken, err := getProviderAndOidcTokens(c) if err != nil { - return handler.HandleHTTPError(err) - } - if provider == nil { - return c.JSON(http.StatusBadRequest, models.Message{Message: "Provider does not exist"}) - } - - log.Debugf("Trying to authenticate user using provider: %s", provider.Key) - - provider.Oauth2Config.RedirectURL = cb.RedirectURL - - // Parse the access & ID token - oauth2Token, err := provider.Oauth2Config.Exchange(context.Background(), cb.Code) - if err != nil { - var rerr *oauth2.RetrieveError - if errors.As(err, &rerr) { - - details := make(map[string]interface{}) - if err := json.Unmarshal(rerr.Body, &details); err != nil { - log.Errorf("Error unmarshalling token for provider %s: %v", provider.Name, err) - return handler.HandleHTTPError(err) - } - - log.Error(err) + var detailedErr *models.ErrOpenIDBadRequestWithDetails + if errors.As(err, &detailedErr) { return c.JSON(http.StatusBadRequest, map[string]interface{}{ - "message": "Could not authenticate against third party.", - "details": details, + "message": detailedErr.Message, + "details": detailedErr.Details, }) } - return handler.HandleHTTPError(err) } - // Extract the ID Token from OAuth2 token. - rawIDToken, ok := oauth2Token.Extra("id_token").(string) - if !ok { - return c.JSON(http.StatusBadRequest, models.Message{Message: "Missing token"}) - } - - verifier := provider.openIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID}) - - // Parse and verify ID Token payload. - idToken, err := verifier.Verify(context.Background(), rawIDToken) + cl, err := getClaims(provider, oauthToken, idToken) if err != nil { - log.Errorf("Error verifying token for provider %s: %v", provider.Name, err) return handler.HandleHTTPError(err) } - // Extract custom claims - cl := &claims{} - - err = idToken.Claims(cl) - if err != nil { - log.Errorf("Error getting token claims for provider %s: %v", provider.Name, err) - return handler.HandleHTTPError(err) - } - - if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" { - info, err := provider.openIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token)) - if err != nil { - log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err) - return handler.HandleHTTPError(err) - } - - cl2 := &claims{} - err = info.Claims(cl2) - if err != nil { - log.Errorf("Error parsing userinfo claims for provider %s: %v", provider.Name, err) - return handler.HandleHTTPError(err) - } - - if cl.Email == "" { - cl.Email = cl2.Email - } - - if cl.Name == "" { - cl.Name = cl2.Name - } - - if cl.PreferredUsername == "" { - cl.PreferredUsername = cl2.PreferredUsername - } - - if cl.PreferredUsername == "" && cl2.Nickname != "" { - cl.PreferredUsername = cl2.Nickname - } - - if cl.Email == "" { - log.Errorf("Claim does not contain an email address for provider %s", provider.Name) - return handler.HandleHTTPError(&user.ErrNoOpenIDEmailProvided{}) - } - } - s := db.NewSession() defer s.Close() // Check if we have seen this user before - u, err := getOrCreateUser(s, cl, idToken.Issuer, idToken.Subject) + u, err := getOrCreateUser(s, cl, provider, idToken) if err != nil { _ = s.Rollback() log.Errorf("Error creating new user for provider %s: %v", provider.Name, err) @@ -403,40 +322,71 @@ func GetOrCreateTeamsByOIDC(s *xorm.Session, teamData []*models.OIDCTeam, u *use return te, err } -func getOrCreateUser(s *xorm.Session, cl *claims, issuer, subject string) (u *user.User, err error) { +func getOrCreateUser(s *xorm.Session, cl *claims, provider *Provider, idToken *oidc.IDToken) (u *user.User, err error) { + + // set defaults + fallbackMatchFound := false + alreadyCreatedFromIssuer := false + + // first check if the user already signed up using the provider - // Check if the user exists for that issuer and subject u, err = user.GetUserWithEmail(s, &user.User{ - Issuer: issuer, - Subject: subject, + Issuer: idToken.Issuer, + Subject: idToken.Subject, }) if err != nil && !user.IsErrUserDoesNotExist(err) { return nil, err } + alreadyCreatedFromIssuer = err == nil // found if no error, not found if we reach it here despite an error - // If no user exists, create one with the preferred username if it is not already taken - if user.IsErrUserDoesNotExist(err) { + if !alreadyCreatedFromIssuer && (provider.EmailFallback || provider.UsernameFallback) { + + // try finding the user on fallback mappingproperties + + searchUser := &user.User{ + Issuer: user.IssuerLocal, + } + if provider.UsernameFallback { + // Match oidc subject on username as each is unique identifier in its own referential + // Discouraged if multiple account providers are used. + searchUser.Username = idToken.Subject + } + if provider.EmailFallback { + // Used alone, allow for someone to connect from various provider to the same account + // Discouraged for untrusted provider where someone can set email without verification + // Note : mapping on email prevent from auto-updating user email + searchUser.Email = cl.Email + } + + // Check if the user exists for the given fallback matching options + u, err = user.GetUserWithEmail(s, searchUser) + if err != nil && !user.IsErrUserDoesNotExist(err) { + return nil, err + } + fallbackMatchFound = err == nil // found if no error, not found if we reach it here despite an error + } + + if !alreadyCreatedFromIssuer && !fallbackMatchFound { + + // If no user exists, create one with the preferred username if it is not already taken uu := &user.User{ Username: strings.ReplaceAll(cl.PreferredUsername, " ", "-"), Email: cl.Email, Name: cl.Name, Status: user.StatusActive, - Issuer: issuer, - Subject: subject, + Issuer: idToken.Issuer, + Subject: idToken.Subject, } - return auth.CreateUserWithRandomUsername(s, uu) - } + } else if alreadyCreatedFromIssuer { - // If it exists, check if the email address changed and change it if not - if cl.Email != u.Email || cl.Name != u.Name { + // try updating user.Name and/or user.Email if necessary if cl.Email != u.Email { u.Email = cl.Email } if cl.Name != u.Name { u.Name = cl.Name } - u, err = user.UpdateUser(s, u, false) if err != nil { return nil, err @@ -445,3 +395,110 @@ func getOrCreateUser(s *xorm.Session, cl *claims, issuer, subject string) (u *us return } + +func getClaims(provider *Provider, oauth2Token *oauth2.Token, idToken *oidc.IDToken) (*claims, error) { + + cl := &claims{} + err := idToken.Claims(cl) + if err != nil { + log.Errorf("Error getting token claims for provider %s: %v", provider.Name, err) + return nil, err + } + + if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" { + info, err := provider.openIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token)) + if err != nil { + log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err) + return nil, err + } + + cl2 := &claims{} + err = info.Claims(cl2) + if err != nil { + log.Errorf("Error parsing userinfo claims for provider %s: %v", provider.Name, err) + return nil, err + } + + if cl.Email == "" { + cl.Email = cl2.Email + } + + if cl.Name == "" { + cl.Name = cl2.Name + } + + if cl.PreferredUsername == "" { + cl.PreferredUsername = cl2.PreferredUsername + } + + if cl.PreferredUsername == "" && cl2.Nickname != "" { + cl.PreferredUsername = cl2.Nickname + } + + if cl.Email == "" { + log.Errorf("Claim does not contain an email address for provider %s", provider.Name) + return nil, &user.ErrNoOpenIDEmailProvided{} + } + } + return cl, nil +} + +func getProviderAndOidcTokens(c echo.Context) (*Provider, *oauth2.Token, *oidc.IDToken, error) { + + cb := &Callback{} + if err := c.Bind(cb); err != nil { + return nil, nil, nil, &models.ErrOpenIDBadRequest{Message: "Bad data"} + } + + // Check if the provider exists + providerKey := c.Param("provider") + provider, err := GetProvider(providerKey) + if err != nil { + return nil, nil, nil, err + } + if provider == nil { + return nil, nil, nil, &models.ErrOpenIDBadRequest{Message: "Provider does not exist"} + } + + log.Debugf("Trying to authenticate user using provider: %s", provider.Key) + + provider.Oauth2Config.RedirectURL = cb.RedirectURL + // Parse the access & ID token + oauth2Token, err := provider.Oauth2Config.Exchange(context.Background(), cb.Code) + if err != nil { + var rerr *oauth2.RetrieveError + if errors.As(err, &rerr) { + + details := make(map[string]interface{}) + if err := json.Unmarshal(rerr.Body, &details); err != nil { + log.Errorf("Error unmarshalling token for provider %s: %v", provider.Name, err) + return nil, nil, nil, err + } + + log.Error(err) + return nil, nil, nil, &models.ErrOpenIDBadRequestWithDetails{ + Message: "Could not authenticate against third party.", + Details: details, + } + } + + return nil, nil, nil, err + } + + // Extract the ID Token from OAuth2 token. + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return nil, nil, nil, &models.ErrOpenIDBadRequest{Message: "Missing token"} + } + + verifier := provider.openIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID}) + + // Parse and verify ID Token payload. + idToken, err := verifier.Verify(context.Background(), rawIDToken) + if err != nil { + log.Errorf("Error verifying token for provider %s: %v", provider.Name, err) + return nil, nil, nil, err + } + + return provider, oauth2Token, idToken, nil +} diff --git a/pkg/modules/auth/openid/openid_test.go b/pkg/modules/auth/openid/openid_test.go index e50f26d81..c4e303e86 100644 --- a/pkg/modules/auth/openid/openid_test.go +++ b/pkg/modules/auth/openid/openid_test.go @@ -23,6 +23,7 @@ import ( "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/user" "code.vikunja.io/api/pkg/utils" + "github.com/coreos/go-oidc/v3/oidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -37,7 +38,10 @@ func TestGetOrCreateUser(t *testing.T) { Email: "test@example.com", PreferredUsername: "someUserWhoDoesNotExistYet", } - u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345") + provider := &Provider{} + idToken := &oidc.IDToken{Issuer: "https://some.issuer", Subject: "12345"} + + u, err := getOrCreateUser(s, cl, provider, idToken) require.NoError(t, err) err = s.Commit() require.NoError(t, err) @@ -57,7 +61,10 @@ func TestGetOrCreateUser(t *testing.T) { Email: "test@example.com", PreferredUsername: "", } - u, err := getOrCreateUser(s, cl, "https://some.issuer", "12345") + provider := &Provider{} + idToken := &oidc.IDToken{Issuer: "https://some.issuer", Subject: "12345"} + + u, err := getOrCreateUser(s, cl, provider, idToken) require.NoError(t, err) assert.NotEmpty(t, u.Username) err = s.Commit() @@ -76,7 +83,10 @@ func TestGetOrCreateUser(t *testing.T) { cl := &claims{ Email: "", } - _, err := getOrCreateUser(s, cl, "https://some.issuer", "12345") + provider := &Provider{} + idToken := &oidc.IDToken{Issuer: "https://some.issuer", Subject: "12345"} + + _, err := getOrCreateUser(s, cl, provider, idToken) require.Error(t, err) }) t.Run("existing user, different email address", func(t *testing.T) { @@ -87,7 +97,10 @@ func TestGetOrCreateUser(t *testing.T) { cl := &claims{ Email: "other-email-address@some.service.com", } - u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345") + provider := &Provider{} + idToken := &oidc.IDToken{Issuer: "https://some.service.com", Subject: "12345"} + + u, err := getOrCreateUser(s, cl, provider, idToken) require.NoError(t, err) err = s.Commit() require.NoError(t, err) @@ -111,7 +124,10 @@ func TestGetOrCreateUser(t *testing.T) { }, } - u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345") + provider := &Provider{} + idToken := &oidc.IDToken{Issuer: "https://some.service.com", Subject: "12345"} + + u, err := getOrCreateUser(s, cl, provider, idToken) require.NoError(t, err) teamData, errs := getTeamDataFromToken(cl.VikunjaGroups, nil) for _, err := range errs { @@ -148,7 +164,10 @@ func TestGetOrCreateUser(t *testing.T) { }, } - u, err := getOrCreateUser(s, cl, "https://some.service.com", "12345") + provider := &Provider{} + idToken := &oidc.IDToken{Issuer: "https://some.service.com", Subject: "12345"} + + u, err := getOrCreateUser(s, cl, provider, idToken) require.NoError(t, err) teamData, errs := getTeamDataFromToken(cl.VikunjaGroups, nil) for _, err := range errs { @@ -231,4 +250,62 @@ func TestGetOrCreateUser(t *testing.T) { "id": oidcTeams, }) }) + t.Run("ProviderFallback : Match to existing local user on username", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + cl := &claims{} + provider := &Provider{ + UsernameFallback: true, + } + idToken := &oidc.IDToken{Issuer: "https://some.issuer", Subject: "user11"} + + u, err := getOrCreateUser(s, cl, provider, idToken) + require.NoError(t, err) + assert.Equal(t, idToken.Subject, u.Username, "subject match username") + assert.Equal(t, user.IssuerLocal, u.Issuer, "User should be a local one") + assert.Equal(t, 11, int(u.ID), "user id 11 expected") + }) + t.Run("ProviderFallback : Match to existing local user on email", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + cl := &claims{ + Email: "user11@example.com", + } + provider := &Provider{ + EmailFallback: true, + } + idToken := &oidc.IDToken{Issuer: "https://some.issuer", Subject: "user11"} + + u, err := getOrCreateUser(s, cl, provider, idToken) + require.NoError(t, err) + assert.Equal(t, cl.Email, u.Email, "email should match") + assert.Equal(t, user.IssuerLocal, u.Issuer, "User should be a local one") + assert.Equal(t, 11, int(u.ID), "user id 11 expected") + }) + t.Run("ProviderFallback : Match to existing local user on username and email", func(t *testing.T) { + + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + cl := &claims{ + Email: "user11@example.com", + } + provider := &Provider{ + UsernameFallback: true, + EmailFallback: true, + } + idToken := &oidc.IDToken{Issuer: "https://some.issuer", Subject: "user11"} + + u, err := getOrCreateUser(s, cl, provider, idToken) + require.NoError(t, err) + assert.Equal(t, cl.Email, u.Email, "email should match") + assert.Equal(t, idToken.Subject, u.Username, "subject match username") + assert.Equal(t, user.IssuerLocal, u.Issuer, "User should be a local one") + assert.Equal(t, 11, int(u.ID), "user id 11 expected") + }) } diff --git a/pkg/modules/auth/openid/providers.go b/pkg/modules/auth/openid/providers.go index a646e8658..e779da83e 100644 --- a/pkg/modules/auth/openid/providers.go +++ b/pkg/modules/auth/openid/providers.go @@ -123,6 +123,8 @@ func getProviderFromMap(pi map[string]interface{}, key string) (provider *Provid []string{ "logouturl", "scope", + "emailfallback", + "usernamefallback", }, requiredKeys..., ) @@ -162,14 +164,32 @@ func getProviderFromMap(pi map[string]interface{}, key string) (provider *Provid scope = "openid profile email" } + var emailFallback = false + emailFallbackValue, exists := pi["emailfallback"] + if exists { + emailFallbackTypedValue, ok := emailFallbackValue.(bool) + if ok { + emailFallback = emailFallbackTypedValue + } + } + var usernameFallback = false + usernameFallbackValue, exists := pi["usernamefallback"] + if exists { + usernameFallbackTypedValue, ok := usernameFallbackValue.(bool) + if ok { + usernameFallback = usernameFallbackTypedValue + } + } provider = &Provider{ - Name: name, - Key: key, - AuthURL: pi["authurl"].(string), - OriginalAuthURL: pi["authurl"].(string), - ClientSecret: pi["clientsecret"].(string), - LogoutURL: logoutURL, - Scope: scope, + Name: name, + Key: key, + AuthURL: pi["authurl"].(string), + OriginalAuthURL: pi["authurl"].(string), + ClientSecret: pi["clientsecret"].(string), + LogoutURL: logoutURL, + Scope: scope, + EmailFallback: emailFallback, + UsernameFallback: usernameFallback, } cl, is := pi["clientid"].(int) diff --git a/pkg/web/web.go b/pkg/web/web.go index e00bbbcd6..c204db83b 100644 --- a/pkg/web/web.go +++ b/pkg/web/web.go @@ -50,6 +50,11 @@ type HTTPError struct { Message string `json:"message"` } +type HTTPErrorWithDetails struct { + HTTPError + Details interface{} `json:"details"` +} + // Auth defines the authentication interface used to get some auth thing type Auth interface { // Most of the time, we need an ID from the auth object only. Having this method saves the need to cast it.