diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2025-03-09 01:53:44 +1100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-03-09 01:53:44 +1100 |
| commit | 9241731a4dd5592b4a02b5352c903b4d06b6f4ab (patch) | |
| tree | 5184b98751912a261ff70fd8721b9cd4f1c98f1e | |
| parent | bbcb38ab9ff35e69d5d52a71ab56346749f5e8b1 (diff) | |
feat(embed): make authelia embedable (#8841)
This adds a highly experimental option for developers looking to embed Authelia within another go binary.
Closes #5803
Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
49 files changed, 1963 insertions, 823 deletions
diff --git a/docs/content/contributing/guidelines/commit-message.md b/docs/content/contributing/guidelines/commit-message.md index 4a8ba796f..99d109cbc 100644 --- a/docs/content/contributing/guidelines/commit-message.md +++ b/docs/content/contributing/guidelines/commit-message.md @@ -54,7 +54,7 @@ for, and the structure it must have. │ cmd|codecov|commands|configuration|deps|docker|duo|expression|go| │ golangci-lint|handlers|husky|logging|metrics|middlewares|mocks|model| │ notification|npm|ntp|oidc|random|regulation|renovate|reviewdog|server| - │ session|storage|suites|templates|totp|utils|web|webauthn + │ service|session|storage|suites|templates|totp|utils|web|webauthn │ └─⫸ Commit Type: build|ci|docs|feat|fix|i18n|perf|refactor|release|revert|test ``` @@ -100,6 +100,7 @@ commit messages). * random * regulation * server +* service * session * storage * suites diff --git a/experimental/embed/config.go b/experimental/embed/config.go new file mode 100644 index 000000000..81d21463b --- /dev/null +++ b/experimental/embed/config.go @@ -0,0 +1,71 @@ +package embed + +import ( + "fmt" + + "github.com/authelia/authelia/v4/internal/configuration" + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/configuration/validator" +) + +// NewConfiguration builds a new configuration given a list of paths and filters. The filters can either be nil or +// generated using NewNamedConfigFileFilters. This function essentially operates the same as Authelia does normally in +// configuration steps. +func NewConfiguration(paths []string, filters []configuration.BytesFilter) (keys []string, config *schema.Configuration, val *schema.StructValidator, err error) { + sources := configuration.NewDefaultSourcesWithDefaults( + paths, + filters, + configuration.DefaultEnvPrefix, + configuration.DefaultEnvDelimiter, + []configuration.Source{configuration.NewMapSource(configuration.Defaults())}) + + val = schema.NewStructValidator() + + var definitions *schema.Definitions + + if definitions, err = configuration.LoadDefinitions(val, sources...); err != nil { + return nil, nil, nil, err + } + + config = &schema.Configuration{} + + if keys, err = configuration.LoadAdvanced( + val, + "", + config, + definitions, + sources...); err != nil { + return nil, nil, nil, err + } + + return keys, config, val, nil +} + +// ValidateConfigurationAndKeys performs all configuration validation steps. The provided *schema.StructValidator should +// at minimum be checked for errors before continuing. +func ValidateConfigurationAndKeys(config *schema.Configuration, keys []string, val *schema.StructValidator) { + ValidateConfigurationKeys(keys, val) + ValidateConfiguration(config, val) +} + +// ValidateConfigurationKeys just the keys validation steps. The provided *schema.StructValidator should +// at minimum be checked for errors before continuing. This should be used prior to using ValidateConfiguration. +func ValidateConfigurationKeys(keys []string, val *schema.StructValidator) { + validator.ValidateKeys(keys, configuration.GetMultiKeyMappedDeprecationKeys(), configuration.DefaultEnvPrefix, val) +} + +// ValidateConfiguration just the configuration validation steps. The provided *schema.StructValidator should +// at minimum be checked for errors before continuing. This should be used after using ValidateConfigurationKeys. +func ValidateConfiguration(config *schema.Configuration, val *schema.StructValidator) { + validator.ValidateConfiguration(config, val) +} + +// NewNamedConfigFileFilters allows configuring a set of file filters. The officially supported filter has the name +// 'template'. The only other one at this stage is 'expand-env' which is deprecated. +func NewNamedConfigFileFilters(names ...string) (filters []configuration.BytesFilter, err error) { + if filters, err = configuration.NewFileFilters(names); err != nil { + return nil, fmt.Errorf("error occurred loading filters: %w", err) + } + + return filters, nil +} diff --git a/experimental/embed/config_test.go b/experimental/embed/config_test.go new file mode 100644 index 000000000..eb0ce6e2f --- /dev/null +++ b/experimental/embed/config_test.go @@ -0,0 +1,195 @@ +package embed + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authelia/authelia/v4/internal/configuration" +) + +func TestNewConfiguration(t *testing.T) { + testCases := []struct { + name string + paths []string + filters []configuration.BytesFilter + keys []string + warnings []string + errors []string + err string + }{ + { + name: "ShouldHandleWebAuthn", + paths: []string{"../../internal/configuration/test_resources/config.webauthn.yml"}, + keys: []string{ + "server.endpoints.rate_limits.reset_password_finish.enable", + "server.endpoints.rate_limits.reset_password_start.enable", + "server.endpoints.rate_limits.second_factor_duo.enable", + "server.endpoints.rate_limits.second_factor_totp.enable", + "server.endpoints.rate_limits.session_elevation_finish.enable", + "server.endpoints.rate_limits.session_elevation_start.enable", + "webauthn.selection_criteria.attachment", + "webauthn.selection_criteria.discoverability", + "webauthn.selection_criteria.user_verification", + }, + warnings: nil, + errors: []string{ + "identity_validation: reset_password: option 'jwt_secret' is required when the reset password functionality isn't disabled", + "authentication_backend: you must ensure either the 'file' or 'ldap' authentication backend is configured", + "access_control: 'default_policy' option 'deny' is invalid: when no rules are specified it must be 'two_factor' or 'one_factor'", + "session: option 'cookies' is required", + "storage: option 'encryption_key' is required", + "storage: configuration for a 'local', 'mysql' or 'postgres' database must be provided", + "notifier: you must ensure either the 'smtp' or 'filesystem' notifier is configured", + }, + }, + { + name: "ShouldHandleConfigWithDefinitions", + paths: []string{"../../internal/configuration/test_resources/config_with_definitions.yml"}, + keys: []string{ + "access_control.default_policy", + "access_control.networks", + "access_control.networks[].name", + "access_control.networks[].networks", + "access_control.rules", + "access_control.rules[].domain", + "access_control.rules[].networks", + "access_control.rules[].policy", + "access_control.rules[].resources", + "access_control.rules[].subject", + "authentication_backend.ldap.additional_groups_dn", + "authentication_backend.ldap.additional_users_dn", + "authentication_backend.ldap.address", + "authentication_backend.ldap.attributes.group_name", + "authentication_backend.ldap.attributes.mail", + "authentication_backend.ldap.attributes.username", + "authentication_backend.ldap.base_dn", + "authentication_backend.ldap.groups_filter", + "authentication_backend.ldap.tls.private_key", + "authentication_backend.ldap.user", + "authentication_backend.ldap.users_filter", + "authentication_backend.refresh_interval", + "definitions.network.lan", + "definitions.user_attributes.example.expression", + "duo_api.hostname", + "duo_api.integration_key", + "log.level", + "notifier.smtp.address", + "notifier.smtp.disable_require_tls", + "notifier.smtp.sender", + "notifier.smtp.username", + "regulation.ban_time", + "regulation.find_time", + "regulation.max_retries", + "server.address", + "server.endpoints.authz.auth-request.authn_strategies", + "server.endpoints.authz.auth-request.authn_strategies[].name", + "server.endpoints.authz.auth-request.implementation", + "server.endpoints.authz.ext-authz.authn_strategies", + "server.endpoints.authz.ext-authz.authn_strategies[].name", + "server.endpoints.authz.ext-authz.implementation", + "server.endpoints.authz.forward-auth.authn_strategies", + "server.endpoints.authz.forward-auth.authn_strategies[].name", + "server.endpoints.authz.forward-auth.implementation", + "server.endpoints.authz.legacy.implementation", + "server.endpoints.rate_limits.reset_password_finish.enable", + "server.endpoints.rate_limits.reset_password_start.enable", + "server.endpoints.rate_limits.second_factor_duo.enable", + "server.endpoints.rate_limits.second_factor_totp.enable", + "server.endpoints.rate_limits.session_elevation_finish.enable", + "server.endpoints.rate_limits.session_elevation_start.enable", + "session.cookies", + "session.cookies[].authelia_url", + "session.cookies[].default_redirection_url", + "session.cookies[].domain", + "session.expiration", + "session.inactivity", + "session.name", + "session.redis.high_availability.sentinel_name", + "session.redis.host", + "session.redis.port", + "storage.mysql.address", + "storage.mysql.database", + "storage.mysql.username", + "totp.issuer", + "webauthn.selection_criteria.discoverability", + "webauthn.selection_criteria.user_verification", + }, + warnings: nil, + errors: []string{ + "duo_api: option 'secret_key' is required when duo is enabled but it's absent", + "identity_validation: reset_password: option 'jwt_secret' is required when the reset password functionality isn't disabled", + "authentication_backend: ldap: option 'password' is required", + "session: option 'secret' is required when using the 'redis' provider", + "storage: option 'encryption_key' is required", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Run("Individual", func(t *testing.T) { + keys, config, val, err := NewConfiguration(tc.paths, tc.filters) + + assert.Equal(t, tc.keys, keys) + + if tc.err != "" { + assert.EqualError(t, err, tc.err) + assert.Nil(t, config) + } else { + assert.NoError(t, err) + assert.NotNil(t, config) + ValidateConfigurationKeys(keys, val) + ValidateConfiguration(config, val) + } + + require.Len(t, val.Warnings(), len(tc.warnings)) + require.Len(t, val.Errors(), len(tc.errors)) + + for i, err := range val.Warnings() { + assert.EqualError(t, err, tc.warnings[i]) + } + + for i, err := range val.Errors() { + assert.EqualError(t, err, tc.errors[i]) + } + }) + t.Run("Combined", func(t *testing.T) { + keys, config, val, err := NewConfiguration(tc.paths, tc.filters) + + assert.Equal(t, tc.keys, keys) + + if tc.err != "" { + assert.EqualError(t, err, tc.err) + assert.Nil(t, config) + } else { + assert.NoError(t, err) + assert.NotNil(t, config) + ValidateConfigurationAndKeys(config, keys, val) + } + + require.Len(t, val.Warnings(), len(tc.warnings)) + require.Len(t, val.Errors(), len(tc.errors)) + + for i, err := range val.Warnings() { + assert.EqualError(t, err, tc.warnings[i]) + } + + for i, err := range val.Errors() { + assert.EqualError(t, err, tc.errors[i]) + } + }) + }) + } +} + +func TestNewNamedConfigFileFilters(t *testing.T) { + filters, err := NewNamedConfigFileFilters("abc") + assert.Nil(t, filters) + assert.EqualError(t, err, "error occurred loading filters: invalid filter named 'abc'") + + filters, err = NewNamedConfigFileFilters("template") + assert.NotNil(t, filters) + assert.NoError(t, err) +} diff --git a/experimental/embed/context.go b/experimental/embed/context.go new file mode 100644 index 000000000..84cf483aa --- /dev/null +++ b/experimental/embed/context.go @@ -0,0 +1,44 @@ +package embed + +import ( + "context" + + "github.com/sirupsen/logrus" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/middlewares" +) + +// Context is an interface used in various areas of Authelia to simplify access to important elements like the +// configuration, providers, and logger. +type Context interface { + GetLogger() *logrus.Entry + GetProviders() middlewares.Providers + GetConfiguration() *schema.Configuration + + context.Context +} + +type ctxEmbed struct { + Configuration *Configuration + Providers Providers + Logger *logrus.Entry + + context.Context +} + +func (c *ctxEmbed) GetConfiguration() *schema.Configuration { + return c.Configuration.ToInternal() +} + +func (c *ctxEmbed) GetProviders() middlewares.Providers { + return c.Providers.ToInternal() +} + +func (c *ctxEmbed) GetLogger() *logrus.Entry { + return c.Logger +} + +var ( + _ middlewares.Context = (*ctxEmbed)(nil) +) diff --git a/experimental/embed/context_test.go b/experimental/embed/context_test.go new file mode 100644 index 000000000..cb5a2486e --- /dev/null +++ b/experimental/embed/context_test.go @@ -0,0 +1,43 @@ +package embed + +import ( + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/authelia/authelia/v4/internal/logging" +) + +func TestContext(t *testing.T) { + ctx := &ctxEmbed{} + + assert.Nil(t, ctx.GetConfiguration()) + assert.Nil(t, ctx.GetLogger()) + + providers := ctx.GetProviders() + + assert.Nil(t, providers.StorageProvider) + assert.Nil(t, providers.Notifier) + assert.Nil(t, providers.UserProvider) + assert.Nil(t, providers.SessionProvider) + assert.Nil(t, providers.MetaDataService) + assert.Nil(t, providers.Metrics) + assert.Nil(t, providers.Templates) + assert.Nil(t, providers.Random) + assert.Nil(t, providers.OpenIDConnect) + assert.Nil(t, providers.UserAttributeResolver) + assert.Nil(t, providers.Authorizer) + assert.Nil(t, providers.NTP) + assert.Nil(t, providers.TOTP) +} + +func TestContextWithValues(t *testing.T) { + ctx := &ctxEmbed{ + Configuration: &Configuration{}, + Logger: logrus.NewEntry(logging.Logger()), + } + + assert.NotNil(t, ctx.GetConfiguration()) + assert.NotNil(t, ctx.GetLogger()) +} diff --git a/experimental/embed/doc.go b/experimental/embed/doc.go new file mode 100644 index 000000000..9e9d4d09b --- /dev/null +++ b/experimental/embed/doc.go @@ -0,0 +1,13 @@ +// Package embed provides tooling useful to embed Authelia into an external go process. This package is considered +// experimental and as such is not supported by the standard versioning policy. It's strongly recommended that care is +// taken when integrating with this package and appropriate tests are conducted when upgrading. +// +// This package and all subpackages are intended to facilitate differing levels of embedability within Authelia. It's +// likely this package and subpackages will break often. +// +// The following considerations should be made in using this package: +// - It's likely that many methods within this package can panic if not properly utilized. +// - The package is likely at this stage to be changed abruptly from version to version in a breaking way. +// - The package will likely have breaking changes at any minor version bump well into the future (breaking changes to +// this package as a result of changing internal packages will not be a consideration that will slow development). +package embed diff --git a/experimental/embed/embed.go b/experimental/embed/embed.go new file mode 100644 index 000000000..0a2d91168 --- /dev/null +++ b/experimental/embed/embed.go @@ -0,0 +1,7 @@ +package embed + +func ProvidersStartupCheck(ctx Context, log bool) (err error) { + providers := ctx.GetProviders() + + return providers.StartupChecks(ctx, log) +} diff --git a/experimental/embed/embed_test.go b/experimental/embed/embed_test.go new file mode 100644 index 000000000..c20e2cb77 --- /dev/null +++ b/experimental/embed/embed_test.go @@ -0,0 +1,19 @@ +package embed + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShouldPanicNilCtx(t *testing.T) { + assert.Panics(t, func() { + _ = ProvidersStartupCheck(nil, false) + }) + + ctx := &ctxEmbed{} + + assert.Panics(t, func() { + _ = ProvidersStartupCheck(ctx, false) + }) +} diff --git a/experimental/embed/provider/authentication.go b/experimental/embed/provider/authentication.go new file mode 100644 index 000000000..3903e6c1c --- /dev/null +++ b/experimental/embed/provider/authentication.go @@ -0,0 +1,22 @@ +package provider + +import ( + "crypto/x509" + + "github.com/authelia/authelia/v4/internal/authentication" + "github.com/authelia/authelia/v4/internal/configuration/schema" +) + +// NewAuthenticationFile directly instantiates a new authentication.UserProvider using a *authentication.FileUserProvider. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewAuthenticationFile(config *schema.Configuration) authentication.UserProvider { + return authentication.NewFileUserProvider(config.AuthenticationBackend.File) +} + +// NewAuthenticationLDAP directly instantiates a new authentication.UserProvider using a *authentication.LDAPUserProvider. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewAuthenticationLDAP(config *schema.Configuration, caCertPool *x509.CertPool) authentication.UserProvider { + return authentication.NewLDAPUserProvider(config.AuthenticationBackend, caCertPool) +} diff --git a/experimental/embed/provider/general.go b/experimental/embed/provider/general.go new file mode 100644 index 000000000..95e5d04a7 --- /dev/null +++ b/experimental/embed/provider/general.go @@ -0,0 +1,113 @@ +package provider + +import ( + "crypto/x509" + + "github.com/authelia/authelia/v4/internal/authorization" + "github.com/authelia/authelia/v4/internal/clock" + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/expression" + "github.com/authelia/authelia/v4/internal/metrics" + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/ntp" + "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/random" + "github.com/authelia/authelia/v4/internal/regulation" + "github.com/authelia/authelia/v4/internal/session" + "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/templates" + "github.com/authelia/authelia/v4/internal/totp" + "github.com/authelia/authelia/v4/internal/webauthn" +) + +// New returns a completely new set of providers using the internal API. It is expected you'll check the errs return +// value for any errors, and handle any warnings in a graceful way. If errors are returned the providers should not be +// utilized to run anything. +func New(config *schema.Configuration, caCertPool *x509.CertPool) (providers middlewares.Providers, warns []error, errs []error) { + return middlewares.NewProviders(config, caCertPool) +} + +// NewClock creates a new clock provider. +func NewClock() clock.Provider { + return clock.New() +} + +// NewAuthorizer creates a new *authorization.Authorizer. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewAuthorizer(config *schema.Configuration) *authorization.Authorizer { + return authorization.NewAuthorizer(config) +} + +// NewSession creates a new *session.Provider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewSession(config *schema.Configuration, caCertPool *x509.CertPool) *session.Provider { + return session.NewProvider(config.Session, caCertPool) +} + +// NewRegulator creates a new *regulation.Regulator given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewRegulator(config *schema.Configuration, storage storage.RegulatorProvider, clock clock.Provider) *regulation.Regulator { + return regulation.NewRegulator(config.Regulation, storage, clock) +} + +// NewMetrics creates a new metrics.Provider. +func NewMetrics() metrics.Provider { + return metrics.NewPrometheus() +} + +// NewNTP creates a new *ntp.Provider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewNTP(config *schema.Configuration) *ntp.Provider { + return ntp.NewProvider(&config.NTP) +} + +// NewOpenIDConnect creates a new *oidc.OpenIDConnectProvider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewOpenIDConnect(config *schema.Configuration, storage storage.Provider, templates *templates.Provider) *oidc.OpenIDConnectProvider { + return oidc.NewOpenIDConnectProvider(config, storage, templates) +} + +// NewTemplates creates a new *templates.Provider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewTemplates(config *schema.Configuration) (provider *templates.Provider, err error) { + return templates.New(templates.Config{EmailTemplatesPath: config.Notifier.TemplatePath}) +} + +// NewTOTP creates a new totp.Provider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewTOTP(config *schema.Configuration) totp.Provider { + return totp.NewTimeBasedProvider(config.TOTP) +} + +// NewPasswordPolicy creates a new middlewares.PasswordPolicyProvider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewPasswordPolicy(config *schema.Configuration) middlewares.PasswordPolicyProvider { + return middlewares.NewPasswordPolicyProvider(config.PasswordPolicy) +} + +// NewRandom creates a new random.Provider given a valid configuration. This uses the rand/crypto package. +func NewRandom() random.Provider { + return &random.Cryptographical{} +} + +// NewUserAttributeResolver creates a new expression.UserAttributeResolver given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewUserAttributeResolver(config *schema.Configuration) expression.UserAttributeResolver { + return expression.NewUserAttributes(config) +} + +// NewMetaDataService creates a new webauthn.MetaDataProvider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewMetaDataService(config *schema.Configuration, store storage.CachedDataProvider) (provider webauthn.MetaDataProvider, err error) { + return webauthn.NewMetaDataProvider(config, store) +} diff --git a/experimental/embed/provider/notification.go b/experimental/embed/provider/notification.go new file mode 100644 index 000000000..a566cd20d --- /dev/null +++ b/experimental/embed/provider/notification.go @@ -0,0 +1,24 @@ +package provider + +import ( + "crypto/x509" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/notification" +) + +// NewNotificationSMTP creates a new notification.Notifier using the *notification.SMTPNotifier given a valid +// configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewNotificationSMTP(config *schema.Configuration, caCertPool *x509.CertPool) notification.Notifier { + return notification.NewSMTPNotifier(config.Notifier.SMTP, caCertPool) +} + +// NewNotificationFile creates a new notification.Notifier using the *notification.FileNotifier given a valid +// configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewNotificationFile(config *schema.Configuration, caCertPool *x509.CertPool) notification.Notifier { + return notification.NewSMTPNotifier(config.Notifier.SMTP, caCertPool) +} diff --git a/experimental/embed/provider/storage.go b/experimental/embed/provider/storage.go new file mode 100644 index 000000000..42c40aefd --- /dev/null +++ b/experimental/embed/provider/storage.go @@ -0,0 +1,29 @@ +package provider + +import ( + "crypto/x509" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/storage" +) + +// NewStoragePostgreSQL creates a new storage.Provider using the *storage.PostgreSQLProvider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewStoragePostgreSQL(config *schema.Configuration, caCertPool *x509.CertPool) storage.Provider { + return storage.NewPostgreSQLProvider(config, caCertPool) +} + +// NewStorageMySQL creates a new storage.Provider using the *storage.MySQLProvider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewStorageMySQL(config *schema.Configuration, caCertPool *x509.CertPool) storage.Provider { + return storage.NewMySQLProvider(config, caCertPool) +} + +// NewStorageSQLite creates a new storage.Provider using the *storage.SQLiteProvider given a valid configuration. +// +// Warning: This method may panic if the provided configuration isn't validated. +func NewStorageSQLite(config *schema.Configuration) storage.Provider { + return storage.NewSQLiteProvider(config) +} diff --git a/experimental/embed/types.go b/experimental/embed/types.go new file mode 100644 index 000000000..bd048206e --- /dev/null +++ b/experimental/embed/types.go @@ -0,0 +1,24 @@ +package embed + +import ( + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/middlewares" +) + +// Configuration is a type alias for the internal schema.Configuration type. It allows manually configuring Authelia +// and transitioning to the internal implementation. +type Configuration schema.Configuration + +// ToInternal converts this Configuration struct into a *schema.Configuration struct using a type cast. +func (c *Configuration) ToInternal() *schema.Configuration { + return (*schema.Configuration)(c) +} + +// Providers is a type alias for the internal middlewares.Providers type. It allows manually performing setup of the +// various Authelia providers and transitioning to the internal implementation. +type Providers middlewares.Providers + +// ToInternal converts this Providers struct into a middlewares.Providers struct using a type cast. +func (p Providers) ToInternal() middlewares.Providers { + return middlewares.Providers(p) +} diff --git a/experimental/embed/types_test.go b/experimental/embed/types_test.go new file mode 100644 index 000000000..3a93ce183 --- /dev/null +++ b/experimental/embed/types_test.go @@ -0,0 +1,15 @@ +package embed + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTypeBasics(t *testing.T) { + c := &Configuration{} + assert.NotNil(t, c.ToInternal()) + + p := &Providers{} + assert.NotNil(t, p.ToInternal()) +} diff --git a/internal/authentication/file_user_provider.go b/internal/authentication/file_user_provider.go index 9be6d8edf..bd7af3417 100644 --- a/internal/authentication/file_user_provider.go +++ b/internal/authentication/file_user_provider.go @@ -77,7 +77,7 @@ func (p *FileUserProvider) Reload() (reloaded bool, err error) { return true, nil } -func (p *FileUserProvider) Shutdown() (err error) { +func (p *FileUserProvider) Close() (err error) { return nil } diff --git a/internal/authentication/ldap_user_provider_lifecycle.go b/internal/authentication/ldap_user_provider_lifecycle.go index 73dceaa6a..8f45c99f2 100644 --- a/internal/authentication/ldap_user_provider_lifecycle.go +++ b/internal/authentication/ldap_user_provider_lifecycle.go @@ -10,7 +10,7 @@ import ( "github.com/authelia/authelia/v4/internal/utils" ) -func (p *LDAPUserProvider) Shutdown() (err error) { +func (p *LDAPUserProvider) Close() (err error) { return p.factory.Close() } diff --git a/internal/authentication/ldap_user_provider_test.go b/internal/authentication/ldap_user_provider_test.go index 3975faa39..7b6b4ee73 100644 --- a/internal/authentication/ldap_user_provider_test.go +++ b/internal/authentication/ldap_user_provider_test.go @@ -457,7 +457,7 @@ func TestShouldCheckLDAPServerExtensionsPooled(t *testing.T) { assert.False(t, provider.features.ControlTypes.MsftPwdPolHints) assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated) - assert.EqualError(t, provider.Shutdown(), "errors occurred closing the client pool: close error") + assert.EqualError(t, provider.Close(), "errors occurred closing the client pool: close error") } func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntry(t *testing.T) { @@ -597,7 +597,7 @@ func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntryPoo assert.False(t, provider.features.ControlTypes.MsftPwdPolHints) assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated) - assert.NoError(t, provider.Shutdown()) + assert.NoError(t, provider.Close()) } func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntryPooledClosing(t *testing.T) { @@ -680,7 +680,7 @@ func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntryPoo assert.False(t, provider.features.ControlTypes.MsftPwdPolHints) assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated) - assert.NoError(t, provider.Shutdown()) + assert.NoError(t, provider.Close()) } func TestShouldCheckLDAPServerControlTypes(t *testing.T) { @@ -820,7 +820,7 @@ func TestShouldCheckLDAPServerControlTypesPooled(t *testing.T) { assert.True(t, provider.features.ControlTypes.MsftPwdPolHints) assert.True(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated) - assert.NoError(t, provider.Shutdown()) + assert.NoError(t, provider.Close()) } func TestShouldNotEnablePasswdModifyExtensionOrControlTypes(t *testing.T) { @@ -886,7 +886,7 @@ func TestShouldNotEnablePasswdModifyExtensionOrControlTypes(t *testing.T) { assert.False(t, provider.features.ControlTypes.MsftPwdPolHints) assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated) - assert.NoError(t, provider.Shutdown()) + assert.NoError(t, provider.Close()) } func TestShouldNotEnablePasswdModifyExtensionOrControlTypesPooled(t *testing.T) { @@ -962,7 +962,7 @@ func TestShouldNotEnablePasswdModifyExtensionOrControlTypesPooled(t *testing.T) assert.False(t, provider.features.ControlTypes.MsftPwdPolHints) assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated) - assert.NoError(t, provider.Shutdown()) + assert.NoError(t, provider.Close()) } func TestShouldReturnCheckServerConnectError(t *testing.T) { @@ -1092,7 +1092,7 @@ func TestShouldReturnCheckServerSearchErrorPooled(t *testing.T) { assert.False(t, provider.features.Extensions.PwdModifyExOp) - assert.NoError(t, provider.Shutdown()) + assert.NoError(t, provider.Close()) } func TestShouldPermitRootDSEFailure(t *testing.T) { @@ -1187,7 +1187,7 @@ func TestShouldPermitRootDSEFailurePooled(t *testing.T) { ) assert.NoError(t, provider.StartupCheck()) - assert.NoError(t, provider.Shutdown()) + assert.NoError(t, provider.Close()) } type SearchRequestMatcher struct { diff --git a/internal/authentication/user_provider.go b/internal/authentication/user_provider.go index 9d56e65cd..b1addebc9 100644 --- a/internal/authentication/user_provider.go +++ b/internal/authentication/user_provider.go @@ -22,5 +22,5 @@ type UserProvider interface { // ChangePassword is used to change a user's password but requires their old password to be successfully verified. ChangePassword(username string, oldPassword string, newPassword string) (err error) - Shutdown() (err error) + Close() (err error) } diff --git a/internal/commands/const.go b/internal/commands/const.go index ffe416472..095f21de0 100644 --- a/internal/commands/const.go +++ b/internal/commands/const.go @@ -926,8 +926,6 @@ new one is compatible for and retrofitting it would be incredibly difficult.` ) const ( - fmtLogServerListening = "Listening for %s connections on '%s' path '%s'" - fmtYAMLConfigTemplateHeader = ` --- ## @@ -947,28 +945,6 @@ const ( ) const ( - logFieldService = "service" - logFieldFile = "file" - logFieldOP = "op" - - serviceTypeServer = "server" - serviceTypeWatcher = "watcher" - serviceTypeSignal = "signal" - - logFieldProvider = "provider" - logMessageStartupCheckError = "Error occurred running a startup check" - logMessageStartupCheckPerforming = "Performing Startup Check" - logMessageStartupCheckSuccess = "Startup Check Completed Successfully" - - providerNameNTP = "ntp" - providerNameStorage = "storage" - providerNameUser = "user" - providerNameNotification = "notification" - providerNameExpressions = "expressions" - providerNameWebAuthnMetaData = "webauthn-metadata" -) - -const ( wordYes = "Yes" wordNo = "No" ) diff --git a/internal/commands/context.go b/internal/commands/context.go index 77e9c8491..a3dd32f4f 100644 --- a/internal/commands/context.go +++ b/internal/commands/context.go @@ -15,27 +15,14 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/authelia/authelia/v4/internal/authentication" - "github.com/authelia/authelia/v4/internal/authorization" - "github.com/authelia/authelia/v4/internal/clock" "github.com/authelia/authelia/v4/internal/configuration" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/validator" - "github.com/authelia/authelia/v4/internal/expression" "github.com/authelia/authelia/v4/internal/logging" - "github.com/authelia/authelia/v4/internal/metrics" "github.com/authelia/authelia/v4/internal/middlewares" - "github.com/authelia/authelia/v4/internal/notification" - "github.com/authelia/authelia/v4/internal/ntp" - "github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/random" - "github.com/authelia/authelia/v4/internal/regulation" - "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/storage" - "github.com/authelia/authelia/v4/internal/templates" - "github.com/authelia/authelia/v4/internal/totp" "github.com/authelia/authelia/v4/internal/utils" - "github.com/authelia/authelia/v4/internal/webauthn" ) // NewCmdCtx returns a new CmdCtx. @@ -44,7 +31,7 @@ func NewCmdCtx() *CmdCtx { return &CmdCtx{ Context: ctx, - log: logging.Logger(), + log: logrus.NewEntry(logging.Logger()), providers: middlewares.Providers{ Random: &random.Cryptographical{}, }, @@ -56,7 +43,7 @@ func NewCmdCtx() *CmdCtx { type CmdCtx struct { context.Context - log *logrus.Logger + log *logrus.Entry config *schema.Configuration providers middlewares.Providers @@ -87,7 +74,7 @@ type CmdCtxConfig struct { type CobraRunECmd func(cmd *cobra.Command, args []string) (err error) // GetLogger returns the *logrus.Logger satisfying part of the ServiceCtx. -func (ctx *CmdCtx) GetLogger() *logrus.Logger { +func (ctx *CmdCtx) GetLogger() *logrus.Entry { return ctx.log } @@ -154,50 +141,11 @@ func (ctx *CmdCtx) LoadTrustedCertificates() (warns, errs []error) { // LoadProviders loads all providers into the CmdCtx. func (ctx *CmdCtx) LoadProviders() (warns, errs []error) { - // TODO: Adjust this so the CertPool can be used like a provider. if warns, errs = ctx.LoadTrustedCertificates(); len(warns) != 0 || len(errs) != 0 { return warns, errs } - ctx.providers.StorageProvider = getStorageProvider(ctx) - - ctx.providers.Authorizer = authorization.NewAuthorizer(ctx.config) - ctx.providers.NTP = ntp.NewProvider(&ctx.config.NTP) - ctx.providers.PasswordPolicy = middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy) - ctx.providers.Regulator = regulation.NewRegulator(ctx.config.Regulation, ctx.providers.StorageProvider, clock.New()) - ctx.providers.SessionProvider = session.NewProvider(ctx.config.Session, ctx.trusted) - ctx.providers.TOTP = totp.NewTimeBasedProvider(ctx.config.TOTP) - ctx.providers.UserAttributeResolver = expression.NewUserAttributes(ctx.config) - - var err error - - switch { - case ctx.config.AuthenticationBackend.File != nil: - ctx.providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File) - case ctx.config.AuthenticationBackend.LDAP != nil: - ctx.providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted) - } - - if ctx.providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil { - errs = append(errs, err) - } - - if ctx.providers.MetaDataService, err = webauthn.NewMetaDataProvider(ctx.config, ctx.providers.StorageProvider); err != nil { - errs = append(errs, err) - } - - switch { - case ctx.config.Notifier.SMTP != nil: - ctx.providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted) - case ctx.config.Notifier.FileSystem != nil: - ctx.providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem) - } - - ctx.providers.OpenIDConnect = oidc.NewOpenIDConnectProvider(ctx.config, ctx.providers.StorageProvider, ctx.providers.Templates) - - if ctx.config.Telemetry.Metrics.Enabled { - ctx.providers.Metrics = metrics.NewPrometheus() - } + ctx.providers, warns, errs = middlewares.NewProviders(ctx.config, ctx.trusted) return warns, errs } diff --git a/internal/commands/helpers.go b/internal/commands/helpers.go index 84ff9074c..af889af1c 100644 --- a/internal/commands/helpers.go +++ b/internal/commands/helpers.go @@ -13,16 +13,7 @@ import ( ) func getStorageProvider(ctx *CmdCtx) (provider storage.Provider) { - switch { - case ctx.config.Storage.PostgreSQL != nil: - return storage.NewPostgreSQLProvider(ctx.config, ctx.trusted) - case ctx.config.Storage.MySQL != nil: - return storage.NewMySQLProvider(ctx.config, ctx.trusted) - case ctx.config.Storage.Local != nil: - return storage.NewSQLiteProvider(ctx.config) - default: - return nil - } + return storage.NewProvider(ctx.config, ctx.trusted) } func containsIdentifier(identifier model.UserOpaqueIdentifier, identifiers []model.UserOpaqueIdentifier) bool { diff --git a/internal/commands/root.go b/internal/commands/root.go index e294cccfd..690779d07 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -1,13 +1,15 @@ package commands import ( + "errors" "fmt" "os" "github.com/spf13/cobra" "github.com/authelia/authelia/v4/internal/logging" - "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/service" "github.com/authelia/authelia/v4/internal/utils" ) @@ -82,102 +84,22 @@ func (ctx *CmdCtx) RootRunE(_ *cobra.Command, _ []string) (err error) { ctx.log.Error(err) } - ctx.log.Fatalf("Errors occurred provisioning providers.") + ctx.log.Fatal("Errors occurred provisioning providers") } - doStartupChecks(ctx) + if err = ctx.providers.StartupChecks(ctx, true); err != nil { + var scerr *middlewares.ErrProviderStartupCheck - ctx.cconfig = nil - - ctx.log.Trace("Starting Services") - - servicesRun(ctx) - - return nil -} - -func doStartupChecks(ctx *CmdCtx) { - var ( - failures []string - err error - ) - - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameStorage}).Trace(logMessageStartupCheckPerforming) - - if err = doStartupCheck(ctx, providerNameStorage, ctx.providers.StorageProvider, false); err != nil { - ctx.log.WithError(err).WithField(logFieldProvider, providerNameStorage).Error(logMessageStartupCheckError) - - failures = append(failures, providerNameStorage) - } else { - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameStorage}).Trace(logMessageStartupCheckSuccess) - } - - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameUser}).Trace(logMessageStartupCheckPerforming) - - if err = doStartupCheck(ctx, providerNameUser, ctx.providers.UserProvider, false); err != nil { - ctx.log.WithError(err).WithField(logFieldProvider, providerNameUser).Error(logMessageStartupCheckError) - - failures = append(failures, providerNameUser) - } else { - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameUser}).Trace(logMessageStartupCheckSuccess) - } - - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNotification}).Trace(logMessageStartupCheckPerforming) - - if err = doStartupCheck(ctx, providerNameNotification, ctx.providers.Notifier, ctx.config.Notifier.DisableStartupCheck); err != nil { - ctx.log.WithError(err).WithField(logFieldProvider, providerNameNotification).Error(logMessageStartupCheckError) - - failures = append(failures, providerNameNotification) - } else { - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNotification}).Trace(logMessageStartupCheckSuccess) - } - - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNTP}).Trace(logMessageStartupCheckPerforming) - - if err = doStartupCheck(ctx, providerNameNTP, ctx.providers.NTP, ctx.config.NTP.DisableStartupCheck); err != nil { - if !ctx.config.NTP.DisableFailure { - ctx.log.WithError(err).WithField(logFieldProvider, providerNameNTP).Error(logMessageStartupCheckError) - - failures = append(failures, providerNameNTP) + if errors.As(err, &scerr) { + ctx.GetLogger().WithField("providers", scerr.Failed()).Fatalf("One or more providers had fatal failures performing startup checks, for more details check the error level logs") } else { - ctx.log.WithError(err).WithField(logFieldProvider, providerNameNTP).Warn(logMessageStartupCheckError) + ctx.log.Fatal("Errors occurred performing startup checks") } - } else { - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNTP}).Trace(logMessageStartupCheckSuccess) - } - - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameExpressions}).Trace(logMessageStartupCheckPerforming) - - if err = doStartupCheck(ctx, providerNameExpressions, ctx.providers.UserAttributeResolver, false); err != nil { - ctx.log.WithError(err).WithField(logFieldProvider, providerNameExpressions).Error(logMessageStartupCheckError) - - failures = append(failures, providerNameExpressions) - } else { - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameExpressions}).Trace(logMessageStartupCheckSuccess) - } - - if err = doStartupCheck(ctx, providerNameWebAuthnMetaData, ctx.providers.MetaDataService, !ctx.config.WebAuthn.Metadata.Enabled || ctx.providers.MetaDataService == nil); err != nil { - ctx.log.WithError(err).WithField(logFieldProvider, providerNameWebAuthnMetaData).Error(logMessageStartupCheckError) - - failures = append(failures, providerNameWebAuthnMetaData) - } else { - ctx.log.WithFields(map[string]any{logFieldProvider: providerNameWebAuthnMetaData}).Trace("Startup Check Completed Successfully") } - if len(failures) != 0 { - ctx.log.WithField("providers", failures).Fatalf("One or more providers had fatal failures performing startup checks, for more detail check the error level logs") - } -} - -func doStartupCheck(ctx *CmdCtx, name string, provider model.StartupCheck, disabled bool) error { - if disabled { - ctx.log.Debugf("%s provider: startup check skipped as it is disabled", name) - return nil - } + ctx.cconfig = nil - if provider == nil { - return fmt.Errorf("unrecognized provider or it is not configured properly") - } + ctx.log.Trace("Starting Services") - return provider.StartupCheck() + return service.RunAll(ctx) } diff --git a/internal/commands/services.go b/internal/commands/services.go deleted file mode 100644 index 18598b488..000000000 --- a/internal/commands/services.go +++ /dev/null @@ -1,463 +0,0 @@ -package commands - -import ( - "context" - "fmt" - "net" - "os" - "os/signal" - "path/filepath" - "strings" - "sync" - "syscall" - "time" - - "github.com/fsnotify/fsnotify" - "github.com/sirupsen/logrus" - "github.com/valyala/fasthttp" - "golang.org/x/sync/errgroup" - - "github.com/authelia/authelia/v4/internal/authentication" - "github.com/authelia/authelia/v4/internal/configuration/schema" - "github.com/authelia/authelia/v4/internal/logging" - "github.com/authelia/authelia/v4/internal/middlewares" - "github.com/authelia/authelia/v4/internal/server" -) - -// NewServerService creates a new ServerService with the appropriate logger etc. -func NewServerService(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Logger) (service *ServerService) { - return &ServerService{ - name: name, - server: server, - listener: listener, - paths: paths, - isTLS: isTLS, - log: log.WithFields(map[string]any{logFieldService: serviceTypeServer, serviceTypeServer: name}), - } -} - -// NewFileWatcherService creates a new FileWatcherService with the appropriate logger etc. -func NewFileWatcherService(name, path string, reload ProviderReload, log *logrus.Logger) (service *FileWatcherService, err error) { - if path == "" { - return nil, fmt.Errorf("path must be specified") - } - - var info os.FileInfo - - if info, err = os.Stat(path); err != nil { - return nil, fmt.Errorf("error stating file '%s': %w", path, err) - } - - if path, err = filepath.Abs(path); err != nil { - return nil, fmt.Errorf("error determining absolute path of file '%s': %w", path, err) - } - - var watcher *fsnotify.Watcher - - if watcher, err = fsnotify.NewWatcher(); err != nil { - return nil, err - } - - entry := log.WithFields(map[string]any{logFieldService: serviceTypeWatcher, serviceTypeWatcher: name}) - - if info.IsDir() { - service = &FileWatcherService{ - name: name, - watcher: watcher, - reload: reload, - log: entry, - directory: filepath.Clean(path), - } - } else { - service = &FileWatcherService{ - name: name, - watcher: watcher, - reload: reload, - log: entry, - directory: filepath.Dir(path), - file: filepath.Base(path), - } - } - - if err = service.watcher.Add(service.directory); err != nil { - return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err) - } - - return service, nil -} - -// NewSignalService creates a new SignalService with the appropriate logger etc. -func NewSignalService(name string, action func() (err error), log *logrus.Logger, signals ...os.Signal) (service *SignalService) { - return &SignalService{ - name: name, - signals: signals, - action: action, - log: log.WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: name}), - } -} - -type ServiceCtx interface { - GetLogger() *logrus.Logger - GetProviders() middlewares.Providers - GetConfiguration() *schema.Configuration - - context.Context -} - -// ProviderReload represents the required methods to support reloading a provider. -type ProviderReload interface { - Reload() (reloaded bool, err error) -} - -// Service represents the required methods to support handling a service. -type Service interface { - // ServiceType returns the type name for the Service. - ServiceType() string - - // ServiceName returns the individual name for the Service. - ServiceName() string - - // Run performs the running operations for the Service. - Run() (err error) - - // Shutdown perform the shutdown cleanup and termination operations for the Service. - Shutdown() - - // Log returns the logger configured for the service. - Log() *logrus.Entry -} - -// ServerService is a Service which runs a web server. -type ServerService struct { - name string - server *fasthttp.Server - paths []string - isTLS bool - listener net.Listener - log *logrus.Entry -} - -// ServiceType returns the service type for this service, which is always 'server'. -func (service *ServerService) ServiceType() string { - return serviceTypeServer -} - -// ServiceName returns the individual name for this service. -func (service *ServerService) ServiceName() string { - return service.name -} - -// Run the ServerService. -func (service *ServerService) Run() (err error) { - defer func() { - if r := recover(); r != nil { - service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)") - } - }() - - service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '")) - - if err = service.server.Serve(service.listener); err != nil { - service.log.WithError(err).Error("Error returned attempting to serve requests") - - return err - } - - return nil -} - -// Shutdown the ServerService. -func (service *ServerService) Shutdown() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - - defer cancel() - - if err := service.server.ShutdownWithContext(ctx); err != nil { - service.log.WithError(err).Error("Error occurred during shutdown") - } -} - -// Log returns the *logrus.Entry of the ServerService. -func (service *ServerService) Log() *logrus.Entry { - return service.log -} - -// FileWatcherService is a Service that watches files for changes. -type FileWatcherService struct { - name string - - watcher *fsnotify.Watcher - reload ProviderReload - - log *logrus.Entry - file string - directory string -} - -// ServiceType returns the service type for this service, which is always 'watcher'. -func (service *FileWatcherService) ServiceType() string { - return serviceTypeWatcher -} - -// ServiceName returns the individual name for this service. -func (service *FileWatcherService) ServiceName() string { - return service.name -} - -// Run the FileWatcherService. -func (service *FileWatcherService) Run() (err error) { - defer func() { - if r := recover(); r != nil { - service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)") - } - }() - - service.log.WithField(logFieldFile, filepath.Join(service.directory, service.file)).Info("Watching file for changes") - - for { - select { - case event, ok := <-service.watcher.Events: - if !ok { - return nil - } - - log := service.log.WithFields(map[string]any{logFieldFile: event.Name, logFieldOP: event.Op}) - - if service.file != "" && service.file != filepath.Base(event.Name) { - log.Trace("File modification detected to irrelevant file") - break - } - - switch { - case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create: - log.Debug("File modification was detected") - - var reloaded bool - - switch reloaded, err = service.reload.Reload(); { - case err != nil: - log.WithError(err).Error("Error occurred during reload") - case reloaded: - log.Info("Reloaded successfully") - default: - log.Debug("Reload was triggered but it was skipped") - } - case event.Op&fsnotify.Remove == fsnotify.Remove: - log.Debug("File remove was detected") - } - case err, ok := <-service.watcher.Errors: - if !ok { - return nil - } - - service.log.WithError(err).Error("Error while watching file for changes") - } - } -} - -// Shutdown the FileWatcherService. -func (service *FileWatcherService) Shutdown() { - if err := service.watcher.Close(); err != nil { - service.log.WithError(err).Error("Error occurred during shutdown") - } -} - -// Log returns the *logrus.Entry of the FileWatcherService. -func (service *FileWatcherService) Log() *logrus.Entry { - return service.log -} - -// SignalService is a Service which performs actions on signals. -type SignalService struct { - name string - signals []os.Signal - action func() (err error) - log *logrus.Entry - - notify chan os.Signal - quit chan struct{} -} - -// ServiceType returns the service type for this service, which is always 'server'. -func (service *SignalService) ServiceType() string { - return serviceTypeSignal -} - -// ServiceName returns the individual name for this service. -func (service *SignalService) ServiceName() string { - return service.name -} - -// Run the ServerService. -func (service *SignalService) Run() (err error) { - service.quit = make(chan struct{}) - - service.notify = make(chan os.Signal, 1) - - signal.Notify(service.notify, service.signals...) - - for { - select { - case s := <-service.notify: - if err = service.action(); err != nil { - service.log.WithError(err).Error("Error occurred executing service action.") - } else { - service.log.WithFields(map[string]any{"signal-received": s.String()}).Debug("Successfully executed service action.") - } - case <-service.quit: - return - } - } -} - -// Shutdown the ServerService. -func (service *SignalService) Shutdown() { - signal.Stop(service.notify) - - service.quit <- struct{}{} -} - -// Log returns the *logrus.Entry of the ServerService. -func (service *SignalService) Log() *logrus.Entry { - return service.log -} - -func svcSvrMainFunc(ctx ServiceCtx) (service Service) { - switch svr, listener, paths, isTLS, err := server.CreateDefaultServer(ctx.GetConfiguration(), ctx.GetProviders()); { - case err != nil: - ctx.GetLogger().WithError(err).Fatal("Create Server Service (main) returned error") - case svr != nil && listener != nil: - service = NewServerService("main", svr, listener, paths, isTLS, ctx.GetLogger()) - default: - ctx.GetLogger().Fatal("Create Server Service (main) failed") - } - - return service -} - -func svcSvrMetricsFunc(ctx ServiceCtx) (service Service) { - switch svr, listener, paths, isTLS, err := server.CreateMetricsServer(ctx.GetConfiguration(), ctx.GetProviders()); { - case err != nil: - ctx.GetLogger().WithError(err).Fatal("Create Server Service (metrics) returned error") - case svr != nil && listener != nil: - service = NewServerService("metrics", svr, listener, paths, isTLS, ctx.GetLogger()) - default: - ctx.GetLogger().Debug("Create Server Service (metrics) skipped") - } - - return service -} - -func svcWatcherUsersFunc(ctx ServiceCtx) (service Service) { - var err error - - config := ctx.GetConfiguration() - - if config.AuthenticationBackend.File != nil && config.AuthenticationBackend.File.Watch { - provider := ctx.GetProviders().UserProvider.(*authentication.FileUserProvider) - - if service, err = NewFileWatcherService("users", config.AuthenticationBackend.File.Path, provider, ctx.GetLogger()); err != nil { - ctx.GetLogger().WithError(err).Fatal("Create Watcher Service (users) returned error") - } - } - - return service -} - -func svcSignalLogReOpenFunc(ctx ServiceCtx) (service Service) { - config := ctx.GetConfiguration() - - if config.Log.FilePath == "" { - return nil - } - - return NewSignalService("log-reload", logging.Reopen, ctx.GetLogger(), syscall.SIGHUP) -} - -func connectionType(isTLS bool) string { - if isTLS { - return "TLS" - } - - return "non-TLS" -} - -func servicesRun(ctx ServiceCtx) { - cctx, cancel := context.WithCancel(ctx) - - group, cctx := errgroup.WithContext(cctx) - - defer cancel() - - quit := make(chan os.Signal, 1) - - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - - defer signal.Stop(quit) - - var ( - services []Service - ) - - for _, serviceFunc := range []func(ctx ServiceCtx) Service{ - svcSvrMainFunc, svcSvrMetricsFunc, - svcWatcherUsersFunc, svcSignalLogReOpenFunc, - } { - if service := serviceFunc(ctx); service != nil { - service.Log().Trace("Service Loaded") - - services = append(services, service) - - group.Go(service.Run) - } - } - - ctx.GetLogger().Info("Startup complete") - - select { - case s := <-quit: - ctx.GetLogger().WithField("signal", s.String()).Debug("Shutdown initiated due to process signal") - case <-cctx.Done(): - ctx.GetLogger().Debug("Shutdown initiated due to context completion") - } - - cancel() - - ctx.GetLogger().Info("Shutdown initiated") - - wgShutdown := &sync.WaitGroup{} - - ctx.GetLogger().Tracef("Shutdown of %d services is required", len(services)) - - for _, service := range services { - wgShutdown.Add(1) - - go func(service Service) { - service.Log().Trace("Shutdown of service initiated") - - service.Shutdown() - - wgShutdown.Done() - - service.Log().Trace("Shutdown of service complete") - }(service) - } - - wgShutdown.Wait() - - var err error - - if err = ctx.GetProviders().UserProvider.Shutdown(); err != nil { - ctx.GetLogger().WithError(err).Error("Error occurred closing authentication connections") - } - - if err = ctx.GetProviders().StorageProvider.Close(); err != nil { - ctx.GetLogger().WithError(err).Error("Error occurred closing database connections") - } - - if err = group.Wait(); err != nil { - ctx.GetLogger().WithError(err).Error("Error occurred waiting for shutdown") - } - - ctx.GetLogger().Info("Shutdown complete") -} diff --git a/internal/commands/util.go b/internal/commands/util.go index 8686c2303..4950f38a8 100644 --- a/internal/commands/util.go +++ b/internal/commands/util.go @@ -21,19 +21,6 @@ import ( "github.com/authelia/authelia/v4/internal/utils" ) -func recoverErr(i any) error { - switch v := i.(type) { - case nil: - return nil - case string: - return fmt.Errorf("recovered panic: %s", v) - case error: - return fmt.Errorf("recovered panic: %w", v) - default: - return fmt.Errorf("recovered panic with unknown type: %v", v) - } -} - func flagsGetUserIdentifiersGenerateOptions(flags *pflag.FlagSet) (users, services, sectors []string, err error) { if users, err = flags.GetStringSlice(cmdFlagNameUsers); err != nil { return nil, nil, nil, err @@ -72,7 +59,7 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar } switch { - case useCharSet, !useCharSet && !useCharacters: + case useCharSet, !useCharacters: var c string if c, err = flags.GetString(flagNameCharSet); err != nil { @@ -107,7 +94,7 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar default: return "", fmt.Errorf("flag '--%s' with value '%s' is invalid, must be one of 'ascii', 'alphanumeric', 'alphabetic', 'numeric', 'numeric-hex', or 'rfc3986'", flagNameCharSet, c) } - case useCharacters: + default: if charset, err = flags.GetString(flagNameCharacters); err != nil { return "", err } diff --git a/internal/configuration/koanf_util.go b/internal/configuration/koanf_util.go index d0a0b1b0a..2b123d978 100644 --- a/internal/configuration/koanf_util.go +++ b/internal/configuration/koanf_util.go @@ -2,6 +2,7 @@ package configuration import ( "fmt" + "sort" "strings" "github.com/knadh/koanf/providers/confmap" @@ -35,6 +36,8 @@ func koanfGetKeys(ko *koanf.Koanf) (keys []string) { } } + sort.Strings(keys) + return keys } diff --git a/internal/configuration/test_resources/config_with_definitions.yml b/internal/configuration/test_resources/config_with_definitions.yml index 99449df2d..cb34f4741 100644 --- a/internal/configuration/test_resources/config_with_definitions.yml +++ b/internal/configuration/test_resources/config_with_definitions.yml @@ -1,6 +1,4 @@ --- -default_redirection_url: 'https://home.example.com:8080/' - server: address: 'tcp://127.0.0.1:9091' endpoints: @@ -161,7 +159,10 @@ session: name: 'authelia_session' expiration: '1h' # 1 hour inactivity: '5m' # 5 minutes - domain: 'example.com' + cookies: + - domain: 'example.com' + default_redirection_url: 'https://home.example.com:8080/' + authelia_url: 'https://auth.example.com' redis: host: '127.0.0.1' port: 6379 diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index 17ca8a81c..06587f05c 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -667,6 +667,11 @@ func (ctx *AutheliaCtx) GetConfiguration() (config schema.Configuration) { return ctx.Configuration } +// GetProviders returns the providers for this context. +func (ctx *AutheliaCtx) GetProviders() (providers Providers) { + return ctx.Providers +} + func (ctx *AutheliaCtx) GetWebAuthnProvider() (w *webauthn.WebAuthn, err error) { var ( origin *url.URL diff --git a/internal/middlewares/authelia_context_test.go b/internal/middlewares/authelia_context_test.go index 60c972f0c..54451890a 100644 --- a/internal/middlewares/authelia_context_test.go +++ b/internal/middlewares/authelia_context_test.go @@ -423,6 +423,14 @@ func TestShouldDetectNonXHR(t *testing.T) { assert.False(t, mock.Ctx.IsXHR()) } +func TestAutheliaCtxMisc(t *testing.T) { + ctx := middlewares.NewAutheliaCtx(&fasthttp.RequestCtx{}, schema.Configuration{}, middlewares.Providers{}) + + assert.NotNil(t, ctx.GetConfiguration()) + assert.NotNil(t, ctx.GetProviders()) + assert.NotNil(t, ctx.GetLogger()) +} + func TestShouldReturnCorrectSecondFactorMethods(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index 25bf67093..7fb3bb30b 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -95,6 +95,20 @@ const ( UserValueRouterKeyExtAuthzPath = "extauthz" ) +const ( + LogFieldProvider = "provider" + LogMessageStartupCheckError = "Error occurred running a startup check" + LogMessageStartupCheckPerforming = "Performing Startup Check" + LogMessageStartupCheckSuccess = "Startup Check Completed Successfully" + + ProviderNameNTP = "ntp" + ProviderNameStorage = "storage" + ProviderNameUser = "user" + ProviderNameNotification = "notification" + ProviderNameExpressions = "expressions" + ProviderNameWebAuthnMetaData = "webauthn-metadata" +) + var ( protoHTTPS = []byte(strProtoHTTPS) protoHTTP = []byte(strProtoHTTP) diff --git a/internal/middlewares/startup.go b/internal/middlewares/startup.go new file mode 100644 index 000000000..0664be342 --- /dev/null +++ b/internal/middlewares/startup.go @@ -0,0 +1,124 @@ +package middlewares + +import ( + "fmt" + "strings" + + "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/utils" +) + +func (p *Providers) StartupChecks(ctx Context, log bool) (err error) { + e := &ErrProviderStartupCheck{errors: map[string]error{}} + + var ( + disable bool + provider model.StartupCheck + ) + + provider, disable = ctx.GetProviders().StorageProvider, false + doStartupCheck(ctx, ProviderNameStorage, provider, disable, log, e.errors) + + provider, disable = ctx.GetProviders().UserProvider, false + doStartupCheck(ctx, ProviderNameUser, provider, disable, log, e.errors) + + provider, disable = ctx.GetProviders().Notifier, false + doStartupCheck(ctx, ProviderNameNotification, provider, disable, log, e.errors) + + provider, disable = ctx.GetProviders().NTP, ctx.GetConfiguration().NTP.DisableStartupCheck + doStartupCheck(ctx, ProviderNameNTP, provider, disable, log, e.errors) + + provider, disable = ctx.GetProviders().UserAttributeResolver, false + doStartupCheck(ctx, ProviderNameExpressions, provider, disable, log, e.errors) + + provider = ctx.GetProviders().MetaDataService + disable = !ctx.GetConfiguration().WebAuthn.Metadata.Enabled || ctx.GetProviders().MetaDataService == nil + doStartupCheck(ctx, ProviderNameWebAuthnMetaData, provider, disable, log, e.errors) + + var filters []string + + if ctx.GetConfiguration().NTP.DisableFailure { + filters = append(filters, ProviderNameNTP) + } + + return e.FilterError(filters...) +} + +func doStartupCheck(ctx Context, name string, provider model.StartupCheck, disabled, log bool, errors map[string]error) { + if log { + ctx.GetLogger().WithFields(map[string]any{LogFieldProvider: name}).Trace(LogMessageStartupCheckPerforming) + } + + if disabled { + if log { + ctx.GetLogger().Debugf("%s provider: startup check skipped as it is disabled", name) + } + + return + } + + if provider == nil { + errors[name] = fmt.Errorf("unrecognized provider or it is not configured properly") + + return + } + + var err error + + if err = provider.StartupCheck(); err != nil { + if log { + ctx.GetLogger().WithError(err).WithField(LogFieldProvider, name).Error(LogMessageStartupCheckError) + } + + errors[name] = err + + return + } + + if log { + ctx.GetLogger().WithFields(map[string]any{LogFieldProvider: name}).Trace("Startup Check Completed Successfully") + } +} + +type ErrProviderStartupCheck struct { + errors map[string]error +} + +func (e *ErrProviderStartupCheck) Error() string { + keys := make([]string, 0, len(e.errors)) + for k := range e.errors { + keys = append(keys, k) + } + + return fmt.Sprintf("errors occurred performing checks on the '%s' providers", strings.Join(keys, ", ")) +} + +func (e *ErrProviderStartupCheck) Failed() (failed []string) { + for key := range e.errors { + failed = append(failed, key) + } + + return failed +} + +func (e *ErrProviderStartupCheck) FilterError(providers ...string) error { + filtered := map[string]error{} + + for provider, err := range e.errors { + if utils.IsStringInSlice(provider, providers) { + continue + } + + filtered[provider] = err + } + + if len(filtered) == 0 { + return nil + } + + return &ErrProviderStartupCheck{errors: filtered} +} + +func (e *ErrProviderStartupCheck) ErrorMap() map[string]error { + return e.errors +} diff --git a/internal/middlewares/timing_attack_delay_test.go b/internal/middlewares/timing_attack_delay_test.go index 976be04a8..8548d0e27 100644 --- a/internal/middlewares/timing_attack_delay_test.go +++ b/internal/middlewares/timing_attack_delay_test.go @@ -47,7 +47,7 @@ func TestTimingAttackDelayCalculations(t *testing.T) { expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds()) ctx := &AutheliaCtx{ - Logger: logging.Logger().WithFields(logrus.Fields{}), + Logger: logrus.NewEntry(logging.Logger()), Providers: Providers{ Random: &random.Cryptographical{}, }, diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index c099c322e..dbace4014 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -1,6 +1,8 @@ package middlewares import ( + "context" + "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" @@ -54,6 +56,14 @@ type Providers struct { MetaDataService webauthn.MetaDataProvider } +type Context interface { + GetLogger() *logrus.Entry + GetProviders() Providers + GetConfiguration() *schema.Configuration + + context.Context +} + // RequestHandler represents an Authelia request handler. type RequestHandler = func(*AutheliaCtx) diff --git a/internal/middlewares/util.go b/internal/middlewares/util.go index 03b80f557..7a7eb39e4 100644 --- a/internal/middlewares/util.go +++ b/internal/middlewares/util.go @@ -1,7 +1,26 @@ package middlewares import ( + "crypto/x509" + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/authentication" + "github.com/authelia/authelia/v4/internal/authorization" + "github.com/authelia/authelia/v4/internal/clock" + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/expression" + "github.com/authelia/authelia/v4/internal/metrics" + "github.com/authelia/authelia/v4/internal/notification" + "github.com/authelia/authelia/v4/internal/ntp" + "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/random" + "github.com/authelia/authelia/v4/internal/regulation" + "github.com/authelia/authelia/v4/internal/session" + "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/templates" + "github.com/authelia/authelia/v4/internal/totp" + "github.com/authelia/authelia/v4/internal/webauthn" ) // SetContentTypeApplicationJSON sets the Content-Type header to `application/json; charset=utf-8`. @@ -13,3 +32,48 @@ func SetContentTypeApplicationJSON(ctx *fasthttp.RequestCtx) { func SetContentTypeTextPlain(ctx *fasthttp.RequestCtx) { ctx.SetContentTypeBytes(contentTypeTextPlain) } + +// NewProviders provisions all providers based on the configuration provided. +func NewProviders(config *schema.Configuration, caCertPool *x509.CertPool) (providers Providers, warns, errs []error) { + providers.Random = &random.Cryptographical{} + providers.StorageProvider = storage.NewProvider(config, caCertPool) + providers.Authorizer = authorization.NewAuthorizer(config) + providers.NTP = ntp.NewProvider(&config.NTP) + providers.PasswordPolicy = NewPasswordPolicyProvider(config.PasswordPolicy) + providers.Regulator = regulation.NewRegulator(config.Regulation, providers.StorageProvider, clock.New()) + providers.SessionProvider = session.NewProvider(config.Session, caCertPool) + providers.TOTP = totp.NewTimeBasedProvider(config.TOTP) + providers.UserAttributeResolver = expression.NewUserAttributes(config) + + var err error + + switch { + case config.AuthenticationBackend.File != nil: + providers.UserProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File) + case config.AuthenticationBackend.LDAP != nil: + providers.UserProvider = authentication.NewLDAPUserProvider(config.AuthenticationBackend, caCertPool) + } + + if providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: config.Notifier.TemplatePath}); err != nil { + errs = append(errs, err) + } + + if providers.MetaDataService, err = webauthn.NewMetaDataProvider(config, providers.StorageProvider); err != nil { + errs = append(errs, err) + } + + switch { + case config.Notifier.SMTP != nil: + providers.Notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, caCertPool) + case config.Notifier.FileSystem != nil: + providers.Notifier = notification.NewFileNotifier(*config.Notifier.FileSystem) + } + + providers.OpenIDConnect = oidc.NewOpenIDConnectProvider(config, providers.StorageProvider, providers.Templates) + + if config.Telemetry.Metrics.Enabled { + providers.Metrics = metrics.NewPrometheus() + } + + return providers, warns, errs +} diff --git a/internal/mocks/user_provider.go b/internal/mocks/user_provider.go index 204940221..46f17f3ca 100644 --- a/internal/mocks/user_provider.go +++ b/internal/mocks/user_provider.go @@ -69,6 +69,20 @@ func (mr *MockUserProviderMockRecorder) CheckUserPassword(username, password any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckUserPassword", reflect.TypeOf((*MockUserProvider)(nil).CheckUserPassword), username, password) } +// Close mocks base method. +func (m *MockUserProvider) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockUserProviderMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockUserProvider)(nil).Close)) +} + // GetDetails mocks base method. func (m *MockUserProvider) GetDetails(username string) (*authentication.UserDetails, error) { m.ctrl.T.Helper() @@ -99,20 +113,6 @@ func (mr *MockUserProviderMockRecorder) GetDetailsExtended(username any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDetailsExtended", reflect.TypeOf((*MockUserProvider)(nil).GetDetailsExtended), username) } -// Shutdown mocks base method. -func (m *MockUserProvider) Shutdown() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Shutdown") - ret0, _ := ret[0].(error) - return ret0 -} - -// Shutdown indicates an expected call of Shutdown. -func (mr *MockUserProviderMockRecorder) Shutdown() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockUserProvider)(nil).Shutdown)) -} - // StartupCheck mocks base method. func (m *MockUserProvider) StartupCheck() error { m.ctrl.T.Helper() diff --git a/internal/server/handlers.go b/internal/server/handlers.go index d633f503e..9cff4755c 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -112,10 +112,10 @@ func handleMethodNotAllowed(ctx *fasthttp.RequestCtx) { ctx.SetBodyString(fmt.Sprintf("%d %s", fasthttp.StatusMethodNotAllowed, fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed))) } +type RegisterRoutesBridgedFunc = func(r *router.Router, config *schema.Configuration, providers middlewares.Providers, bridge middlewares.Bridge) + //nolint:gocyclo func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { - log := logging.Logger() - optsTemplatedFile := NewTemplatedFileOptions(config) serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile) @@ -215,32 +215,14 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers) switch name { case "legacy": - log. - WithField("path_prefix", pathAuthzLegacy). - WithField("implementation", endpoint.Implementation). - WithField("methods", "*"). - Trace("Registering Authz Endpoint") - r.ANY(pathAuthzLegacy, handler) r.ANY(path.Join(pathAuthzLegacy, pathParamAuthzEnvoy), handler) default: switch endpoint.Implementation { case handlers.AuthzImplLegacy.String(), handlers.AuthzImplExtAuthz.String(): - log. - WithField("path_prefix", uri). - WithField("implementation", endpoint.Implementation). - WithField("methods", "*"). - Trace("Registering Authz Endpoint") - r.ANY(uri, handler) r.ANY(path.Join(uri, pathParamAuthzEnvoy), handler) default: - log. - WithField("path", uri). - WithField("implementation", endpoint.Implementation). - WithField("methods", []string{fasthttp.MethodGet, fasthttp.MethodHead}). - Trace("Registering Authz Endpoint") - r.GET(uri, handler) r.HEAD(uri, handler) } @@ -367,127 +349,141 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers) } if providers.OpenIDConnect != nil { - bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares( - middlewares.SecurityHeadersBase, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore, - ).Build() + RegisterOpenIDConnectRoutes(r, config, providers) + } - r.GET("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentGET)) - r.POST("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentPOST)) + r.RedirectFixedPath = false + r.HandleMethodNotAllowed = true + r.MethodNotAllowed = handleMethodNotAllowed + r.NotFound = handleNotFound(bridge(serveIndexHandler)) - allowedOrigins := utils.StringSliceFromURLs(config.IdentityProviders.OIDC.CORS.AllowedOrigins) + handler := middlewares.LogRequest(r.Handler) + if config.Server.Address.RouterPath() != "/" { + handler = middlewares.StripPath(config.Server.Address.RouterPath())(handler) + } - r.OPTIONS(oidc.EndpointPathWellKnownOpenIDConfiguration, policyCORSPublicGET.HandleOPTIONS) - r.GET(oidc.EndpointPathWellKnownOpenIDConfiguration, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "openid_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OpenIDConnectConfigurationWellKnownGET)))) + handler = middlewares.MultiWrap(handler, middlewares.RecoverPanic, middlewares.NewMetricsRequest(providers.Metrics)) - r.OPTIONS(oidc.EndpointPathWellKnownOAuthAuthorizationServer, policyCORSPublicGET.HandleOPTIONS) - r.GET(oidc.EndpointPathWellKnownOAuthAuthorizationServer, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "oauth_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OAuthAuthorizationServerWellKnownGET)))) + return handler +} - r.OPTIONS(oidc.EndpointPathJWKs, policyCORSPublicGET.HandleOPTIONS) - r.GET(oidc.EndpointPathJWKs, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(middlewareAPI(handlers.JSONWebKeySetGET)))) +// RegisterOpenIDConnectRoutes handles registration of OpenID Connect 1.0 routes. +func RegisterOpenIDConnectRoutes(r *router.Router, config *schema.Configuration, providers middlewares.Providers) { + middlewareAPI := middlewares.NewBridgeBuilder(*config, providers). + WithPreMiddlewares(middlewares.SecurityHeadersBase, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). + Build() - // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. - r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS) - r.GET("/api/oidc/jwks", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.JSONWebKeySetGET)))) + policyCORSPublicGET := middlewares.NewCORSPolicyBuilder(). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet). + WithAllowedOrigins("*"). + Build() - policyCORSAuthorization := middlewares.NewCORSPolicyBuilder(). - WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost). - WithAllowedOrigins(allowedOrigins...). - WithEnabled(utils.IsStringInSlice(oidc.EndpointAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)). - Build() + bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares( + middlewares.SecurityHeadersBase, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore, + ).Build() - authorization := middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointAuthorization), policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization)))) + r.GET("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentGET)) + r.POST("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentPOST)) - r.OPTIONS(oidc.EndpointPathAuthorization, policyCORSAuthorization.HandleOnlyOPTIONS) - r.GET(oidc.EndpointPathAuthorization, authorization) - r.POST(oidc.EndpointPathAuthorization, authorization) + allowedOrigins := utils.StringSliceFromURLs(config.IdentityProviders.OIDC.CORS.AllowedOrigins) - // TODO (james-d-elliott): Remove in GA. This is a legacy endpoint. - r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS) - r.GET("/api/oidc/authorize", authorization) - r.POST("/api/oidc/authorize", authorization) + r.OPTIONS(oidc.EndpointPathWellKnownOpenIDConfiguration, policyCORSPublicGET.HandleOPTIONS) + r.GET(oidc.EndpointPathWellKnownOpenIDConfiguration, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "openid_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OpenIDConnectConfigurationWellKnownGET)))) - policyCORSDeviceAuthorization := middlewares.NewCORSPolicyBuilder(). - WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). - WithAllowedOrigins(allowedOrigins...). - WithEnabled(utils.IsStringInSlice(oidc.EndpointDeviceAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)). - Build() + r.OPTIONS(oidc.EndpointPathWellKnownOAuthAuthorizationServer, policyCORSPublicGET.HandleOPTIONS) + r.GET(oidc.EndpointPathWellKnownOAuthAuthorizationServer, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "oauth_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OAuthAuthorizationServerWellKnownGET)))) - r.OPTIONS(oidc.EndpointPathDeviceAuthorization, policyCORSDeviceAuthorization.HandleOnlyOPTIONS) - r.POST(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), policyCORSDeviceAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPOST))))) - r.PUT(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPUT)))) + r.OPTIONS(oidc.EndpointPathJWKs, policyCORSPublicGET.HandleOPTIONS) + r.GET(oidc.EndpointPathJWKs, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(middlewareAPI(handlers.JSONWebKeySetGET)))) - policyCORSPAR := middlewares.NewCORSPolicyBuilder(). - WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). - WithAllowedOrigins(allowedOrigins...). - WithEnabled(utils.IsStringInSliceFold(oidc.EndpointPushedAuthorizationRequest, config.IdentityProviders.OIDC.CORS.Endpoints)). - Build() + // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. + r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS) + r.GET("/api/oidc/jwks", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.JSONWebKeySetGET)))) - r.OPTIONS(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.HandleOnlyOPTIONS) - r.POST(oidc.EndpointPathPushedAuthorizationRequest, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointPushedAuthorizationRequest), policyCORSPAR.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectPushedAuthorizationRequest))))) + policyCORSAuthorization := middlewares.NewCORSPolicyBuilder(). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.EndpointAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)). + Build() - policyCORSToken := middlewares.NewCORSPolicyBuilder(). - WithAllowCredentials(true). - WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). - WithAllowedOrigins(allowedOrigins...). - WithEnabled(utils.IsStringInSlice(oidc.EndpointToken, config.IdentityProviders.OIDC.CORS.Endpoints)). - Build() + authorization := middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointAuthorization), policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization)))) - r.OPTIONS(oidc.EndpointPathToken, policyCORSToken.HandleOPTIONS) - r.POST(oidc.EndpointPathToken, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointToken), policyCORSToken.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST))))) + r.OPTIONS(oidc.EndpointPathAuthorization, policyCORSAuthorization.HandleOnlyOPTIONS) + r.GET(oidc.EndpointPathAuthorization, authorization) + r.POST(oidc.EndpointPathAuthorization, authorization) - policyCORSUserinfo := middlewares.NewCORSPolicyBuilder(). - WithAllowCredentials(true). - WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost). - WithAllowedOrigins(allowedOrigins...). - WithEnabled(utils.IsStringInSlice(oidc.EndpointUserinfo, config.IdentityProviders.OIDC.CORS.Endpoints)). - Build() + // TODO (james-d-elliott): Remove in GA. This is a legacy endpoint. + r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS) + r.GET("/api/oidc/authorize", authorization) + r.POST("/api/oidc/authorize", authorization) - r.OPTIONS(oidc.EndpointPathUserinfo, policyCORSUserinfo.HandleOPTIONS) - r.GET(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))) - r.POST(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))) + policyCORSDeviceAuthorization := middlewares.NewCORSPolicyBuilder(). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.EndpointDeviceAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)). + Build() - policyCORSIntrospection := middlewares.NewCORSPolicyBuilder(). - WithAllowCredentials(true). - WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). - WithAllowedOrigins(allowedOrigins...). - WithEnabled(utils.IsStringInSlice(oidc.EndpointIntrospection, config.IdentityProviders.OIDC.CORS.Endpoints)). - Build() + r.OPTIONS(oidc.EndpointPathDeviceAuthorization, policyCORSDeviceAuthorization.HandleOnlyOPTIONS) + r.POST(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), policyCORSDeviceAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPOST))))) + r.PUT(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPUT)))) - r.OPTIONS(oidc.EndpointPathIntrospection, policyCORSIntrospection.HandleOPTIONS) - r.POST(oidc.EndpointPathIntrospection, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))) + policyCORSPAR := middlewares.NewCORSPolicyBuilder(). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSliceFold(oidc.EndpointPushedAuthorizationRequest, config.IdentityProviders.OIDC.CORS.Endpoints)). + Build() - // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. - r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS) - r.POST("/api/oidc/introspect", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))) + r.OPTIONS(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.HandleOnlyOPTIONS) + r.POST(oidc.EndpointPathPushedAuthorizationRequest, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointPushedAuthorizationRequest), policyCORSPAR.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectPushedAuthorizationRequest))))) - policyCORSRevocation := middlewares.NewCORSPolicyBuilder(). - WithAllowCredentials(true). - WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). - WithAllowedOrigins(allowedOrigins...). - WithEnabled(utils.IsStringInSlice(oidc.EndpointRevocation, config.IdentityProviders.OIDC.CORS.Endpoints)). - Build() + policyCORSToken := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.EndpointToken, config.IdentityProviders.OIDC.CORS.Endpoints)). + Build() - r.OPTIONS(oidc.EndpointPathRevocation, policyCORSRevocation.HandleOPTIONS) - r.POST(oidc.EndpointPathRevocation, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))) + r.OPTIONS(oidc.EndpointPathToken, policyCORSToken.HandleOPTIONS) + r.POST(oidc.EndpointPathToken, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointToken), policyCORSToken.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST))))) - // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. - r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS) - r.POST("/api/oidc/revoke", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))) - } + policyCORSUserinfo := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.EndpointUserinfo, config.IdentityProviders.OIDC.CORS.Endpoints)). + Build() - r.RedirectFixedPath = false - r.HandleMethodNotAllowed = true - r.MethodNotAllowed = handleMethodNotAllowed - r.NotFound = handleNotFound(bridge(serveIndexHandler)) + r.OPTIONS(oidc.EndpointPathUserinfo, policyCORSUserinfo.HandleOPTIONS) + r.GET(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))) + r.POST(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))) - handler := middlewares.LogRequest(r.Handler) - if config.Server.Address.RouterPath() != "/" { - handler = middlewares.StripPath(config.Server.Address.RouterPath())(handler) - } + policyCORSIntrospection := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.EndpointIntrospection, config.IdentityProviders.OIDC.CORS.Endpoints)). + Build() - handler = middlewares.MultiWrap(handler, middlewares.RecoverPanic, middlewares.NewMetricsRequest(providers.Metrics)) + r.OPTIONS(oidc.EndpointPathIntrospection, policyCORSIntrospection.HandleOPTIONS) + r.POST(oidc.EndpointPathIntrospection, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))) - return handler + // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. + r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS) + r.POST("/api/oidc/introspect", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))) + + policyCORSRevocation := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.EndpointRevocation, config.IdentityProviders.OIDC.CORS.Endpoints)). + Build() + + r.OPTIONS(oidc.EndpointPathRevocation, policyCORSRevocation.HandleOPTIONS) + r.POST(oidc.EndpointPathRevocation, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))) + + // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. + r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS) + r.POST("/api/oidc/revoke", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))) } func handleMetrics(path string) fasthttp.RequestHandler { diff --git a/internal/server/server.go b/internal/server/server.go index 564460155..c9836f863 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,10 +15,10 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" ) -// CreateDefaultServer Create Authelia's internal web server with the given configuration and providers. -func CreateDefaultServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) { +// New Create Authelia's internal web server with the given configuration and providers. +func New(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) { if err = providers.Templates.LoadTemplatedAssets(assets); err != nil { - return nil, nil, nil, false, fmt.Errorf("failed to load templated assets: %w", err) + return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server: error occurred loading templated assets: %w", err) } server = &fasthttp.Server{ @@ -38,14 +38,14 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro ) if listener, err = config.Server.Address.Listener(); err != nil { - return nil, nil, nil, false, fmt.Errorf("error occurred while attempting to initialize main server listener for address '%s': %w", config.Server.Address.String(), err) + return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server listener for address '%s': %w", config.Server.Address.String(), err) } if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" { isTLS, connectionScheme = true, schemeHTTPS if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil { - return nil, nil, nil, false, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err) + return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server tls parameters: failed to load certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err) } if len(config.Server.TLS.ClientCertificates) > 0 { @@ -55,7 +55,7 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro for _, path := range config.Server.TLS.ClientCertificates { if cert, err = os.ReadFile(path); err != nil { - return nil, nil, nil, false, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err) + return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server tls parameters: failed to load client certificate '%s': %w", path, err) } caCertPool.AppendCertsFromPEM(cert) @@ -72,7 +72,7 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Address.Hostname(), config.Server.Address.RouterPath(), config.Server.Address.Port()); err != nil { - return nil, nil, nil, false, fmt.Errorf("unable to configure healthcheck: %w", err) + return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server healthcheck metadata: %w", err) } paths = []string{"/"} @@ -87,8 +87,8 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro return server, listener, paths, isTLS, nil } -// CreateMetricsServer creates a metrics server. -func CreateMetricsServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) { +// NewMetrics creates a metrics server. +func NewMetrics(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) { if providers.Metrics == nil { return } @@ -106,7 +106,7 @@ func CreateMetricsServer(config *schema.Configuration, providers middlewares.Pro } if listener, err = config.Telemetry.Metrics.Address.Listener(); err != nil { - return nil, nil, nil, false, fmt.Errorf("error occurred while attempting to initialize metrics telemetry server listener for address '%s': %w", config.Telemetry.Metrics.Address.String(), err) + return nil, nil, nil, false, fmt.Errorf("error occurred initializing metrics telemetry server listener for address '%s': %w", config.Telemetry.Metrics.Address.String(), err) } return server, listener, []string{config.Telemetry.Metrics.Address.RouterPath()}, false, nil diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 356285fc6..914906238 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -148,7 +148,7 @@ func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLS return nil, err } - s, listener, _, _, err := CreateDefaultServer(&configuration, providers) + s, listener, _, _, err := New(&configuration, providers) if err != nil { return nil, err diff --git a/internal/service/const.go b/internal/service/const.go new file mode 100644 index 000000000..ef60ade9b --- /dev/null +++ b/internal/service/const.go @@ -0,0 +1,15 @@ +package service + +const ( + fmtLogServerListening = "Listening for %s connections on '%s' path '%s'" +) + +const ( + logFieldService = "service" + logFieldFile = "file" + logFieldOP = "op" + + serviceTypeServer = "server" + serviceTypeWatcher = "watcher" + serviceTypeSignal = "signal" +) diff --git a/internal/service/file_watcher.go b/internal/service/file_watcher.go new file mode 100644 index 000000000..55202b806 --- /dev/null +++ b/internal/service/file_watcher.go @@ -0,0 +1,174 @@ +package service + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" + + "github.com/authelia/authelia/v4/internal/authentication" +) + +func ProvisionUsersFileWatcher(ctx Context) (service Provider, err error) { + config := ctx.GetConfiguration() + providers := ctx.GetProviders() + + if config.AuthenticationBackend.File != nil && config.AuthenticationBackend.File.Watch { + provider, ok := providers.UserProvider.(*authentication.FileUserProvider) + + if !ok { + return nil, errors.New("error occurred asserting user provider") + } + + if service, err = NewFileWatcher("users", config.AuthenticationBackend.File.Path, provider, ctx.GetLogger()); err != nil { + return nil, err + } + } + + return service, nil +} + +// NewFileWatcher creates a new FileWatcher with the appropriate logger etc. +func NewFileWatcher(name, path string, reload ReloadableProvider, log *logrus.Entry) (service *FileWatcher, err error) { + if path == "" { + return nil, fmt.Errorf("error initializing file watcher: path must be specified") + } + + if path, err = filepath.Abs(path); err != nil { + return nil, fmt.Errorf("error initializing file watcher: could not determine the absolute path of file '%s': %w", path, err) + } + + var info os.FileInfo + + if info, err = os.Stat(path); err != nil { + switch { + case os.IsNotExist(err): + return nil, fmt.Errorf("error initializing file watcher: error stating file '%s': file does not exist", path) + case os.IsPermission(err): + return nil, fmt.Errorf("error initializing file watcher: error stating file '%s': permission denied trying to read the file", path) + default: + return nil, fmt.Errorf("error initializing file watcher: error stating file '%s': %w", path, err) + } + } + + var watcher *fsnotify.Watcher + + if watcher, err = fsnotify.NewWatcher(); err != nil { + return nil, err + } + + entry := log.WithFields(map[string]any{logFieldService: serviceTypeWatcher, serviceTypeWatcher: name}) + + if info.IsDir() { + service = &FileWatcher{ + name: name, + watcher: watcher, + reload: reload, + log: entry, + directory: filepath.Clean(path), + } + } else { + service = &FileWatcher{ + name: name, + watcher: watcher, + reload: reload, + log: entry, + directory: filepath.Dir(path), + file: filepath.Base(path), + } + } + + if err = service.watcher.Add(service.directory); err != nil { + return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err) + } + + return service, nil +} + +// FileWatcher is a Provider that watches files for changes. +type FileWatcher struct { + name string + + watcher *fsnotify.Watcher + reload ReloadableProvider + + log *logrus.Entry + file string + directory string +} + +// ServiceType returns the service type for this service, which is always 'watcher'. +func (service *FileWatcher) ServiceType() string { + return serviceTypeWatcher +} + +// ServiceName returns the individual name for this service. +func (service *FileWatcher) ServiceName() string { + return service.name +} + +// Run the FileWatcher. +func (service *FileWatcher) Run() (err error) { + defer func() { + if r := recover(); r != nil { + service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)") + } + }() + + service.log.WithField(logFieldFile, filepath.Join(service.directory, service.file)).Info("Watching file for changes") + + for { + select { + case event, ok := <-service.watcher.Events: + if !ok { + return nil + } + + log := service.log.WithFields(map[string]any{logFieldFile: event.Name, logFieldOP: event.Op}) + + if service.file != "" && service.file != filepath.Base(event.Name) { + log.Trace("File modification detected to irrelevant file") + break + } + + switch { + case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create: + log.Debug("File modification was detected") + + var reloaded bool + + switch reloaded, err = service.reload.Reload(); { + case err != nil: + log.WithError(err).Error("Error occurred during reload") + case reloaded: + log.Info("Reloaded successfully") + default: + log.Debug("Reload was triggered but it was skipped") + } + case event.Op&fsnotify.Remove == fsnotify.Remove: + log.Debug("File remove was detected") + } + case err, ok := <-service.watcher.Errors: + if !ok { + return nil + } + + service.log.WithError(err).Error("Error while watching file for changes") + } + } +} + +// Shutdown the FileWatcher. +func (service *FileWatcher) Shutdown() { + if err := service.watcher.Close(); err != nil { + service.log.WithError(err).Error("Error occurred during shutdown") + } +} + +// Log returns the *logrus.Entry of the FileWatcher. +func (service *FileWatcher) Log() *logrus.Entry { + return service.log +} diff --git a/internal/service/file_watcher_test.go b/internal/service/file_watcher_test.go new file mode 100644 index 000000000..f4dce095f --- /dev/null +++ b/internal/service/file_watcher_test.go @@ -0,0 +1,211 @@ +package service + +import ( + "context" + "os" + "path/filepath" + "regexp" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authelia/authelia/v4/internal/authentication" + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/logging" + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/templates" +) + +func TestProvisionUsersFileWatcher(t *testing.T) { + dir := t.TempDir() + + f, err := os.Create(filepath.Join(dir, "users.yml")) + require.NoError(t, err) + require.NoError(t, f.Close()) + + tx, err := templates.New(templates.Config{}) + require.NoError(t, err) + + address, err := schema.NewAddress("tcp://:9091") + require.NoError(t, err) + + config := &schema.Configuration{ + Server: schema.Server{ + Address: &schema.AddressTCP{Address: *address}, + }, + } + + provision := ProvisionUsersFileWatcher + + ctx := &testCtx{ + Context: context.Background(), + Configuration: config, + Providers: middlewares.Providers{ + Templates: tx, + }, + Logger: logrus.NewEntry(logging.Logger()), + } + + watcher, err := provision(ctx) + assert.NoError(t, err) + assert.Nil(t, watcher) + + watcher, err = provision(ctx) + assert.NoError(t, err) + assert.Nil(t, watcher) + + config.AuthenticationBackend.File = &schema.AuthenticationBackendFile{ + Path: filepath.Join(dir, "users.yml"), + Watch: true, + } + + watcher, err = provision(ctx) + assert.EqualError(t, err, "error occurred asserting user provider") + assert.Nil(t, watcher) + + ctx.Providers.UserProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File) + + config.AuthenticationBackend.File = &schema.AuthenticationBackendFile{ + Watch: true, + } + + watcher, err = provision(ctx) + assert.EqualError(t, err, "error initializing file watcher: path must be specified") + assert.Nil(t, watcher) + + config.AuthenticationBackend.File = &schema.AuthenticationBackendFile{ + Path: filepath.Join(dir, "users.yml"), + Watch: true, + } + + watcher, err = provision(ctx) + assert.NoError(t, err) + assert.NotNil(t, watcher) + assert.NotNil(t, watcher.Log()) + assert.Equal(t, "users", watcher.ServiceName()) + assert.Equal(t, "watcher", watcher.ServiceType()) + + watcher.Shutdown() +} + +func TestNewFileWatcher(t *testing.T) { + dir := t.TempDir() + + reloader := &testReloader{reload: true} + + f, err := os.Create(filepath.Join(dir, "test.log")) + require.NoError(t, err) + + service, err := NewFileWatcher("example", filepath.Join(dir, "test.log"), reloader, logrus.NewEntry(logging.Logger())) + + assert.NoError(t, err) + + go func() { + require.NoError(t, service.Run()) + }() + + // Give the service a moment to start. + time.Sleep(100 * time.Millisecond) + + _, err = f.Write([]byte("test")) + require.NoError(t, err) + + require.NoError(t, f.Close()) + + time.Sleep(time.Second) + + assert.Equal(t, 1, reloader.count) + + assert.NoError(t, os.WriteFile(filepath.Join(dir, "test2.log"), []byte("test"), 0600)) + + assert.Equal(t, 1, reloader.count) + + assert.NoError(t, os.Remove(filepath.Join(dir, "test2.log"))) + + assert.Equal(t, 1, reloader.count) + + service.Shutdown() +} + +func TestNewFileWatcherDirectory(t *testing.T) { + dir := t.TempDir() + + reloader := &testReloader{reload: true} + + service, err := NewFileWatcher("example", dir, reloader, logrus.NewEntry(logging.Logger())) + + assert.NoError(t, err) + + go func() { + require.NoError(t, service.Run()) + }() + + // Give the service a moment to start. + time.Sleep(100 * time.Millisecond) + + f, err := os.Create(filepath.Join(dir, "test.log")) + require.NoError(t, err) + + _, err = f.Write([]byte("test")) + require.NoError(t, err) + + require.NoError(t, f.Close()) + + time.Sleep(time.Second) + + assert.Equal(t, 2, reloader.count) + + service.Shutdown() +} + +func TestNewFileWatcherBadPath(t *testing.T) { + dir := t.TempDir() + + reloader := &testReloader{reload: true} + + service, err := NewFileWatcher("example", filepath.Join(dir, "test.log"), reloader, logrus.NewEntry(logging.Logger())) + + require.Error(t, err) + assert.Regexp(t, regexp.MustCompile(`^error initializing file watcher: error stating file '/tmp/[^/]+/\d+/test.log': file does not exist$`), err.Error()) + + assert.Nil(t, service) +} + +func TestNewFileWatcherBadPermission(t *testing.T) { + dir := t.TempDir() + + reloader := &testReloader{reload: true} + + require.NoError(t, os.Mkdir(filepath.Join(dir, "tmp"), 0700)) + + f, err := os.Create(filepath.Join(dir, "tmp", "test.log")) + + require.NoError(t, err) + require.NoError(t, f.Close()) + + require.NoError(t, os.Chmod(filepath.Join(dir, "tmp"), 0o000)) + + service, err := NewFileWatcher("example", filepath.Join(dir, "tmp", "test.log"), reloader, logrus.NewEntry(logging.Logger())) + + require.Error(t, err) + assert.Regexp(t, regexp.MustCompile(`^error initializing file watcher: error stating file '/tmp/[^/]+/\d+/tmp/test.log': permission denied trying to read the file$`), err.Error()) + + require.NoError(t, os.Chmod(filepath.Join(dir, "tmp"), 0o700)) + + assert.Nil(t, service) +} + +type testReloader struct { + count int + reload bool + err error +} + +func (r *testReloader) Reload() (bool, error) { + r.count++ + + return r.reload, r.err +} diff --git a/internal/service/provider.go b/internal/service/provider.go new file mode 100644 index 000000000..182ddba67 --- /dev/null +++ b/internal/service/provider.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + + "github.com/sirupsen/logrus" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/middlewares" +) + +// Provider represents the required methods to support handling a service. +type Provider interface { + // ServiceType returns the type name for the Provider. + ServiceType() string + + // ServiceName returns the individual name for the Provider. + ServiceName() string + + // Run performs the running operations for the Provider. + Run() (err error) + + // Shutdown perform the shutdown cleanup and termination operations for the Provider. + Shutdown() + + // Log returns the logger configured for the service. + Log() *logrus.Entry +} + +// ReloadableProvider represents the required methods to support reloading a provider. +type ReloadableProvider interface { + Reload() (reloaded bool, err error) +} + +type Provisioner func(ctx Context) (provider Provider, err error) + +func GetProvisioners() []Provisioner { + return []Provisioner{ + ProvisionServer, + ProvisionServerMetrics, + ProvisionUsersFileWatcher, + ProvisionLoggingSignal, + } +} + +type Context interface { + GetLogger() *logrus.Entry + GetProviders() middlewares.Providers + GetConfiguration() *schema.Configuration + + context.Context +} diff --git a/internal/service/provider_test.go b/internal/service/provider_test.go new file mode 100644 index 000000000..58a653557 --- /dev/null +++ b/internal/service/provider_test.go @@ -0,0 +1,13 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetProvisioners(t *testing.T) { + provisioners := GetProvisioners() + + assert.Len(t, provisioners, 4) +} diff --git a/internal/service/server.go b/internal/service/server.go new file mode 100644 index 000000000..15921af01 --- /dev/null +++ b/internal/service/server.go @@ -0,0 +1,120 @@ +package service + +import ( + "context" + "net" + "strings" + "time" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/server" +) + +func ProvisionServer(ctx Context) (service Provider, err error) { + var ( + s *fasthttp.Server + listener net.Listener + paths []string + isTLS bool + ) + + switch s, listener, paths, isTLS, err = server.New(ctx.GetConfiguration(), ctx.GetProviders()); { + case err != nil: + return nil, err + case s != nil && listener != nil: + service = NewBaseServer("main", s, listener, paths, isTLS, ctx.GetLogger()) + default: + return nil, nil + } + + return service, nil +} + +func ProvisionServerMetrics(ctx Context) (service Provider, err error) { + var ( + s *fasthttp.Server + listener net.Listener + paths []string + isTLS bool + ) + + switch s, listener, paths, isTLS, err = server.NewMetrics(ctx.GetConfiguration(), ctx.GetProviders()); { + case err != nil: + return nil, err + case s != nil && listener != nil: + service = NewBaseServer("metrics", s, listener, paths, isTLS, ctx.GetLogger()) + default: + return nil, nil + } + + return service, nil +} + +// NewBaseServer creates a new Server with the appropriate logger etc. +func NewBaseServer(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Entry) (service *Server) { + return &Server{ + name: name, + server: server, + listener: listener, + paths: paths, + isTLS: isTLS, + log: log.WithFields(map[string]any{logFieldService: serviceTypeServer, serviceTypeServer: name}), + } +} + +// Server is a Provider which runs a web server. +type Server struct { + name string + server *fasthttp.Server + paths []string + isTLS bool + listener net.Listener + log *logrus.Entry +} + +// ServiceType returns the service type for this service, which is always 'server'. +func (service *Server) ServiceType() string { + return serviceTypeServer +} + +// ServiceName returns the individual name for this service. +func (service *Server) ServiceName() string { + return service.name +} + +// Run the Server. +func (service *Server) Run() (err error) { + defer func() { + if r := recover(); r != nil { + service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)") + } + }() + + service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '")) + + if err = service.server.Serve(service.listener); err != nil { + service.log.WithError(err).Error("Error returned attempting to serve requests") + + return err + } + + return nil +} + +// Shutdown the Server. +func (service *Server) Shutdown() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + + defer cancel() + + if err := service.server.ShutdownWithContext(ctx); err != nil { + service.log.WithError(err).Error("Error occurred during shutdown") + } +} + +// Log returns the *logrus.Entry of the Server. +func (service *Server) Log() *logrus.Entry { + return service.log +} diff --git a/internal/service/sever_test.go b/internal/service/sever_test.go new file mode 100644 index 000000000..6effa88e9 --- /dev/null +++ b/internal/service/sever_test.go @@ -0,0 +1,121 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/logging" + "github.com/authelia/authelia/v4/internal/metrics" + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/templates" +) + +func TestNewMainServer(t *testing.T) { + tx, err := templates.New(templates.Config{}) + require.NoError(t, err) + + address, err := schema.NewAddress("tcp://:9091") + require.NoError(t, err) + + config := &schema.Configuration{ + Server: schema.Server{ + Address: &schema.AddressTCP{Address: *address}, + }, + } + + ctx := &testCtx{ + Context: context.Background(), + Configuration: config, + Providers: middlewares.Providers{ + Templates: tx, + }, + Logger: logrus.NewEntry(logging.Logger()), + } + + server, err := ProvisionServer(ctx) + assert.NoError(t, err) + assert.NotNil(t, server) + + go func() { + require.NoError(t, server.Run()) + }() + + // Give the service a moment to start. + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, "main", server.ServiceName()) + assert.Equal(t, "server", server.ServiceType()) + assert.NotNil(t, server.Log()) + + server.Shutdown() +} + +func TestNewMetricsServer(t *testing.T) { + tx, err := templates.New(templates.Config{}) + require.NoError(t, err) + + address, err := schema.NewAddress("tcp://:9891/metrics") + require.NoError(t, err) + + config := &schema.Configuration{ + Telemetry: schema.Telemetry{ + Metrics: schema.TelemetryMetrics{ + Enabled: true, + Address: &schema.AddressTCP{Address: *address}, + }, + }, + } + + ctx := &testCtx{ + Context: context.Background(), + Configuration: config, + Providers: middlewares.Providers{ + Templates: tx, + Metrics: metrics.NewPrometheus(), + }, + Logger: logrus.NewEntry(logging.Logger()), + } + + server, err := ProvisionServerMetrics(ctx) + assert.NoError(t, err) + assert.NotNil(t, server) + + go func() { + require.NoError(t, server.Run()) + }() + + // Give the service a moment to start. + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, "metrics", server.ServiceName()) + assert.Equal(t, "server", server.ServiceType()) + assert.NotNil(t, server.Log()) + + server.Shutdown() +} + +type testCtx struct { + Configuration *schema.Configuration + Providers middlewares.Providers + Logger *logrus.Entry + + context.Context +} + +func (c *testCtx) GetConfiguration() *schema.Configuration { + return c.Configuration +} + +func (c *testCtx) GetProviders() middlewares.Providers { + return c.Providers +} + +func (c *testCtx) GetLogger() *logrus.Entry { + return c.Logger +} diff --git a/internal/service/signal.go b/internal/service/signal.go new file mode 100644 index 000000000..b6445a792 --- /dev/null +++ b/internal/service/signal.go @@ -0,0 +1,81 @@ +package service + +import ( + "os" + "os/signal" + "syscall" + + "github.com/sirupsen/logrus" + + "github.com/authelia/authelia/v4/internal/logging" +) + +func ProvisionLoggingSignal(ctx Context) (service Provider, err error) { + config := ctx.GetConfiguration() + + if config == nil || len(config.Log.FilePath) == 0 { + return nil, nil + } + + return &Signal{ + name: "log-reload", + signals: []os.Signal{syscall.SIGHUP}, + action: logging.Reopen, + log: ctx.GetLogger().WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: "log-reload"}), + }, nil +} + +// Signal is a Service which performs actions on signals. +type Signal struct { + name string + signals []os.Signal + action func() (err error) + log *logrus.Entry + + notify chan os.Signal + quit chan struct{} +} + +// ServiceType returns the service type for this service, which is always 'server'. +func (service *Signal) ServiceType() string { + return serviceTypeSignal +} + +// ServiceName returns the individual name for this service. +func (service *Signal) ServiceName() string { + return service.name +} + +// Run the ServerService. +func (service *Signal) Run() (err error) { + service.quit = make(chan struct{}) + + service.notify = make(chan os.Signal, 1) + + signal.Notify(service.notify, service.signals...) + + for { + select { + case s := <-service.notify: + if err = service.action(); err != nil { + service.log.WithError(err).Error("Error occurred executing service action.") + } else { + service.log.WithFields(map[string]any{"signal-received": s.String()}).Debug("Successfully executed service action.") + } + case <-service.quit: + return + } + } +} + +// Shutdown the ServerService. +func (service *Signal) Shutdown() { + signal.Stop(service.notify) + + service.quit <- struct{}{} +} + +// Log returns the *logrus.Entry of the ServerService. +func (service *Signal) Log() *logrus.Entry { + return service.log +} diff --git a/internal/commands/services_test.go b/internal/service/signal_test.go index 3a9364e16..a43b7db03 100644 --- a/internal/commands/services_test.go +++ b/internal/service/signal_test.go @@ -1,4 +1,4 @@ -package commands +package service import ( "context" @@ -13,9 +13,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/authelia/authelia/v4/internal/logging" - "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/middlewares" ) @@ -23,11 +22,11 @@ import ( type mockServiceCtx struct { ctx context.Context config *schema.Configuration - logger *logrus.Logger + logger *logrus.Entry providers middlewares.Providers } -func (m *mockServiceCtx) GetLogger() *logrus.Logger { +func (m *mockServiceCtx) GetLogger() *logrus.Entry { return m.logger } @@ -64,7 +63,7 @@ func newMockServiceCtx() *mockServiceCtx { return &mockServiceCtx{ ctx: context.Background(), config: config, - logger: logger, + logger: logrus.NewEntry(logger), providers: middlewares.Providers{}, } } @@ -98,7 +97,12 @@ func TestSignalService_Run(t *testing.T) { return tc.actionError } - service := NewSignalService("test", action, logger, tc.signal) + service := &Signal{ + name: "log-reload", + signals: []os.Signal{syscall.SIGHUP}, + action: action, + log: logger.WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: "log-reload"}), + } errChan := make(chan error, 1) done := make(chan struct{}) @@ -110,6 +114,7 @@ func TestSignalService_Run(t *testing.T) { close(done) }() + // Give the service a moment to start. time.Sleep(100 * time.Millisecond) p, err := os.FindProcess(os.Getpid()) @@ -163,7 +168,7 @@ func TestSvcSignalLogReOpenFunc(t *testing.T) { mockCtx := newMockServiceCtx() mockCtx.config.Log.FilePath = tc.logFilePath - service := svcSignalLogReOpenFunc(mockCtx) + service, _ := ProvisionLoggingSignal(mockCtx) if tc.expectService { require.NotNil(t, service) @@ -197,11 +202,11 @@ func TestLogReopenFiles(t *testing.T) { ctx := &mockServiceCtx{ ctx: context.Background(), config: config, - logger: logging.Logger(), + logger: logrus.NewEntry(logging.Logger()), providers: middlewares.Providers{}, } - service := svcSignalLogReOpenFunc(ctx) + service, _ := ProvisionLoggingSignal(ctx) require.NotNil(t, service) errChan := make(chan error, 1) @@ -214,6 +219,7 @@ func TestLogReopenFiles(t *testing.T) { close(done) }() + // Give the service a moment to start. time.Sleep(100 * time.Millisecond) p, err := os.FindProcess(os.Getpid()) @@ -237,7 +243,12 @@ func TestLogReopenFiles(t *testing.T) { func TestSignalService_Shutdown(t *testing.T) { logger := logrus.New() action := func() error { return nil } - service := NewSignalService("test", action, logger, syscall.SIGHUP) + service := &Signal{ + name: "test", + signals: []os.Signal{syscall.SIGHUP}, + action: action, + log: logger.WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: "test"}), + } done := make(chan struct{}) go func() { diff --git a/internal/service/util.go b/internal/service/util.go new file mode 100644 index 000000000..3e4a7f35c --- /dev/null +++ b/internal/service/util.go @@ -0,0 +1,120 @@ +package service + +import ( + "context" + "fmt" + "os" + "os/signal" + "sync" + "syscall" + + "golang.org/x/sync/errgroup" +) + +func RunAll(ctx Context) (err error) { + provisioners := GetProvisioners() + + return Run(ctx, provisioners...) +} + +func Run(ctx Context, provisioners ...Provisioner) (err error) { + cctx, cancel := context.WithCancel(ctx) + + group, cctx := errgroup.WithContext(cctx) + + defer cancel() + + quit := make(chan os.Signal, 1) + + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + defer signal.Stop(quit) + + var ( + services []Provider + ) + + log := ctx.GetLogger() + + for _, provisioner := range provisioners { + if service, err := provisioner(ctx); err != nil { + return fmt.Errorf("error occurred provisioning services: %w", err) + } else if service != nil { + services = append(services, service) + } + } + + for _, service := range services { + group.Go(service.Run) + } + + log.Info("Startup complete") + + select { + case s := <-quit: + log.WithField("signal", s.String()).Debug("Shutdown initiated due to process signal") + case <-cctx.Done(): + log.Debug("Shutdown initiated due to context completion") + } + + cancel() + + log.Info("Shutdown initiated") + + wgShutdown := &sync.WaitGroup{} + + log.Tracef("Shutdown of %d services is required", len(services)) + + for _, service := range services { + wgShutdown.Add(1) + + go func(service Provider) { + service.Log().Trace("Shutdown of service initiated") + + service.Shutdown() + + wgShutdown.Done() + + service.Log().Trace("Shutdown of service complete") + }(service) + } + + wgShutdown.Wait() + + if err = ctx.GetProviders().UserProvider.Close(); err != nil { + ctx.GetLogger().WithError(err).Error("Error occurred closing authentication connections") + } + + if err = ctx.GetProviders().StorageProvider.Close(); err != nil { + log.WithError(err).Error("Error occurred closing database connections") + } + + if err = group.Wait(); err != nil { + log.WithError(err).Error("Error occurred waiting for shutdown") + } + + log.Info("Shutdown complete") + + return nil +} + +func connectionType(isTLS bool) string { + if isTLS { + return "TLS" + } + + return "non-TLS" +} + +func recoverErr(i any) error { + switch v := i.(type) { + case nil: + return nil + case string: + return fmt.Errorf("recovered panic: %s", v) + case error: + return fmt.Errorf("recovered panic: %w", v) + default: + return fmt.Errorf("recovered panic with unknown type: %v", v) + } +} diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index d738cd80c..3ba918b87 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -3,6 +3,7 @@ package storage import ( "context" "crypto/sha256" + "crypto/x509" "database/sql" "errors" "fmt" @@ -19,6 +20,20 @@ import ( "github.com/authelia/authelia/v4/internal/model" ) +// NewProvider dynamically initializes a storage.Provider given a *schema.Configuration and *x509.CertPool. +func NewProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider Provider) { + switch { + case config.Storage.PostgreSQL != nil: + return NewPostgreSQLProvider(config, caCertPool) + case config.Storage.MySQL != nil: + return NewMySQLProvider(config, caCertPool) + case config.Storage.Local != nil: + return NewSQLiteProvider(config) + default: + return nil + } +} + // NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's. func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceName string) (provider SQLProvider) { db, err := sqlx.Open(driverName, dataSourceName) diff --git a/web/.commitlintrc.cjs b/web/.commitlintrc.cjs index 82516cc34..ea897e475 100644 --- a/web/.commitlintrc.cjs +++ b/web/.commitlintrc.cjs @@ -47,6 +47,7 @@ module.exports = { "renovate", "reviewdog", "server", + "service", "session", "storage", "suites", |
