diff --git a/internal/tests/mock_services.go b/internal/tests/mock_services.go index 0f2f1720..b25cb24f 100644 --- a/internal/tests/mock_services.go +++ b/internal/tests/mock_services.go @@ -313,3 +313,45 @@ func (m *MockServiceRegistry) Get(name string) any { args := m.Called(name) return args.Get(0) } + +type MockSecondaryStorage struct { + mock.Mock +} + +func (m *MockSecondaryStorage) Get(ctx context.Context, key string) (any, error) { + args := m.Called(ctx, key) + return args.Get(0), args.Error(1) +} + +func (m *MockSecondaryStorage) Set(ctx context.Context, key string, value any, ttl *time.Duration) error { + args := m.Called(ctx, key, value, ttl) + return args.Error(0) +} + +func (m *MockSecondaryStorage) Delete(ctx context.Context, key string) error { + args := m.Called(ctx, key) + return args.Error(0) +} + +func (m *MockSecondaryStorage) Incr(ctx context.Context, key string, ttl *time.Duration) (int, error) { + args := m.Called(ctx, key, ttl) + return args.Int(0), args.Error(1) +} + +func (m *MockSecondaryStorage) TTL(ctx context.Context, key string) (*time.Duration, error) { + args := m.Called(ctx, key) + if v := args.Get(0); v != nil { + return v.(*time.Duration), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *MockSecondaryStorage) Scan(ctx context.Context, prefix string) ([]string, error) { + args := m.Called(ctx, prefix) + return args.Get(0).([]string), args.Error(1) +} + +func (m *MockSecondaryStorage) Close() error { + args := m.Called() + return args.Error(0) +} diff --git a/models/storage.go b/models/storage.go index 294abf81..e1e55ce8 100644 --- a/models/storage.go +++ b/models/storage.go @@ -25,6 +25,8 @@ type SecondaryStorage interface { Incr(ctx context.Context, key string, ttl *time.Duration) (int, error) // TTL retrieves the time-to-live (TTL) for the given key. TTL(ctx context.Context, key string) (*time.Duration, error) + // Scan returns all keys matching the given prefix that have not expired. + Scan(ctx context.Context, prefix string) ([]string, error) // Close closes the storage and releases any resources. Close() error } diff --git a/plugins/bearer/hooks.go b/plugins/bearer/hooks.go index 8bf55eb1..e72f673b 100644 --- a/plugins/bearer/hooks.go +++ b/plugins/bearer/hooks.go @@ -49,7 +49,7 @@ func (p *BearerPlugin) validateBearerToken(reqCtx *models.RequestContext) error return nil } - userID, err := p.jwtService.ValidateToken(token) + actor, err := p.jwtService.ValidateToken(reqCtx.Request.Context(), token) if err != nil { reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{ "message": "Bearer token invalid or expired", @@ -58,10 +58,7 @@ func (p *BearerPlugin) validateBearerToken(reqCtx *models.RequestContext) error return nil } - reqCtx.SetActorInContext(&models.Actor{ - ID: userID, - Type: models.ActorUser, - }) + reqCtx.SetActorInContext(actor) return nil } @@ -78,15 +75,12 @@ func (p *BearerPlugin) validateBearerTokenOptional(reqCtx *models.RequestContext return nil } - userID, err := p.jwtService.ValidateToken(token) + actor, err := p.jwtService.ValidateToken(reqCtx.Request.Context(), token) if err != nil { return nil } - reqCtx.SetActorInContext(&models.Actor{ - ID: userID, - Type: models.ActorUser, - }) + reqCtx.SetActorInContext(actor) return nil } diff --git a/plugins/bearer/hooks_test.go b/plugins/bearer/hooks_test.go new file mode 100644 index 00000000..0cecc27e --- /dev/null +++ b/plugins/bearer/hooks_test.go @@ -0,0 +1,179 @@ +package bearer + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + bearertests "github.com/Authula/authula/plugins/bearer/tests" +) + +func newTestBearerPlugin(jwtSvc *bearertests.MockJWTService) *BearerPlugin { + return &BearerPlugin{ + config: BearerPluginConfig{HeaderName: "Authorization"}, + jwtService: jwtSvc, + } +} + +func newBearerRequestCtx(t *testing.T, header string) *models.RequestContext { + t.Helper() + req, _, reqCtx := internaltests.NewHandlerRequestWithActor(t, http.MethodGet, "/test", nil, nil) + if header != "" { + req.Header.Set("Authorization", header) + reqCtx.Headers = req.Header + } + return reqCtx +} + +func TestValidateBearerToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + header string + setupMock func(*bearertests.MockJWTService) + preSetActor *models.Actor + wantHandled bool + wantStatus int + wantActor *models.Actor + }{ + { + name: "actor_already_set", + preSetActor: &models.Actor{ID: "existing-user", Type: models.ActorUser}, + setupMock: func(m *bearertests.MockJWTService) { + }, + wantHandled: false, + wantActor: &models.Actor{ID: "existing-user", Type: models.ActorUser}, + }, + { + name: "no_token", + header: "", + setupMock: func(m *bearertests.MockJWTService) { + }, + wantHandled: true, + wantStatus: http.StatusUnauthorized, + }, + { + name: "invalid_token", + header: "Bearer invalid-token", + setupMock: func(m *bearertests.MockJWTService) { + m.On("ValidateToken", mock.Anything, "invalid-token").Return(nil, errors.New("invalid token")).Once() + }, + wantHandled: true, + wantStatus: http.StatusUnauthorized, + }, + { + name: "valid_user_token", + header: "Bearer valid-user-token", + setupMock: func(m *bearertests.MockJWTService) { + m.On("ValidateToken", mock.Anything, "valid-user-token").Return(&models.Actor{ID: "user-1", Type: models.ActorUser}, nil).Once() + }, + wantActor: &models.Actor{ID: "user-1", Type: models.ActorUser, Scopes: []string{}, Metadata: map[string]any{}}, + }, + { + name: "valid_machine_token", + header: "Bearer valid-machine-token", + setupMock: func(m *bearertests.MockJWTService) { + m.On("ValidateToken", mock.Anything, "valid-machine-token").Return(&models.Actor{ID: "client-1", Type: models.ActorMachine, OrganizationID: internaltests.PtrString("org-1"), Scopes: []string{"read"}}, nil).Once() + }, + wantActor: &models.Actor{ID: "client-1", Type: models.ActorMachine, OrganizationID: internaltests.PtrString("org-1"), Scopes: []string{"read"}, Metadata: map[string]any{}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSvc := &bearertests.MockJWTService{} + tt.setupMock(mockSvc) + + p := newTestBearerPlugin(mockSvc) + reqCtx := newBearerRequestCtx(t, tt.header) + if tt.preSetActor != nil { + reqCtx.Actor = tt.preSetActor + } + + err := p.validateBearerToken(reqCtx) + require.NoError(t, err) + + require.Equal(t, tt.wantHandled, reqCtx.Handled) + if tt.wantStatus != 0 { + require.Equal(t, tt.wantStatus, reqCtx.ResponseStatus) + } + if tt.wantActor != nil { + require.Equal(t, tt.wantActor, reqCtx.Actor) + } + mockSvc.AssertExpectations(t) + }) + } +} + +func TestValidateBearerTokenOptional(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + header string + setupMock func(*bearertests.MockJWTService) + preSetActor *models.Actor + wantHandled bool + wantActor *models.Actor + }{ + { + name: "actor_already_set_optional", + preSetActor: &models.Actor{ID: "existing-user", Type: models.ActorUser}, + setupMock: func(m *bearertests.MockJWTService) { + }, + wantHandled: false, + wantActor: &models.Actor{ID: "existing-user", Type: models.ActorUser}, + }, + { + name: "no_token_optional", + header: "", + setupMock: func(m *bearertests.MockJWTService) { + }, + wantHandled: false, + }, + { + name: "invalid_token_optional", + header: "Bearer invalid-token", + setupMock: func(m *bearertests.MockJWTService) { + m.On("ValidateToken", mock.Anything, "invalid-token").Return(nil, errors.New("invalid token")).Once() + }, + wantHandled: false, + }, + { + name: "valid_token_optional", + header: "Bearer valid-user-token", + setupMock: func(m *bearertests.MockJWTService) { + m.On("ValidateToken", mock.Anything, "valid-user-token").Return(&models.Actor{ID: "user-1", Type: models.ActorUser}, nil).Once() + }, + wantActor: &models.Actor{ID: "user-1", Type: models.ActorUser, Scopes: []string{}, Metadata: map[string]any{}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSvc := &bearertests.MockJWTService{} + tt.setupMock(mockSvc) + + p := newTestBearerPlugin(mockSvc) + reqCtx := newBearerRequestCtx(t, tt.header) + if tt.preSetActor != nil { + reqCtx.Actor = tt.preSetActor + } + + err := p.validateBearerTokenOptional(reqCtx) + require.NoError(t, err) + + require.Equal(t, tt.wantHandled, reqCtx.Handled) + if tt.wantActor != nil { + require.Equal(t, tt.wantActor, reqCtx.Actor) + } + mockSvc.AssertExpectations(t) + }) + } +} diff --git a/plugins/bearer/plugin.go b/plugins/bearer/plugin.go index cfadfa72..419802e5 100644 --- a/plugins/bearer/plugin.go +++ b/plugins/bearer/plugin.go @@ -73,13 +73,16 @@ func (p *BearerPlugin) AuthMiddleware() func(http.Handler) http.Handler { return } - userID, err := p.jwtService.ValidateToken(token) + actor, err := p.jwtService.ValidateToken(r.Context(), token) if err != nil { p.writeUnauthorized(w, err) return } - ctx := context.WithValue(r.Context(), models.ContextUserID, userID) + ctx := context.WithValue(r.Context(), models.ContextAuthActor, actor) + if actor.ID != "" { + ctx = context.WithValue(ctx, models.ContextUserID, actor.ID) + } next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -90,8 +93,11 @@ func (p *BearerPlugin) OptionalAuthMiddleware() func(http.Handler) http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token, err := p.extractToken(r) if err == nil && token != "" { - if userID, validateErr := p.jwtService.ValidateToken(token); validateErr == nil { - ctx := context.WithValue(r.Context(), models.ContextUserID, userID) + if actor, validateErr := p.jwtService.ValidateToken(r.Context(), token); validateErr == nil { + ctx := context.WithValue(r.Context(), models.ContextAuthActor, actor) + if actor.ID != "" { + ctx = context.WithValue(ctx, models.ContextUserID, actor.ID) + } r = r.WithContext(ctx) } } diff --git a/plugins/bearer/plugin_test.go b/plugins/bearer/plugin_test.go new file mode 100644 index 00000000..56b209eb --- /dev/null +++ b/plugins/bearer/plugin_test.go @@ -0,0 +1,68 @@ +package bearer + +import ( + "testing" + + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + bearertests "github.com/Authula/authula/plugins/bearer/tests" +) + +func TestBearerPlugin_Metadata(t *testing.T) { + t.Parallel() + + plugin := New(BearerPluginConfig{}) + metadata := plugin.Metadata() + + require.NotEmpty(t, metadata.ID) + require.NotEmpty(t, metadata.Version) + require.NotEmpty(t, metadata.Description) +} + +func TestBearerPlugin_Config(t *testing.T) { + t.Parallel() + + cfg := BearerPluginConfig{HeaderName: "Custom-Auth", Enabled: true} + plugin := New(cfg) + + returnedCfg := plugin.Config() + require.Equal(t, cfg, returnedCfg) +} + +func TestBearerPlugin_Init(t *testing.T) { + t.Parallel() + + t.Run("missing_jwt_service", func(t *testing.T) { + t.Parallel() + reg := &internaltests.MockServiceRegistry{} + reg.On("Get", models.ServiceJWT.String()).Return(nil).Once() + + plugin := New(BearerPluginConfig{}) + err := plugin.Init(&models.PluginContext{ + Logger: &internaltests.MockLogger{}, + ServiceRegistry: reg, + GetConfig: func() *models.Config { return &models.Config{} }, + }) + require.Error(t, err) + reg.AssertExpectations(t) + }) + + t.Run("success", func(t *testing.T) { + t.Parallel() + mockSvc := &bearertests.MockJWTService{} + reg := &internaltests.MockServiceRegistry{} + reg.On("Get", models.ServiceJWT.String()).Return(mockSvc).Once() + + plugin := New(BearerPluginConfig{}) + err := plugin.Init(&models.PluginContext{ + Logger: &internaltests.MockLogger{}, + ServiceRegistry: reg, + GetConfig: func() *models.Config { return &models.Config{} }, + }) + require.NoError(t, err) + require.Equal(t, mockSvc, plugin.jwtService) + reg.AssertExpectations(t) + }) +} diff --git a/plugins/bearer/tests/mocks.go b/plugins/bearer/tests/mocks.go new file mode 100644 index 00000000..a7ed3a06 --- /dev/null +++ b/plugins/bearer/tests/mocks.go @@ -0,0 +1,21 @@ +package tests + +import ( + "context" + + "github.com/stretchr/testify/mock" + + "github.com/Authula/authula/models" +) + +type MockJWTService struct { + mock.Mock +} + +func (m *MockJWTService) ValidateToken(ctx context.Context, token string) (*models.Actor, error) { + args := m.Called(ctx, token) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Actor), args.Error(1) +} diff --git a/plugins/jwt/constants/constants.go b/plugins/jwt/constants/constants.go index a96976ef..a9515f23 100644 --- a/plugins/jwt/constants/constants.go +++ b/plugins/jwt/constants/constants.go @@ -2,14 +2,12 @@ package constants import "errors" -// Event type constants const ( EventTokenReuseRecovered = "jwt.token.reuse.recovered" EventTokenReuseThrottled = "jwt.token.reuse.throttled" EventTokenReuseMalicious = "jwt.token.reuse.malicious" ) -// Error definitions var ( ErrInvalidToken = errors.New("provided token is invalid or malformed") ErrTokenExpired = errors.New("token has expired") diff --git a/plugins/jwt/handlers/refresh_handler_test.go b/plugins/jwt/handlers/refresh_handler_test.go new file mode 100644 index 00000000..c6ec1d2a --- /dev/null +++ b/plugins/jwt/handlers/refresh_handler_test.go @@ -0,0 +1,89 @@ +package handlers + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/plugins/jwt/types" + "github.com/Authula/authula/plugins/jwt/usecases" +) + +type mockRefreshTokenUseCase struct { + mock.Mock +} + +func (m *mockRefreshTokenUseCase) RefreshTokens(ctx context.Context, refreshToken string) (*usecases.RefreshTokenResult, error) { + args := m.Called(ctx, refreshToken) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*usecases.RefreshTokenResult), args.Error(1) +} + +func TestRefreshTokenHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body []byte + prepare func(*mockRefreshTokenUseCase) + expectedStatus int + }{ + { + name: "invalid_json", + body: []byte("{"), + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "missing_refresh_token", + body: internaltests.MarshalToJSON(t, types.RefreshTokenRequest{RefreshToken: ""}), + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "use_case_error", + body: internaltests.MarshalToJSON(t, types.RefreshTokenRequest{RefreshToken: "bad-token"}), + prepare: func(m *mockRefreshTokenUseCase) { + m.On("RefreshTokens", mock.Anything, "bad-token").Return((*usecases.RefreshTokenResult)(nil), errors.New("invalid")).Once() + }, + expectedStatus: http.StatusUnauthorized, + }, + { + name: "success", + body: internaltests.MarshalToJSON(t, types.RefreshTokenRequest{RefreshToken: "valid-token"}), + prepare: func(m *mockRefreshTokenUseCase) { + m.On("RefreshTokens", mock.Anything, "valid-token").Return(&usecases.RefreshTokenResult{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + }, nil).Once() + }, + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockUC := &mockRefreshTokenUseCase{} + if tt.prepare != nil { + tt.prepare(mockUC) + } + + handler := &RefreshTokenHandler{ + Logger: &internaltests.MockLogger{}, + RefreshTokenUseCase: mockUC, + } + + req, w, reqCtx := internaltests.NewHandlerRequestWithActor(t, http.MethodPost, "/token/refresh", tt.body, nil) + handler.Handle().ServeHTTP(w, req) + require.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + mockUC.AssertExpectations(t) + }) + } +} diff --git a/plugins/jwt/handlers/wellknown_jwks_handler_test.go b/plugins/jwt/handlers/wellknown_jwks_handler_test.go new file mode 100644 index 00000000..bddd9897 --- /dev/null +++ b/plugins/jwt/handlers/wellknown_jwks_handler_test.go @@ -0,0 +1,73 @@ +package handlers + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/plugins/jwt/usecases" +) + +type mockJWKSUseCase struct { + mock.Mock +} + +func (m *mockJWKSUseCase) GetJWKS(ctx context.Context) (*usecases.JWKSResult, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*usecases.JWKSResult), args.Error(1) +} + +func TestWellKnownJWKSHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prepare func(*mockJWKSUseCase) + expectedStatus int + }{ + { + name: "use_case_error", + prepare: func(m *mockJWKSUseCase) { + m.On("GetJWKS", mock.Anything).Return((*usecases.JWKSResult)(nil), errors.New("no keys")).Once() + }, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "success", + prepare: func(m *mockJWKSUseCase) { + m.On("GetJWKS", mock.Anything).Return(&usecases.JWKSResult{KeySet: jwk.NewSet()}, nil).Once() + }, + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockUC := &mockJWKSUseCase{} + if tt.prepare != nil { + tt.prepare(mockUC) + } + + handler := &WellKnownJWKSHandler{ + Logger: &internaltests.MockLogger{}, + JWKSUseCase: mockUC, + } + + req, w, reqCtx := internaltests.NewHandlerRequestWithActor(t, http.MethodGet, "/.well-known/jwks.json", nil, nil) + handler.Handle().ServeHTTP(w, req) + require.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + mockUC.AssertExpectations(t) + }) + } +} diff --git a/plugins/jwt/hooks.go b/plugins/jwt/hooks.go index 6ab41bc9..15c8110a 100644 --- a/plugins/jwt/hooks.go +++ b/plugins/jwt/hooks.go @@ -1,12 +1,12 @@ package jwt import ( - "context" "fmt" "net/http" "time" "github.com/Authula/authula/models" + jwtservices "github.com/Authula/authula/plugins/jwt/services" "github.com/Authula/authula/plugins/jwt/types" ) @@ -43,7 +43,7 @@ func (p *JWTPlugin) issueTokensHookMatcher(reqCtx *models.RequestContext) bool { } func (p *JWTPlugin) issueTokensHook(reqCtx *models.RequestContext) error { - if reqCtx.Actor == nil || reqCtx.Actor.Type != models.ActorUser { + if reqCtx.Actor == nil { return nil } @@ -51,45 +51,71 @@ func (p *JWTPlugin) issueTokensHook(reqCtx *models.RequestContext) error { return nil } - sessionID, ok := reqCtx.Values[models.ContextSessionID.String()].(string) - if !ok || sessionID == "" { - return nil - } - - tokenPair, err := p.jwtService.GenerateTokens(context.Background(), reqCtx.Actor.ID, sessionID) - if err != nil { - p.Logger.Error("failed to generate JWT tokens", "user_id", reqCtx.Actor.ID, "session_id", sessionID, "error", err) - return fmt.Errorf("failed to generate authentication tokens: %w", err) - } + ctx := reqCtx.Request.Context() - expiresAt := time.Now().Add(p.pluginConfig.RefreshExpiresIn) - if err := p.refreshService.StoreInitialRefreshToken(reqCtx.Request.Context(), tokenPair.RefreshToken, sessionID, expiresAt); err != nil { - p.Logger.Error("failed to store refresh token", "user_id", reqCtx.Actor.ID, "session_id", sessionID, "error", err) - return fmt.Errorf("failed to store refresh token: %w", err) + switch reqCtx.Actor.Type { + case models.ActorMachine: + { + orgID := "" + if reqCtx.Actor.OrganizationID != nil { + orgID = *reqCtx.Actor.OrganizationID + } + tokenPair, err := p.jwtService.(jwtservices.TokenService).GenerateMachineToken( + ctx, reqCtx.Actor.ID, orgID, reqCtx.Actor.Scopes, + ) + if err != nil { + p.Logger.Error("failed to generate machine JWT token", "client_id", reqCtx.Actor.ID, "error", err) + return fmt.Errorf("failed to generate machine authentication tokens: %w", err) + } + reqCtx.Values[types.JWTTokenTypeAccess.String()] = tokenPair.AccessToken + } + case models.ActorUser: + { + sessionID, ok := reqCtx.Values[models.ContextSessionID.String()].(string) + if !ok || sessionID == "" { + return nil + } + + tokenPair, err := p.jwtService.(jwtservices.TokenService).GenerateUserToken(ctx, reqCtx.Actor.ID, sessionID) + if err != nil { + p.Logger.Error("failed to generate user JWT tokens", "user_id", reqCtx.Actor.ID, "session_id", sessionID, "error", err) + return fmt.Errorf("failed to generate authentication tokens: %w", err) + } + + expiresAt := time.Now().Add(p.pluginConfig.RefreshExpiresIn) + if err := p.refreshService.StoreInitialRefreshToken(ctx, tokenPair.RefreshToken, sessionID, expiresAt); err != nil { + p.Logger.Error("failed to store refresh token", "user_id", reqCtx.Actor.ID, "session_id", sessionID, "error", err) + return fmt.Errorf("failed to store refresh token: %w", err) + } + + reqCtx.Values[types.JWTTokenTypeAccess.String()] = tokenPair.AccessToken + reqCtx.Values[types.JWTTokenTypeRefresh.String()] = tokenPair.RefreshToken + } } - reqCtx.Values[types.JWTTokenTypeAccess.String()] = tokenPair.AccessToken - reqCtx.Values[types.JWTTokenTypeRefresh.String()] = tokenPair.RefreshToken - return nil } func (p *JWTPlugin) respondHook(reqCtx *models.RequestContext) error { - if reqCtx.Actor == nil || reqCtx.Actor.Type != models.ActorUser { + if reqCtx.Actor == nil { return nil } - access, ok1 := reqCtx.Values[types.JWTTokenTypeAccess.String()].(string) - refresh, ok2 := reqCtx.Values[types.JWTTokenTypeRefresh.String()].(string) - if !ok1 || !ok2 { + access, ok := reqCtx.Values[types.JWTTokenTypeAccess.String()].(string) + if !ok || access == "" { return nil } - reqCtx.SetJSONResponse(http.StatusOK, map[string]any{ - "access_token": access, - "refresh_token": refresh, - }) - reqCtx.Handled = true + payload := map[string]any{ + "access_token": access, + "token_type": "Bearer", + } + if refresh, ok := reqCtx.Values[types.JWTTokenTypeRefresh.String()].(string); ok && refresh != "" { + payload["refresh_token"] = refresh + } + + reqCtx.SetJSONResponse(http.StatusOK, payload) + reqCtx.Handled = true return nil } diff --git a/plugins/jwt/hooks_test.go b/plugins/jwt/hooks_test.go new file mode 100644 index 00000000..04442e6c --- /dev/null +++ b/plugins/jwt/hooks_test.go @@ -0,0 +1,263 @@ +package jwt + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + jwttests "github.com/Authula/authula/plugins/jwt/tests" + "github.com/Authula/authula/plugins/jwt/types" +) + +var errHookTest = errors.New("hook test error") + +func newTestPlugin(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) *JWTPlugin { + return &JWTPlugin{ + pluginConfig: types.JWTPluginConfig{ + ExpiresIn: 15 * time.Minute, + RefreshExpiresIn: 7 * 24 * time.Hour, + }, + jwtService: tokenSvc, + refreshService: refreshSvc, + Logger: &internaltests.MockLogger{}, + } +} + +func TestIssueTokensHook(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actor *models.Actor + setup func(*testing.T, *models.RequestContext) + mock func(*jwttests.MockTokenService, *jwttests.MockRefreshTokenService) + wantErr string + check func(*testing.T, *models.RequestContext) + }{ + { + name: "nil_actor", + actor: nil, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) {}, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.Empty(t, reqCtx.Values) + }, + }, + { + name: "skip_mint_flag", + actor: &models.Actor{ID: "user-1", Type: models.ActorUser}, + setup: func(t *testing.T, reqCtx *models.RequestContext) { + reqCtx.Values[models.ContextAuthIdempotentSkipTokensMint.String()] = true + }, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) {}, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.Empty(t, reqCtx.Values[types.JWTTokenTypeAccess.String()]) + }, + }, + { + name: "user_success", + actor: &models.Actor{ID: "user-1", Type: models.ActorUser}, + setup: func(t *testing.T, reqCtx *models.RequestContext) { + reqCtx.Values[models.ContextSessionID.String()] = "sess-1" + }, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) { + pair := &types.TokenPair{ + AccessToken: "access-token-1", + RefreshToken: "refresh-token-1", + } + tokenSvc.On("GenerateUserToken", mock.Anything, "user-1", "sess-1").Return(pair, nil) + refreshSvc.On("StoreInitialRefreshToken", mock.Anything, "refresh-token-1", "sess-1", mock.Anything).Return(nil) + }, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.Equal(t, "access-token-1", reqCtx.Values[types.JWTTokenTypeAccess.String()]) + require.Equal(t, "refresh-token-1", reqCtx.Values[types.JWTTokenTypeRefresh.String()]) + }, + }, + { + name: "user_no_session_id", + actor: &models.Actor{ID: "user-1", Type: models.ActorUser}, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) {}, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.Empty(t, reqCtx.Values) + }, + }, + { + name: "user_token_error", + actor: &models.Actor{ID: "user-1", Type: models.ActorUser}, + setup: func(t *testing.T, reqCtx *models.RequestContext) { + reqCtx.Values[models.ContextSessionID.String()] = "sess-1" + }, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) { + tokenSvc.On("GenerateUserToken", mock.Anything, "user-1", "sess-1").Return(nil, errHookTest) + }, + wantErr: "failed to generate authentication tokens", + }, + { + name: "machine_success", + actor: &models.Actor{ + ID: "client-1", + Type: models.ActorMachine, + OrganizationID: new("org-1"), + Scopes: []string{"read", "write"}, + }, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) { + pair := &types.TokenPair{AccessToken: "machine-access-token"} + tokenSvc.On("GenerateMachineToken", mock.Anything, "client-1", "org-1", []string{"read", "write"}).Return(pair, nil) + }, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.Equal(t, "machine-access-token", reqCtx.Values[types.JWTTokenTypeAccess.String()]) + _, hasRefresh := reqCtx.Values[types.JWTTokenTypeRefresh.String()] + require.False(t, hasRefresh) + }, + }, + { + name: "machine_no_optional_fields", + actor: &models.Actor{ + ID: "client-2", + Type: models.ActorMachine, + }, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) { + pair := &types.TokenPair{AccessToken: "machine-access-token-2"} + tokenSvc.On("GenerateMachineToken", mock.Anything, "client-2", "", []string(nil)).Return(pair, nil) + }, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.Equal(t, "machine-access-token-2", reqCtx.Values[types.JWTTokenTypeAccess.String()]) + }, + }, + { + name: "machine_token_error", + actor: &models.Actor{ + ID: "client-3", + Type: models.ActorMachine, + }, + mock: func(tokenSvc *jwttests.MockTokenService, refreshSvc *jwttests.MockRefreshTokenService) { + tokenSvc.On("GenerateMachineToken", mock.Anything, "client-3", "", []string(nil)).Return(nil, errHookTest) + }, + wantErr: "failed to generate machine authentication tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + tokenSvc := new(jwttests.MockTokenService) + refreshSvc := new(jwttests.MockRefreshTokenService) + + _, _, reqCtx := internaltests.NewHandlerRequestWithActor(t, http.MethodGet, "/test", nil, tt.actor) + + if tt.setup != nil { + tt.setup(t, reqCtx) + } + + if tt.mock != nil { + tt.mock(tokenSvc, refreshSvc) + } + + plugin := newTestPlugin(tokenSvc, refreshSvc) + + err := plugin.issueTokensHook(reqCtx) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + + if tt.check != nil { + tt.check(t, reqCtx) + } + + tokenSvc.AssertExpectations(t) + refreshSvc.AssertExpectations(t) + }) + } +} + +func TestRespondHook(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actor *models.Actor + setup func(*testing.T, *models.RequestContext) + wantErr string + check func(*testing.T, *models.RequestContext) + }{ + { + name: "nil_actor", + actor: nil, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.False(t, reqCtx.Handled) + }, + }, + { + name: "no_access_token", + actor: &models.Actor{ID: "user-1", Type: models.ActorUser}, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.False(t, reqCtx.Handled) + }, + }, + { + name: "user_success", + actor: &models.Actor{ID: "user-1", Type: models.ActorUser}, + setup: func(t *testing.T, reqCtx *models.RequestContext) { + reqCtx.Values[types.JWTTokenTypeAccess.String()] = "access-1" + reqCtx.Values[types.JWTTokenTypeRefresh.String()] = "refresh-1" + }, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.True(t, reqCtx.Handled) + require.Equal(t, http.StatusOK, reqCtx.ResponseStatus) + require.JSONEq(t, `{"access_token":"access-1","token_type":"Bearer","refresh_token":"refresh-1"}`, string(reqCtx.ResponseBody)) + }, + }, + { + name: "machine_success", + actor: &models.Actor{ + ID: "client-1", + Type: models.ActorMachine, + OrganizationID: new("org-1"), + }, + setup: func(t *testing.T, reqCtx *models.RequestContext) { + reqCtx.Values[types.JWTTokenTypeAccess.String()] = "machine-access-1" + }, + check: func(t *testing.T, reqCtx *models.RequestContext) { + require.True(t, reqCtx.Handled) + require.Equal(t, http.StatusOK, reqCtx.ResponseStatus) + require.JSONEq(t, `{"access_token":"machine-access-1","token_type":"Bearer"}`, string(reqCtx.ResponseBody)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, _, reqCtx := internaltests.NewHandlerRequestWithActor(t, http.MethodGet, "/test", nil, tt.actor) + + if tt.setup != nil { + tt.setup(t, reqCtx) + } + + plugin := newTestPlugin(new(jwttests.MockTokenService), new(jwttests.MockRefreshTokenService)) + + err := plugin.respondHook(reqCtx) + + if tt.wantErr != "" { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + if tt.check != nil { + tt.check(t, reqCtx) + } + }) + } +} diff --git a/plugins/jwt/migrations.go b/plugins/jwt/migrations.go index 6c8ae619..90c89a10 100644 --- a/plugins/jwt/migrations.go +++ b/plugins/jwt/migrations.go @@ -1,157 +1,10 @@ package jwt import ( - "context" - - "github.com/uptrace/bun" - "github.com/Authula/authula/migrations" + "github.com/Authula/authula/plugins/jwt/migrationset" ) -func jwtMigrationsForProvider(provider string) []migrations.Migration { - return migrations.ForProvider(provider, migrations.ProviderVariants{ - "sqlite": func() []migrations.Migration { return []migrations.Migration{jwtSQLiteInitial()} }, - "postgres": func() []migrations.Migration { return []migrations.Migration{jwtPostgresInitial()} }, - "mysql": func() []migrations.Migration { return []migrations.Migration{jwtMySQLInitial()} }, - }) -} - -func jwtSQLiteInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260131000000_jwt_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `CREATE TABLE IF NOT EXISTS jwks ( - id TEXT PRIMARY KEY, - public_key TEXT NOT NULL, - private_key TEXT NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP NULL -);`, - `CREATE INDEX IF NOT EXISTS idx_jwks_expires_at ON jwks(expires_at);`, - `CREATE TABLE IF NOT EXISTS refresh_tokens ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - token_hash TEXT NOT NULL UNIQUE, - expires_at TIMESTAMP NOT NULL, - is_revoked INTEGER DEFAULT 0, - revoked_at TIMESTAMP NULL, - last_reuse_attempt TIMESTAMP NULL DEFAULT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -);`, - `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_session_id ON refresh_tokens(session_id);`, - `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_token_hash ON refresh_tokens(token_hash);`, - `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);`, - `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_revoked_only ON refresh_tokens(is_revoked) WHERE is_revoked = 1;`, - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP TABLE IF EXISTS refresh_tokens;`, - `DROP TABLE IF EXISTS jwks;`, - ) - }, - } -} - -func jwtPostgresInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260131000000_jwt_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `CREATE TABLE IF NOT EXISTS jwks ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - public_key TEXT NOT NULL, - private_key TEXT NOT NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP WITH TIME ZONE NULL -);`, - `CREATE INDEX IF NOT EXISTS idx_jwks_expires_at ON jwks(expires_at);`, - `CREATE TABLE IF NOT EXISTS refresh_tokens ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - session_id UUID NOT NULL, - token_hash VARCHAR(64) UNIQUE NOT NULL, - expires_at TIMESTAMP WITH TIME ZONE NOT NULL, - is_revoked BOOLEAN DEFAULT FALSE, - revoked_at TIMESTAMP WITH TIME ZONE NULL, - last_reuse_attempt TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - CONSTRAINT fk_refresh_tokens_session FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE -);`, - `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_session_id ON refresh_tokens(session_id);`, - `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);`, - `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_revoked_only ON refresh_tokens(is_revoked) WHERE is_revoked = TRUE;`, - `CREATE OR REPLACE FUNCTION cleanup_expired_refresh_tokens() -RETURNS VOID AS $$ -BEGIN - DELETE FROM refresh_tokens WHERE expires_at < NOW(); -END; -$$ LANGUAGE plpgsql;`, - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP FUNCTION IF EXISTS cleanup_expired_refresh_tokens();`, - `DROP TABLE IF EXISTS refresh_tokens;`, - `DROP TABLE IF EXISTS jwks;`, - ) - }, - } -} - -func jwtMySQLInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260131000000_jwt_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `CREATE TABLE IF NOT EXISTS jwks ( - id BINARY(16) NOT NULL PRIMARY KEY, - public_key TEXT NOT NULL, - private_key TEXT NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP NULL -);`, - `CREATE INDEX idx_jwks_expires_at ON jwks(expires_at);`, - `CREATE TABLE IF NOT EXISTS refresh_tokens ( - id BINARY(16) NOT NULL PRIMARY KEY, - session_id BINARY(16) NOT NULL, - token_hash VARCHAR(64) UNIQUE NOT NULL, - expires_at TIMESTAMP NOT NULL, - is_revoked BOOLEAN DEFAULT FALSE, - revoked_at TIMESTAMP NULL, - last_reuse_attempt TIMESTAMP NULL DEFAULT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT fk_refresh_tokens_session FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE -);`, - `CREATE INDEX idx_refresh_tokens_session_id ON refresh_tokens(session_id);`, - `CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);`, - `CREATE INDEX idx_refresh_tokens_active_session ON refresh_tokens(session_id, is_revoked);`, - `CREATE INDEX idx_refresh_tokens_last_reuse_attempt ON refresh_tokens(last_reuse_attempt);`, - `DROP PROCEDURE IF EXISTS cleanup_expired_refresh_tokens;`, - `CREATE PROCEDURE cleanup_expired_refresh_tokens() -BEGIN - DELETE FROM refresh_tokens WHERE expires_at < NOW(); -END;`, - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP PROCEDURE IF EXISTS cleanup_expired_refresh_tokens;`, - `DROP TABLE IF EXISTS refresh_tokens;`, - `DROP TABLE IF EXISTS jwks;`, - ) - }, - } +func JWTMigrationsForProvider(provider string) []migrations.Migration { + return migrationset.JWTMigrationsForProvider(provider) } diff --git a/plugins/jwt/migrationset/migrations.go b/plugins/jwt/migrationset/migrations.go new file mode 100644 index 00000000..46cc823e --- /dev/null +++ b/plugins/jwt/migrationset/migrations.go @@ -0,0 +1,157 @@ +package migrationset + +import ( + "context" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/migrations" +) + +func JWTMigrationsForProvider(provider string) []migrations.Migration { + return migrations.ForProvider(provider, migrations.ProviderVariants{ + "sqlite": func() []migrations.Migration { return []migrations.Migration{jwtSQLiteInitial()} }, + "postgres": func() []migrations.Migration { return []migrations.Migration{jwtPostgresInitial()} }, + "mysql": func() []migrations.Migration { return []migrations.Migration{jwtMySQLInitial()} }, + }) +} + +func jwtSQLiteInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260131000000_jwt_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE TABLE IF NOT EXISTS jwks ( + id TEXT PRIMARY KEY, + public_key TEXT NOT NULL, + private_key TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_jwks_expires_at ON jwks(expires_at);`, + `CREATE TABLE IF NOT EXISTS refresh_tokens ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + token_hash TEXT NOT NULL UNIQUE, + expires_at TIMESTAMP NOT NULL, + is_revoked INTEGER DEFAULT 0, + revoked_at TIMESTAMP NULL, + last_reuse_attempt TIMESTAMP NULL DEFAULT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + );`, + `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_session_id ON refresh_tokens(session_id);`, + `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_token_hash ON refresh_tokens(token_hash);`, + `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);`, + `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_revoked_only ON refresh_tokens(is_revoked) WHERE is_revoked = 1;`, + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS refresh_tokens;`, + `DROP TABLE IF EXISTS jwks;`, + ) + }, + } +} + +func jwtPostgresInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260131000000_jwt_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE TABLE IF NOT EXISTS jwks ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + public_key TEXT NOT NULL, + private_key TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP WITH TIME ZONE NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_jwks_expires_at ON jwks(expires_at);`, + `CREATE TABLE IF NOT EXISTS refresh_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session_id UUID NOT NULL, + token_hash VARCHAR(64) UNIQUE NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + is_revoked BOOLEAN DEFAULT FALSE, + revoked_at TIMESTAMP WITH TIME ZONE NULL, + last_reuse_attempt TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + CONSTRAINT fk_refresh_tokens_session FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE + );`, + `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_session_id ON refresh_tokens(session_id);`, + `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);`, + `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_revoked_only ON refresh_tokens(is_revoked) WHERE is_revoked = TRUE;`, + `CREATE OR REPLACE FUNCTION cleanup_expired_refresh_tokens() + RETURNS VOID AS $$ + BEGIN + DELETE FROM refresh_tokens WHERE expires_at < NOW(); + END; + $$ LANGUAGE plpgsql;`, + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP FUNCTION IF EXISTS cleanup_expired_refresh_tokens();`, + `DROP TABLE IF EXISTS refresh_tokens;`, + `DROP TABLE IF EXISTS jwks;`, + ) + }, + } +} + +func jwtMySQLInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260131000000_jwt_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE TABLE IF NOT EXISTS jwks ( + id BINARY(16) NOT NULL PRIMARY KEY, + public_key TEXT NOT NULL, + private_key TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP NULL + );`, + `CREATE INDEX idx_jwks_expires_at ON jwks(expires_at);`, + `CREATE TABLE IF NOT EXISTS refresh_tokens ( + id BINARY(16) NOT NULL PRIMARY KEY, + session_id BINARY(16) NOT NULL, + token_hash VARCHAR(64) UNIQUE NOT NULL, + expires_at TIMESTAMP NOT NULL, + is_revoked BOOLEAN DEFAULT FALSE, + revoked_at TIMESTAMP NULL, + last_reuse_attempt TIMESTAMP NULL DEFAULT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT fk_refresh_tokens_session FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE + );`, + `CREATE INDEX idx_refresh_tokens_session_id ON refresh_tokens(session_id);`, + `CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at);`, + `CREATE INDEX idx_refresh_tokens_active_session ON refresh_tokens(session_id, is_revoked);`, + `CREATE INDEX idx_refresh_tokens_last_reuse_attempt ON refresh_tokens(last_reuse_attempt);`, + `DROP PROCEDURE IF EXISTS cleanup_expired_refresh_tokens;`, + `CREATE PROCEDURE cleanup_expired_refresh_tokens() + BEGIN + DELETE FROM refresh_tokens WHERE expires_at < NOW(); + END;`, + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP PROCEDURE IF EXISTS cleanup_expired_refresh_tokens;`, + `DROP TABLE IF EXISTS refresh_tokens;`, + `DROP TABLE IF EXISTS jwks;`, + ) + }, + } +} diff --git a/plugins/jwt/plugin.go b/plugins/jwt/plugin.go index 7c1a8810..b4ec171e 100644 --- a/plugins/jwt/plugin.go +++ b/plugins/jwt/plugin.go @@ -21,7 +21,7 @@ type JWTPlugin struct { Logger models.Logger sessionService services.SessionService tokenService services.TokenService - jwtService *jwtservices.JWTServiceImpl + jwtService services.JWTService refreshService jwtservices.RefreshTokenService keyService jwtservices.KeyService cacheService jwtservices.CacheService @@ -55,11 +55,6 @@ func (p *JWTPlugin) Init(ctx *models.PluginContext) error { return err } - if err := p.pluginConfig.NormalizeAlgorithm(); err != nil { - p.Logger.Error("invalid jwt algorithm in plugin config", "error", err) - return err - } - sessionService, ok := ctx.ServiceRegistry.Get(models.ServiceSession.String()).(services.SessionService) if !ok { p.Logger.Error("session service not found") @@ -81,7 +76,7 @@ func (p *JWTPlugin) Init(ctx *models.PluginContext) error { jwksRepo := repositories.NewBunJWKSRepository(ctx.DB) refreshTokenRepo := repositories.NewRefreshTokenRepository(ctx.DB) - p.keyService = jwtservices.NewKeyService(jwksRepo, p.Logger, p.tokenService, p.globalConfig.Secret, p.pluginConfig.Algorithm) + p.keyService = jwtservices.NewKeyService(jwksRepo, p.Logger, p.tokenService, p.globalConfig.Secret) p.cacheService = jwtservices.NewCacheService(jwksRepo, p.secondaryStorage, p.Logger, p.pluginConfig.JWKSCacheTTL) if p.secondaryStorage == nil { @@ -113,7 +108,7 @@ func (p *JWTPlugin) Init(ctx *models.PluginContext) error { p.Logger.Warn("failed to pre-populate cache on startup", "error", err) } - jwtServiceImpl, ok := jwtservices.NewJWTService( + p.jwtService = jwtservices.NewJWTService( p.Logger, p.sessionService, p.tokenService, @@ -122,29 +117,25 @@ func (p *JWTPlugin) Init(ctx *models.PluginContext) error { p.blacklistService, p.pluginConfig.ExpiresIn, p.pluginConfig.RefreshExpiresIn, - ).(*jwtservices.JWTServiceImpl) - if !ok { - return errors.New("failed to create JWT service") - } - p.jwtService = jwtServiceImpl + ) p.refreshService = jwtservices.NewRefreshTokenService( p.Logger, ctx.EventBus, p.sessionService, - p.jwtService, + p.jwtService.(jwtservices.TokenService), refreshTokenRepo, p.pluginConfig.RefreshGracePeriod, p.pluginConfig.RefreshExpiresIn, ) - ctx.ServiceRegistry.Register(models.ServiceJWT.String(), jwtServiceImpl) + ctx.ServiceRegistry.Register(models.ServiceJWT.String(), p.jwtService) return nil } func (p *JWTPlugin) Migrations(provider string) []migrations.Migration { - return jwtMigrationsForProvider(provider) + return JWTMigrationsForProvider(provider) } func (p *JWTPlugin) DependsOn() []string { diff --git a/plugins/jwt/plugin_test.go b/plugins/jwt/plugin_test.go index 29082b3f..ec9bc13d 100644 --- a/plugins/jwt/plugin_test.go +++ b/plugins/jwt/plugin_test.go @@ -14,24 +14,6 @@ func TestJWTPluginConfig_DefaultConfig(t *testing.T) { config types.JWTPluginConfig check func(*testing.T, types.JWTPluginConfig) }{ - { - name: "sets default algorithm", - config: types.JWTPluginConfig{}, - check: func(t *testing.T, c types.JWTPluginConfig) { - if c.Algorithm != types.JWTAlgEdDSA { - t.Errorf("Algorithm = %v, want %v", c.Algorithm, types.JWTAlgEdDSA) - } - }, - }, - { - name: "preserves custom algorithm", - config: types.JWTPluginConfig{Algorithm: "rs256"}, - check: func(t *testing.T, c types.JWTPluginConfig) { - if c.Algorithm != "rs256" { - t.Errorf("Algorithm = %v, want rs256", c.Algorithm) - } - }, - }, { name: "sets default key rotation interval", config: types.JWTPluginConfig{}, @@ -75,16 +57,12 @@ func TestJWTPluginConfig_DefaultConfig(t *testing.T) { { name: "preserves custom values", config: types.JWTPluginConfig{ - Algorithm: "es256", KeyRotationInterval: 30 * 24 * time.Hour, ExpiresIn: 30 * time.Minute, RefreshExpiresIn: 14 * 24 * time.Hour, JWKSCacheTTL: 12 * time.Hour, }, check: func(t *testing.T, c types.JWTPluginConfig) { - if c.Algorithm != "es256" { - t.Errorf("Algorithm = %v, want es256", c.Algorithm) - } if c.KeyRotationInterval != 30*24*time.Hour { t.Errorf("KeyRotationInterval not preserved") } @@ -134,7 +112,6 @@ func TestJWTPlugin_Metadata(t *testing.T) { func TestJWTPlugin_Config(t *testing.T) { config := types.JWTPluginConfig{ - Algorithm: "es256", ExpiresIn: 30 * time.Minute, } @@ -150,8 +127,8 @@ func TestJWTPlugin_Config(t *testing.T) { t.Fatal("Config() did not return types.JWTPluginConfig type") } - if cfg.Algorithm != config.Algorithm { - t.Errorf("Config Algorithm = %v, want %v", cfg.Algorithm, config.Algorithm) + if cfg.ExpiresIn != config.ExpiresIn { + t.Errorf("Config ExpiresIn = %v, want %v", cfg.ExpiresIn, config.ExpiresIn) } } @@ -212,38 +189,8 @@ func TestKeyRotationInterval_ConfigPreservation(t *testing.T) { } } -func TestKeyRotationInterval_AlgorithmCompatibility(t *testing.T) { - algorithms := []types.JWTAlgorithm{ - types.JWTAlgEdDSA, - types.JWTAlgRS256, - types.JWTAlgPS256, - types.JWTAlgES256, - types.JWTAlgES512, - } - - for _, alg := range algorithms { - t.Run(string(alg), func(t *testing.T) { - config := types.JWTPluginConfig{ - Algorithm: alg, - KeyRotationInterval: 30 * 24 * time.Hour, - } - config.ApplyDefaults() - - if config.Algorithm != alg { - t.Errorf("Algorithm changed from %v to %v", alg, config.Algorithm) - } - - if config.KeyRotationInterval != 30*24*time.Hour { - t.Errorf("KeyRotationInterval = %v, want %v", - config.KeyRotationInterval, 30*24*time.Hour) - } - }) - } -} - func TestKeyRotationInterval_WithOtherConfigOptions(t *testing.T) { config := types.JWTPluginConfig{ - Algorithm: types.JWTAlgEdDSA, KeyRotationInterval: 45 * 24 * time.Hour, ExpiresIn: 5 * time.Minute, RefreshExpiresIn: 30 * 24 * time.Hour, @@ -256,10 +203,6 @@ func TestKeyRotationInterval_WithOtherConfigOptions(t *testing.T) { t.Errorf("KeyRotationInterval = %v, want %v", config.KeyRotationInterval, 45*24*time.Hour) } - if config.Algorithm != types.JWTAlgEdDSA { - t.Errorf("Algorithm = %v, want %v", config.Algorithm, types.JWTAlgEdDSA) - } - if config.ExpiresIn != 5*time.Minute { t.Errorf("ExpiresIn = %v, want %v", config.ExpiresIn, 5*time.Minute) } @@ -501,7 +444,6 @@ func TestKeyRotationGracePeriod_EdgeCases(t *testing.T) { func TestKeyRotationGracePeriod_ConfigWithOtherOptions(t *testing.T) { config := types.JWTPluginConfig{ - Algorithm: types.JWTAlgEdDSA, KeyRotationInterval: 45 * 24 * time.Hour, KeyRotationGracePeriod: 30 * time.Minute, ExpiresIn: 5 * time.Minute, @@ -521,10 +463,6 @@ func TestKeyRotationGracePeriod_ConfigWithOtherOptions(t *testing.T) { config.KeyRotationGracePeriod, 30*time.Minute) } - if config.Algorithm != types.JWTAlgEdDSA { - t.Errorf("Algorithm = %v, want %v", config.Algorithm, types.JWTAlgEdDSA) - } - if config.ExpiresIn != 5*time.Minute { t.Errorf("ExpiresIn = %v, want %v", config.ExpiresIn, 5*time.Minute) } diff --git a/plugins/jwt/repositories/jwks_repository_test.go b/plugins/jwt/repositories/jwks_repository_test.go new file mode 100644 index 00000000..ce859597 --- /dev/null +++ b/plugins/jwt/repositories/jwks_repository_test.go @@ -0,0 +1,188 @@ +package repositories + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/plugins/jwt/migrationset" + "github.com/Authula/authula/plugins/jwt/types" +) + +func setupJWKSRepo(t *testing.T) (*bun.DB, *bunJWKSRepository) { + t.Helper() + db := internaltests.NewSQLiteIntegrationDB(t) + migrator, err := migrations.NewMigrator(db, &internaltests.MockLogger{}) + require.NoError(t, err) + err = migrator.Migrate(context.Background(), []migrations.MigrationSet{ + { + PluginID: "jwt", + Migrations: migrationset.JWTMigrationsForProvider("sqlite"), + }, + }) + require.NoError(t, err) + return db, &bunJWKSRepository{db: db} +} + +func TestBunJWKSRepository(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) + }{ + { + name: "StoreJWKSKey", + run: func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) { + key := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: "public-key-1", + PrivateKey: "private-key-1", + } + err := repo.StoreJWKSKey(ctx, key) + require.NoError(t, err) + + keys, err := repo.GetJWKSKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, key.ID, keys[0].ID) + require.Equal(t, key.PublicKey, keys[0].PublicKey) + require.Equal(t, key.PrivateKey, keys[0].PrivateKey) + require.NotZero(t, keys[0].CreatedAt) + require.Nil(t, keys[0].ExpiresAt) + }, + }, + { + name: "GetJWKSKeyByID", + run: func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) { + key := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: "public-key-2", + PrivateKey: "private-key-2", + } + err := repo.StoreJWKSKey(ctx, key) + require.NoError(t, err) + + found, err := repo.GetJWKSKeyByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, found) + require.Equal(t, key.ID, found.ID) + require.Equal(t, key.PublicKey, found.PublicKey) + require.Equal(t, key.PrivateKey, found.PrivateKey) + }, + }, + { + name: "GetJWKSKeyByID_not_found", + run: func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) { + found, err := repo.GetJWKSKeyByID(ctx, "non-existent-id") + require.NoError(t, err) + require.Nil(t, found) + }, + }, + { + name: "GetJWKSKeys_expired_excluded", + run: func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) { + now := time.Now() + past := now.Add(-1 * time.Hour) + + unexpiredKey := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: "public-key-unexpired", + PrivateKey: "private-key-unexpired", + ExpiresAt: nil, + } + err := repo.StoreJWKSKey(ctx, unexpiredKey) + require.NoError(t, err) + + expiredKey := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: "public-key-expired", + PrivateKey: "private-key-expired", + ExpiresAt: &past, + } + err = repo.StoreJWKSKey(ctx, expiredKey) + require.NoError(t, err) + + keys, err := repo.GetJWKSKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, unexpiredKey.ID, keys[0].ID) + }, + }, + { + name: "MarkKeyExpired", + run: func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) { + key := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: "public-key-3", + PrivateKey: "private-key-3", + } + err := repo.StoreJWKSKey(ctx, key) + require.NoError(t, err) + + past := time.Now().Add(-1 * time.Hour) + err = repo.MarkKeyExpired(ctx, key.ID, past) + require.NoError(t, err) + + keys, err := repo.GetJWKSKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 0) + }, + }, + { + name: "PurgeExpiredKeys", + run: func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) { + twoDaysAgo := time.Now().Add(-48 * time.Hour) + key := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: "public-key-4", + PrivateKey: "private-key-4", + ExpiresAt: &twoDaysAgo, + } + err := repo.StoreJWKSKey(ctx, key) + require.NoError(t, err) + + err = repo.PurgeExpiredKeys(ctx) + require.NoError(t, err) + + keys, err := repo.GetJWKSKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 0) + }, + }, + { + name: "UpdateJWKSKey", + run: func(t *testing.T, repo *bunJWKSRepository, ctx context.Context) { + key := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: "original-public-key", + PrivateKey: "original-private-key", + } + err := repo.StoreJWKSKey(ctx, key) + require.NoError(t, err) + + key.PublicKey = "updated-public-key" + err = repo.UpdateJWKSKey(ctx, key) + require.NoError(t, err) + + found, err := repo.GetJWKSKeyByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, found) + require.Equal(t, "updated-public-key", found.PublicKey) + require.Equal(t, "original-private-key", found.PrivateKey) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, repo := setupJWKSRepo(t) + ctx := context.Background() + tc.run(t, repo, ctx) + }) + } +} diff --git a/plugins/jwt/repositories/refresh_token_repository.go b/plugins/jwt/repositories/refresh_token_repository.go index 3bf6dc11..6090862d 100644 --- a/plugins/jwt/repositories/refresh_token_repository.go +++ b/plugins/jwt/repositories/refresh_token_repository.go @@ -6,23 +6,11 @@ import ( "github.com/Authula/authula/plugins/jwt/types" ) -// RefreshTokenRepository provides data access for refresh token records type RefreshTokenRepository interface { - // StoreRefreshToken saves a refresh token record StoreRefreshToken(ctx context.Context, record *types.RefreshToken) error - - // GetRefreshToken retrieves a refresh token by hash GetRefreshToken(ctx context.Context, tokenHash string) (*types.RefreshToken, error) - - // RevokeRefreshToken marks a token as revoked RevokeRefreshToken(ctx context.Context, tokenHash string) error - - // RevokeAllSessionTokens revokes all refresh tokens for a session RevokeAllSessionTokens(ctx context.Context, sessionID string) error - - // SetLastReuseAttempt updates the last reuse attempt timestamp for a token SetLastReuseAttempt(ctx context.Context, tokenHash string) error - - // CleanupExpiredTokens removes expired refresh token records CleanupExpiredTokens(ctx context.Context) error } diff --git a/plugins/jwt/repositories/refresh_token_repository_test.go b/plugins/jwt/repositories/refresh_token_repository_test.go new file mode 100644 index 00000000..c72ceaff --- /dev/null +++ b/plugins/jwt/repositories/refresh_token_repository_test.go @@ -0,0 +1,324 @@ +package repositories + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/plugins/jwt/migrationset" + "github.com/Authula/authula/plugins/jwt/types" +) + +func setupRefreshTokenRepo(t *testing.T) (*bun.DB, *refreshTokenRepositoryImpl) { + t.Helper() + db := internaltests.NewSQLiteIntegrationDB(t) + migrator, err := migrations.NewMigrator(db, &internaltests.MockLogger{}) + require.NoError(t, err) + err = migrator.Migrate(context.Background(), []migrations.MigrationSet{ + { + PluginID: "jwt", + Migrations: migrationset.JWTMigrationsForProvider("sqlite"), + }, + }) + require.NoError(t, err) + return db, &refreshTokenRepositoryImpl{db: db} +} + +func TestRefreshTokenRepository_Store(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + record func() *types.RefreshToken + }{ + { + name: "stores and retrieves refresh token", + record: func() *types.RefreshToken { + return &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: uuid.New().String(), + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour), + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, repo := setupRefreshTokenRepo(t) + ctx := context.Background() + + record := tt.record() + now := time.Now().Truncate(time.Millisecond) + + err := repo.StoreRefreshToken(ctx, record) + require.NoError(t, err) + + got, err := repo.GetRefreshToken(ctx, record.TokenHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, record.ID, got.ID) + require.Equal(t, record.SessionID, got.SessionID) + require.Equal(t, record.TokenHash, got.TokenHash) + require.WithinDuration(t, record.ExpiresAt, got.ExpiresAt, time.Second) + require.False(t, got.IsRevoked) + require.Nil(t, got.RevokedAt) + require.Nil(t, got.LastReuseAttempt) + require.WithinDuration(t, now, got.CreatedAt, time.Second) + }) + } +} + +func TestRefreshTokenRepository_GetRefreshToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tokenHash string + wantNil bool + }{ + { + name: "returns token when found", + tokenHash: uuid.New().String(), + wantNil: false, + }, + { + name: "returns nil when not found", + tokenHash: "nonexistent-hash", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, repo := setupRefreshTokenRepo(t) + ctx := context.Background() + + if !tt.wantNil { + err := repo.StoreRefreshToken(ctx, &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: uuid.New().String(), + TokenHash: tt.tokenHash, + ExpiresAt: time.Now().Add(time.Hour), + }) + require.NoError(t, err) + } + + got, err := repo.GetRefreshToken(ctx, tt.tokenHash) + + require.NoError(t, err) + if tt.wantNil { + require.Nil(t, got) + } else { + require.NotNil(t, got) + require.Equal(t, tt.tokenHash, got.TokenHash) + } + }) + } +} + +func TestRefreshTokenRepository_RevokeRefreshToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + record *types.RefreshToken + }{ + { + name: "revokes token and sets revoked_at", + record: &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: uuid.New().String(), + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, repo := setupRefreshTokenRepo(t) + ctx := context.Background() + + err := repo.StoreRefreshToken(ctx, tt.record) + require.NoError(t, err) + + err = repo.RevokeRefreshToken(ctx, tt.record.TokenHash) + require.NoError(t, err) + + got, err := repo.GetRefreshToken(ctx, tt.record.TokenHash) + require.NoError(t, err) + require.NotNil(t, got) + require.True(t, got.IsRevoked) + require.NotNil(t, got.RevokedAt) + }) + } +} + +func TestRefreshTokenRepository_RevokeAllSessionTokens(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + targetSession string + otherSession string + }{ + { + name: "revokes only tokens for the target session", + targetSession: uuid.New().String(), + otherSession: uuid.New().String(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, repo := setupRefreshTokenRepo(t) + ctx := context.Background() + + token1 := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: tt.targetSession, + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour), + } + token2 := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: tt.targetSession, + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour), + } + token3 := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: tt.otherSession, + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour), + } + + require.NoError(t, repo.StoreRefreshToken(ctx, token1)) + require.NoError(t, repo.StoreRefreshToken(ctx, token2)) + require.NoError(t, repo.StoreRefreshToken(ctx, token3)) + + err := repo.RevokeAllSessionTokens(ctx, tt.targetSession) + require.NoError(t, err) + + got1, err := repo.GetRefreshToken(ctx, token1.TokenHash) + require.NoError(t, err) + require.True(t, got1.IsRevoked) + require.NotNil(t, got1.RevokedAt) + + got2, err := repo.GetRefreshToken(ctx, token2.TokenHash) + require.NoError(t, err) + require.True(t, got2.IsRevoked) + require.NotNil(t, got2.RevokedAt) + + got3, err := repo.GetRefreshToken(ctx, token3.TokenHash) + require.NoError(t, err) + require.NotNil(t, got3) + require.False(t, got3.IsRevoked) + require.Nil(t, got3.RevokedAt) + }) + } +} + +func TestRefreshTokenRepository_SetLastReuseAttempt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + record *types.RefreshToken + }{ + { + name: "sets last_reuse_attempt on token", + record: &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: uuid.New().String(), + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, repo := setupRefreshTokenRepo(t) + ctx := context.Background() + + err := repo.StoreRefreshToken(ctx, tt.record) + require.NoError(t, err) + + err = repo.SetLastReuseAttempt(ctx, tt.record.TokenHash) + require.NoError(t, err) + + got, err := repo.GetRefreshToken(ctx, tt.record.TokenHash) + require.NoError(t, err) + require.NotNil(t, got) + require.NotNil(t, got.LastReuseAttempt) + }) + } +} + +func TestRefreshTokenRepository_CleanupExpiredTokens(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expiredRecord *types.RefreshToken + validRecord *types.RefreshToken + }{ + { + name: "removes expired tokens and keeps valid ones", + expiredRecord: &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: uuid.New().String(), + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(-time.Hour), + }, + validRecord: &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: uuid.New().String(), + TokenHash: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, repo := setupRefreshTokenRepo(t) + ctx := context.Background() + + require.NoError(t, repo.StoreRefreshToken(ctx, tt.expiredRecord)) + require.NoError(t, repo.StoreRefreshToken(ctx, tt.validRecord)) + + err := repo.CleanupExpiredTokens(ctx) + require.NoError(t, err) + + gotExpired, err := repo.GetRefreshToken(ctx, tt.expiredRecord.TokenHash) + require.NoError(t, err) + require.Nil(t, gotExpired) + + gotValid, err := repo.GetRefreshToken(ctx, tt.validRecord.TokenHash) + require.NoError(t, err) + require.NotNil(t, gotValid) + }) + } +} diff --git a/plugins/jwt/services/blacklist_service.go b/plugins/jwt/services/blacklist_service.go index 47c07776..2276d02a 100644 --- a/plugins/jwt/services/blacklist_service.go +++ b/plugins/jwt/services/blacklist_service.go @@ -82,17 +82,39 @@ func (s *blacklistService) BlacklistAllSessionTokens(ctx context.Context, sessio return nil } +const ( + jwtBlacklistTokenPrefix = "jwt:blacklist:token:" + jwtBlacklistSessionPrefix = "jwt:blacklist:session:" +) + func (s *blacklistService) CleanupExpired(ctx context.Context) error { - // With storage TTL, cleanup happens automatically - // This method is a no-op for cache-based implementation - // If using database storage, implement cleanup logic here + prefixes := []string{jwtBlacklistTokenPrefix, jwtBlacklistSessionPrefix} + for _, prefix := range prefixes { + keys, err := s.storage.Scan(ctx, prefix) + if err != nil { + s.logger.Error("failed to scan blacklist keys", "prefix", prefix, "error", err) + return fmt.Errorf("failed to scan blacklist keys: %w", err) + } + for _, key := range keys { + ttl, err := s.storage.TTL(ctx, key) + if err != nil { + s.logger.Error("failed to check TTL for key", "key", key, "error", err) + continue + } + if ttl == nil || *ttl <= 0 { + if err := s.storage.Delete(ctx, key); err != nil { + s.logger.Error("failed to delete expired blacklist entry", "key", key, "error", err) + } + } + } + } return nil } func (s *blacklistService) blacklistKey(jti string) string { - return fmt.Sprintf("jwt:blacklist:token:%s", jti) + return fmt.Sprintf("%s%s", jwtBlacklistTokenPrefix, jti) } func (s *blacklistService) sessionBlacklistKey(sessionID string) string { - return fmt.Sprintf("jwt:blacklist:session:%s", sessionID) + return fmt.Sprintf("%s%s", jwtBlacklistSessionPrefix, sessionID) } diff --git a/plugins/jwt/services/blacklist_service_test.go b/plugins/jwt/services/blacklist_service_test.go new file mode 100644 index 00000000..1626c709 --- /dev/null +++ b/plugins/jwt/services/blacklist_service_test.go @@ -0,0 +1,276 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" +) + +func TestBlacklistService_BlacklistToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + jti string + expiresAt time.Time + setupMock func(*internaltests.MockSecondaryStorage) + wantErr string + }{ + { + name: "success", + jti: "test-jti", + expiresAt: time.Now().Add(1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Set", mock.Anything, "jwt:blacklist:token:test-jti", "1", mock.Anything).Return(nil) + }, + }, + { + name: "empty jti", + jti: "", + expiresAt: time.Now().Add(1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) {}, + wantErr: "jti cannot be empty", + }, + { + name: "expired token", + jti: "test-jti", + expiresAt: time.Now().Add(-1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) {}, + }, + { + name: "storage error", + jti: "test-jti", + expiresAt: time.Now().Add(1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Set", mock.Anything, "jwt:blacklist:token:test-jti", "1", mock.Anything).Return(assert.AnError) + }, + wantErr: "failed to blacklist token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + storage := new(internaltests.MockSecondaryStorage) + svc := NewBlacklistService(storage, &internaltests.MockLogger{}) + tt.setupMock(storage) + + err := svc.BlacklistToken(context.Background(), tt.jti, tt.expiresAt) + + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + storage.AssertExpectations(t) + }) + } +} + +func TestBlacklistService_IsBlacklisted(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + jti string + setupMock func(*internaltests.MockSecondaryStorage) + wantBlacklisted bool + wantErr string + }{ + { + name: "blacklisted", + jti: "test-jti", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Get", mock.Anything, "jwt:blacklist:token:test-jti").Return("1", nil) + }, + wantBlacklisted: true, + }, + { + name: "not blacklisted", + jti: "test-jti", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Get", mock.Anything, "jwt:blacklist:token:test-jti").Return(nil, nil) + }, + }, + { + name: "empty jti", + jti: "", + setupMock: func(storage *internaltests.MockSecondaryStorage) {}, + }, + { + name: "storage error", + jti: "test-jti", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Get", mock.Anything, "jwt:blacklist:token:test-jti").Return(nil, assert.AnError) + }, + wantErr: "failed to check blacklist", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + storage := new(internaltests.MockSecondaryStorage) + svc := NewBlacklistService(storage, &internaltests.MockLogger{}) + tt.setupMock(storage) + + blacklisted, err := svc.IsBlacklisted(context.Background(), tt.jti) + + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.wantBlacklisted, blacklisted) + storage.AssertExpectations(t) + }) + } +} + +func TestBlacklistService_BlacklistAllSessionTokens(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sessionID string + expiresAt time.Time + setupMock func(*internaltests.MockSecondaryStorage) + wantErr string + }{ + { + name: "success", + sessionID: "sess-1", + expiresAt: time.Now().Add(1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Set", mock.Anything, "jwt:blacklist:session:sess-1", "1", mock.Anything).Return(nil) + }, + }, + { + name: "empty sessionID", + sessionID: "", + expiresAt: time.Now().Add(1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) {}, + wantErr: "sessionID cannot be empty", + }, + { + name: "expired", + sessionID: "sess-1", + expiresAt: time.Now().Add(-1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) {}, + }, + { + name: "storage error", + sessionID: "sess-1", + expiresAt: time.Now().Add(1 * time.Hour), + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Set", mock.Anything, "jwt:blacklist:session:sess-1", "1", mock.Anything).Return(assert.AnError) + }, + wantErr: "failed to blacklist session tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + storage := new(internaltests.MockSecondaryStorage) + svc := NewBlacklistService(storage, &internaltests.MockLogger{}) + tt.setupMock(storage) + + err := svc.BlacklistAllSessionTokens(context.Background(), tt.sessionID, tt.expiresAt) + + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + storage.AssertExpectations(t) + }) + } +} + +func TestBlacklistService_CleanupExpired(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupMock func(*internaltests.MockSecondaryStorage) + wantErr string + }{ + { + name: "no keys to clean", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Scan", mock.Anything, "jwt:blacklist:token:").Return([]string{}, nil) + storage.On("Scan", mock.Anything, "jwt:blacklist:session:").Return([]string{}, nil) + }, + }, + { + name: "scan error on first prefix", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Scan", mock.Anything, "jwt:blacklist:token:").Return([]string{}, assert.AnError) + }, + wantErr: "failed to scan blacklist keys", + }, + { + name: "expired keys deleted", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Scan", mock.Anything, "jwt:blacklist:token:").Return([]string{"jwt:blacklist:token:key1"}, nil) + storage.On("TTL", mock.Anything, "jwt:blacklist:token:key1").Return(nil, nil) + storage.On("Delete", mock.Anything, "jwt:blacklist:token:key1").Return(nil) + storage.On("Scan", mock.Anything, "jwt:blacklist:session:").Return([]string{}, nil) + }, + }, + { + name: "non-expired keys skipped", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + dur := time.Hour + storage.On("Scan", mock.Anything, "jwt:blacklist:token:").Return([]string{"jwt:blacklist:token:key1"}, nil) + storage.On("TTL", mock.Anything, "jwt:blacklist:token:key1").Return(&dur, nil) + storage.On("Scan", mock.Anything, "jwt:blacklist:session:").Return([]string{}, nil) + }, + }, + { + name: "TTL error continues", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Scan", mock.Anything, "jwt:blacklist:token:").Return([]string{"jwt:blacklist:token:key1"}, nil) + storage.On("TTL", mock.Anything, "jwt:blacklist:token:key1").Return(nil, assert.AnError) + storage.On("Scan", mock.Anything, "jwt:blacklist:session:").Return([]string{}, nil) + }, + }, + { + name: "Delete error continues", + setupMock: func(storage *internaltests.MockSecondaryStorage) { + storage.On("Scan", mock.Anything, "jwt:blacklist:token:").Return([]string{"jwt:blacklist:token:key1"}, nil) + storage.On("TTL", mock.Anything, "jwt:blacklist:token:key1").Return(nil, nil) + storage.On("Delete", mock.Anything, "jwt:blacklist:token:key1").Return(assert.AnError) + storage.On("Scan", mock.Anything, "jwt:blacklist:session:").Return([]string{}, nil) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + storage := new(internaltests.MockSecondaryStorage) + svc := NewBlacklistService(storage, &internaltests.MockLogger{}) + tt.setupMock(storage) + + err := svc.CleanupExpired(context.Background()) + + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + storage.AssertExpectations(t) + }) + } +} diff --git a/plugins/jwt/services/cache_service.go b/plugins/jwt/services/cache_service.go index 6854358d..7c839d08 100644 --- a/plugins/jwt/services/cache_service.go +++ b/plugins/jwt/services/cache_service.go @@ -22,7 +22,6 @@ type cacheService struct { cacheTTL time.Duration } -// NewCacheService creates a new cache service func NewCacheService(repo repositories.JWKSRepository, secondaryStorage models.SecondaryStorage, logger models.Logger, cacheTTL time.Duration) CacheService { return &cacheService{ repo: repo, @@ -32,7 +31,6 @@ func NewCacheService(repo repositories.JWKSRepository, secondaryStorage models.S } } -// GetCachedJWKS retrieves JWKS from cache if available and not expired func (s *cacheService) GetCachedJWKS(ctx context.Context) (jwk.Set, error) { if s.secondaryStorage == nil { return nil, errors.New("secondary storage not available") @@ -57,7 +55,6 @@ func (s *cacheService) GetCachedJWKS(ctx context.Context) (jwk.Set, error) { return set, nil } -// FetchJWKSFromDatabase loads all non-expired public keys from the database func (s *cacheService) FetchJWKSFromDatabase(ctx context.Context) (jwk.Set, error) { jwksKeys, err := s.repo.GetJWKSKeys(ctx) if err != nil { @@ -72,25 +69,8 @@ func (s *cacheService) FetchJWKSFromDatabase(ctx context.Context) (jwk.Set, erro continue } - // Set the Key ID so JWT validation can match the token's kid to the correct key _ = pubKey.Set(jwk.KeyIDKey, wk.ID) - - // Ensure algorithm is properly set based on key type - // This helps the JWT library know which algorithm to use for verification - keyType := pubKey.KeyType().String() - var alg string - switch keyType { - case "OKP": - alg = "EdDSA" - case "RSA": - alg = "RS256" - case "EC": - alg = "ES256" - } - if alg != "" { - _ = pubKey.Set(jwk.AlgorithmKey, alg) - } - + _ = pubKey.Set(jwk.AlgorithmKey, "EdDSA") _ = set.AddKey(pubKey) } @@ -101,7 +81,6 @@ func (s *cacheService) FetchJWKSFromDatabase(ctx context.Context) (jwk.Set, erro return set, nil } -// CacheJWKS stores the JWKS in the cache with the configured TTL func (s *cacheService) CacheJWKS(ctx context.Context, set jwk.Set) error { if s.secondaryStorage == nil { return nil @@ -119,7 +98,6 @@ func (s *cacheService) CacheJWKS(ctx context.Context, set jwk.Set) error { return nil } -// InvalidateCache removes the cached JWKS immediately and fetches fresh from DB func (s *cacheService) InvalidateCache(ctx context.Context) error { if s.secondaryStorage == nil { return nil @@ -138,7 +116,6 @@ func (s *cacheService) InvalidateCache(ctx context.Context) error { return s.CacheJWKS(ctx, set) } -// GetJWKSWithFallback retrieves JWKS from cache with database fallback func (s *cacheService) GetJWKSWithFallback(ctx context.Context) (jwk.Set, error) { set, err := s.GetCachedJWKS(ctx) if err == nil { diff --git a/plugins/jwt/services/cache_service_test.go b/plugins/jwt/services/cache_service_test.go new file mode 100644 index 00000000..790d18db --- /dev/null +++ b/plugins/jwt/services/cache_service_test.go @@ -0,0 +1,364 @@ +package services + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/jwt/migrationset" + "github.com/Authula/authula/plugins/jwt/repositories" + jwttests "github.com/Authula/authula/plugins/jwt/tests" + "github.com/Authula/authula/plugins/jwt/types" +) + +type cacheTestFixture struct { + db *bun.DB + repo repositories.JWKSRepository + storage *jwttests.InMemoryStorage + logger models.Logger + ttl time.Duration +} + +func newCacheTestFixture(t *testing.T) *cacheTestFixture { + t.Helper() + db := internaltests.NewSQLiteIntegrationDB(t) + + migrator, err := migrations.NewMigrator(db, &internaltests.MockLogger{}) + require.NoError(t, err) + err = migrator.Migrate(context.Background(), []migrations.MigrationSet{ + { + PluginID: models.PluginJWT.String(), + Migrations: migrationset.JWTMigrationsForProvider("sqlite"), + }, + }) + require.NoError(t, err) + + return &cacheTestFixture{ + db: db, + repo: repositories.NewBunJWKSRepository(db), + storage: jwttests.NewInMemoryStorage(), + logger: &internaltests.MockLogger{}, + ttl: 24 * time.Hour, + } +} + +func (f *cacheTestFixture) newCacheService() CacheService { + return &cacheService{ + repo: f.repo, + secondaryStorage: f.storage, + logger: f.logger, + cacheTTL: f.ttl, + } +} + +func generateTestJWKS(t *testing.T) *types.JWKS { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privBytes, _ := x509.MarshalPKCS8PrivateKey(priv) + privPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + + pubBytes, _ := x509.MarshalPKIXPublicKey(priv.Public()) + pubPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}) + + return &types.JWKS{ + ID: uuid.New().String(), + PublicKey: string(pubPEM), + PrivateKey: string(privPEM), + CreatedAt: time.Now(), + } +} + +func seedJWKSKeys(t *testing.T, ctx context.Context, repo repositories.JWKSRepository, count int) []*types.JWKS { + t.Helper() + keys := make([]*types.JWKS, count) + for i := range count { + key := generateTestJWKS(t) + err := repo.StoreJWKSKey(ctx, key) + require.NoError(t, err) + keys[i] = key + } + return keys +} + +func TestCacheService_GetCachedJWKS(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupSvc func(*cacheTestFixture) CacheService + setupMock func(*cacheTestFixture) + wantErr string + }{ + { + name: "no secondary storage", + setupSvc: func(f *cacheTestFixture) CacheService { + return &cacheService{ + repo: f.repo, + secondaryStorage: nil, + logger: f.logger, + cacheTTL: f.ttl, + } + }, + wantErr: "secondary storage not available", + }, + { + name: "empty cache", + setupSvc: func(f *cacheTestFixture) CacheService { + return &cacheService{ + repo: f.repo, + secondaryStorage: f.storage, + logger: f.logger, + cacheTTL: f.ttl, + } + }, + wantErr: "cached JWKS is empty or invalid type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newCacheTestFixture(t) + svc := tt.setupSvc(f) + + _, err := svc.GetCachedJWKS(context.Background()) + + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestCacheService_FetchJWKSFromDatabase(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(context.Context, *cacheTestFixture) + wantKeys int + wantErr string + }{ + { + name: "no keys in database", + setup: func(ctx context.Context, f *cacheTestFixture) {}, + wantErr: "no valid keys found", + }, + { + name: "keys found", + setup: func(ctx context.Context, f *cacheTestFixture) { + seedJWKSKeys(t, ctx, f.repo, 2) + }, + wantKeys: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newCacheTestFixture(t) + svc := f.newCacheService() + ctx := context.Background() + + tt.setup(ctx, f) + + set, err := svc.FetchJWKSFromDatabase(ctx) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + require.NotNil(t, set) + require.Equal(t, tt.wantKeys, set.Len()) + } + }) + } +} + +func TestCacheService_GetJWKSWithFallback(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(context.Context, *cacheTestFixture) + preCache bool + wantErr string + wantKeyLen int + }{ + { + name: "cache miss", + setup: func(ctx context.Context, f *cacheTestFixture) { + seedJWKSKeys(t, ctx, f.repo, 1) + }, + wantKeyLen: 1, + }, + { + name: "cache hit", + setup: func(ctx context.Context, f *cacheTestFixture) { + seedJWKSKeys(t, ctx, f.repo, 1) + }, + preCache: true, + wantKeyLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newCacheTestFixture(t) + svc := f.newCacheService() + ctx := context.Background() + + tt.setup(ctx, f) + + // Pre-populate cache if requested + if tt.preCache { + _, err := svc.GetJWKSWithFallback(ctx) + require.NoError(t, err) + } + + set, err := svc.GetJWKSWithFallback(ctx) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + require.NotNil(t, set) + require.Equal(t, tt.wantKeyLen, set.Len()) + } + }) + } +} + +func TestCacheService_InvalidateCache(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + noStorage bool + setup func(context.Context, *cacheTestFixture) + wantErr string + }{ + { + name: "no secondary storage", + noStorage: true, + setup: func(ctx context.Context, f *cacheTestFixture) {}, + }, + { + name: "success", + setup: func(ctx context.Context, f *cacheTestFixture) { + seedJWKSKeys(t, ctx, f.repo, 2) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newCacheTestFixture(t) + + var svc CacheService + if tt.noStorage { + svc = &cacheService{ + repo: f.repo, + secondaryStorage: nil, + logger: f.logger, + cacheTTL: f.ttl, + } + } else { + svc = f.newCacheService() + } + + ctx := context.Background() + tt.setup(ctx, f) + + if tt.noStorage { + err := svc.InvalidateCache(ctx) + require.NoError(t, err) + return + } + + // Populate cache first + _, err := svc.GetJWKSWithFallback(ctx) + require.NoError(t, err) + + err = svc.InvalidateCache(ctx) + require.NoError(t, err) + + // Cache should be repopulated + set, err := svc.GetCachedJWKS(ctx) + require.NoError(t, err) + require.NotNil(t, set) + }) + } +} + +func TestCacheService_CacheJWKS(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + noStorage bool + wantErr string + }{ + { + name: "no secondary storage", + noStorage: true, + }, + { + name: "success", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newCacheTestFixture(t) + + var svc CacheService + if tt.noStorage { + svc = &cacheService{ + repo: f.repo, + secondaryStorage: nil, + logger: f.logger, + cacheTTL: f.ttl, + } + } else { + svc = f.newCacheService() + } + + ctx := context.Background() + seedJWKSKeys(t, ctx, f.repo, 1) + set, err := svc.FetchJWKSFromDatabase(ctx) + require.NoError(t, err) + + err = svc.CacheJWKS(ctx, set) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/plugins/jwt/services/interfaces.go b/plugins/jwt/services/interfaces.go index dc77b62b..934a5635 100644 --- a/plugins/jwt/services/interfaces.go +++ b/plugins/jwt/services/interfaces.go @@ -9,44 +9,21 @@ import ( "github.com/Authula/authula/plugins/jwt/types" ) -// RefreshTokenResponse contains the result of a token refresh operation -type RefreshTokenResponse struct { - AccessToken string - RefreshToken string +type TokenService interface { + GenerateUserToken(ctx context.Context, userID string, sessionID string) (*types.TokenPair, error) + GenerateMachineToken(ctx context.Context, clientID string, organizationID string, scopes []string) (*types.TokenPair, error) } -// JwtService defines the JWT operations -type JwtService interface { - GenerateTokens(ctx context.Context, userID string, sessionID string) (*types.TokenPair, error) -} - -// KeyService manages cryptographic key generation, rotation, and retrieval type KeyService interface { - // GenerateKeysIfMissing generates the initial key pair if none exist in the database GenerateKeysIfMissing(ctx context.Context) error - - // GetActiveKey retrieves the currently active (non-expired) key GetActiveKey(ctx context.Context) (*types.JWKS, error) - - // IsKeyRotationDue returns true if the active key's age exceeds the rotation interval IsKeyRotationDue(ctx context.Context, rotationInterval time.Duration) bool - // RotateKeysIfNeeded rotates keys if they're past the rotation interval // gracePeriod specifies how long old keys remain valid after rotation // Returns true if rotation occurred, false otherwise RotateKeysIfNeeded(ctx context.Context, rotationInterval time.Duration, gracePeriod time.Duration, invalidateCacheFunc func(context.Context) error) (bool, error) } -// RefreshTokenStorage defines storage operations for refresh tokens -type RefreshTokenStorage interface { - StoreRefreshToken(ctx context.Context, record *types.RefreshToken) error - GetRefreshToken(ctx context.Context, tokenHash string) (*types.RefreshToken, error) - RevokeRefreshToken(ctx context.Context, tokenHash string) error - SetLastReuseAttempt(ctx context.Context, tokenHash string) error - RevokeAllSessionTokens(ctx context.Context, sessionID string) error -} - -// RefreshTokenRepository defines data access operations for refresh tokens type RefreshTokenRepository interface { StoreRefreshToken(ctx context.Context, record *types.RefreshToken) error GetRefreshToken(ctx context.Context, tokenHash string) (*types.RefreshToken, error) @@ -56,44 +33,22 @@ type RefreshTokenRepository interface { CleanupExpiredTokens(ctx context.Context) error } -// RefreshTokenService handles refresh token operations type RefreshTokenService interface { - // RefreshTokens refreshes the access and refresh tokens using the provided refresh token - RefreshTokens(ctx context.Context, refreshToken string) (*RefreshTokenResponse, error) - - // StoreInitialRefreshToken stores the initial refresh token along with its session ID and expiration time + RefreshTokens(ctx context.Context, refreshToken string) (*types.RefreshTokenResponse, error) StoreInitialRefreshToken(ctx context.Context, refreshToken string, sessionID string, expiresAt time.Time) error } -// BlacklistService handles token blacklisting/revocation type BlacklistService interface { - // BlacklistToken adds a token JTI to the blacklist with TTL BlacklistToken(ctx context.Context, jti string, expiresAt time.Time) error - - // IsBlacklisted checks if a token JTI is blacklisted IsBlacklisted(ctx context.Context, jti string) (bool, error) - - // BlacklistAllSessionTokens blacklists all tokens for a session BlacklistAllSessionTokens(ctx context.Context, sessionID string, expiresAt time.Time) error - - // CleanupExpired removes expired blacklist entries (for non-TTL stores) CleanupExpired(ctx context.Context) error } -// CacheService manages JWKS caching with database fallback type CacheService interface { - // GetCachedJWKS retrieves JWKS from cache if available and not expired GetCachedJWKS(ctx context.Context) (jwk.Set, error) - - // FetchJWKSFromDatabase loads all non-expired public keys from the database FetchJWKSFromDatabase(ctx context.Context) (jwk.Set, error) - - // CacheJWKS stores the JWKS in the cache with the configured TTL CacheJWKS(ctx context.Context, set jwk.Set) error - - // InvalidateCache removes the cached JWKS immediately and fetches fresh from DB InvalidateCache(ctx context.Context) error - - // GetJWKSWithFallback retrieves JWKS from cache with database fallback GetJWKSWithFallback(ctx context.Context) (jwk.Set, error) } diff --git a/plugins/jwt/services/jwt_service.go b/plugins/jwt/services/jwt_service.go deleted file mode 100644 index 598b8ad5..00000000 --- a/plugins/jwt/services/jwt_service.go +++ /dev/null @@ -1,241 +0,0 @@ -package services - -import ( - "context" - "errors" - "fmt" - "log/slog" - "time" - - "github.com/google/uuid" - "github.com/lestrrat-go/jwx/v3/jwa" - "github.com/lestrrat-go/jwx/v3/jwk" - "github.com/lestrrat-go/jwx/v3/jwt" - - "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/jwt/types" - "github.com/Authula/authula/services" -) - -// JWTServiceImpl is the concrete implementation of the JWTService interface -type JWTServiceImpl struct { - logger models.Logger - tokenService services.TokenService - keyService KeyService - cacheService CacheService - blacklistService BlacklistService - sessionService services.SessionService - expiresIn time.Duration - refreshExpiresIn time.Duration -} - -// NewJWTService creates a new JWT service implementation -func NewJWTService( - logger models.Logger, - sessionService services.SessionService, - tokenService services.TokenService, - keyService KeyService, - cacheService CacheService, - blacklistService BlacklistService, - expiresIn time.Duration, - refreshExpiresIn time.Duration, -) services.JWTService { - return &JWTServiceImpl{ - logger: logger, - sessionService: sessionService, - tokenService: tokenService, - keyService: keyService, - cacheService: cacheService, - blacklistService: blacklistService, - expiresIn: expiresIn, - refreshExpiresIn: refreshExpiresIn, - } -} - -// GenerateTokens creates access and refresh JWT tokens tied to a session -func (s *JWTServiceImpl) GenerateTokens(ctx context.Context, userID string, sessionID string) (*types.TokenPair, error) { - if sessionID == "" { - return nil, errors.New("session id is required to generate tokens") - } - - jwksKey, err := s.keyService.GetActiveKey(context.Background()) - if err != nil { - return nil, fmt.Errorf("failed to get active key: %w", err) - } - - privateKeyPEM, err := s.tokenService.Decrypt(jwksKey.PrivateKey) - if err != nil { - return nil, fmt.Errorf("failed to decrypt private key: %w", err) - } - - privKey, err := jwk.ParseKey([]byte(privateKeyPEM), jwk.WithPEM(true)) - if err != nil { - return nil, fmt.Errorf("failed to parse private key: %w", err) - } - - // Set the Key ID (kid) on the key so it's included in the JWT header - if err := privKey.Set(jwk.KeyIDKey, jwksKey.ID); err != nil { - return nil, fmt.Errorf("failed to set key ID: %w", err) - } - - keyAlgorithm := s.detectAlgorithmFromKey(privKey) - - now := time.Now() - jti := uuid.New().String() - - accessClaims := jwt.New() - if err := accessClaims.Set(jwt.SubjectKey, userID); err != nil { - return nil, fmt.Errorf("failed to set subject: %w", err) - } - if err := accessClaims.Set(jwt.IssuedAtKey, now); err != nil { - return nil, fmt.Errorf("failed to set issued at: %w", err) - } - if err := accessClaims.Set(jwt.ExpirationKey, now.Add(s.expiresIn)); err != nil { - return nil, fmt.Errorf("failed to set expiration: %w", err) - } - if err := accessClaims.Set(jwt.JwtIDKey, jti); err != nil { - return nil, fmt.Errorf("failed to set JWT ID: %w", err) - } - if err := accessClaims.Set("user_id", userID); err != nil { - return nil, fmt.Errorf("failed to set user_id: %w", err) - } - if err := accessClaims.Set("session_id", sessionID); err != nil { - return nil, fmt.Errorf("failed to set session_id: %w", err) - } - if err := accessClaims.Set("type", types.JWTTokenTypeAccess.String()); err != nil { - return nil, fmt.Errorf("failed to set type: %w", err) - } - - accessTokenBytes, err := jwt.Sign(accessClaims, jwt.WithKey(keyAlgorithm, privKey)) - if err != nil { - return nil, fmt.Errorf("failed to sign access token: %w", err) - } - - refreshClaims := jwt.New() - if err := refreshClaims.Set(jwt.SubjectKey, userID); err != nil { - return nil, fmt.Errorf("failed to set subject in refresh token: %w", err) - } - if err := refreshClaims.Set(jwt.IssuedAtKey, now); err != nil { - return nil, fmt.Errorf("failed to set issued at in refresh token: %w", err) - } - if err := refreshClaims.Set(jwt.ExpirationKey, now.Add(s.refreshExpiresIn)); err != nil { - return nil, fmt.Errorf("failed to set expiration in refresh token: %w", err) - } - if err := refreshClaims.Set(jwt.JwtIDKey, jti); err != nil { - return nil, fmt.Errorf("failed to set JWT ID in refresh token: %w", err) - } - if err := refreshClaims.Set("user_id", userID); err != nil { - return nil, fmt.Errorf("failed to set user_id in refresh token: %w", err) - } - if err := refreshClaims.Set("session_id", sessionID); err != nil { - return nil, fmt.Errorf("failed to set session_id in refresh token: %w", err) - } - if err := refreshClaims.Set("type", types.JWTTokenTypeRefresh.String()); err != nil { - return nil, fmt.Errorf("failed to set type in refresh token: %w", err) - } - - refreshTokenBytes, err := jwt.Sign(refreshClaims, jwt.WithKey(keyAlgorithm, privKey)) - if err != nil { - return nil, fmt.Errorf("failed to sign refresh token: %w", err) - } - - return &types.TokenPair{ - AccessToken: string(accessTokenBytes), - RefreshToken: string(refreshTokenBytes), - ExpiresIn: s.expiresIn, - TokenType: "Bearer", - }, nil -} - -// ValidateToken validates a JWT token and ensures the referenced session is still active -func (s *JWTServiceImpl) ValidateToken(token string) (userID string, err error) { - jwkSet, err := s.cacheService.GetJWKSWithFallback(context.Background()) - if err != nil { - return "", fmt.Errorf("failed to get JWKS: %w", err) - } - - parsedToken, err := jwt.Parse([]byte(token), jwt.WithKeySet(jwkSet), jwt.WithValidate(true)) - if err != nil { - return "", fmt.Errorf("failed to parse token: %w", err) - } - - jti, ok := parsedToken.JwtID() - if ok && jti != "" && s.blacklistService != nil { - isBlacklisted, err := s.blacklistService.IsBlacklisted(context.Background(), jti) - if err != nil { - // Don't fail validation on blacklist check error, but continue - } else if isBlacklisted { - return "", errors.New("token has been revoked") - } - } - - var tokenType string - if err := parsedToken.Get("type", &tokenType); err != nil { - slog.Debug("parsedToken", "token", parsedToken) - return "", errors.New("missing token type claim") - } - - if tokenType != types.JWTTokenTypeAccess.String() { - return "", errors.New("invalid token type") - } - - var extractedUserID string - if err := parsedToken.Get("user_id", &extractedUserID); err != nil { - return "", errors.New("missing user_id claim") - } - - if extractedUserID == "" { - return "", errors.New("missing user_id claim") - } - - var sessionID string - if err := parsedToken.Get("session_id", &sessionID); err != nil { - return "", errors.New("missing session_id claim") - } - - if sessionID == "" { - return "", errors.New("missing session_id claim") - } - - if s.blacklistService != nil { - isBlacklisted, err := s.blacklistService.IsBlacklisted(context.Background(), "session:"+sessionID) - if err != nil { - // Don't fail validation on blacklist check error, but continue - } else if isBlacklisted { - return "", errors.New("session has been revoked") - } - } - - // Ensure the session is still active - session, err := s.sessionService.GetByID(context.Background(), sessionID) - if err != nil || session == nil { - return "", errors.New("session not found or invalid") - } - - return extractedUserID, nil -} - -func (s *JWTServiceImpl) detectAlgorithmFromKey(k jwk.Key) jwa.SignatureAlgorithm { - if alg, ok := k.Algorithm(); ok { - if sigAlg, ok := alg.(jwa.SignatureAlgorithm); ok { - return sigAlg - } - } - - keyType := k.KeyType().String() - var detectedAlg jwa.SignatureAlgorithm - switch keyType { - case "OKP": - detectedAlg = jwa.EdDSA() - case "RSA": - detectedAlg = jwa.RS256() - case "EC": - detectedAlg = jwa.ES256() - case "oct": - detectedAlg = jwa.HS256() - default: - detectedAlg = jwa.EdDSA() - } - - return detectedAlg -} diff --git a/plugins/jwt/services/key_service.go b/plugins/jwt/services/key_service.go index 744da340..b5199012 100644 --- a/plugins/jwt/services/key_service.go +++ b/plugins/jwt/services/key_service.go @@ -2,11 +2,8 @@ package services import ( "context" - "crypto/ecdsa" "crypto/ed25519" - "crypto/elliptic" "crypto/rand" - "crypto/rsa" "crypto/x509" "encoding/pem" "errors" @@ -25,17 +22,14 @@ type keyService struct { repo repositories.JWKSRepository logger models.Logger secret string - algorithm types.JWTAlgorithm tokenService coreservices.TokenService } -// NewKeyService creates a new key service -func NewKeyService(repo repositories.JWKSRepository, logger models.Logger, tokenService coreservices.TokenService, secret string, algorithm types.JWTAlgorithm) KeyService { +func NewKeyService(repo repositories.JWKSRepository, logger models.Logger, tokenService coreservices.TokenService, secret string) KeyService { return &keyService{ repo: repo, logger: logger, secret: secret, - algorithm: algorithm, tokenService: tokenService, } } @@ -121,59 +115,19 @@ func (s *keyService) RotateKeysIfNeeded(ctx context.Context, rotationInterval ti return true, nil } -// generateKey returns a newly generated private/public key pair for the given algorithm -func generateKey(alg types.JWTAlgorithm) (priv any, pub any, err error) { - switch alg { - case types.JWTAlgRS256, types.JWTAlgPS256: - priv, err = rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - pub = &priv.(*rsa.PrivateKey).PublicKey - return - - case types.JWTAlgES256: - priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, nil, err - } - pub = &priv.(*ecdsa.PrivateKey).PublicKey - return - - case types.JWTAlgES512: - priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - if err != nil { - return nil, nil, err - } - pub = &priv.(*ecdsa.PrivateKey).PublicKey - return - - case types.JWTAlgEdDSA: - var seed [32]byte - if _, err := rand.Read(seed[:]); err != nil { - return nil, nil, fmt.Errorf("failed to read random seed: %w", err) - } - priv = ed25519.NewKeyFromSeed(seed[:]) - pub = priv.(ed25519.PrivateKey).Public() - return - - case types.JWTAlgECDHES: - // ECDH-ES uses EC P-256 keys for key agreement (future JWE) - priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, nil, err - } - pub = &priv.(*ecdsa.PrivateKey).PublicKey - return - - default: - return nil, nil, fmt.Errorf("unsupported algorithm: %s", alg) +func generateKey() (priv any, pub any, err error) { + var seed [32]byte + if _, err := rand.Read(seed[:]); err != nil { + return nil, nil, fmt.Errorf("failed to read random seed: %w", err) } + priv = ed25519.NewKeyFromSeed(seed[:]) + pub = priv.(ed25519.PrivateKey).Public() + return } // generateAndStoreKey generates a key pair and stores it in the database func (s *keyService) generateAndStoreKey(ctx context.Context) error { - privKey, pubKey, err := generateKey(s.algorithm) + privKey, pubKey, err := generateKey() if err != nil { return fmt.Errorf("failed to generate key pair: %w", err) } @@ -205,7 +159,7 @@ func (s *keyService) generateAndStoreKey(ctx context.Context) error { return fmt.Errorf("failed to store key: %w", err) } - s.logger.Info("generated and stored key", "id", jwksKey.ID, "algorithm", s.algorithm.String()) + s.logger.Info("generated and stored key", "id", jwksKey.ID, "algorithm", "EdDSA") return nil } @@ -215,12 +169,6 @@ func privateKeyToPEM(privKey any) ([]byte, error) { var keyType string switch pk := privKey.(type) { - case *rsa.PrivateKey: - keyBytes, _ = x509.MarshalPKCS8PrivateKey(pk) - keyType = "PRIVATE KEY" - case *ecdsa.PrivateKey: - keyBytes, _ = x509.MarshalPKCS8PrivateKey(pk) - keyType = "PRIVATE KEY" case ed25519.PrivateKey: keyBytes, _ = x509.MarshalPKCS8PrivateKey(pk) keyType = "PRIVATE KEY" @@ -242,12 +190,6 @@ func publicKeyToPEM(pubKey any) ([]byte, error) { var keyType string switch pk := pubKey.(type) { - case *rsa.PublicKey: - keyBytes, _ = x509.MarshalPKIXPublicKey(pk) - keyType = "PUBLIC KEY" - case *ecdsa.PublicKey: - keyBytes, _ = x509.MarshalPKIXPublicKey(pk) - keyType = "PUBLIC KEY" case ed25519.PublicKey: keyBytes, _ = x509.MarshalPKIXPublicKey(pk) keyType = "PUBLIC KEY" diff --git a/plugins/jwt/services/key_service_test.go b/plugins/jwt/services/key_service_test.go new file mode 100644 index 00000000..b35d5387 --- /dev/null +++ b/plugins/jwt/services/key_service_test.go @@ -0,0 +1,247 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/Authula/authula/models" + coreservices "github.com/Authula/authula/services" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/plugins/jwt/migrationset" + "github.com/Authula/authula/plugins/jwt/repositories" + jwttests "github.com/Authula/authula/plugins/jwt/tests" + "github.com/Authula/authula/plugins/jwt/types" +) + +func setupKeyServiceTest(t *testing.T) (KeyService, repositories.JWKSRepository) { + t.Helper() + db := internaltests.NewSQLiteIntegrationDB(t) + + migrator, err := migrations.NewMigrator(db, &internaltests.MockLogger{}) + require.NoError(t, err) + err = migrator.Migrate(context.Background(), []migrations.MigrationSet{ + { + PluginID: models.PluginJWT.String(), + Migrations: migrationset.JWTMigrationsForProvider("sqlite"), + }, + }) + require.NoError(t, err) + + repo := repositories.NewBunJWKSRepository(db) + logger := &internaltests.MockLogger{} + coreTokenSvc := jwttests.NopTokenService{} + + svc := NewKeyService(repo, logger, coreservices.TokenService(coreTokenSvc), "test-secret") + return svc, repo +} + +func TestKeyService_GenerateKeysIfMissing(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(context.Context, repositories.JWKSRepository) + wantNewID bool + }{ + { + name: "no keys", + setup: func(ctx context.Context, repo repositories.JWKSRepository) {}, + wantNewID: true, + }, + { + name: "keys exist", + setup: func(ctx context.Context, repo repositories.JWKSRepository) { + err := repo.StoreJWKSKey(ctx, &types.JWKS{ + ID: "pre-seeded-key", + PublicKey: "pre-seeded-public-key", + PrivateKey: "pre-seeded-private-key", + }) + require.NoError(t, err) + }, + wantNewID: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + svc, repo := setupKeyServiceTest(t) + ctx := context.Background() + + tt.setup(ctx, repo) + + err := svc.GenerateKeysIfMissing(ctx) + require.NoError(t, err) + + keys, err := repo.GetJWKSKeys(ctx) + require.NoError(t, err) + + if tt.wantNewID { + require.Len(t, keys, 1) + require.NotEmpty(t, keys[0].ID) + assert.Contains(t, keys[0].PublicKey, "BEGIN PUBLIC KEY") + assert.Contains(t, keys[0].PrivateKey, "BEGIN PRIVATE KEY") + } else { + require.Len(t, keys, 1) + require.Equal(t, "pre-seeded-key", keys[0].ID) + } + }) + } +} + +func TestKeyService_GetActiveKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(context.Context, repositories.JWKSRepository) + wantID string + wantErr string + }{ + { + name: "no keys", + setup: func(ctx context.Context, repo repositories.JWKSRepository) {}, + wantErr: "no active key found", + }, + { + name: "single key", + setup: func(ctx context.Context, repo repositories.JWKSRepository) { + err := repo.StoreJWKSKey(ctx, &types.JWKS{ + ID: "single-key", + PublicKey: "public-key-1", + PrivateKey: "private-key-1", + CreatedAt: time.Now(), + }) + require.NoError(t, err) + }, + wantID: "single-key", + }, + { + name: "returns most recent key", + setup: func(ctx context.Context, repo repositories.JWKSRepository) { + err := repo.StoreJWKSKey(ctx, &types.JWKS{ + ID: "old-key", + PublicKey: "public-key-old", + PrivateKey: "private-key-old", + CreatedAt: time.Now().Add(-1 * time.Hour), + }) + require.NoError(t, err) + + err = repo.StoreJWKSKey(ctx, &types.JWKS{ + ID: "new-key", + PublicKey: "public-key-new", + PrivateKey: "private-key-new", + CreatedAt: time.Now(), + }) + require.NoError(t, err) + }, + wantID: "new-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + svc, repo := setupKeyServiceTest(t) + ctx := context.Background() + + tt.setup(ctx, repo) + + active, err := svc.GetActiveKey(ctx) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + require.Nil(t, active) + } else { + require.NoError(t, err) + require.NotNil(t, active) + require.Equal(t, tt.wantID, active.ID) + } + }) + } +} + +func TestKeyService_RotateKeysIfNeeded(t *testing.T) { + t.Parallel() + + rotationInterval := 24 * time.Hour + gracePeriod := 1 * time.Hour + + tests := []struct { + name string + setup func(context.Context, repositories.JWKSRepository) + wantRotated bool + wantErr string + }{ + { + name: "rotation due", + setup: func(ctx context.Context, repo repositories.JWKSRepository) { + err := repo.StoreJWKSKey(ctx, &types.JWKS{ + ID: "old-key", + PublicKey: "old-public-key", + PrivateKey: "old-private-key", + CreatedAt: time.Now().Add(-25 * time.Hour), + }) + require.NoError(t, err) + }, + wantRotated: true, + }, + { + name: "not due", + setup: func(ctx context.Context, repo repositories.JWKSRepository) { + err := repo.StoreJWKSKey(ctx, &types.JWKS{ + ID: "recent-key", + PublicKey: "recent-public-key", + PrivateKey: "recent-private-key", + CreatedAt: time.Now().Add(-1 * time.Hour), + }) + require.NoError(t, err) + }, + wantRotated: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + svc, repo := setupKeyServiceTest(t) + ctx := context.Background() + + tt.setup(ctx, repo) + + rotated, err := svc.RotateKeysIfNeeded(ctx, rotationInterval, gracePeriod, nil) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantRotated, rotated) + + keys, err := repo.GetJWKSKeys(ctx) + require.NoError(t, err) + + if tt.wantRotated { + require.Len(t, keys, 2) + // Old key should have expires_at set + oldKey, err := repo.GetJWKSKeyByID(ctx, "old-key") + require.NoError(t, err) + require.NotNil(t, oldKey) + require.NotNil(t, oldKey.ExpiresAt) + } else { + require.Len(t, keys, 1) + } + } + }) + } +} diff --git a/plugins/jwt/services/refresh_token_service.go b/plugins/jwt/services/refresh_token_service.go index be38b60d..c9e7df15 100644 --- a/plugins/jwt/services/refresh_token_service.go +++ b/plugins/jwt/services/refresh_token_service.go @@ -21,7 +21,7 @@ type refreshTokenService struct { logger models.Logger eventBus models.EventBus sessionService coreservices.SessionService - jwtService JwtService + jwtService TokenService storage RefreshTokenRepository gracePeriod time.Duration refreshExpiresIn time.Duration @@ -31,7 +31,7 @@ func NewRefreshTokenService( logger models.Logger, eventBus models.EventBus, sessionService coreservices.SessionService, - jwtService JwtService, + jwtService TokenService, storage RefreshTokenRepository, gracePeriod time.Duration, refreshExpiresIn time.Duration, @@ -47,12 +47,12 @@ func NewRefreshTokenService( } } -func (s *refreshTokenService) RefreshTokens(ctx context.Context, refreshToken string) (*RefreshTokenResponse, error) { +func (s *refreshTokenService) RefreshTokens(ctx context.Context, refreshToken string) (*types.RefreshTokenResponse, error) { return s.RefreshTokensWithMetadata(ctx, refreshToken, events.AuditMetadata{}) } // RefreshTokensWithMetadata refreshes tokens with optional audit metadata for event logging -func (s *refreshTokenService) RefreshTokensWithMetadata(ctx context.Context, refreshToken string, auditMeta events.AuditMetadata) (*RefreshTokenResponse, error) { +func (s *refreshTokenService) RefreshTokensWithMetadata(ctx context.Context, refreshToken string, auditMeta events.AuditMetadata) (*types.RefreshTokenResponse, error) { // Hash the incoming refresh token tokenHash := HashRefreshToken(refreshToken) @@ -136,7 +136,7 @@ func (s *refreshTokenService) RefreshTokensWithMetadata(ctx context.Context, ref } // completeTokenRotation handles the token rotation after validation passes -func (s *refreshTokenService) completeTokenRotation(ctx context.Context, tokenHash string, record *types.RefreshToken) (*RefreshTokenResponse, error) { +func (s *refreshTokenService) completeTokenRotation(ctx context.Context, tokenHash string, record *types.RefreshToken) (*types.RefreshTokenResponse, error) { // Check if token is expired if time.Now().After(record.ExpiresAt) { return nil, fmt.Errorf("refresh token expired") @@ -161,7 +161,7 @@ func (s *refreshTokenService) completeTokenRotation(ctx context.Context, tokenHa } // STEP 2: Generate new token pair - tokenPair, err := s.jwtService.GenerateTokens(ctx, session.UserID, record.SessionID) + tokenPair, err := s.jwtService.GenerateUserToken(ctx, session.UserID, record.SessionID) if err != nil { s.logger.Error("failed to generate new tokens", "user_id", session.UserID, "session_id", record.SessionID, "error", err) return nil, fmt.Errorf("failed to generate tokens") @@ -186,7 +186,7 @@ func (s *refreshTokenService) completeTokenRotation(ctx context.Context, tokenHa return nil, fmt.Errorf("failed to rotate token") } - return &RefreshTokenResponse{ + return &types.RefreshTokenResponse{ AccessToken: tokenPair.AccessToken, RefreshToken: tokenPair.RefreshToken, }, nil diff --git a/plugins/jwt/services/refresh_token_service_test.go b/plugins/jwt/services/refresh_token_service_test.go new file mode 100644 index 00000000..687bcca8 --- /dev/null +++ b/plugins/jwt/services/refresh_token_service_test.go @@ -0,0 +1,274 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalmocks "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/jwt/tests" + "github.com/Authula/authula/plugins/jwt/types" +) + +type refreshTokenTestFixture struct { + logger *internalmocks.MockLogger + eventBus *internalmocks.MockEventBus + sessionSvc *tests.MockSessionService + tokenSvc *tests.MockTokenService + repo *tests.MockRefreshTokenRepository + gracePeriod time.Duration + refreshExpiresIn time.Duration +} + +func newRefreshTokenTestFixture() *refreshTokenTestFixture { + return &refreshTokenTestFixture{ + logger: &internalmocks.MockLogger{}, + eventBus: &internalmocks.MockEventBus{}, + sessionSvc: &tests.MockSessionService{}, + tokenSvc: &tests.MockTokenService{}, + repo: &tests.MockRefreshTokenRepository{}, + gracePeriod: 10 * time.Second, + refreshExpiresIn: 7 * 24 * time.Hour, + } +} + +func (f *refreshTokenTestFixture) newService() RefreshTokenService { + return &refreshTokenService{ + logger: f.logger, + eventBus: f.eventBus, + sessionService: f.sessionSvc, + jwtService: f.tokenSvc, + storage: f.repo, + gracePeriod: f.gracePeriod, + refreshExpiresIn: f.refreshExpiresIn, + } +} + +func TestRefreshTokenService_RefreshTokens(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + setup func(f *refreshTokenTestFixture) + wantErr string + }{ + { + name: "successful_rotation", + setup: func(f *refreshTokenTestFixture) { + now := time.Now() + record := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: "sess-1", + TokenHash: "hash", + ExpiresAt: now.Add(1 * time.Hour), + IsRevoked: false, + } + session := &models.Session{ + ID: "sess-1", + UserID: "user-1", + } + tokenPair := &types.TokenPair{ + AccessToken: "access-token-1", + RefreshToken: "refresh-token-1", + } + + f.repo.On("GetRefreshToken", ctx, mock.Anything).Return(record, nil) + f.sessionSvc.On("GetByID", ctx, "sess-1").Return(session, nil) + f.repo.On("RevokeRefreshToken", ctx, mock.Anything).Return(nil) + f.tokenSvc.On("GenerateUserToken", ctx, "user-1", "sess-1").Return(tokenPair, nil) + f.repo.On("StoreRefreshToken", ctx, mock.Anything).Return(nil) + }, + }, + { + name: "token_not_found", + setup: func(f *refreshTokenTestFixture) { + f.repo.On("GetRefreshToken", ctx, mock.Anything).Return(nil, nil) + }, + wantErr: "invalid refresh token", + }, + { + name: "token_expired", + setup: func(f *refreshTokenTestFixture) { + now := time.Now() + record := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: "sess-1", + TokenHash: "hash", + ExpiresAt: now.Add(-1 * time.Hour), + IsRevoked: false, + } + f.repo.On("GetRefreshToken", ctx, mock.Anything).Return(record, nil) + }, + wantErr: "refresh token expired", + }, + { + name: "session_not_found", + setup: func(f *refreshTokenTestFixture) { + now := time.Now() + record := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: "sess-1", + TokenHash: "hash", + ExpiresAt: now.Add(1 * time.Hour), + IsRevoked: false, + } + f.repo.On("GetRefreshToken", ctx, mock.Anything).Return(record, nil) + f.sessionSvc.On("GetByID", ctx, "sess-1").Return(nil, nil) + }, + wantErr: "session expired or invalid", + }, + { + name: "reuse_tier1_recovery", + setup: func(f *refreshTokenTestFixture) { + now := time.Now() + revokedAt := now.Add(-5 * time.Second) + record := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: "sess-1", + TokenHash: "hash", + ExpiresAt: now.Add(1 * time.Hour), + IsRevoked: true, + RevokedAt: &revokedAt, + LastReuseAttempt: nil, + } + session := &models.Session{ + ID: "sess-1", + UserID: "user-1", + } + tokenPair := &types.TokenPair{ + AccessToken: "access-token-1", + RefreshToken: "refresh-token-1", + } + + f.repo.On("GetRefreshToken", ctx, mock.Anything).Return(record, nil) + f.repo.On("SetLastReuseAttempt", ctx, mock.Anything).Return(nil) + f.sessionSvc.On("GetByID", ctx, "sess-1").Return(session, nil) + f.repo.On("RevokeRefreshToken", ctx, mock.Anything).Return(nil) + f.tokenSvc.On("GenerateUserToken", ctx, "user-1", "sess-1").Return(tokenPair, nil) + f.repo.On("StoreRefreshToken", ctx, mock.Anything).Return(nil) + }, + }, + { + name: "reuse_tier2_throttle", + setup: func(f *refreshTokenTestFixture) { + now := time.Now() + revokedAt := now.Add(-5 * time.Second) + lastReuse := now.Add(-2 * time.Second) + record := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: "sess-1", + TokenHash: "hash", + ExpiresAt: now.Add(1 * time.Hour), + IsRevoked: true, + RevokedAt: &revokedAt, + LastReuseAttempt: &lastReuse, + } + f.repo.On("GetRefreshToken", ctx, mock.Anything).Return(record, nil) + }, + wantErr: "invalid refresh token", + }, + { + name: "reuse_tier3_reject", + setup: func(f *refreshTokenTestFixture) { + now := time.Now() + revokedAt := now.Add(-30 * time.Second) + record := &types.RefreshToken{ + ID: uuid.New().String(), + SessionID: "sess-1", + TokenHash: "hash", + ExpiresAt: now.Add(1 * time.Hour), + IsRevoked: true, + RevokedAt: &revokedAt, + LastReuseAttempt: nil, + } + f.repo.On("GetRefreshToken", ctx, mock.Anything).Return(record, nil) + }, + wantErr: "invalid refresh token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := newRefreshTokenTestFixture() + f.eventBus.On("Publish", mock.Anything).Return(nil).Maybe() + if tt.setup != nil { + tt.setup(f) + } + svc := f.newService() + + resp, err := svc.RefreshTokens(ctx, "refresh-token") + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + require.Nil(t, resp) + } else { + require.NoError(t, err) + require.NotNil(t, resp) + require.NotEmpty(t, resp.AccessToken) + require.NotEmpty(t, resp.RefreshToken) + } + + f.repo.AssertExpectations(t) + f.sessionSvc.AssertExpectations(t) + f.tokenSvc.AssertExpectations(t) + }) + } +} + +func TestRefreshTokenService_StoreInitialRefreshToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + refreshToken string + sessionID string + setupMock func(f *refreshTokenTestFixture) + wantErr string + }{ + { + name: "success", + refreshToken: "refresh-token", + sessionID: "sess-1", + setupMock: func(f *refreshTokenTestFixture) { + f.repo.On("StoreRefreshToken", mock.Anything, mock.Anything).Return(nil) + }, + }, + { + name: "storage error", + refreshToken: "refresh-token", + sessionID: "sess-1", + setupMock: func(f *refreshTokenTestFixture) { + f.repo.On("StoreRefreshToken", mock.Anything, mock.Anything).Return(assert.AnError) + }, + wantErr: assert.AnError.Error(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newRefreshTokenTestFixture() + f.eventBus.On("Publish", mock.Anything).Return(nil).Maybe() + tt.setupMock(f) + svc := f.newService() + + future := time.Now().Add(24 * time.Hour) + err := svc.StoreInitialRefreshToken(context.Background(), tt.refreshToken, tt.sessionID, future) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + f.repo.AssertExpectations(t) + }) + } +} diff --git a/plugins/jwt/services/services_test.go b/plugins/jwt/services/services_test.go new file mode 100644 index 00000000..9b8c2dbd --- /dev/null +++ b/plugins/jwt/services/services_test.go @@ -0,0 +1,165 @@ +package services + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jwt" + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + jwttests "github.com/Authula/authula/plugins/jwt/tests" + "github.com/Authula/authula/plugins/jwt/types" +) + +type serviceTestFixture struct { + logger models.Logger + sessionSvc *jwttests.MockSessionService + coreTokenSvc *jwttests.MockTokenServiceCore + keySvc *jwttests.MockKeyService + cacheSvc *jwttests.MockCacheService + blacklistSvc *jwttests.MockBlacklistService + refreshTokenRepo *jwttests.MockRefreshTokenRepository + activeKey *types.JWKS + expiresIn time.Duration + refreshExpiresIn time.Duration +} + +func newServiceTestFixture(t *testing.T) *serviceTestFixture { + t.Helper() + + pubPEM, privPEM, err := generateEd25519KeyPair() + if err != nil { + t.Fatalf("failed to generate test key: %v", err) + } + + activeKey := &types.JWKS{ + ID: uuid.New().String(), + PublicKey: string(pubPEM), + PrivateKey: string(privPEM), + CreatedAt: time.Now(), + } + + return &serviceTestFixture{ + logger: &internaltests.MockLogger{}, + sessionSvc: &jwttests.MockSessionService{}, + coreTokenSvc: &jwttests.MockTokenServiceCore{}, + keySvc: &jwttests.MockKeyService{}, + cacheSvc: &jwttests.MockCacheService{}, + blacklistSvc: &jwttests.MockBlacklistService{}, + refreshTokenRepo: &jwttests.MockRefreshTokenRepository{}, + activeKey: activeKey, + expiresIn: 15 * time.Minute, + refreshExpiresIn: 7 * 24 * time.Hour, + } +} + +func (f *serviceTestFixture) newJWTService() *tokenService { + return &tokenService{ + logger: f.logger, + coreTokenService: f.coreTokenSvc, + keyService: f.keySvc, + cacheService: f.cacheSvc, + blacklistService: f.blacklistSvc, + sessionService: f.sessionSvc, + expiresIn: f.expiresIn, + refreshExpiresIn: f.refreshExpiresIn, + } +} + +func (f *serviceTestFixture) signTestToken(t *testing.T, claims map[string]any) string { + t.Helper() + + privKey, err := jwk.ParseKey([]byte(f.activeKey.PrivateKey), jwk.WithPEM(true)) + if err != nil { + t.Fatalf("failed to parse private key: %v", err) + } + + if err := privKey.Set(jwk.KeyIDKey, f.activeKey.ID); err != nil { + t.Fatalf("failed to set key ID: %v", err) + } + + token := jwt.New() + for k, v := range claims { + if err := token.Set(k, v); err != nil { + t.Fatalf("failed to set claim %s: %v", k, err) + } + } + + signed, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA(), privKey)) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + return string(signed) +} + +func (f *serviceTestFixture) setupKeyServiceMock() { + f.keySvc.On("GetActiveKey", mock.Anything).Return(f.activeKey, nil).Maybe() + + pubKey, err := jwk.ParseKey([]byte(f.activeKey.PublicKey), jwk.WithPEM(true)) + if err != nil { + panic(err) + } + + if err := pubKey.Set(jwk.KeyIDKey, f.activeKey.ID); err != nil { + panic(err) + } + if err := pubKey.Set(jwk.AlgorithmKey, "EdDSA"); err != nil { + panic(err) + } + + set := jwk.NewSet() + if err := set.AddKey(pubKey); err != nil { + panic(err) + } + + data, err := json.Marshal(set) + if err != nil { + panic(err) + } + parsedSet, err := jwk.Parse(data) + if err != nil { + panic(err) + } + + f.cacheSvc.On("GetJWKSWithFallback", mock.Anything).Return(parsedSet, nil).Maybe() +} + +func generateEd25519KeyPair() (pubPEM, privPEM []byte, err error) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, nil, err + } + + privPEM = pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: privBytes, + }) + + pubBytes, err := x509.MarshalPKIXPublicKey(priv.Public()) + if err != nil { + return nil, nil, err + } + + pubPEM = pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubBytes, + }) + + return pubPEM, privPEM, nil +} diff --git a/plugins/jwt/services/token_service.go b/plugins/jwt/services/token_service.go new file mode 100644 index 00000000..075ca31f --- /dev/null +++ b/plugins/jwt/services/token_service.go @@ -0,0 +1,322 @@ +package services + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jwt" + + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/jwt/types" + "github.com/Authula/authula/services" +) + +type tokenService struct { + logger models.Logger + coreTokenService services.TokenService + keyService KeyService + cacheService CacheService + blacklistService BlacklistService + sessionService services.SessionService + expiresIn time.Duration + refreshExpiresIn time.Duration +} + +func NewJWTService( + logger models.Logger, + sessionService services.SessionService, + coreTokenService services.TokenService, + keyService KeyService, + cacheService CacheService, + blacklistService BlacklistService, + expiresIn time.Duration, + refreshExpiresIn time.Duration, +) services.JWTService { + return &tokenService{ + logger: logger, + sessionService: sessionService, + coreTokenService: coreTokenService, + keyService: keyService, + cacheService: cacheService, + blacklistService: blacklistService, + expiresIn: expiresIn, + refreshExpiresIn: refreshExpiresIn, + } +} + +func (s *tokenService) detectAlgorithmFromKey(k jwk.Key) jwa.SignatureAlgorithm { + return jwa.EdDSA() +} + +func (s *tokenService) GenerateUserToken(ctx context.Context, userID string, sessionID string) (*types.TokenPair, error) { + if sessionID == "" { + return nil, errors.New("session id is required to generate tokens") + } + + jwksKey, err := s.keyService.GetActiveKey(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get active key: %w", err) + } + + privateKeyPEM, err := s.coreTokenService.Decrypt(jwksKey.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to decrypt private key: %w", err) + } + + privKey, err := jwk.ParseKey([]byte(privateKeyPEM), jwk.WithPEM(true)) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + if err := privKey.Set(jwk.KeyIDKey, jwksKey.ID); err != nil { + return nil, fmt.Errorf("failed to set key ID: %w", err) + } + + keyAlgorithm := s.detectAlgorithmFromKey(privKey) + now := time.Now() + jti := uuid.New().String() + + accessClaims := jwt.New() + if err := accessClaims.Set(jwt.SubjectKey, userID); err != nil { + return nil, fmt.Errorf("failed to set subject: %w", err) + } + if err := accessClaims.Set(jwt.IssuedAtKey, now); err != nil { + return nil, fmt.Errorf("failed to set issued at: %w", err) + } + if err := accessClaims.Set(jwt.ExpirationKey, now.Add(s.expiresIn)); err != nil { + return nil, fmt.Errorf("failed to set expiration: %w", err) + } + if err := accessClaims.Set(jwt.JwtIDKey, jti); err != nil { + return nil, fmt.Errorf("failed to set JWT ID: %w", err) + } + if err := accessClaims.Set("user_id", userID); err != nil { + return nil, fmt.Errorf("failed to set user_id: %w", err) + } + if err := accessClaims.Set("session_id", sessionID); err != nil { + return nil, fmt.Errorf("failed to set session_id: %w", err) + } + if err := accessClaims.Set("token_type", types.JWTTokenTypeAccess.String()); err != nil { + return nil, fmt.Errorf("failed to set token_type: %w", err) + } + if err := accessClaims.Set("actor_type", "user"); err != nil { + return nil, fmt.Errorf("failed to set actor_type: %w", err) + } + + accessTokenBytes, err := jwt.Sign(accessClaims, jwt.WithKey(keyAlgorithm, privKey)) + if err != nil { + return nil, fmt.Errorf("failed to sign access token: %w", err) + } + + refreshClaims := jwt.New() + if err := refreshClaims.Set(jwt.SubjectKey, userID); err != nil { + return nil, fmt.Errorf("failed to set subject in refresh token: %w", err) + } + if err := refreshClaims.Set(jwt.IssuedAtKey, now); err != nil { + return nil, fmt.Errorf("failed to set issued at in refresh token: %w", err) + } + if err := refreshClaims.Set(jwt.ExpirationKey, now.Add(s.refreshExpiresIn)); err != nil { + return nil, fmt.Errorf("failed to set expiration in refresh token: %w", err) + } + if err := refreshClaims.Set(jwt.JwtIDKey, jti); err != nil { + return nil, fmt.Errorf("failed to set JWT ID in refresh token: %w", err) + } + if err := refreshClaims.Set("user_id", userID); err != nil { + return nil, fmt.Errorf("failed to set user_id in refresh token: %w", err) + } + if err := refreshClaims.Set("session_id", sessionID); err != nil { + return nil, fmt.Errorf("failed to set session_id in refresh token: %w", err) + } + if err := refreshClaims.Set("token_type", types.JWTTokenTypeRefresh.String()); err != nil { + return nil, fmt.Errorf("failed to set token_type in refresh token: %w", err) + } + if err := refreshClaims.Set("actor_type", "user"); err != nil { + return nil, fmt.Errorf("failed to set actor_type in refresh token: %w", err) + } + + refreshTokenBytes, err := jwt.Sign(refreshClaims, jwt.WithKey(keyAlgorithm, privKey)) + if err != nil { + return nil, fmt.Errorf("failed to sign refresh token: %w", err) + } + + return &types.TokenPair{ + AccessToken: string(accessTokenBytes), + RefreshToken: string(refreshTokenBytes), + ExpiresIn: s.expiresIn, + TokenType: "Bearer", + }, nil +} + +func (s *tokenService) GenerateMachineToken(ctx context.Context, clientID string, organizationID string, scopes []string) (*types.TokenPair, error) { + jwksKey, err := s.keyService.GetActiveKey(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get active key: %w", err) + } + + privateKeyPEM, err := s.coreTokenService.Decrypt(jwksKey.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to decrypt private key: %w", err) + } + + privKey, err := jwk.ParseKey([]byte(privateKeyPEM), jwk.WithPEM(true)) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + if err := privKey.Set(jwk.KeyIDKey, jwksKey.ID); err != nil { + return nil, fmt.Errorf("failed to set key ID: %w", err) + } + + keyAlgorithm := s.detectAlgorithmFromKey(privKey) + now := time.Now() + jti := uuid.New().String() + + accessClaims := jwt.New() + if err := accessClaims.Set(jwt.SubjectKey, clientID); err != nil { + return nil, fmt.Errorf("failed to set subject: %w", err) + } + if err := accessClaims.Set(jwt.IssuedAtKey, now); err != nil { + return nil, fmt.Errorf("failed to set issued at: %w", err) + } + if err := accessClaims.Set(jwt.ExpirationKey, now.Add(s.expiresIn)); err != nil { + return nil, fmt.Errorf("failed to set expiration: %w", err) + } + if err := accessClaims.Set(jwt.JwtIDKey, jti); err != nil { + return nil, fmt.Errorf("failed to set JWT ID: %w", err) + } + if err := accessClaims.Set("token_type", types.JWTTokenTypeAccess.String()); err != nil { + return nil, fmt.Errorf("failed to set token_type: %w", err) + } + if err := accessClaims.Set("actor_type", "machine"); err != nil { + return nil, fmt.Errorf("failed to set actor_type: %w", err) + } + + if organizationID != "" { + if err := accessClaims.Set("org_id", organizationID); err != nil { + return nil, fmt.Errorf("failed to set org_id: %w", err) + } + } + + if len(scopes) > 0 { + if err := accessClaims.Set("scopes", scopes); err != nil { + return nil, fmt.Errorf("failed to set scopes: %w", err) + } + } + + accessTokenBytes, err := jwt.Sign(accessClaims, jwt.WithKey(keyAlgorithm, privKey)) + if err != nil { + return nil, fmt.Errorf("failed to sign access token: %w", err) + } + + return &types.TokenPair{ + AccessToken: string(accessTokenBytes), + RefreshToken: "", + ExpiresIn: s.expiresIn, + TokenType: "Bearer", + }, nil +} + +func (s *tokenService) ValidateToken(ctx context.Context, token string) (*models.Actor, error) { + jwkSet, err := s.cacheService.GetJWKSWithFallback(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get JWKS: %w", err) + } + + parsedToken, err := jwt.Parse([]byte(token), jwt.WithKeySet(jwkSet), jwt.WithValidate(true)) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + jti, ok := parsedToken.JwtID() + if ok && jti != "" && s.blacklistService != nil { + isBlacklisted, err := s.blacklistService.IsBlacklisted(ctx, jti) + if err == nil && isBlacklisted { + return nil, errors.New("token has been revoked") + } + } + + var tokenType string + if err := parsedToken.Get("token_type", &tokenType); err != nil { + return nil, errors.New("missing token_type claim") + } + + if tokenType != types.JWTTokenTypeAccess.String() { + return nil, errors.New("invalid token_type") + } + + var actorType string + if err := parsedToken.Get("actor_type", &actorType); err != nil || actorType == "" { + actorType = "user" + } + + actor := &models.Actor{ + Metadata: map[string]any{"auth_mechanism": "jwt_bearer"}, + } + + if actorType == "machine" { + var sub string + if err := parsedToken.Get(jwt.SubjectKey, &sub); err != nil || sub == "" { + return nil, errors.New("missing subject claim") + } + actor.ID = sub + actor.Type = models.ActorMachine + + var orgID string + if err := parsedToken.Get("org_id", &orgID); err == nil && orgID != "" { + actor.OrganizationID = &orgID + } + + var raw any + if err := parsedToken.Get("scopes", &raw); err == nil { + switch v := raw.(type) { + case []string: + if len(v) > 0 { + actor.Scopes = v + } + case []any: + scopes := make([]string, 0, len(v)) + for _, s := range v { + if str, ok := s.(string); ok { + scopes = append(scopes, str) + } + } + if len(scopes) > 0 { + actor.Scopes = scopes + } + } + } + + return actor, nil + } + + var userID string + if err := parsedToken.Get("user_id", &userID); err != nil || userID == "" { + return nil, errors.New("missing user_id claim") + } + + var sessionID string + if err := parsedToken.Get("session_id", &sessionID); err != nil || sessionID == "" { + return nil, errors.New("missing session_id claim") + } + + if s.blacklistService != nil { + isBlacklisted, err := s.blacklistService.IsBlacklisted(ctx, "session:"+sessionID) + if err == nil && isBlacklisted { + return nil, errors.New("session has been revoked") + } + } + + session, err := s.sessionService.GetByID(ctx, sessionID) + if err != nil || session == nil { + return nil, errors.New("session not found or invalid") + } + + actor.ID = userID + actor.Type = models.ActorUser + + return actor, nil +} diff --git a/plugins/jwt/services/token_service_test.go b/plugins/jwt/services/token_service_test.go new file mode 100644 index 00000000..bf132dca --- /dev/null +++ b/plugins/jwt/services/token_service_test.go @@ -0,0 +1,457 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/jwt/types" +) + +func parseTestToken(t *testing.T, tokenStr string, f *serviceTestFixture) map[string]any { + t.Helper() + + pubKey, err := jwk.ParseKey([]byte(f.activeKey.PublicKey), jwk.WithPEM(true)) + require.NoError(t, err) + + parsed, err := jwt.Parse([]byte(tokenStr), jwt.WithKey(jwa.EdDSA(), pubKey), jwt.WithValidate(false)) + require.NoError(t, err) + + var sub, userID, sessionID, tokenType, actType, orgID, jti string + var scopes []any + + _ = parsed.Get("sub", &sub) + _ = parsed.Get("user_id", &userID) + _ = parsed.Get("session_id", &sessionID) + _ = parsed.Get("token_type", &tokenType) + _ = parsed.Get("actor_type", &actType) + _ = parsed.Get("org_id", &orgID) + _ = parsed.Get("scopes", &scopes) + _ = parsed.Get(jwt.JwtIDKey, &jti) + + result := make(map[string]any) + if sub != "" { + result["sub"] = sub + } + if userID != "" { + result["user_id"] = userID + } + if sessionID != "" { + result["session_id"] = sessionID + } + if tokenType != "" { + result["token_type"] = tokenType + } + if actType != "" { + result["actor_type"] = actType + } + if orgID != "" { + result["org_id"] = orgID + } + if len(scopes) > 0 { + result["scopes"] = scopes + } + if jti != "" { + result["jti"] = jti + } + return result +} + +func TestTokenService_ValidateToken(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + tests := []struct { + name string + claims map[string]any + setupMocks func(t *testing.T, f *serviceTestFixture) + wantActor func(t *testing.T, actor *models.Actor) + wantErr string + }{ + { + name: "valid_user_token", + claims: map[string]any{ + jwt.SubjectKey: "user-1", + "user_id": "user-1", + "session_id": "sess-1", + "token_type": types.JWTTokenTypeAccess.String(), + "actor_type": "user", + jwt.JwtIDKey: "jti-1", + jwt.IssuedAtKey: time.Now(), + jwt.ExpirationKey: time.Now().Add(15 * time.Minute), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) { + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "jti-1").Return(false, nil) + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "session:sess-1").Return(false, nil) + f.sessionSvc.On("GetByID", mock.Anything, "sess-1").Return(&models.Session{ID: "sess-1"}, nil) + }, + wantActor: func(t *testing.T, actor *models.Actor) { + require.Equal(t, "user-1", actor.ID) + require.Equal(t, models.ActorUser, actor.Type) + require.Nil(t, actor.OrganizationID) + require.Equal(t, "jwt_bearer", actor.Metadata["auth_mechanism"]) + }, + }, + { + name: "valid_user_token_no_actor_type", + claims: map[string]any{ + jwt.SubjectKey: "user-1", + "user_id": "user-1", + "session_id": "sess-1", + "token_type": types.JWTTokenTypeAccess.String(), + jwt.JwtIDKey: "jti-2", + jwt.IssuedAtKey: time.Now(), + jwt.ExpirationKey: time.Now().Add(15 * time.Minute), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) { + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "jti-2").Return(false, nil) + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "session:sess-1").Return(false, nil) + f.sessionSvc.On("GetByID", mock.Anything, "sess-1").Return(&models.Session{ID: "sess-1"}, nil) + }, + wantActor: func(t *testing.T, actor *models.Actor) { + require.Equal(t, "user-1", actor.ID) + require.Equal(t, models.ActorUser, actor.Type) + require.Equal(t, "jwt_bearer", actor.Metadata["auth_mechanism"]) + }, + }, + { + name: "valid_machine_token", + claims: map[string]any{ + jwt.SubjectKey: "client-1", + "token_type": types.JWTTokenTypeAccess.String(), + "actor_type": "machine", + "org_id": "org-1", + "scopes": []string{"read:users", "write:users"}, + jwt.JwtIDKey: "jti-3", + jwt.IssuedAtKey: time.Now(), + jwt.ExpirationKey: time.Now().Add(15 * time.Minute), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) { + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "jti-3").Return(false, nil) + }, + wantActor: func(t *testing.T, actor *models.Actor) { + require.Equal(t, "client-1", actor.ID) + require.Equal(t, models.ActorMachine, actor.Type) + require.NotNil(t, actor.OrganizationID) + require.Equal(t, "org-1", *actor.OrganizationID) + require.Equal(t, []string{"read:users", "write:users"}, actor.Scopes) + require.Equal(t, "jwt_bearer", actor.Metadata["auth_mechanism"]) + }, + }, + { + name: "machine_token_no_optional_fields", + claims: map[string]any{ + jwt.SubjectKey: "client-2", + "token_type": types.JWTTokenTypeAccess.String(), + "actor_type": "machine", + jwt.JwtIDKey: "jti-4", + jwt.IssuedAtKey: time.Now(), + jwt.ExpirationKey: time.Now().Add(15 * time.Minute), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) { + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "jti-4").Return(false, nil) + }, + wantActor: func(t *testing.T, actor *models.Actor) { + require.Equal(t, "client-2", actor.ID) + require.Equal(t, models.ActorMachine, actor.Type) + require.Nil(t, actor.OrganizationID) + require.Nil(t, actor.Scopes) + require.Equal(t, "jwt_bearer", actor.Metadata["auth_mechanism"]) + }, + }, + { + name: "expired_token", + claims: map[string]any{ + jwt.SubjectKey: "user-1", + "user_id": "user-1", + "session_id": "sess-1", + "token_type": types.JWTTokenTypeAccess.String(), + "actor_type": "user", + jwt.JwtIDKey: "jti-5", + jwt.IssuedAtKey: time.Now().Add(-2 * time.Hour), + jwt.ExpirationKey: time.Now().Add(-1 * time.Hour), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) {}, + wantErr: "failed to parse token", + }, + { + name: "blacklisted_token", + claims: map[string]any{ + jwt.SubjectKey: "user-1", + "user_id": "user-1", + "session_id": "sess-1", + "token_type": types.JWTTokenTypeAccess.String(), + "actor_type": "user", + jwt.JwtIDKey: "jti-6", + jwt.IssuedAtKey: time.Now(), + jwt.ExpirationKey: time.Now().Add(15 * time.Minute), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) { + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "jti-6").Return(true, nil) + }, + wantErr: "token has been revoked", + }, + { + name: "missing_user_id", + claims: map[string]any{ + jwt.SubjectKey: "user-1", + "session_id": "sess-1", + "token_type": types.JWTTokenTypeAccess.String(), + "actor_type": "user", + jwt.JwtIDKey: "jti-7", + jwt.IssuedAtKey: time.Now(), + jwt.ExpirationKey: time.Now().Add(15 * time.Minute), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) { + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "jti-7").Return(false, nil) + }, + wantErr: "missing user_id claim", + }, + { + name: "missing_session", + claims: map[string]any{ + jwt.SubjectKey: "user-1", + "user_id": "user-1", + "session_id": "sess-1", + "token_type": types.JWTTokenTypeAccess.String(), + "actor_type": "user", + jwt.JwtIDKey: "jti-8", + jwt.IssuedAtKey: time.Now(), + jwt.ExpirationKey: time.Now().Add(15 * time.Minute), + }, + setupMocks: func(t *testing.T, f *serviceTestFixture) { + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "jti-8").Return(false, nil) + f.blacklistSvc.On("IsBlacklisted", mock.Anything, "session:sess-1").Return(false, nil) + f.sessionSvc.On("GetByID", mock.Anything, "sess-1").Return(nil, nil) + }, + wantErr: "session not found or invalid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := newServiceTestFixture(t) + svc := f.newJWTService() + f.setupKeyServiceMock() + tt.setupMocks(t, f) + + tokenStr := f.signTestToken(t, tt.claims) + actor, err := svc.ValidateToken(ctx, tokenStr) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + require.Nil(t, actor) + } else { + require.NoError(t, err) + require.NotNil(t, actor) + tt.wantActor(t, actor) + } + + f.blacklistSvc.AssertExpectations(t) + f.sessionSvc.AssertExpectations(t) + }) + } +} + +func TestTokenService_GenerateUserToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + sessionID string + setupMocks func(*serviceTestFixture) + wantErr string + }{ + { + name: "success", + userID: "user-1", + sessionID: "sess-1", + setupMocks: func(f *serviceTestFixture) { + f.keySvc.On("GetActiveKey", mock.Anything).Return(f.activeKey, nil) + f.coreTokenSvc.On("Decrypt", f.activeKey.PrivateKey).Return(f.activeKey.PrivateKey, nil) + }, + }, + { + name: "empty session id", + userID: "user-1", + sessionID: "", + setupMocks: func(f *serviceTestFixture) {}, + wantErr: "session id is required", + }, + { + name: "key service error", + userID: "user-1", + sessionID: "sess-1", + setupMocks: func(f *serviceTestFixture) { + f.keySvc.On("GetActiveKey", mock.Anything).Return(nil, assert.AnError) + }, + wantErr: "failed to get active key", + }, + { + name: "decrypt error", + userID: "user-1", + sessionID: "sess-1", + setupMocks: func(f *serviceTestFixture) { + f.keySvc.On("GetActiveKey", mock.Anything).Return(f.activeKey, nil) + f.coreTokenSvc.On("Decrypt", f.activeKey.PrivateKey).Return("", assert.AnError) + }, + wantErr: "failed to decrypt private key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newServiceTestFixture(t) + svc := f.newJWTService() + tt.setupMocks(f) + + ctx := context.Background() + pair, err := svc.GenerateUserToken(ctx, tt.userID, tt.sessionID) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + require.Nil(t, pair) + } else { + require.NoError(t, err) + require.NotNil(t, pair) + require.NotEmpty(t, pair.AccessToken) + require.NotEmpty(t, pair.RefreshToken) + require.Equal(t, 15*time.Minute, pair.ExpiresIn) + require.Equal(t, "Bearer", pair.TokenType) + + claims := parseTestToken(t, pair.AccessToken, f) + require.Equal(t, tt.userID, claims["sub"]) + require.Equal(t, tt.userID, claims["user_id"]) + require.Equal(t, tt.sessionID, claims["session_id"]) + require.Equal(t, "user", claims["actor_type"]) + require.Equal(t, "access_token", claims["token_type"]) + require.NotEmpty(t, claims["jti"]) + } + + f.keySvc.AssertExpectations(t) + f.coreTokenSvc.AssertExpectations(t) + }) + } +} + +func TestTokenService_GenerateMachineToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + clientID string + organizationID string + scopes []string + setupMocks func(*serviceTestFixture) + wantErr string + wantOrgID string + wantScopes []any + }{ + { + name: "success", + clientID: "client-1", + organizationID: "org-1", + scopes: []string{"read", "write"}, + setupMocks: func(f *serviceTestFixture) { + f.keySvc.On("GetActiveKey", mock.Anything).Return(f.activeKey, nil) + f.coreTokenSvc.On("Decrypt", f.activeKey.PrivateKey).Return(f.activeKey.PrivateKey, nil) + }, + wantOrgID: "org-1", + wantScopes: []any{"read", "write"}, + }, + { + name: "no optional fields", + clientID: "client-2", + organizationID: "", + scopes: nil, + setupMocks: func(f *serviceTestFixture) { + f.keySvc.On("GetActiveKey", mock.Anything).Return(f.activeKey, nil) + f.coreTokenSvc.On("Decrypt", f.activeKey.PrivateKey).Return(f.activeKey.PrivateKey, nil) + }, + }, + { + name: "key service error", + clientID: "client-1", + organizationID: "org-1", + scopes: []string{"read"}, + setupMocks: func(f *serviceTestFixture) { + f.keySvc.On("GetActiveKey", mock.Anything).Return(nil, assert.AnError) + }, + wantErr: "failed to get active key", + }, + { + name: "decrypt error", + clientID: "client-1", + organizationID: "org-1", + scopes: []string{"read"}, + setupMocks: func(f *serviceTestFixture) { + f.keySvc.On("GetActiveKey", mock.Anything).Return(f.activeKey, nil) + f.coreTokenSvc.On("Decrypt", f.activeKey.PrivateKey).Return("", assert.AnError) + }, + wantErr: "failed to decrypt private key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + f := newServiceTestFixture(t) + svc := f.newJWTService() + tt.setupMocks(f) + + ctx := context.Background() + pair, err := svc.GenerateMachineToken(ctx, tt.clientID, tt.organizationID, tt.scopes) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + require.Nil(t, pair) + } else { + require.NoError(t, err) + require.NotNil(t, pair) + require.NotEmpty(t, pair.AccessToken) + require.Empty(t, pair.RefreshToken) + require.Equal(t, 15*time.Minute, pair.ExpiresIn) + require.Equal(t, "Bearer", pair.TokenType) + + claims := parseTestToken(t, pair.AccessToken, f) + require.Equal(t, tt.clientID, claims["sub"]) + require.Equal(t, "machine", claims["actor_type"]) + require.Equal(t, "access_token", claims["token_type"]) + require.NotEmpty(t, claims["jti"]) + + if tt.wantOrgID != "" { + require.Equal(t, tt.wantOrgID, claims["org_id"]) + } else { + require.Empty(t, claims["org_id"]) + } + + if tt.wantScopes != nil { + scopes, ok := claims["scopes"].([]any) + require.True(t, ok) + require.ElementsMatch(t, tt.wantScopes, scopes) + } else { + require.Nil(t, claims["scopes"]) + } + } + + f.keySvc.AssertExpectations(t) + f.coreTokenSvc.AssertExpectations(t) + }) + } +} diff --git a/plugins/jwt/tests/mocks.go b/plugins/jwt/tests/mocks.go new file mode 100644 index 00000000..b85ca98d --- /dev/null +++ b/plugins/jwt/tests/mocks.go @@ -0,0 +1,380 @@ +package tests + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/stretchr/testify/mock" + + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/jwt/repositories" + jwtservicesTypes "github.com/Authula/authula/plugins/jwt/types" +) + +type InMemoryStorage struct { + mu sync.RWMutex + data map[string]string +} + +func NewInMemoryStorage() *InMemoryStorage { + return &InMemoryStorage{data: make(map[string]string)} +} + +func (s *InMemoryStorage) Get(_ context.Context, key string) (any, error) { + s.mu.RLock() + defer s.mu.RUnlock() + val, ok := s.data[key] + if !ok { + return nil, nil + } + return val, nil +} + +func (s *InMemoryStorage) Set(_ context.Context, key string, value any, _ *time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + s.data[key] = value.(string) + return nil +} + +func (s *InMemoryStorage) Delete(_ context.Context, key string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.data, key) + return nil +} + +func (s *InMemoryStorage) Incr(_ context.Context, _ string, _ *time.Duration) (int, error) { + return 0, nil +} + +func (s *InMemoryStorage) TTL(_ context.Context, _ string) (*time.Duration, error) { + return nil, nil +} + +func (s *InMemoryStorage) Scan(_ context.Context, prefix string) ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + var keys []string + for key := range s.data { + if strings.HasPrefix(key, prefix) { + keys = append(keys, key) + } + } + return keys, nil +} + +func (s *InMemoryStorage) Close() error { + return nil +} + +type NopTokenService struct{} + +func (NopTokenService) Generate() (string, error) { return "", nil } +func (NopTokenService) Hash(token string) string { return token } +func (NopTokenService) Encrypt(token string) (string, error) { return token, nil } +func (NopTokenService) Decrypt(encrypted string) (string, error) { return encrypted, nil } + +var _ repositories.JWKSRepository = (*MockJWKSRepository)(nil) + +type MockTokenService struct{ mock.Mock } + +func (m *MockTokenService) GenerateUserToken(ctx context.Context, userID string, sessionID string) (*jwtservicesTypes.TokenPair, error) { + args := m.Called(ctx, userID, sessionID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*jwtservicesTypes.TokenPair), args.Error(1) +} + +func (m *MockTokenService) GenerateMachineToken(ctx context.Context, clientID string, organizationID string, scopes []string) (*jwtservicesTypes.TokenPair, error) { + args := m.Called(ctx, clientID, organizationID, scopes) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*jwtservicesTypes.TokenPair), args.Error(1) +} + +func (m *MockTokenService) ValidateToken(ctx context.Context, token string) (*models.Actor, error) { + args := m.Called(ctx, token) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Actor), args.Error(1) +} + +type MockKeyService struct{ mock.Mock } + +func (m *MockKeyService) GenerateKeysIfMissing(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockKeyService) GetActiveKey(ctx context.Context) (*jwtservicesTypes.JWKS, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*jwtservicesTypes.JWKS), args.Error(1) +} + +func (m *MockKeyService) IsKeyRotationDue(ctx context.Context, rotationInterval time.Duration) bool { + args := m.Called(ctx, rotationInterval) + return args.Bool(0) +} + +func (m *MockKeyService) RotateKeysIfNeeded(ctx context.Context, rotationInterval time.Duration, gracePeriod time.Duration, invalidateCacheFunc func(context.Context) error) (bool, error) { + args := m.Called(ctx, rotationInterval, gracePeriod, invalidateCacheFunc) + return args.Bool(0), args.Error(1) +} + +type MockCacheService struct{ mock.Mock } + +func (m *MockCacheService) GetCachedJWKS(ctx context.Context) (jwk.Set, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(jwk.Set), args.Error(1) +} + +func (m *MockCacheService) FetchJWKSFromDatabase(ctx context.Context) (jwk.Set, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(jwk.Set), args.Error(1) +} + +func (m *MockCacheService) CacheJWKS(ctx context.Context, set jwk.Set) error { + args := m.Called(ctx, set) + return args.Error(0) +} + +func (m *MockCacheService) InvalidateCache(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockCacheService) GetJWKSWithFallback(ctx context.Context) (jwk.Set, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(jwk.Set), args.Error(1) +} + +type MockBlacklistService struct{ mock.Mock } + +func (m *MockBlacklistService) BlacklistToken(ctx context.Context, jti string, expiresAt time.Time) error { + args := m.Called(ctx, jti, expiresAt) + return args.Error(0) +} + +func (m *MockBlacklistService) IsBlacklisted(ctx context.Context, jti string) (bool, error) { + args := m.Called(ctx, jti) + return args.Bool(0), args.Error(1) +} + +func (m *MockBlacklistService) BlacklistAllSessionTokens(ctx context.Context, sessionID string, expiresAt time.Time) error { + args := m.Called(ctx, sessionID, expiresAt) + return args.Error(0) +} + +func (m *MockBlacklistService) CleanupExpired(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +type MockRefreshTokenRepository struct{ mock.Mock } + +func (m *MockRefreshTokenRepository) StoreRefreshToken(ctx context.Context, record *jwtservicesTypes.RefreshToken) error { + args := m.Called(ctx, record) + return args.Error(0) +} + +func (m *MockRefreshTokenRepository) GetRefreshToken(ctx context.Context, tokenHash string) (*jwtservicesTypes.RefreshToken, error) { + args := m.Called(ctx, tokenHash) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*jwtservicesTypes.RefreshToken), args.Error(1) +} + +func (m *MockRefreshTokenRepository) RevokeRefreshToken(ctx context.Context, tokenHash string) error { + args := m.Called(ctx, tokenHash) + return args.Error(0) +} + +func (m *MockRefreshTokenRepository) RevokeAllSessionTokens(ctx context.Context, sessionID string) error { + args := m.Called(ctx, sessionID) + return args.Error(0) +} + +func (m *MockRefreshTokenRepository) SetLastReuseAttempt(ctx context.Context, tokenHash string) error { + args := m.Called(ctx, tokenHash) + return args.Error(0) +} + +func (m *MockRefreshTokenRepository) CleanupExpiredTokens(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +type MockJWKSRepository struct{ mock.Mock } + +func (m *MockJWKSRepository) GetJWKSKeys(ctx context.Context) ([]*jwtservicesTypes.JWKS, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]*jwtservicesTypes.JWKS), args.Error(1) +} + +func (m *MockJWKSRepository) GetJWKSKeyByID(ctx context.Context, id string) (*jwtservicesTypes.JWKS, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*jwtservicesTypes.JWKS), args.Error(1) +} + +func (m *MockJWKSRepository) StoreJWKSKey(ctx context.Context, key *jwtservicesTypes.JWKS) error { + args := m.Called(ctx, key) + return args.Error(0) +} + +func (m *MockJWKSRepository) UpdateJWKSKey(ctx context.Context, key *jwtservicesTypes.JWKS) error { + args := m.Called(ctx, key) + return args.Error(0) +} + +func (m *MockJWKSRepository) MarkKeyExpired(ctx context.Context, id string, expiresAt time.Time) error { + args := m.Called(ctx, id, expiresAt) + return args.Error(0) +} + +func (m *MockJWKSRepository) PurgeExpiredKeys(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +type MockRefreshTokenService struct{ mock.Mock } + +func (m *MockRefreshTokenService) RefreshTokens(ctx context.Context, refreshToken string) (*jwtservicesTypes.RefreshTokenResponse, error) { + args := m.Called(ctx, refreshToken) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*jwtservicesTypes.RefreshTokenResponse), args.Error(1) +} + +func (m *MockRefreshTokenService) StoreInitialRefreshToken(ctx context.Context, refreshToken string, sessionID string, expiresAt time.Time) error { + args := m.Called(ctx, refreshToken, sessionID, expiresAt) + return args.Error(0) +} + +type MockJWTService struct{ mock.Mock } + +func (m *MockJWTService) ValidateToken(ctx context.Context, token string) (*models.Actor, error) { + args := m.Called(ctx, token) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Actor), args.Error(1) +} + +type MockSessionService struct{ mock.Mock } + +func (m *MockSessionService) GetByID(ctx context.Context, id string) (*models.Session, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Session), args.Error(1) +} + +func (m *MockSessionService) Create(ctx context.Context, userID string, hashedToken string, ipAddress *string, userAgent *string, maxAge time.Duration) (*models.Session, error) { + args := m.Called(ctx, userID, hashedToken, ipAddress, userAgent, maxAge) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Session), args.Error(1) +} + +func (m *MockSessionService) GetByUserID(ctx context.Context, userID string) (*models.Session, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Session), args.Error(1) +} + +func (m *MockSessionService) GetByToken(ctx context.Context, hashedToken string) (*models.Session, error) { + args := m.Called(ctx, hashedToken) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Session), args.Error(1) +} + +func (m *MockSessionService) Update(ctx context.Context, session *models.Session) (*models.Session, error) { + args := m.Called(ctx, session) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.Session), args.Error(1) +} + +func (m *MockSessionService) Delete(ctx context.Context, ID string) error { + args := m.Called(ctx, ID) + return args.Error(0) +} + +func (m *MockSessionService) DeleteAllByUserID(ctx context.Context, userID string) error { + args := m.Called(ctx, userID) + return args.Error(0) +} + +func (m *MockSessionService) DeleteAllExpired(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockSessionService) GetDistinctUserIDs(ctx context.Context) ([]string, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), args.Error(1) +} + +func (m *MockSessionService) DeleteOldestByUserID(ctx context.Context, userID string, maxCount int) error { + args := m.Called(ctx, userID, maxCount) + return args.Error(0) +} + +type MockTokenServiceCore struct{ mock.Mock } + +func (m *MockTokenServiceCore) Generate() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) +} + +func (m *MockTokenServiceCore) Hash(token string) string { + args := m.Called(token) + return args.String(0) +} + +func (m *MockTokenServiceCore) Encrypt(token string) (string, error) { + args := m.Called(token) + return args.String(0), args.Error(1) +} + +func (m *MockTokenServiceCore) Decrypt(encrypted string) (string, error) { + args := m.Called(encrypted) + return args.String(0), args.Error(1) +} diff --git a/plugins/jwt/tests/mocks_test.go b/plugins/jwt/tests/mocks_test.go new file mode 100644 index 00000000..2b682523 --- /dev/null +++ b/plugins/jwt/tests/mocks_test.go @@ -0,0 +1,17 @@ +package tests + +import ( + jwtservices "github.com/Authula/authula/plugins/jwt/services" + "github.com/Authula/authula/services" +) + +var _ jwtservices.TokenService = (*MockTokenService)(nil) +var _ jwtservices.KeyService = (*MockKeyService)(nil) +var _ jwtservices.CacheService = (*MockCacheService)(nil) +var _ jwtservices.BlacklistService = (*MockBlacklistService)(nil) +var _ jwtservices.RefreshTokenRepository = (*MockRefreshTokenRepository)(nil) +var _ jwtservices.RefreshTokenService = (*MockRefreshTokenService)(nil) +var _ services.JWTService = (*MockJWTService)(nil) +var _ services.JWTService = (*MockTokenService)(nil) +var _ services.SessionService = (*MockSessionService)(nil) +var _ services.TokenService = (*MockTokenServiceCore)(nil) diff --git a/plugins/jwt/types/api.go b/plugins/jwt/types/api.go index b83dc6e3..7f24a8be 100644 --- a/plugins/jwt/types/api.go +++ b/plugins/jwt/types/api.go @@ -1,39 +1,11 @@ package types import ( - "errors" "time" "github.com/Authula/authula/plugins/jwt/constants" ) -type JWTAlgorithm string - -const ( - JWTAlgEdDSA JWTAlgorithm = "eddsa" - JWTAlgRS256 JWTAlgorithm = "rs256" - JWTAlgPS256 JWTAlgorithm = "ps256" - JWTAlgES256 JWTAlgorithm = "es256" - JWTAlgES512 JWTAlgorithm = "es512" - JWTAlgECDHES JWTAlgorithm = "ecdh-es" -) - -func (a JWTAlgorithm) String() string { - return string(a) -} - -// ValidateAlgorithm enforces that the algorithm can be used for JWT signing -func ValidateAlgorithm(alg JWTAlgorithm) error { - switch alg { - case JWTAlgEdDSA, JWTAlgRS256, JWTAlgPS256, JWTAlgES256, JWTAlgES512: - return nil - case JWTAlgECDHES: - return errors.New("ECDH-ES cannot be used for JWT signing") - default: - return errors.New("unsupported JWT algorithm") - } -} - type JWTTokenType string const ( @@ -45,41 +17,19 @@ func (t JWTTokenType) String() string { return string(t) } -// ParseAlgorithm parses a string into an Algorithm, accepting only canonical names (case-insensitive input) -func ParseAlgorithm(s string) (JWTAlgorithm, error) { - switch s { - case "eddsa": - return JWTAlgEdDSA, nil - case "rs256": - return JWTAlgRS256, nil - case "ps256": - return JWTAlgPS256, nil - case "es256": - return JWTAlgES256, nil - case "es512": - return JWTAlgES512, nil - case "ecdh-es": - return JWTAlgECDHES, nil - default: - return "", errors.New("unsupported jwt algorithm") - } -} - -// Claims represents standard JWT claims -type Claims struct { - UserID string `json:"user_id"` - SessionID string `json:"sid"` - Type string `json:"type"` // "access_token" or "refresh_token" - Sub string `json:"sub"` - Iss string `json:"iss"` - Aud string `json:"aud"` - Exp int64 `json:"exp"` - Iat int64 `json:"iat"` - Nbf int64 `json:"nbf,omitempty"` - Jti string `json:"jti"` +type TokenClaims struct { + Subject string `json:"sub"` + UserID string `json:"user_id,omitempty"` + SessionID string `json:"session_id,omitempty"` + TokenType string `json:"token_type"` + ActorType string `json:"actor_type,omitempty"` + OrganizationID string `json:"org_id,omitempty"` + Scopes []string `json:"scopes,omitempty"` + JTI string `json:"jti"` + IssuedAt int64 `json:"iat"` + Expiration int64 `json:"exp"` } -// TokenPair holds both access and refresh tokens type TokenPair struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` diff --git a/plugins/jwt/types/config.go b/plugins/jwt/types/config.go index 60f0f384..4903c616 100644 --- a/plugins/jwt/types/config.go +++ b/plugins/jwt/types/config.go @@ -4,10 +4,8 @@ import ( "time" ) -// JWTPluginConfig configures the JWKS-based JWT plugin type JWTPluginConfig struct { Enabled bool `json:"enabled" toml:"enabled"` - Algorithm JWTAlgorithm `json:"algorithm" toml:"algorithm"` // EdDSA (default), RS256, PS256, ES256, ES512 KeyRotationInterval time.Duration `json:"key_rotation_interval" toml:"key_rotation_interval"` // Default: 30 days KeyRotationGracePeriod time.Duration `json:"key_rotation_grace_period" toml:"key_rotation_grace_period"` // Grace period for old key validity after rotation, default: 1 hour ExpiresIn time.Duration `json:"expires_in" toml:"expires_in"` // Access token TTL @@ -16,11 +14,7 @@ type JWTPluginConfig struct { RefreshGracePeriod time.Duration `json:"refresh_grace_period" toml:"refresh_grace_period"` // Grace period for refresh token reuse, default 10s } -// ApplyDefaults returns sensible defaults for the JWT plugin func (c *JWTPluginConfig) ApplyDefaults() { - if c.Algorithm == "" { - c.Algorithm = JWTAlgEdDSA - } if c.KeyRotationInterval == 0 { c.KeyRotationInterval = 30 * 24 * time.Hour } @@ -40,21 +34,3 @@ func (c *JWTPluginConfig) ApplyDefaults() { c.RefreshGracePeriod = 10 * time.Second } } - -// NormalizeAlgorithm normalizes and validates the algorithm string. Use when -// parsing config or on update to catch legacy or unsupported values. -func (c *JWTPluginConfig) NormalizeAlgorithm() error { - if c.Algorithm == "" { - c.Algorithm = JWTAlgEdDSA - return nil - } - parsed, err := ParseAlgorithm(string(c.Algorithm)) - if err != nil { - return err - } - if err := ValidateAlgorithm(parsed); err != nil { - return err - } - c.Algorithm = parsed - return nil -} diff --git a/plugins/jwt/types/models.go b/plugins/jwt/types/models.go index 5f3f6915..e29fe759 100644 --- a/plugins/jwt/types/models.go +++ b/plugins/jwt/types/models.go @@ -6,7 +6,6 @@ import ( "github.com/uptrace/bun" ) -// JWKS represents a cryptographic key pair for signing and verification type JWKS struct { bun.BaseModel `bun:"table:jwks"` @@ -17,7 +16,6 @@ type JWKS struct { CreatedAt time.Time `json:"created_at" bun:"column:created_at,default:current_timestamp"` } -// RefreshToken represents a stored refresh token in the database type RefreshToken struct { bun.BaseModel `bun:"table:refresh_tokens"` diff --git a/plugins/jwt/usecases/jwks_usecase_test.go b/plugins/jwt/usecases/jwks_usecase_test.go new file mode 100644 index 00000000..94d11e94 --- /dev/null +++ b/plugins/jwt/usecases/jwks_usecase_test.go @@ -0,0 +1,56 @@ +package usecases + +import ( + "context" + "errors" + "testing" + + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestJWKSUseCase(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(*useCaseTestFixture) + assert func(*testing.T, *JWKSResult, error) + }{ + { + name: "success", + setup: func(f *useCaseTestFixture) { + set := jwk.NewSet() + f.cacheSvc.On("GetJWKSWithFallback", mock.Anything).Return(set, nil).Once() + }, + assert: func(t *testing.T, result *JWKSResult, err error) { + require.NoError(t, err) + require.NotNil(t, result.KeySet) + }, + }, + { + name: "service_error", + setup: func(f *useCaseTestFixture) { + f.cacheSvc.On("GetJWKSWithFallback", mock.Anything).Return(jwk.Set(nil), errors.New("cache error")).Once() + }, + assert: func(t *testing.T, result *JWKSResult, err error) { + require.Error(t, err) + require.Nil(t, result) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + f := newUseCaseTestFixture() + if tc.setup != nil { + tc.setup(f) + } + result, err := f.newJWKSUseCase().GetJWKS(context.Background()) + tc.assert(t, result, err) + f.cacheSvc.AssertExpectations(t) + }) + } +} diff --git a/plugins/jwt/usecases/refresh_token_usecase_test.go b/plugins/jwt/usecases/refresh_token_usecase_test.go new file mode 100644 index 00000000..80690433 --- /dev/null +++ b/plugins/jwt/usecases/refresh_token_usecase_test.go @@ -0,0 +1,63 @@ +package usecases + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/Authula/authula/plugins/jwt/types" +) + +func TestRefreshTokenUseCase(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + token string + setup func(*useCaseTestFixture) + assert func(*testing.T, *RefreshTokenResult, error) + }{ + { + name: "success", + token: "valid-token", + setup: func(f *useCaseTestFixture) { + f.refreshTokenSvc.On("RefreshTokens", mock.Anything, "valid-token").Return(&types.RefreshTokenResponse{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + }, nil).Once() + }, + assert: func(t *testing.T, result *RefreshTokenResult, err error) { + require.NoError(t, err) + require.Equal(t, "new-access", result.AccessToken) + require.Equal(t, "new-refresh", result.RefreshToken) + }, + }, + { + name: "service_error", + token: "bad-token", + setup: func(f *useCaseTestFixture) { + f.refreshTokenSvc.On("RefreshTokens", mock.Anything, "bad-token").Return((*types.RefreshTokenResponse)(nil), errors.New("invalid refresh token")).Once() + }, + assert: func(t *testing.T, result *RefreshTokenResult, err error) { + require.Error(t, err) + require.Nil(t, result) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + f := newUseCaseTestFixture() + if tc.setup != nil { + tc.setup(f) + } + result, err := f.newRefreshTokenUseCase().RefreshTokens(context.Background(), tc.token) + tc.assert(t, result, err) + f.refreshTokenSvc.AssertExpectations(t) + }) + } +} diff --git a/plugins/jwt/usecases/usecases_test.go b/plugins/jwt/usecases/usecases_test.go new file mode 100644 index 00000000..43402b2c --- /dev/null +++ b/plugins/jwt/usecases/usecases_test.go @@ -0,0 +1,28 @@ +package usecases + +import ( + internaltests "github.com/Authula/authula/internal/tests" + jwttests "github.com/Authula/authula/plugins/jwt/tests" +) + +type useCaseTestFixture struct { + logger *internaltests.MockLogger + refreshTokenSvc *jwttests.MockRefreshTokenService + cacheSvc *jwttests.MockCacheService +} + +func newUseCaseTestFixture() *useCaseTestFixture { + return &useCaseTestFixture{ + logger: &internaltests.MockLogger{}, + refreshTokenSvc: &jwttests.MockRefreshTokenService{}, + cacheSvc: &jwttests.MockCacheService{}, + } +} + +func (f *useCaseTestFixture) newRefreshTokenUseCase() RefreshTokenUseCase { + return NewRefreshTokenUseCase(f.logger, f.refreshTokenSvc) +} + +func (f *useCaseTestFixture) newJWKSUseCase() JWKSUseCase { + return NewJWKSUseCase(f.logger, f.cacheSvc) +} diff --git a/plugins/rate-limit/services/secondary_storage_test.go b/plugins/rate-limit/services/secondary_storage_test.go index 89b009ed..7686f941 100644 --- a/plugins/rate-limit/services/secondary_storage_test.go +++ b/plugins/rate-limit/services/secondary_storage_test.go @@ -2,23 +2,24 @@ package services import ( "context" + "strings" "testing" "time" ) -type fakeSecondaryStorage struct { +type dummySecondaryStorage struct { values map[string]any ttls map[string]time.Duration } -func newFakeSecondaryStorage() *fakeSecondaryStorage { - return &fakeSecondaryStorage{values: map[string]any{}, ttls: map[string]time.Duration{}} +func newDummySecondaryStorage() *dummySecondaryStorage { + return &dummySecondaryStorage{values: map[string]any{}, ttls: map[string]time.Duration{}} } -func (s *fakeSecondaryStorage) Get(_ context.Context, key string) (any, error) { +func (s *dummySecondaryStorage) Get(_ context.Context, key string) (any, error) { return s.values[key], nil } -func (s *fakeSecondaryStorage) Set(_ context.Context, key string, value any, ttl *time.Duration) error { +func (s *dummySecondaryStorage) Set(_ context.Context, key string, value any, ttl *time.Duration) error { s.values[key] = value if ttl != nil { s.ttls[key] = *ttl @@ -27,12 +28,12 @@ func (s *fakeSecondaryStorage) Set(_ context.Context, key string, value any, ttl } return nil } -func (s *fakeSecondaryStorage) Delete(_ context.Context, key string) error { +func (s *dummySecondaryStorage) Delete(_ context.Context, key string) error { delete(s.values, key) delete(s.ttls, key) return nil } -func (s *fakeSecondaryStorage) Incr(_ context.Context, key string, ttl *time.Duration) (int, error) { +func (s *dummySecondaryStorage) Incr(_ context.Context, key string, ttl *time.Duration) (int, error) { if s.values[key] == nil { s.values[key] = 1 if ttl != nil { @@ -44,18 +45,28 @@ func (s *fakeSecondaryStorage) Incr(_ context.Context, key string, ttl *time.Dur s.values[key] = count return count, nil } -func (s *fakeSecondaryStorage) TTL(_ context.Context, key string) (*time.Duration, error) { +func (s *dummySecondaryStorage) TTL(_ context.Context, key string) (*time.Duration, error) { if ttl, ok := s.ttls[key]; ok { return &ttl, nil } return nil, nil } -func (s *fakeSecondaryStorage) Close() error { return nil } +func (s *dummySecondaryStorage) Scan(_ context.Context, prefix string) ([]string, error) { + var keys []string + for key := range s.values { + if strings.HasPrefix(key, prefix) { + keys = append(keys, key) + } + } + return keys, nil +} + +func (s *dummySecondaryStorage) Close() error { return nil } func TestSecondaryStorageProviderRuleLifecycle(t *testing.T) { t.Parallel() - storage := newFakeSecondaryStorage() + storage := newDummySecondaryStorage() provider := NewSecondaryStorageProvider("custom", storage) ctx := context.Background() window := 4 * time.Minute diff --git a/plugins/secondary-storage/database_secondary_storage.go b/plugins/secondary-storage/database_secondary_storage.go index c4cc8981..efebe875 100644 --- a/plugins/secondary-storage/database_secondary_storage.go +++ b/plugins/secondary-storage/database_secondary_storage.go @@ -289,6 +289,30 @@ func (storage *DatabaseSecondaryStorage) TTL(ctx context.Context, key string) (* return &ttl, nil } +// Scan returns all non-expired keys matching the given prefix. +func (storage *DatabaseSecondaryStorage) Scan(ctx context.Context, prefix string) ([]string, error) { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("context canceled: %w", ctx.Err()) + default: + } + + var entries []KeyValueStore + err := storage.db.NewSelect().Model(&entries). + Where("key LIKE ?", prefix+"%"). + Where("expires_at IS NULL OR expires_at > ?", time.Now()). + Scan(ctx) + if err != nil { + return nil, fmt.Errorf("database scan error: %w", err) + } + + keys := make([]string, len(entries)) + for i, entry := range entries { + keys[i] = entry.Key + } + return keys, nil +} + // Close gracefully shuts down the storage by stopping the cleanup goroutine. // This method is idempotent and safe to call multiple times. func (storage *DatabaseSecondaryStorage) Close() error { diff --git a/plugins/secondary-storage/memory_secondary_storage.go b/plugins/secondary-storage/memory_secondary_storage.go index 4fcd885e..705c13ff 100644 --- a/plugins/secondary-storage/memory_secondary_storage.go +++ b/plugins/secondary-storage/memory_secondary_storage.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strconv" + "strings" "sync" "time" ) @@ -234,6 +235,31 @@ func (storage *MemorySecondaryStorage) TTL(ctx context.Context, key string) (*ti return &ttl, nil } +// Scan returns all non-expired keys matching the given prefix. +func (storage *MemorySecondaryStorage) Scan(ctx context.Context, prefix string) ([]string, error) { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("context canceled: %w", ctx.Err()) + default: + } + + storage.mu.RLock() + defer storage.mu.RUnlock() + + now := time.Now() + var keys []string + for key, entry := range storage.store { + if !strings.HasPrefix(key, prefix) { + continue + } + if entry.expiresAt != nil && now.After(*entry.expiresAt) { + continue + } + keys = append(keys, key) + } + return keys, nil +} + // Close gracefully shuts down the storage by stopping the cleanup goroutine. // This method is idempotent and safe to call multiple times. func (storage *MemorySecondaryStorage) Close() error { diff --git a/plugins/secondary-storage/redis_secondary_storage.go b/plugins/secondary-storage/redis_secondary_storage.go index c93fef66..c96917cb 100644 --- a/plugins/secondary-storage/redis_secondary_storage.go +++ b/plugins/secondary-storage/redis_secondary_storage.go @@ -150,6 +150,24 @@ func (rs *RedisSecondaryStorage) TTL(ctx context.Context, key string) (*time.Dur return &ttl, nil } +// Scan returns all keys matching the given prefix using Redis SCAN with MATCH. +func (rs *RedisSecondaryStorage) Scan(ctx context.Context, prefix string) ([]string, error) { + var keys []string + var cursor uint64 + for { + scanned, next, err := rs.client.Scan(ctx, cursor, prefix+"*", 100).Result() + if err != nil { + return nil, fmt.Errorf("redis scan error: %w", err) + } + keys = append(keys, scanned...) + cursor = next + if cursor == 0 { + break + } + } + return keys, nil +} + // Close closes the Redis connection func (rs *RedisSecondaryStorage) Close() error { if rs.client != nil { diff --git a/services/jwt.go b/services/jwt.go index 75262682..bf835db4 100644 --- a/services/jwt.go +++ b/services/jwt.go @@ -1,5 +1,11 @@ package services +import ( + "context" + + "github.com/Authula/authula/models" +) + type JWTService interface { - ValidateToken(token string) (userID string, err error) + ValidateToken(ctx context.Context, token string) (*models.Actor, error) }