summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2025-03-09 01:53:44 +1100
committerGitHub <noreply@github.com>2025-03-09 01:53:44 +1100
commit9241731a4dd5592b4a02b5352c903b4d06b6f4ab (patch)
tree5184b98751912a261ff70fd8721b9cd4f1c98f1e
parentbbcb38ab9ff35e69d5d52a71ab56346749f5e8b1 (diff)
feat(embed): make authelia embedable (#8841)
This adds a highly experimental option for developers looking to embed Authelia within another go binary. Closes #5803 Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
-rw-r--r--docs/content/contributing/guidelines/commit-message.md3
-rw-r--r--experimental/embed/config.go71
-rw-r--r--experimental/embed/config_test.go195
-rw-r--r--experimental/embed/context.go44
-rw-r--r--experimental/embed/context_test.go43
-rw-r--r--experimental/embed/doc.go13
-rw-r--r--experimental/embed/embed.go7
-rw-r--r--experimental/embed/embed_test.go19
-rw-r--r--experimental/embed/provider/authentication.go22
-rw-r--r--experimental/embed/provider/general.go113
-rw-r--r--experimental/embed/provider/notification.go24
-rw-r--r--experimental/embed/provider/storage.go29
-rw-r--r--experimental/embed/types.go24
-rw-r--r--experimental/embed/types_test.go15
-rw-r--r--internal/authentication/file_user_provider.go2
-rw-r--r--internal/authentication/ldap_user_provider_lifecycle.go2
-rw-r--r--internal/authentication/ldap_user_provider_test.go16
-rw-r--r--internal/authentication/user_provider.go2
-rw-r--r--internal/commands/const.go24
-rw-r--r--internal/commands/context.go60
-rw-r--r--internal/commands/helpers.go11
-rw-r--r--internal/commands/root.go102
-rw-r--r--internal/commands/services.go463
-rw-r--r--internal/commands/util.go17
-rw-r--r--internal/configuration/koanf_util.go3
-rw-r--r--internal/configuration/test_resources/config_with_definitions.yml7
-rw-r--r--internal/middlewares/authelia_context.go5
-rw-r--r--internal/middlewares/authelia_context_test.go8
-rw-r--r--internal/middlewares/const.go14
-rw-r--r--internal/middlewares/startup.go124
-rw-r--r--internal/middlewares/timing_attack_delay_test.go2
-rw-r--r--internal/middlewares/types.go10
-rw-r--r--internal/middlewares/util.go64
-rw-r--r--internal/mocks/user_provider.go28
-rw-r--r--internal/server/handlers.go222
-rw-r--r--internal/server/server.go20
-rw-r--r--internal/server/server_test.go2
-rw-r--r--internal/service/const.go15
-rw-r--r--internal/service/file_watcher.go174
-rw-r--r--internal/service/file_watcher_test.go211
-rw-r--r--internal/service/provider.go52
-rw-r--r--internal/service/provider_test.go13
-rw-r--r--internal/service/server.go120
-rw-r--r--internal/service/sever_test.go121
-rw-r--r--internal/service/signal.go81
-rw-r--r--internal/service/signal_test.go (renamed from internal/commands/services_test.go)33
-rw-r--r--internal/service/util.go120
-rw-r--r--internal/storage/sql_provider.go15
-rw-r--r--web/.commitlintrc.cjs1
49 files changed, 1963 insertions, 823 deletions
diff --git a/docs/content/contributing/guidelines/commit-message.md b/docs/content/contributing/guidelines/commit-message.md
index 4a8ba796f..99d109cbc 100644
--- a/docs/content/contributing/guidelines/commit-message.md
+++ b/docs/content/contributing/guidelines/commit-message.md
@@ -54,7 +54,7 @@ for, and the structure it must have.
│ cmd|codecov|commands|configuration|deps|docker|duo|expression|go|
│ golangci-lint|handlers|husky|logging|metrics|middlewares|mocks|model|
│ notification|npm|ntp|oidc|random|regulation|renovate|reviewdog|server|
- │ session|storage|suites|templates|totp|utils|web|webauthn
+ │ service|session|storage|suites|templates|totp|utils|web|webauthn
└─⫸ Commit Type: build|ci|docs|feat|fix|i18n|perf|refactor|release|revert|test
```
@@ -100,6 +100,7 @@ commit messages).
* random
* regulation
* server
+* service
* session
* storage
* suites
diff --git a/experimental/embed/config.go b/experimental/embed/config.go
new file mode 100644
index 000000000..81d21463b
--- /dev/null
+++ b/experimental/embed/config.go
@@ -0,0 +1,71 @@
+package embed
+
+import (
+ "fmt"
+
+ "github.com/authelia/authelia/v4/internal/configuration"
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/configuration/validator"
+)
+
+// NewConfiguration builds a new configuration given a list of paths and filters. The filters can either be nil or
+// generated using NewNamedConfigFileFilters. This function essentially operates the same as Authelia does normally in
+// configuration steps.
+func NewConfiguration(paths []string, filters []configuration.BytesFilter) (keys []string, config *schema.Configuration, val *schema.StructValidator, err error) {
+ sources := configuration.NewDefaultSourcesWithDefaults(
+ paths,
+ filters,
+ configuration.DefaultEnvPrefix,
+ configuration.DefaultEnvDelimiter,
+ []configuration.Source{configuration.NewMapSource(configuration.Defaults())})
+
+ val = schema.NewStructValidator()
+
+ var definitions *schema.Definitions
+
+ if definitions, err = configuration.LoadDefinitions(val, sources...); err != nil {
+ return nil, nil, nil, err
+ }
+
+ config = &schema.Configuration{}
+
+ if keys, err = configuration.LoadAdvanced(
+ val,
+ "",
+ config,
+ definitions,
+ sources...); err != nil {
+ return nil, nil, nil, err
+ }
+
+ return keys, config, val, nil
+}
+
+// ValidateConfigurationAndKeys performs all configuration validation steps. The provided *schema.StructValidator should
+// at minimum be checked for errors before continuing.
+func ValidateConfigurationAndKeys(config *schema.Configuration, keys []string, val *schema.StructValidator) {
+ ValidateConfigurationKeys(keys, val)
+ ValidateConfiguration(config, val)
+}
+
+// ValidateConfigurationKeys just the keys validation steps. The provided *schema.StructValidator should
+// at minimum be checked for errors before continuing. This should be used prior to using ValidateConfiguration.
+func ValidateConfigurationKeys(keys []string, val *schema.StructValidator) {
+ validator.ValidateKeys(keys, configuration.GetMultiKeyMappedDeprecationKeys(), configuration.DefaultEnvPrefix, val)
+}
+
+// ValidateConfiguration just the configuration validation steps. The provided *schema.StructValidator should
+// at minimum be checked for errors before continuing. This should be used after using ValidateConfigurationKeys.
+func ValidateConfiguration(config *schema.Configuration, val *schema.StructValidator) {
+ validator.ValidateConfiguration(config, val)
+}
+
+// NewNamedConfigFileFilters allows configuring a set of file filters. The officially supported filter has the name
+// 'template'. The only other one at this stage is 'expand-env' which is deprecated.
+func NewNamedConfigFileFilters(names ...string) (filters []configuration.BytesFilter, err error) {
+ if filters, err = configuration.NewFileFilters(names); err != nil {
+ return nil, fmt.Errorf("error occurred loading filters: %w", err)
+ }
+
+ return filters, nil
+}
diff --git a/experimental/embed/config_test.go b/experimental/embed/config_test.go
new file mode 100644
index 000000000..eb0ce6e2f
--- /dev/null
+++ b/experimental/embed/config_test.go
@@ -0,0 +1,195 @@
+package embed
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/authelia/authelia/v4/internal/configuration"
+)
+
+func TestNewConfiguration(t *testing.T) {
+ testCases := []struct {
+ name string
+ paths []string
+ filters []configuration.BytesFilter
+ keys []string
+ warnings []string
+ errors []string
+ err string
+ }{
+ {
+ name: "ShouldHandleWebAuthn",
+ paths: []string{"../../internal/configuration/test_resources/config.webauthn.yml"},
+ keys: []string{
+ "server.endpoints.rate_limits.reset_password_finish.enable",
+ "server.endpoints.rate_limits.reset_password_start.enable",
+ "server.endpoints.rate_limits.second_factor_duo.enable",
+ "server.endpoints.rate_limits.second_factor_totp.enable",
+ "server.endpoints.rate_limits.session_elevation_finish.enable",
+ "server.endpoints.rate_limits.session_elevation_start.enable",
+ "webauthn.selection_criteria.attachment",
+ "webauthn.selection_criteria.discoverability",
+ "webauthn.selection_criteria.user_verification",
+ },
+ warnings: nil,
+ errors: []string{
+ "identity_validation: reset_password: option 'jwt_secret' is required when the reset password functionality isn't disabled",
+ "authentication_backend: you must ensure either the 'file' or 'ldap' authentication backend is configured",
+ "access_control: 'default_policy' option 'deny' is invalid: when no rules are specified it must be 'two_factor' or 'one_factor'",
+ "session: option 'cookies' is required",
+ "storage: option 'encryption_key' is required",
+ "storage: configuration for a 'local', 'mysql' or 'postgres' database must be provided",
+ "notifier: you must ensure either the 'smtp' or 'filesystem' notifier is configured",
+ },
+ },
+ {
+ name: "ShouldHandleConfigWithDefinitions",
+ paths: []string{"../../internal/configuration/test_resources/config_with_definitions.yml"},
+ keys: []string{
+ "access_control.default_policy",
+ "access_control.networks",
+ "access_control.networks[].name",
+ "access_control.networks[].networks",
+ "access_control.rules",
+ "access_control.rules[].domain",
+ "access_control.rules[].networks",
+ "access_control.rules[].policy",
+ "access_control.rules[].resources",
+ "access_control.rules[].subject",
+ "authentication_backend.ldap.additional_groups_dn",
+ "authentication_backend.ldap.additional_users_dn",
+ "authentication_backend.ldap.address",
+ "authentication_backend.ldap.attributes.group_name",
+ "authentication_backend.ldap.attributes.mail",
+ "authentication_backend.ldap.attributes.username",
+ "authentication_backend.ldap.base_dn",
+ "authentication_backend.ldap.groups_filter",
+ "authentication_backend.ldap.tls.private_key",
+ "authentication_backend.ldap.user",
+ "authentication_backend.ldap.users_filter",
+ "authentication_backend.refresh_interval",
+ "definitions.network.lan",
+ "definitions.user_attributes.example.expression",
+ "duo_api.hostname",
+ "duo_api.integration_key",
+ "log.level",
+ "notifier.smtp.address",
+ "notifier.smtp.disable_require_tls",
+ "notifier.smtp.sender",
+ "notifier.smtp.username",
+ "regulation.ban_time",
+ "regulation.find_time",
+ "regulation.max_retries",
+ "server.address",
+ "server.endpoints.authz.auth-request.authn_strategies",
+ "server.endpoints.authz.auth-request.authn_strategies[].name",
+ "server.endpoints.authz.auth-request.implementation",
+ "server.endpoints.authz.ext-authz.authn_strategies",
+ "server.endpoints.authz.ext-authz.authn_strategies[].name",
+ "server.endpoints.authz.ext-authz.implementation",
+ "server.endpoints.authz.forward-auth.authn_strategies",
+ "server.endpoints.authz.forward-auth.authn_strategies[].name",
+ "server.endpoints.authz.forward-auth.implementation",
+ "server.endpoints.authz.legacy.implementation",
+ "server.endpoints.rate_limits.reset_password_finish.enable",
+ "server.endpoints.rate_limits.reset_password_start.enable",
+ "server.endpoints.rate_limits.second_factor_duo.enable",
+ "server.endpoints.rate_limits.second_factor_totp.enable",
+ "server.endpoints.rate_limits.session_elevation_finish.enable",
+ "server.endpoints.rate_limits.session_elevation_start.enable",
+ "session.cookies",
+ "session.cookies[].authelia_url",
+ "session.cookies[].default_redirection_url",
+ "session.cookies[].domain",
+ "session.expiration",
+ "session.inactivity",
+ "session.name",
+ "session.redis.high_availability.sentinel_name",
+ "session.redis.host",
+ "session.redis.port",
+ "storage.mysql.address",
+ "storage.mysql.database",
+ "storage.mysql.username",
+ "totp.issuer",
+ "webauthn.selection_criteria.discoverability",
+ "webauthn.selection_criteria.user_verification",
+ },
+ warnings: nil,
+ errors: []string{
+ "duo_api: option 'secret_key' is required when duo is enabled but it's absent",
+ "identity_validation: reset_password: option 'jwt_secret' is required when the reset password functionality isn't disabled",
+ "authentication_backend: ldap: option 'password' is required",
+ "session: option 'secret' is required when using the 'redis' provider",
+ "storage: option 'encryption_key' is required",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Run("Individual", func(t *testing.T) {
+ keys, config, val, err := NewConfiguration(tc.paths, tc.filters)
+
+ assert.Equal(t, tc.keys, keys)
+
+ if tc.err != "" {
+ assert.EqualError(t, err, tc.err)
+ assert.Nil(t, config)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, config)
+ ValidateConfigurationKeys(keys, val)
+ ValidateConfiguration(config, val)
+ }
+
+ require.Len(t, val.Warnings(), len(tc.warnings))
+ require.Len(t, val.Errors(), len(tc.errors))
+
+ for i, err := range val.Warnings() {
+ assert.EqualError(t, err, tc.warnings[i])
+ }
+
+ for i, err := range val.Errors() {
+ assert.EqualError(t, err, tc.errors[i])
+ }
+ })
+ t.Run("Combined", func(t *testing.T) {
+ keys, config, val, err := NewConfiguration(tc.paths, tc.filters)
+
+ assert.Equal(t, tc.keys, keys)
+
+ if tc.err != "" {
+ assert.EqualError(t, err, tc.err)
+ assert.Nil(t, config)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, config)
+ ValidateConfigurationAndKeys(config, keys, val)
+ }
+
+ require.Len(t, val.Warnings(), len(tc.warnings))
+ require.Len(t, val.Errors(), len(tc.errors))
+
+ for i, err := range val.Warnings() {
+ assert.EqualError(t, err, tc.warnings[i])
+ }
+
+ for i, err := range val.Errors() {
+ assert.EqualError(t, err, tc.errors[i])
+ }
+ })
+ })
+ }
+}
+
+func TestNewNamedConfigFileFilters(t *testing.T) {
+ filters, err := NewNamedConfigFileFilters("abc")
+ assert.Nil(t, filters)
+ assert.EqualError(t, err, "error occurred loading filters: invalid filter named 'abc'")
+
+ filters, err = NewNamedConfigFileFilters("template")
+ assert.NotNil(t, filters)
+ assert.NoError(t, err)
+}
diff --git a/experimental/embed/context.go b/experimental/embed/context.go
new file mode 100644
index 000000000..84cf483aa
--- /dev/null
+++ b/experimental/embed/context.go
@@ -0,0 +1,44 @@
+package embed
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/middlewares"
+)
+
+// Context is an interface used in various areas of Authelia to simplify access to important elements like the
+// configuration, providers, and logger.
+type Context interface {
+ GetLogger() *logrus.Entry
+ GetProviders() middlewares.Providers
+ GetConfiguration() *schema.Configuration
+
+ context.Context
+}
+
+type ctxEmbed struct {
+ Configuration *Configuration
+ Providers Providers
+ Logger *logrus.Entry
+
+ context.Context
+}
+
+func (c *ctxEmbed) GetConfiguration() *schema.Configuration {
+ return c.Configuration.ToInternal()
+}
+
+func (c *ctxEmbed) GetProviders() middlewares.Providers {
+ return c.Providers.ToInternal()
+}
+
+func (c *ctxEmbed) GetLogger() *logrus.Entry {
+ return c.Logger
+}
+
+var (
+ _ middlewares.Context = (*ctxEmbed)(nil)
+)
diff --git a/experimental/embed/context_test.go b/experimental/embed/context_test.go
new file mode 100644
index 000000000..cb5a2486e
--- /dev/null
+++ b/experimental/embed/context_test.go
@@ -0,0 +1,43 @@
+package embed
+
+import (
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+
+ "github.com/authelia/authelia/v4/internal/logging"
+)
+
+func TestContext(t *testing.T) {
+ ctx := &ctxEmbed{}
+
+ assert.Nil(t, ctx.GetConfiguration())
+ assert.Nil(t, ctx.GetLogger())
+
+ providers := ctx.GetProviders()
+
+ assert.Nil(t, providers.StorageProvider)
+ assert.Nil(t, providers.Notifier)
+ assert.Nil(t, providers.UserProvider)
+ assert.Nil(t, providers.SessionProvider)
+ assert.Nil(t, providers.MetaDataService)
+ assert.Nil(t, providers.Metrics)
+ assert.Nil(t, providers.Templates)
+ assert.Nil(t, providers.Random)
+ assert.Nil(t, providers.OpenIDConnect)
+ assert.Nil(t, providers.UserAttributeResolver)
+ assert.Nil(t, providers.Authorizer)
+ assert.Nil(t, providers.NTP)
+ assert.Nil(t, providers.TOTP)
+}
+
+func TestContextWithValues(t *testing.T) {
+ ctx := &ctxEmbed{
+ Configuration: &Configuration{},
+ Logger: logrus.NewEntry(logging.Logger()),
+ }
+
+ assert.NotNil(t, ctx.GetConfiguration())
+ assert.NotNil(t, ctx.GetLogger())
+}
diff --git a/experimental/embed/doc.go b/experimental/embed/doc.go
new file mode 100644
index 000000000..9e9d4d09b
--- /dev/null
+++ b/experimental/embed/doc.go
@@ -0,0 +1,13 @@
+// Package embed provides tooling useful to embed Authelia into an external go process. This package is considered
+// experimental and as such is not supported by the standard versioning policy. It's strongly recommended that care is
+// taken when integrating with this package and appropriate tests are conducted when upgrading.
+//
+// This package and all subpackages are intended to facilitate differing levels of embedability within Authelia. It's
+// likely this package and subpackages will break often.
+//
+// The following considerations should be made in using this package:
+// - It's likely that many methods within this package can panic if not properly utilized.
+// - The package is likely at this stage to be changed abruptly from version to version in a breaking way.
+// - The package will likely have breaking changes at any minor version bump well into the future (breaking changes to
+// this package as a result of changing internal packages will not be a consideration that will slow development).
+package embed
diff --git a/experimental/embed/embed.go b/experimental/embed/embed.go
new file mode 100644
index 000000000..0a2d91168
--- /dev/null
+++ b/experimental/embed/embed.go
@@ -0,0 +1,7 @@
+package embed
+
+func ProvidersStartupCheck(ctx Context, log bool) (err error) {
+ providers := ctx.GetProviders()
+
+ return providers.StartupChecks(ctx, log)
+}
diff --git a/experimental/embed/embed_test.go b/experimental/embed/embed_test.go
new file mode 100644
index 000000000..c20e2cb77
--- /dev/null
+++ b/experimental/embed/embed_test.go
@@ -0,0 +1,19 @@
+package embed
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestShouldPanicNilCtx(t *testing.T) {
+ assert.Panics(t, func() {
+ _ = ProvidersStartupCheck(nil, false)
+ })
+
+ ctx := &ctxEmbed{}
+
+ assert.Panics(t, func() {
+ _ = ProvidersStartupCheck(ctx, false)
+ })
+}
diff --git a/experimental/embed/provider/authentication.go b/experimental/embed/provider/authentication.go
new file mode 100644
index 000000000..3903e6c1c
--- /dev/null
+++ b/experimental/embed/provider/authentication.go
@@ -0,0 +1,22 @@
+package provider
+
+import (
+ "crypto/x509"
+
+ "github.com/authelia/authelia/v4/internal/authentication"
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+)
+
+// NewAuthenticationFile directly instantiates a new authentication.UserProvider using a *authentication.FileUserProvider.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewAuthenticationFile(config *schema.Configuration) authentication.UserProvider {
+ return authentication.NewFileUserProvider(config.AuthenticationBackend.File)
+}
+
+// NewAuthenticationLDAP directly instantiates a new authentication.UserProvider using a *authentication.LDAPUserProvider.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewAuthenticationLDAP(config *schema.Configuration, caCertPool *x509.CertPool) authentication.UserProvider {
+ return authentication.NewLDAPUserProvider(config.AuthenticationBackend, caCertPool)
+}
diff --git a/experimental/embed/provider/general.go b/experimental/embed/provider/general.go
new file mode 100644
index 000000000..95e5d04a7
--- /dev/null
+++ b/experimental/embed/provider/general.go
@@ -0,0 +1,113 @@
+package provider
+
+import (
+ "crypto/x509"
+
+ "github.com/authelia/authelia/v4/internal/authorization"
+ "github.com/authelia/authelia/v4/internal/clock"
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/expression"
+ "github.com/authelia/authelia/v4/internal/metrics"
+ "github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/ntp"
+ "github.com/authelia/authelia/v4/internal/oidc"
+ "github.com/authelia/authelia/v4/internal/random"
+ "github.com/authelia/authelia/v4/internal/regulation"
+ "github.com/authelia/authelia/v4/internal/session"
+ "github.com/authelia/authelia/v4/internal/storage"
+ "github.com/authelia/authelia/v4/internal/templates"
+ "github.com/authelia/authelia/v4/internal/totp"
+ "github.com/authelia/authelia/v4/internal/webauthn"
+)
+
+// New returns a completely new set of providers using the internal API. It is expected you'll check the errs return
+// value for any errors, and handle any warnings in a graceful way. If errors are returned the providers should not be
+// utilized to run anything.
+func New(config *schema.Configuration, caCertPool *x509.CertPool) (providers middlewares.Providers, warns []error, errs []error) {
+ return middlewares.NewProviders(config, caCertPool)
+}
+
+// NewClock creates a new clock provider.
+func NewClock() clock.Provider {
+ return clock.New()
+}
+
+// NewAuthorizer creates a new *authorization.Authorizer.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewAuthorizer(config *schema.Configuration) *authorization.Authorizer {
+ return authorization.NewAuthorizer(config)
+}
+
+// NewSession creates a new *session.Provider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewSession(config *schema.Configuration, caCertPool *x509.CertPool) *session.Provider {
+ return session.NewProvider(config.Session, caCertPool)
+}
+
+// NewRegulator creates a new *regulation.Regulator given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewRegulator(config *schema.Configuration, storage storage.RegulatorProvider, clock clock.Provider) *regulation.Regulator {
+ return regulation.NewRegulator(config.Regulation, storage, clock)
+}
+
+// NewMetrics creates a new metrics.Provider.
+func NewMetrics() metrics.Provider {
+ return metrics.NewPrometheus()
+}
+
+// NewNTP creates a new *ntp.Provider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewNTP(config *schema.Configuration) *ntp.Provider {
+ return ntp.NewProvider(&config.NTP)
+}
+
+// NewOpenIDConnect creates a new *oidc.OpenIDConnectProvider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewOpenIDConnect(config *schema.Configuration, storage storage.Provider, templates *templates.Provider) *oidc.OpenIDConnectProvider {
+ return oidc.NewOpenIDConnectProvider(config, storage, templates)
+}
+
+// NewTemplates creates a new *templates.Provider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewTemplates(config *schema.Configuration) (provider *templates.Provider, err error) {
+ return templates.New(templates.Config{EmailTemplatesPath: config.Notifier.TemplatePath})
+}
+
+// NewTOTP creates a new totp.Provider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewTOTP(config *schema.Configuration) totp.Provider {
+ return totp.NewTimeBasedProvider(config.TOTP)
+}
+
+// NewPasswordPolicy creates a new middlewares.PasswordPolicyProvider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewPasswordPolicy(config *schema.Configuration) middlewares.PasswordPolicyProvider {
+ return middlewares.NewPasswordPolicyProvider(config.PasswordPolicy)
+}
+
+// NewRandom creates a new random.Provider given a valid configuration. This uses the rand/crypto package.
+func NewRandom() random.Provider {
+ return &random.Cryptographical{}
+}
+
+// NewUserAttributeResolver creates a new expression.UserAttributeResolver given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewUserAttributeResolver(config *schema.Configuration) expression.UserAttributeResolver {
+ return expression.NewUserAttributes(config)
+}
+
+// NewMetaDataService creates a new webauthn.MetaDataProvider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewMetaDataService(config *schema.Configuration, store storage.CachedDataProvider) (provider webauthn.MetaDataProvider, err error) {
+ return webauthn.NewMetaDataProvider(config, store)
+}
diff --git a/experimental/embed/provider/notification.go b/experimental/embed/provider/notification.go
new file mode 100644
index 000000000..a566cd20d
--- /dev/null
+++ b/experimental/embed/provider/notification.go
@@ -0,0 +1,24 @@
+package provider
+
+import (
+ "crypto/x509"
+
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/notification"
+)
+
+// NewNotificationSMTP creates a new notification.Notifier using the *notification.SMTPNotifier given a valid
+// configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewNotificationSMTP(config *schema.Configuration, caCertPool *x509.CertPool) notification.Notifier {
+ return notification.NewSMTPNotifier(config.Notifier.SMTP, caCertPool)
+}
+
+// NewNotificationFile creates a new notification.Notifier using the *notification.FileNotifier given a valid
+// configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewNotificationFile(config *schema.Configuration, caCertPool *x509.CertPool) notification.Notifier {
+ return notification.NewSMTPNotifier(config.Notifier.SMTP, caCertPool)
+}
diff --git a/experimental/embed/provider/storage.go b/experimental/embed/provider/storage.go
new file mode 100644
index 000000000..42c40aefd
--- /dev/null
+++ b/experimental/embed/provider/storage.go
@@ -0,0 +1,29 @@
+package provider
+
+import (
+ "crypto/x509"
+
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/storage"
+)
+
+// NewStoragePostgreSQL creates a new storage.Provider using the *storage.PostgreSQLProvider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewStoragePostgreSQL(config *schema.Configuration, caCertPool *x509.CertPool) storage.Provider {
+ return storage.NewPostgreSQLProvider(config, caCertPool)
+}
+
+// NewStorageMySQL creates a new storage.Provider using the *storage.MySQLProvider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewStorageMySQL(config *schema.Configuration, caCertPool *x509.CertPool) storage.Provider {
+ return storage.NewMySQLProvider(config, caCertPool)
+}
+
+// NewStorageSQLite creates a new storage.Provider using the *storage.SQLiteProvider given a valid configuration.
+//
+// Warning: This method may panic if the provided configuration isn't validated.
+func NewStorageSQLite(config *schema.Configuration) storage.Provider {
+ return storage.NewSQLiteProvider(config)
+}
diff --git a/experimental/embed/types.go b/experimental/embed/types.go
new file mode 100644
index 000000000..bd048206e
--- /dev/null
+++ b/experimental/embed/types.go
@@ -0,0 +1,24 @@
+package embed
+
+import (
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/middlewares"
+)
+
+// Configuration is a type alias for the internal schema.Configuration type. It allows manually configuring Authelia
+// and transitioning to the internal implementation.
+type Configuration schema.Configuration
+
+// ToInternal converts this Configuration struct into a *schema.Configuration struct using a type cast.
+func (c *Configuration) ToInternal() *schema.Configuration {
+ return (*schema.Configuration)(c)
+}
+
+// Providers is a type alias for the internal middlewares.Providers type. It allows manually performing setup of the
+// various Authelia providers and transitioning to the internal implementation.
+type Providers middlewares.Providers
+
+// ToInternal converts this Providers struct into a middlewares.Providers struct using a type cast.
+func (p Providers) ToInternal() middlewares.Providers {
+ return middlewares.Providers(p)
+}
diff --git a/experimental/embed/types_test.go b/experimental/embed/types_test.go
new file mode 100644
index 000000000..3a93ce183
--- /dev/null
+++ b/experimental/embed/types_test.go
@@ -0,0 +1,15 @@
+package embed
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestTypeBasics(t *testing.T) {
+ c := &Configuration{}
+ assert.NotNil(t, c.ToInternal())
+
+ p := &Providers{}
+ assert.NotNil(t, p.ToInternal())
+}
diff --git a/internal/authentication/file_user_provider.go b/internal/authentication/file_user_provider.go
index 9be6d8edf..bd7af3417 100644
--- a/internal/authentication/file_user_provider.go
+++ b/internal/authentication/file_user_provider.go
@@ -77,7 +77,7 @@ func (p *FileUserProvider) Reload() (reloaded bool, err error) {
return true, nil
}
-func (p *FileUserProvider) Shutdown() (err error) {
+func (p *FileUserProvider) Close() (err error) {
return nil
}
diff --git a/internal/authentication/ldap_user_provider_lifecycle.go b/internal/authentication/ldap_user_provider_lifecycle.go
index 73dceaa6a..8f45c99f2 100644
--- a/internal/authentication/ldap_user_provider_lifecycle.go
+++ b/internal/authentication/ldap_user_provider_lifecycle.go
@@ -10,7 +10,7 @@ import (
"github.com/authelia/authelia/v4/internal/utils"
)
-func (p *LDAPUserProvider) Shutdown() (err error) {
+func (p *LDAPUserProvider) Close() (err error) {
return p.factory.Close()
}
diff --git a/internal/authentication/ldap_user_provider_test.go b/internal/authentication/ldap_user_provider_test.go
index 3975faa39..7b6b4ee73 100644
--- a/internal/authentication/ldap_user_provider_test.go
+++ b/internal/authentication/ldap_user_provider_test.go
@@ -457,7 +457,7 @@ func TestShouldCheckLDAPServerExtensionsPooled(t *testing.T) {
assert.False(t, provider.features.ControlTypes.MsftPwdPolHints)
assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated)
- assert.EqualError(t, provider.Shutdown(), "errors occurred closing the client pool: close error")
+ assert.EqualError(t, provider.Close(), "errors occurred closing the client pool: close error")
}
func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntry(t *testing.T) {
@@ -597,7 +597,7 @@ func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntryPoo
assert.False(t, provider.features.ControlTypes.MsftPwdPolHints)
assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated)
- assert.NoError(t, provider.Shutdown())
+ assert.NoError(t, provider.Close())
}
func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntryPooledClosing(t *testing.T) {
@@ -680,7 +680,7 @@ func TestShouldNotCheckLDAPServerExtensionsWhenRootDSEReturnsMoreThanOneEntryPoo
assert.False(t, provider.features.ControlTypes.MsftPwdPolHints)
assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated)
- assert.NoError(t, provider.Shutdown())
+ assert.NoError(t, provider.Close())
}
func TestShouldCheckLDAPServerControlTypes(t *testing.T) {
@@ -820,7 +820,7 @@ func TestShouldCheckLDAPServerControlTypesPooled(t *testing.T) {
assert.True(t, provider.features.ControlTypes.MsftPwdPolHints)
assert.True(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated)
- assert.NoError(t, provider.Shutdown())
+ assert.NoError(t, provider.Close())
}
func TestShouldNotEnablePasswdModifyExtensionOrControlTypes(t *testing.T) {
@@ -886,7 +886,7 @@ func TestShouldNotEnablePasswdModifyExtensionOrControlTypes(t *testing.T) {
assert.False(t, provider.features.ControlTypes.MsftPwdPolHints)
assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated)
- assert.NoError(t, provider.Shutdown())
+ assert.NoError(t, provider.Close())
}
func TestShouldNotEnablePasswdModifyExtensionOrControlTypesPooled(t *testing.T) {
@@ -962,7 +962,7 @@ func TestShouldNotEnablePasswdModifyExtensionOrControlTypesPooled(t *testing.T)
assert.False(t, provider.features.ControlTypes.MsftPwdPolHints)
assert.False(t, provider.features.ControlTypes.MsftPwdPolHintsDeprecated)
- assert.NoError(t, provider.Shutdown())
+ assert.NoError(t, provider.Close())
}
func TestShouldReturnCheckServerConnectError(t *testing.T) {
@@ -1092,7 +1092,7 @@ func TestShouldReturnCheckServerSearchErrorPooled(t *testing.T) {
assert.False(t, provider.features.Extensions.PwdModifyExOp)
- assert.NoError(t, provider.Shutdown())
+ assert.NoError(t, provider.Close())
}
func TestShouldPermitRootDSEFailure(t *testing.T) {
@@ -1187,7 +1187,7 @@ func TestShouldPermitRootDSEFailurePooled(t *testing.T) {
)
assert.NoError(t, provider.StartupCheck())
- assert.NoError(t, provider.Shutdown())
+ assert.NoError(t, provider.Close())
}
type SearchRequestMatcher struct {
diff --git a/internal/authentication/user_provider.go b/internal/authentication/user_provider.go
index 9d56e65cd..b1addebc9 100644
--- a/internal/authentication/user_provider.go
+++ b/internal/authentication/user_provider.go
@@ -22,5 +22,5 @@ type UserProvider interface {
// ChangePassword is used to change a user's password but requires their old password to be successfully verified.
ChangePassword(username string, oldPassword string, newPassword string) (err error)
- Shutdown() (err error)
+ Close() (err error)
}
diff --git a/internal/commands/const.go b/internal/commands/const.go
index ffe416472..095f21de0 100644
--- a/internal/commands/const.go
+++ b/internal/commands/const.go
@@ -926,8 +926,6 @@ new one is compatible for and retrofitting it would be incredibly difficult.`
)
const (
- fmtLogServerListening = "Listening for %s connections on '%s' path '%s'"
-
fmtYAMLConfigTemplateHeader = `
---
##
@@ -947,28 +945,6 @@ const (
)
const (
- logFieldService = "service"
- logFieldFile = "file"
- logFieldOP = "op"
-
- serviceTypeServer = "server"
- serviceTypeWatcher = "watcher"
- serviceTypeSignal = "signal"
-
- logFieldProvider = "provider"
- logMessageStartupCheckError = "Error occurred running a startup check"
- logMessageStartupCheckPerforming = "Performing Startup Check"
- logMessageStartupCheckSuccess = "Startup Check Completed Successfully"
-
- providerNameNTP = "ntp"
- providerNameStorage = "storage"
- providerNameUser = "user"
- providerNameNotification = "notification"
- providerNameExpressions = "expressions"
- providerNameWebAuthnMetaData = "webauthn-metadata"
-)
-
-const (
wordYes = "Yes"
wordNo = "No"
)
diff --git a/internal/commands/context.go b/internal/commands/context.go
index 77e9c8491..a3dd32f4f 100644
--- a/internal/commands/context.go
+++ b/internal/commands/context.go
@@ -15,27 +15,14 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"
- "github.com/authelia/authelia/v4/internal/authentication"
- "github.com/authelia/authelia/v4/internal/authorization"
- "github.com/authelia/authelia/v4/internal/clock"
"github.com/authelia/authelia/v4/internal/configuration"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/configuration/validator"
- "github.com/authelia/authelia/v4/internal/expression"
"github.com/authelia/authelia/v4/internal/logging"
- "github.com/authelia/authelia/v4/internal/metrics"
"github.com/authelia/authelia/v4/internal/middlewares"
- "github.com/authelia/authelia/v4/internal/notification"
- "github.com/authelia/authelia/v4/internal/ntp"
- "github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/random"
- "github.com/authelia/authelia/v4/internal/regulation"
- "github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/storage"
- "github.com/authelia/authelia/v4/internal/templates"
- "github.com/authelia/authelia/v4/internal/totp"
"github.com/authelia/authelia/v4/internal/utils"
- "github.com/authelia/authelia/v4/internal/webauthn"
)
// NewCmdCtx returns a new CmdCtx.
@@ -44,7 +31,7 @@ func NewCmdCtx() *CmdCtx {
return &CmdCtx{
Context: ctx,
- log: logging.Logger(),
+ log: logrus.NewEntry(logging.Logger()),
providers: middlewares.Providers{
Random: &random.Cryptographical{},
},
@@ -56,7 +43,7 @@ func NewCmdCtx() *CmdCtx {
type CmdCtx struct {
context.Context
- log *logrus.Logger
+ log *logrus.Entry
config *schema.Configuration
providers middlewares.Providers
@@ -87,7 +74,7 @@ type CmdCtxConfig struct {
type CobraRunECmd func(cmd *cobra.Command, args []string) (err error)
// GetLogger returns the *logrus.Logger satisfying part of the ServiceCtx.
-func (ctx *CmdCtx) GetLogger() *logrus.Logger {
+func (ctx *CmdCtx) GetLogger() *logrus.Entry {
return ctx.log
}
@@ -154,50 +141,11 @@ func (ctx *CmdCtx) LoadTrustedCertificates() (warns, errs []error) {
// LoadProviders loads all providers into the CmdCtx.
func (ctx *CmdCtx) LoadProviders() (warns, errs []error) {
- // TODO: Adjust this so the CertPool can be used like a provider.
if warns, errs = ctx.LoadTrustedCertificates(); len(warns) != 0 || len(errs) != 0 {
return warns, errs
}
- ctx.providers.StorageProvider = getStorageProvider(ctx)
-
- ctx.providers.Authorizer = authorization.NewAuthorizer(ctx.config)
- ctx.providers.NTP = ntp.NewProvider(&ctx.config.NTP)
- ctx.providers.PasswordPolicy = middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy)
- ctx.providers.Regulator = regulation.NewRegulator(ctx.config.Regulation, ctx.providers.StorageProvider, clock.New())
- ctx.providers.SessionProvider = session.NewProvider(ctx.config.Session, ctx.trusted)
- ctx.providers.TOTP = totp.NewTimeBasedProvider(ctx.config.TOTP)
- ctx.providers.UserAttributeResolver = expression.NewUserAttributes(ctx.config)
-
- var err error
-
- switch {
- case ctx.config.AuthenticationBackend.File != nil:
- ctx.providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File)
- case ctx.config.AuthenticationBackend.LDAP != nil:
- ctx.providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted)
- }
-
- if ctx.providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil {
- errs = append(errs, err)
- }
-
- if ctx.providers.MetaDataService, err = webauthn.NewMetaDataProvider(ctx.config, ctx.providers.StorageProvider); err != nil {
- errs = append(errs, err)
- }
-
- switch {
- case ctx.config.Notifier.SMTP != nil:
- ctx.providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted)
- case ctx.config.Notifier.FileSystem != nil:
- ctx.providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem)
- }
-
- ctx.providers.OpenIDConnect = oidc.NewOpenIDConnectProvider(ctx.config, ctx.providers.StorageProvider, ctx.providers.Templates)
-
- if ctx.config.Telemetry.Metrics.Enabled {
- ctx.providers.Metrics = metrics.NewPrometheus()
- }
+ ctx.providers, warns, errs = middlewares.NewProviders(ctx.config, ctx.trusted)
return warns, errs
}
diff --git a/internal/commands/helpers.go b/internal/commands/helpers.go
index 84ff9074c..af889af1c 100644
--- a/internal/commands/helpers.go
+++ b/internal/commands/helpers.go
@@ -13,16 +13,7 @@ import (
)
func getStorageProvider(ctx *CmdCtx) (provider storage.Provider) {
- switch {
- case ctx.config.Storage.PostgreSQL != nil:
- return storage.NewPostgreSQLProvider(ctx.config, ctx.trusted)
- case ctx.config.Storage.MySQL != nil:
- return storage.NewMySQLProvider(ctx.config, ctx.trusted)
- case ctx.config.Storage.Local != nil:
- return storage.NewSQLiteProvider(ctx.config)
- default:
- return nil
- }
+ return storage.NewProvider(ctx.config, ctx.trusted)
}
func containsIdentifier(identifier model.UserOpaqueIdentifier, identifiers []model.UserOpaqueIdentifier) bool {
diff --git a/internal/commands/root.go b/internal/commands/root.go
index e294cccfd..690779d07 100644
--- a/internal/commands/root.go
+++ b/internal/commands/root.go
@@ -1,13 +1,15 @@
package commands
import (
+ "errors"
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/authelia/authelia/v4/internal/logging"
- "github.com/authelia/authelia/v4/internal/model"
+ "github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/service"
"github.com/authelia/authelia/v4/internal/utils"
)
@@ -82,102 +84,22 @@ func (ctx *CmdCtx) RootRunE(_ *cobra.Command, _ []string) (err error) {
ctx.log.Error(err)
}
- ctx.log.Fatalf("Errors occurred provisioning providers.")
+ ctx.log.Fatal("Errors occurred provisioning providers")
}
- doStartupChecks(ctx)
+ if err = ctx.providers.StartupChecks(ctx, true); err != nil {
+ var scerr *middlewares.ErrProviderStartupCheck
- ctx.cconfig = nil
-
- ctx.log.Trace("Starting Services")
-
- servicesRun(ctx)
-
- return nil
-}
-
-func doStartupChecks(ctx *CmdCtx) {
- var (
- failures []string
- err error
- )
-
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameStorage}).Trace(logMessageStartupCheckPerforming)
-
- if err = doStartupCheck(ctx, providerNameStorage, ctx.providers.StorageProvider, false); err != nil {
- ctx.log.WithError(err).WithField(logFieldProvider, providerNameStorage).Error(logMessageStartupCheckError)
-
- failures = append(failures, providerNameStorage)
- } else {
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameStorage}).Trace(logMessageStartupCheckSuccess)
- }
-
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameUser}).Trace(logMessageStartupCheckPerforming)
-
- if err = doStartupCheck(ctx, providerNameUser, ctx.providers.UserProvider, false); err != nil {
- ctx.log.WithError(err).WithField(logFieldProvider, providerNameUser).Error(logMessageStartupCheckError)
-
- failures = append(failures, providerNameUser)
- } else {
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameUser}).Trace(logMessageStartupCheckSuccess)
- }
-
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNotification}).Trace(logMessageStartupCheckPerforming)
-
- if err = doStartupCheck(ctx, providerNameNotification, ctx.providers.Notifier, ctx.config.Notifier.DisableStartupCheck); err != nil {
- ctx.log.WithError(err).WithField(logFieldProvider, providerNameNotification).Error(logMessageStartupCheckError)
-
- failures = append(failures, providerNameNotification)
- } else {
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNotification}).Trace(logMessageStartupCheckSuccess)
- }
-
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNTP}).Trace(logMessageStartupCheckPerforming)
-
- if err = doStartupCheck(ctx, providerNameNTP, ctx.providers.NTP, ctx.config.NTP.DisableStartupCheck); err != nil {
- if !ctx.config.NTP.DisableFailure {
- ctx.log.WithError(err).WithField(logFieldProvider, providerNameNTP).Error(logMessageStartupCheckError)
-
- failures = append(failures, providerNameNTP)
+ if errors.As(err, &scerr) {
+ ctx.GetLogger().WithField("providers", scerr.Failed()).Fatalf("One or more providers had fatal failures performing startup checks, for more details check the error level logs")
} else {
- ctx.log.WithError(err).WithField(logFieldProvider, providerNameNTP).Warn(logMessageStartupCheckError)
+ ctx.log.Fatal("Errors occurred performing startup checks")
}
- } else {
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameNTP}).Trace(logMessageStartupCheckSuccess)
- }
-
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameExpressions}).Trace(logMessageStartupCheckPerforming)
-
- if err = doStartupCheck(ctx, providerNameExpressions, ctx.providers.UserAttributeResolver, false); err != nil {
- ctx.log.WithError(err).WithField(logFieldProvider, providerNameExpressions).Error(logMessageStartupCheckError)
-
- failures = append(failures, providerNameExpressions)
- } else {
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameExpressions}).Trace(logMessageStartupCheckSuccess)
- }
-
- if err = doStartupCheck(ctx, providerNameWebAuthnMetaData, ctx.providers.MetaDataService, !ctx.config.WebAuthn.Metadata.Enabled || ctx.providers.MetaDataService == nil); err != nil {
- ctx.log.WithError(err).WithField(logFieldProvider, providerNameWebAuthnMetaData).Error(logMessageStartupCheckError)
-
- failures = append(failures, providerNameWebAuthnMetaData)
- } else {
- ctx.log.WithFields(map[string]any{logFieldProvider: providerNameWebAuthnMetaData}).Trace("Startup Check Completed Successfully")
}
- if len(failures) != 0 {
- ctx.log.WithField("providers", failures).Fatalf("One or more providers had fatal failures performing startup checks, for more detail check the error level logs")
- }
-}
-
-func doStartupCheck(ctx *CmdCtx, name string, provider model.StartupCheck, disabled bool) error {
- if disabled {
- ctx.log.Debugf("%s provider: startup check skipped as it is disabled", name)
- return nil
- }
+ ctx.cconfig = nil
- if provider == nil {
- return fmt.Errorf("unrecognized provider or it is not configured properly")
- }
+ ctx.log.Trace("Starting Services")
- return provider.StartupCheck()
+ return service.RunAll(ctx)
}
diff --git a/internal/commands/services.go b/internal/commands/services.go
deleted file mode 100644
index 18598b488..000000000
--- a/internal/commands/services.go
+++ /dev/null
@@ -1,463 +0,0 @@
-package commands
-
-import (
- "context"
- "fmt"
- "net"
- "os"
- "os/signal"
- "path/filepath"
- "strings"
- "sync"
- "syscall"
- "time"
-
- "github.com/fsnotify/fsnotify"
- "github.com/sirupsen/logrus"
- "github.com/valyala/fasthttp"
- "golang.org/x/sync/errgroup"
-
- "github.com/authelia/authelia/v4/internal/authentication"
- "github.com/authelia/authelia/v4/internal/configuration/schema"
- "github.com/authelia/authelia/v4/internal/logging"
- "github.com/authelia/authelia/v4/internal/middlewares"
- "github.com/authelia/authelia/v4/internal/server"
-)
-
-// NewServerService creates a new ServerService with the appropriate logger etc.
-func NewServerService(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Logger) (service *ServerService) {
- return &ServerService{
- name: name,
- server: server,
- listener: listener,
- paths: paths,
- isTLS: isTLS,
- log: log.WithFields(map[string]any{logFieldService: serviceTypeServer, serviceTypeServer: name}),
- }
-}
-
-// NewFileWatcherService creates a new FileWatcherService with the appropriate logger etc.
-func NewFileWatcherService(name, path string, reload ProviderReload, log *logrus.Logger) (service *FileWatcherService, err error) {
- if path == "" {
- return nil, fmt.Errorf("path must be specified")
- }
-
- var info os.FileInfo
-
- if info, err = os.Stat(path); err != nil {
- return nil, fmt.Errorf("error stating file '%s': %w", path, err)
- }
-
- if path, err = filepath.Abs(path); err != nil {
- return nil, fmt.Errorf("error determining absolute path of file '%s': %w", path, err)
- }
-
- var watcher *fsnotify.Watcher
-
- if watcher, err = fsnotify.NewWatcher(); err != nil {
- return nil, err
- }
-
- entry := log.WithFields(map[string]any{logFieldService: serviceTypeWatcher, serviceTypeWatcher: name})
-
- if info.IsDir() {
- service = &FileWatcherService{
- name: name,
- watcher: watcher,
- reload: reload,
- log: entry,
- directory: filepath.Clean(path),
- }
- } else {
- service = &FileWatcherService{
- name: name,
- watcher: watcher,
- reload: reload,
- log: entry,
- directory: filepath.Dir(path),
- file: filepath.Base(path),
- }
- }
-
- if err = service.watcher.Add(service.directory); err != nil {
- return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err)
- }
-
- return service, nil
-}
-
-// NewSignalService creates a new SignalService with the appropriate logger etc.
-func NewSignalService(name string, action func() (err error), log *logrus.Logger, signals ...os.Signal) (service *SignalService) {
- return &SignalService{
- name: name,
- signals: signals,
- action: action,
- log: log.WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: name}),
- }
-}
-
-type ServiceCtx interface {
- GetLogger() *logrus.Logger
- GetProviders() middlewares.Providers
- GetConfiguration() *schema.Configuration
-
- context.Context
-}
-
-// ProviderReload represents the required methods to support reloading a provider.
-type ProviderReload interface {
- Reload() (reloaded bool, err error)
-}
-
-// Service represents the required methods to support handling a service.
-type Service interface {
- // ServiceType returns the type name for the Service.
- ServiceType() string
-
- // ServiceName returns the individual name for the Service.
- ServiceName() string
-
- // Run performs the running operations for the Service.
- Run() (err error)
-
- // Shutdown perform the shutdown cleanup and termination operations for the Service.
- Shutdown()
-
- // Log returns the logger configured for the service.
- Log() *logrus.Entry
-}
-
-// ServerService is a Service which runs a web server.
-type ServerService struct {
- name string
- server *fasthttp.Server
- paths []string
- isTLS bool
- listener net.Listener
- log *logrus.Entry
-}
-
-// ServiceType returns the service type for this service, which is always 'server'.
-func (service *ServerService) ServiceType() string {
- return serviceTypeServer
-}
-
-// ServiceName returns the individual name for this service.
-func (service *ServerService) ServiceName() string {
- return service.name
-}
-
-// Run the ServerService.
-func (service *ServerService) Run() (err error) {
- defer func() {
- if r := recover(); r != nil {
- service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
- }
- }()
-
- service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '"))
-
- if err = service.server.Serve(service.listener); err != nil {
- service.log.WithError(err).Error("Error returned attempting to serve requests")
-
- return err
- }
-
- return nil
-}
-
-// Shutdown the ServerService.
-func (service *ServerService) Shutdown() {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
-
- defer cancel()
-
- if err := service.server.ShutdownWithContext(ctx); err != nil {
- service.log.WithError(err).Error("Error occurred during shutdown")
- }
-}
-
-// Log returns the *logrus.Entry of the ServerService.
-func (service *ServerService) Log() *logrus.Entry {
- return service.log
-}
-
-// FileWatcherService is a Service that watches files for changes.
-type FileWatcherService struct {
- name string
-
- watcher *fsnotify.Watcher
- reload ProviderReload
-
- log *logrus.Entry
- file string
- directory string
-}
-
-// ServiceType returns the service type for this service, which is always 'watcher'.
-func (service *FileWatcherService) ServiceType() string {
- return serviceTypeWatcher
-}
-
-// ServiceName returns the individual name for this service.
-func (service *FileWatcherService) ServiceName() string {
- return service.name
-}
-
-// Run the FileWatcherService.
-func (service *FileWatcherService) Run() (err error) {
- defer func() {
- if r := recover(); r != nil {
- service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
- }
- }()
-
- service.log.WithField(logFieldFile, filepath.Join(service.directory, service.file)).Info("Watching file for changes")
-
- for {
- select {
- case event, ok := <-service.watcher.Events:
- if !ok {
- return nil
- }
-
- log := service.log.WithFields(map[string]any{logFieldFile: event.Name, logFieldOP: event.Op})
-
- if service.file != "" && service.file != filepath.Base(event.Name) {
- log.Trace("File modification detected to irrelevant file")
- break
- }
-
- switch {
- case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
- log.Debug("File modification was detected")
-
- var reloaded bool
-
- switch reloaded, err = service.reload.Reload(); {
- case err != nil:
- log.WithError(err).Error("Error occurred during reload")
- case reloaded:
- log.Info("Reloaded successfully")
- default:
- log.Debug("Reload was triggered but it was skipped")
- }
- case event.Op&fsnotify.Remove == fsnotify.Remove:
- log.Debug("File remove was detected")
- }
- case err, ok := <-service.watcher.Errors:
- if !ok {
- return nil
- }
-
- service.log.WithError(err).Error("Error while watching file for changes")
- }
- }
-}
-
-// Shutdown the FileWatcherService.
-func (service *FileWatcherService) Shutdown() {
- if err := service.watcher.Close(); err != nil {
- service.log.WithError(err).Error("Error occurred during shutdown")
- }
-}
-
-// Log returns the *logrus.Entry of the FileWatcherService.
-func (service *FileWatcherService) Log() *logrus.Entry {
- return service.log
-}
-
-// SignalService is a Service which performs actions on signals.
-type SignalService struct {
- name string
- signals []os.Signal
- action func() (err error)
- log *logrus.Entry
-
- notify chan os.Signal
- quit chan struct{}
-}
-
-// ServiceType returns the service type for this service, which is always 'server'.
-func (service *SignalService) ServiceType() string {
- return serviceTypeSignal
-}
-
-// ServiceName returns the individual name for this service.
-func (service *SignalService) ServiceName() string {
- return service.name
-}
-
-// Run the ServerService.
-func (service *SignalService) Run() (err error) {
- service.quit = make(chan struct{})
-
- service.notify = make(chan os.Signal, 1)
-
- signal.Notify(service.notify, service.signals...)
-
- for {
- select {
- case s := <-service.notify:
- if err = service.action(); err != nil {
- service.log.WithError(err).Error("Error occurred executing service action.")
- } else {
- service.log.WithFields(map[string]any{"signal-received": s.String()}).Debug("Successfully executed service action.")
- }
- case <-service.quit:
- return
- }
- }
-}
-
-// Shutdown the ServerService.
-func (service *SignalService) Shutdown() {
- signal.Stop(service.notify)
-
- service.quit <- struct{}{}
-}
-
-// Log returns the *logrus.Entry of the ServerService.
-func (service *SignalService) Log() *logrus.Entry {
- return service.log
-}
-
-func svcSvrMainFunc(ctx ServiceCtx) (service Service) {
- switch svr, listener, paths, isTLS, err := server.CreateDefaultServer(ctx.GetConfiguration(), ctx.GetProviders()); {
- case err != nil:
- ctx.GetLogger().WithError(err).Fatal("Create Server Service (main) returned error")
- case svr != nil && listener != nil:
- service = NewServerService("main", svr, listener, paths, isTLS, ctx.GetLogger())
- default:
- ctx.GetLogger().Fatal("Create Server Service (main) failed")
- }
-
- return service
-}
-
-func svcSvrMetricsFunc(ctx ServiceCtx) (service Service) {
- switch svr, listener, paths, isTLS, err := server.CreateMetricsServer(ctx.GetConfiguration(), ctx.GetProviders()); {
- case err != nil:
- ctx.GetLogger().WithError(err).Fatal("Create Server Service (metrics) returned error")
- case svr != nil && listener != nil:
- service = NewServerService("metrics", svr, listener, paths, isTLS, ctx.GetLogger())
- default:
- ctx.GetLogger().Debug("Create Server Service (metrics) skipped")
- }
-
- return service
-}
-
-func svcWatcherUsersFunc(ctx ServiceCtx) (service Service) {
- var err error
-
- config := ctx.GetConfiguration()
-
- if config.AuthenticationBackend.File != nil && config.AuthenticationBackend.File.Watch {
- provider := ctx.GetProviders().UserProvider.(*authentication.FileUserProvider)
-
- if service, err = NewFileWatcherService("users", config.AuthenticationBackend.File.Path, provider, ctx.GetLogger()); err != nil {
- ctx.GetLogger().WithError(err).Fatal("Create Watcher Service (users) returned error")
- }
- }
-
- return service
-}
-
-func svcSignalLogReOpenFunc(ctx ServiceCtx) (service Service) {
- config := ctx.GetConfiguration()
-
- if config.Log.FilePath == "" {
- return nil
- }
-
- return NewSignalService("log-reload", logging.Reopen, ctx.GetLogger(), syscall.SIGHUP)
-}
-
-func connectionType(isTLS bool) string {
- if isTLS {
- return "TLS"
- }
-
- return "non-TLS"
-}
-
-func servicesRun(ctx ServiceCtx) {
- cctx, cancel := context.WithCancel(ctx)
-
- group, cctx := errgroup.WithContext(cctx)
-
- defer cancel()
-
- quit := make(chan os.Signal, 1)
-
- signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
-
- defer signal.Stop(quit)
-
- var (
- services []Service
- )
-
- for _, serviceFunc := range []func(ctx ServiceCtx) Service{
- svcSvrMainFunc, svcSvrMetricsFunc,
- svcWatcherUsersFunc, svcSignalLogReOpenFunc,
- } {
- if service := serviceFunc(ctx); service != nil {
- service.Log().Trace("Service Loaded")
-
- services = append(services, service)
-
- group.Go(service.Run)
- }
- }
-
- ctx.GetLogger().Info("Startup complete")
-
- select {
- case s := <-quit:
- ctx.GetLogger().WithField("signal", s.String()).Debug("Shutdown initiated due to process signal")
- case <-cctx.Done():
- ctx.GetLogger().Debug("Shutdown initiated due to context completion")
- }
-
- cancel()
-
- ctx.GetLogger().Info("Shutdown initiated")
-
- wgShutdown := &sync.WaitGroup{}
-
- ctx.GetLogger().Tracef("Shutdown of %d services is required", len(services))
-
- for _, service := range services {
- wgShutdown.Add(1)
-
- go func(service Service) {
- service.Log().Trace("Shutdown of service initiated")
-
- service.Shutdown()
-
- wgShutdown.Done()
-
- service.Log().Trace("Shutdown of service complete")
- }(service)
- }
-
- wgShutdown.Wait()
-
- var err error
-
- if err = ctx.GetProviders().UserProvider.Shutdown(); err != nil {
- ctx.GetLogger().WithError(err).Error("Error occurred closing authentication connections")
- }
-
- if err = ctx.GetProviders().StorageProvider.Close(); err != nil {
- ctx.GetLogger().WithError(err).Error("Error occurred closing database connections")
- }
-
- if err = group.Wait(); err != nil {
- ctx.GetLogger().WithError(err).Error("Error occurred waiting for shutdown")
- }
-
- ctx.GetLogger().Info("Shutdown complete")
-}
diff --git a/internal/commands/util.go b/internal/commands/util.go
index 8686c2303..4950f38a8 100644
--- a/internal/commands/util.go
+++ b/internal/commands/util.go
@@ -21,19 +21,6 @@ import (
"github.com/authelia/authelia/v4/internal/utils"
)
-func recoverErr(i any) error {
- switch v := i.(type) {
- case nil:
- return nil
- case string:
- return fmt.Errorf("recovered panic: %s", v)
- case error:
- return fmt.Errorf("recovered panic: %w", v)
- default:
- return fmt.Errorf("recovered panic with unknown type: %v", v)
- }
-}
-
func flagsGetUserIdentifiersGenerateOptions(flags *pflag.FlagSet) (users, services, sectors []string, err error) {
if users, err = flags.GetStringSlice(cmdFlagNameUsers); err != nil {
return nil, nil, nil, err
@@ -72,7 +59,7 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar
}
switch {
- case useCharSet, !useCharSet && !useCharacters:
+ case useCharSet, !useCharacters:
var c string
if c, err = flags.GetString(flagNameCharSet); err != nil {
@@ -107,7 +94,7 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar
default:
return "", fmt.Errorf("flag '--%s' with value '%s' is invalid, must be one of 'ascii', 'alphanumeric', 'alphabetic', 'numeric', 'numeric-hex', or 'rfc3986'", flagNameCharSet, c)
}
- case useCharacters:
+ default:
if charset, err = flags.GetString(flagNameCharacters); err != nil {
return "", err
}
diff --git a/internal/configuration/koanf_util.go b/internal/configuration/koanf_util.go
index d0a0b1b0a..2b123d978 100644
--- a/internal/configuration/koanf_util.go
+++ b/internal/configuration/koanf_util.go
@@ -2,6 +2,7 @@ package configuration
import (
"fmt"
+ "sort"
"strings"
"github.com/knadh/koanf/providers/confmap"
@@ -35,6 +36,8 @@ func koanfGetKeys(ko *koanf.Koanf) (keys []string) {
}
}
+ sort.Strings(keys)
+
return keys
}
diff --git a/internal/configuration/test_resources/config_with_definitions.yml b/internal/configuration/test_resources/config_with_definitions.yml
index 99449df2d..cb34f4741 100644
--- a/internal/configuration/test_resources/config_with_definitions.yml
+++ b/internal/configuration/test_resources/config_with_definitions.yml
@@ -1,6 +1,4 @@
---
-default_redirection_url: 'https://home.example.com:8080/'
-
server:
address: 'tcp://127.0.0.1:9091'
endpoints:
@@ -161,7 +159,10 @@ session:
name: 'authelia_session'
expiration: '1h' # 1 hour
inactivity: '5m' # 5 minutes
- domain: 'example.com'
+ cookies:
+ - domain: 'example.com'
+ default_redirection_url: 'https://home.example.com:8080/'
+ authelia_url: 'https://auth.example.com'
redis:
host: '127.0.0.1'
port: 6379
diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go
index 17ca8a81c..06587f05c 100644
--- a/internal/middlewares/authelia_context.go
+++ b/internal/middlewares/authelia_context.go
@@ -667,6 +667,11 @@ func (ctx *AutheliaCtx) GetConfiguration() (config schema.Configuration) {
return ctx.Configuration
}
+// GetProviders returns the providers for this context.
+func (ctx *AutheliaCtx) GetProviders() (providers Providers) {
+ return ctx.Providers
+}
+
func (ctx *AutheliaCtx) GetWebAuthnProvider() (w *webauthn.WebAuthn, err error) {
var (
origin *url.URL
diff --git a/internal/middlewares/authelia_context_test.go b/internal/middlewares/authelia_context_test.go
index 60c972f0c..54451890a 100644
--- a/internal/middlewares/authelia_context_test.go
+++ b/internal/middlewares/authelia_context_test.go
@@ -423,6 +423,14 @@ func TestShouldDetectNonXHR(t *testing.T) {
assert.False(t, mock.Ctx.IsXHR())
}
+func TestAutheliaCtxMisc(t *testing.T) {
+ ctx := middlewares.NewAutheliaCtx(&fasthttp.RequestCtx{}, schema.Configuration{}, middlewares.Providers{})
+
+ assert.NotNil(t, ctx.GetConfiguration())
+ assert.NotNil(t, ctx.GetProviders())
+ assert.NotNil(t, ctx.GetLogger())
+}
+
func TestShouldReturnCorrectSecondFactorMethods(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
defer mock.Close()
diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go
index 25bf67093..7fb3bb30b 100644
--- a/internal/middlewares/const.go
+++ b/internal/middlewares/const.go
@@ -95,6 +95,20 @@ const (
UserValueRouterKeyExtAuthzPath = "extauthz"
)
+const (
+ LogFieldProvider = "provider"
+ LogMessageStartupCheckError = "Error occurred running a startup check"
+ LogMessageStartupCheckPerforming = "Performing Startup Check"
+ LogMessageStartupCheckSuccess = "Startup Check Completed Successfully"
+
+ ProviderNameNTP = "ntp"
+ ProviderNameStorage = "storage"
+ ProviderNameUser = "user"
+ ProviderNameNotification = "notification"
+ ProviderNameExpressions = "expressions"
+ ProviderNameWebAuthnMetaData = "webauthn-metadata"
+)
+
var (
protoHTTPS = []byte(strProtoHTTPS)
protoHTTP = []byte(strProtoHTTP)
diff --git a/internal/middlewares/startup.go b/internal/middlewares/startup.go
new file mode 100644
index 000000000..0664be342
--- /dev/null
+++ b/internal/middlewares/startup.go
@@ -0,0 +1,124 @@
+package middlewares
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/authelia/authelia/v4/internal/model"
+ "github.com/authelia/authelia/v4/internal/utils"
+)
+
+func (p *Providers) StartupChecks(ctx Context, log bool) (err error) {
+ e := &ErrProviderStartupCheck{errors: map[string]error{}}
+
+ var (
+ disable bool
+ provider model.StartupCheck
+ )
+
+ provider, disable = ctx.GetProviders().StorageProvider, false
+ doStartupCheck(ctx, ProviderNameStorage, provider, disable, log, e.errors)
+
+ provider, disable = ctx.GetProviders().UserProvider, false
+ doStartupCheck(ctx, ProviderNameUser, provider, disable, log, e.errors)
+
+ provider, disable = ctx.GetProviders().Notifier, false
+ doStartupCheck(ctx, ProviderNameNotification, provider, disable, log, e.errors)
+
+ provider, disable = ctx.GetProviders().NTP, ctx.GetConfiguration().NTP.DisableStartupCheck
+ doStartupCheck(ctx, ProviderNameNTP, provider, disable, log, e.errors)
+
+ provider, disable = ctx.GetProviders().UserAttributeResolver, false
+ doStartupCheck(ctx, ProviderNameExpressions, provider, disable, log, e.errors)
+
+ provider = ctx.GetProviders().MetaDataService
+ disable = !ctx.GetConfiguration().WebAuthn.Metadata.Enabled || ctx.GetProviders().MetaDataService == nil
+ doStartupCheck(ctx, ProviderNameWebAuthnMetaData, provider, disable, log, e.errors)
+
+ var filters []string
+
+ if ctx.GetConfiguration().NTP.DisableFailure {
+ filters = append(filters, ProviderNameNTP)
+ }
+
+ return e.FilterError(filters...)
+}
+
+func doStartupCheck(ctx Context, name string, provider model.StartupCheck, disabled, log bool, errors map[string]error) {
+ if log {
+ ctx.GetLogger().WithFields(map[string]any{LogFieldProvider: name}).Trace(LogMessageStartupCheckPerforming)
+ }
+
+ if disabled {
+ if log {
+ ctx.GetLogger().Debugf("%s provider: startup check skipped as it is disabled", name)
+ }
+
+ return
+ }
+
+ if provider == nil {
+ errors[name] = fmt.Errorf("unrecognized provider or it is not configured properly")
+
+ return
+ }
+
+ var err error
+
+ if err = provider.StartupCheck(); err != nil {
+ if log {
+ ctx.GetLogger().WithError(err).WithField(LogFieldProvider, name).Error(LogMessageStartupCheckError)
+ }
+
+ errors[name] = err
+
+ return
+ }
+
+ if log {
+ ctx.GetLogger().WithFields(map[string]any{LogFieldProvider: name}).Trace("Startup Check Completed Successfully")
+ }
+}
+
+type ErrProviderStartupCheck struct {
+ errors map[string]error
+}
+
+func (e *ErrProviderStartupCheck) Error() string {
+ keys := make([]string, 0, len(e.errors))
+ for k := range e.errors {
+ keys = append(keys, k)
+ }
+
+ return fmt.Sprintf("errors occurred performing checks on the '%s' providers", strings.Join(keys, ", "))
+}
+
+func (e *ErrProviderStartupCheck) Failed() (failed []string) {
+ for key := range e.errors {
+ failed = append(failed, key)
+ }
+
+ return failed
+}
+
+func (e *ErrProviderStartupCheck) FilterError(providers ...string) error {
+ filtered := map[string]error{}
+
+ for provider, err := range e.errors {
+ if utils.IsStringInSlice(provider, providers) {
+ continue
+ }
+
+ filtered[provider] = err
+ }
+
+ if len(filtered) == 0 {
+ return nil
+ }
+
+ return &ErrProviderStartupCheck{errors: filtered}
+}
+
+func (e *ErrProviderStartupCheck) ErrorMap() map[string]error {
+ return e.errors
+}
diff --git a/internal/middlewares/timing_attack_delay_test.go b/internal/middlewares/timing_attack_delay_test.go
index 976be04a8..8548d0e27 100644
--- a/internal/middlewares/timing_attack_delay_test.go
+++ b/internal/middlewares/timing_attack_delay_test.go
@@ -47,7 +47,7 @@ func TestTimingAttackDelayCalculations(t *testing.T) {
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
ctx := &AutheliaCtx{
- Logger: logging.Logger().WithFields(logrus.Fields{}),
+ Logger: logrus.NewEntry(logging.Logger()),
Providers: Providers{
Random: &random.Cryptographical{},
},
diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go
index c099c322e..dbace4014 100644
--- a/internal/middlewares/types.go
+++ b/internal/middlewares/types.go
@@ -1,6 +1,8 @@
package middlewares
import (
+ "context"
+
"github.com/sirupsen/logrus"
"github.com/valyala/fasthttp"
@@ -54,6 +56,14 @@ type Providers struct {
MetaDataService webauthn.MetaDataProvider
}
+type Context interface {
+ GetLogger() *logrus.Entry
+ GetProviders() Providers
+ GetConfiguration() *schema.Configuration
+
+ context.Context
+}
+
// RequestHandler represents an Authelia request handler.
type RequestHandler = func(*AutheliaCtx)
diff --git a/internal/middlewares/util.go b/internal/middlewares/util.go
index 03b80f557..7a7eb39e4 100644
--- a/internal/middlewares/util.go
+++ b/internal/middlewares/util.go
@@ -1,7 +1,26 @@
package middlewares
import (
+ "crypto/x509"
+
"github.com/valyala/fasthttp"
+
+ "github.com/authelia/authelia/v4/internal/authentication"
+ "github.com/authelia/authelia/v4/internal/authorization"
+ "github.com/authelia/authelia/v4/internal/clock"
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/expression"
+ "github.com/authelia/authelia/v4/internal/metrics"
+ "github.com/authelia/authelia/v4/internal/notification"
+ "github.com/authelia/authelia/v4/internal/ntp"
+ "github.com/authelia/authelia/v4/internal/oidc"
+ "github.com/authelia/authelia/v4/internal/random"
+ "github.com/authelia/authelia/v4/internal/regulation"
+ "github.com/authelia/authelia/v4/internal/session"
+ "github.com/authelia/authelia/v4/internal/storage"
+ "github.com/authelia/authelia/v4/internal/templates"
+ "github.com/authelia/authelia/v4/internal/totp"
+ "github.com/authelia/authelia/v4/internal/webauthn"
)
// SetContentTypeApplicationJSON sets the Content-Type header to `application/json; charset=utf-8`.
@@ -13,3 +32,48 @@ func SetContentTypeApplicationJSON(ctx *fasthttp.RequestCtx) {
func SetContentTypeTextPlain(ctx *fasthttp.RequestCtx) {
ctx.SetContentTypeBytes(contentTypeTextPlain)
}
+
+// NewProviders provisions all providers based on the configuration provided.
+func NewProviders(config *schema.Configuration, caCertPool *x509.CertPool) (providers Providers, warns, errs []error) {
+ providers.Random = &random.Cryptographical{}
+ providers.StorageProvider = storage.NewProvider(config, caCertPool)
+ providers.Authorizer = authorization.NewAuthorizer(config)
+ providers.NTP = ntp.NewProvider(&config.NTP)
+ providers.PasswordPolicy = NewPasswordPolicyProvider(config.PasswordPolicy)
+ providers.Regulator = regulation.NewRegulator(config.Regulation, providers.StorageProvider, clock.New())
+ providers.SessionProvider = session.NewProvider(config.Session, caCertPool)
+ providers.TOTP = totp.NewTimeBasedProvider(config.TOTP)
+ providers.UserAttributeResolver = expression.NewUserAttributes(config)
+
+ var err error
+
+ switch {
+ case config.AuthenticationBackend.File != nil:
+ providers.UserProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File)
+ case config.AuthenticationBackend.LDAP != nil:
+ providers.UserProvider = authentication.NewLDAPUserProvider(config.AuthenticationBackend, caCertPool)
+ }
+
+ if providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: config.Notifier.TemplatePath}); err != nil {
+ errs = append(errs, err)
+ }
+
+ if providers.MetaDataService, err = webauthn.NewMetaDataProvider(config, providers.StorageProvider); err != nil {
+ errs = append(errs, err)
+ }
+
+ switch {
+ case config.Notifier.SMTP != nil:
+ providers.Notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, caCertPool)
+ case config.Notifier.FileSystem != nil:
+ providers.Notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
+ }
+
+ providers.OpenIDConnect = oidc.NewOpenIDConnectProvider(config, providers.StorageProvider, providers.Templates)
+
+ if config.Telemetry.Metrics.Enabled {
+ providers.Metrics = metrics.NewPrometheus()
+ }
+
+ return providers, warns, errs
+}
diff --git a/internal/mocks/user_provider.go b/internal/mocks/user_provider.go
index 204940221..46f17f3ca 100644
--- a/internal/mocks/user_provider.go
+++ b/internal/mocks/user_provider.go
@@ -69,6 +69,20 @@ func (mr *MockUserProviderMockRecorder) CheckUserPassword(username, password any
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckUserPassword", reflect.TypeOf((*MockUserProvider)(nil).CheckUserPassword), username, password)
}
+// Close mocks base method.
+func (m *MockUserProvider) Close() error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Close")
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Close indicates an expected call of Close.
+func (mr *MockUserProviderMockRecorder) Close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockUserProvider)(nil).Close))
+}
+
// GetDetails mocks base method.
func (m *MockUserProvider) GetDetails(username string) (*authentication.UserDetails, error) {
m.ctrl.T.Helper()
@@ -99,20 +113,6 @@ func (mr *MockUserProviderMockRecorder) GetDetailsExtended(username any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDetailsExtended", reflect.TypeOf((*MockUserProvider)(nil).GetDetailsExtended), username)
}
-// Shutdown mocks base method.
-func (m *MockUserProvider) Shutdown() error {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Shutdown")
- ret0, _ := ret[0].(error)
- return ret0
-}
-
-// Shutdown indicates an expected call of Shutdown.
-func (mr *MockUserProviderMockRecorder) Shutdown() *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockUserProvider)(nil).Shutdown))
-}
-
// StartupCheck mocks base method.
func (m *MockUserProvider) StartupCheck() error {
m.ctrl.T.Helper()
diff --git a/internal/server/handlers.go b/internal/server/handlers.go
index d633f503e..9cff4755c 100644
--- a/internal/server/handlers.go
+++ b/internal/server/handlers.go
@@ -112,10 +112,10 @@ func handleMethodNotAllowed(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString(fmt.Sprintf("%d %s", fasthttp.StatusMethodNotAllowed, fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed)))
}
+type RegisterRoutesBridgedFunc = func(r *router.Router, config *schema.Configuration, providers middlewares.Providers, bridge middlewares.Bridge)
+
//nolint:gocyclo
func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
- log := logging.Logger()
-
optsTemplatedFile := NewTemplatedFileOptions(config)
serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile)
@@ -215,32 +215,14 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
switch name {
case "legacy":
- log.
- WithField("path_prefix", pathAuthzLegacy).
- WithField("implementation", endpoint.Implementation).
- WithField("methods", "*").
- Trace("Registering Authz Endpoint")
-
r.ANY(pathAuthzLegacy, handler)
r.ANY(path.Join(pathAuthzLegacy, pathParamAuthzEnvoy), handler)
default:
switch endpoint.Implementation {
case handlers.AuthzImplLegacy.String(), handlers.AuthzImplExtAuthz.String():
- log.
- WithField("path_prefix", uri).
- WithField("implementation", endpoint.Implementation).
- WithField("methods", "*").
- Trace("Registering Authz Endpoint")
-
r.ANY(uri, handler)
r.ANY(path.Join(uri, pathParamAuthzEnvoy), handler)
default:
- log.
- WithField("path", uri).
- WithField("implementation", endpoint.Implementation).
- WithField("methods", []string{fasthttp.MethodGet, fasthttp.MethodHead}).
- Trace("Registering Authz Endpoint")
-
r.GET(uri, handler)
r.HEAD(uri, handler)
}
@@ -367,127 +349,141 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
}
if providers.OpenIDConnect != nil {
- bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares(
- middlewares.SecurityHeadersBase, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore,
- ).Build()
+ RegisterOpenIDConnectRoutes(r, config, providers)
+ }
- r.GET("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentGET))
- r.POST("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentPOST))
+ r.RedirectFixedPath = false
+ r.HandleMethodNotAllowed = true
+ r.MethodNotAllowed = handleMethodNotAllowed
+ r.NotFound = handleNotFound(bridge(serveIndexHandler))
- allowedOrigins := utils.StringSliceFromURLs(config.IdentityProviders.OIDC.CORS.AllowedOrigins)
+ handler := middlewares.LogRequest(r.Handler)
+ if config.Server.Address.RouterPath() != "/" {
+ handler = middlewares.StripPath(config.Server.Address.RouterPath())(handler)
+ }
- r.OPTIONS(oidc.EndpointPathWellKnownOpenIDConfiguration, policyCORSPublicGET.HandleOPTIONS)
- r.GET(oidc.EndpointPathWellKnownOpenIDConfiguration, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "openid_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OpenIDConnectConfigurationWellKnownGET))))
+ handler = middlewares.MultiWrap(handler, middlewares.RecoverPanic, middlewares.NewMetricsRequest(providers.Metrics))
- r.OPTIONS(oidc.EndpointPathWellKnownOAuthAuthorizationServer, policyCORSPublicGET.HandleOPTIONS)
- r.GET(oidc.EndpointPathWellKnownOAuthAuthorizationServer, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "oauth_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OAuthAuthorizationServerWellKnownGET))))
+ return handler
+}
- r.OPTIONS(oidc.EndpointPathJWKs, policyCORSPublicGET.HandleOPTIONS)
- r.GET(oidc.EndpointPathJWKs, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(middlewareAPI(handlers.JSONWebKeySetGET))))
+// RegisterOpenIDConnectRoutes handles registration of OpenID Connect 1.0 routes.
+func RegisterOpenIDConnectRoutes(r *router.Router, config *schema.Configuration, providers middlewares.Providers) {
+ middlewareAPI := middlewares.NewBridgeBuilder(*config, providers).
+ WithPreMiddlewares(middlewares.SecurityHeadersBase, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
+ Build()
- // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
- r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS)
- r.GET("/api/oidc/jwks", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.JSONWebKeySetGET))))
+ policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet).
+ WithAllowedOrigins("*").
+ Build()
- policyCORSAuthorization := middlewares.NewCORSPolicyBuilder().
- WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost).
- WithAllowedOrigins(allowedOrigins...).
- WithEnabled(utils.IsStringInSlice(oidc.EndpointAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)).
- Build()
+ bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares(
+ middlewares.SecurityHeadersBase, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore,
+ ).Build()
- authorization := middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointAuthorization), policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization))))
+ r.GET("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentGET))
+ r.POST("/api/oidc/consent", bridgeOIDC(handlers.OpenIDConnectConsentPOST))
- r.OPTIONS(oidc.EndpointPathAuthorization, policyCORSAuthorization.HandleOnlyOPTIONS)
- r.GET(oidc.EndpointPathAuthorization, authorization)
- r.POST(oidc.EndpointPathAuthorization, authorization)
+ allowedOrigins := utils.StringSliceFromURLs(config.IdentityProviders.OIDC.CORS.AllowedOrigins)
- // TODO (james-d-elliott): Remove in GA. This is a legacy endpoint.
- r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS)
- r.GET("/api/oidc/authorize", authorization)
- r.POST("/api/oidc/authorize", authorization)
+ r.OPTIONS(oidc.EndpointPathWellKnownOpenIDConfiguration, policyCORSPublicGET.HandleOPTIONS)
+ r.GET(oidc.EndpointPathWellKnownOpenIDConfiguration, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "openid_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OpenIDConnectConfigurationWellKnownGET))))
- policyCORSDeviceAuthorization := middlewares.NewCORSPolicyBuilder().
- WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
- WithAllowedOrigins(allowedOrigins...).
- WithEnabled(utils.IsStringInSlice(oidc.EndpointDeviceAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)).
- Build()
+ r.OPTIONS(oidc.EndpointPathWellKnownOAuthAuthorizationServer, policyCORSPublicGET.HandleOPTIONS)
+ r.GET(oidc.EndpointPathWellKnownOAuthAuthorizationServer, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "oauth_configuration"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.OAuthAuthorizationServerWellKnownGET))))
- r.OPTIONS(oidc.EndpointPathDeviceAuthorization, policyCORSDeviceAuthorization.HandleOnlyOPTIONS)
- r.POST(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), policyCORSDeviceAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPOST)))))
- r.PUT(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPUT))))
+ r.OPTIONS(oidc.EndpointPathJWKs, policyCORSPublicGET.HandleOPTIONS)
+ r.GET(oidc.EndpointPathJWKs, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(middlewareAPI(handlers.JSONWebKeySetGET))))
- policyCORSPAR := middlewares.NewCORSPolicyBuilder().
- WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
- WithAllowedOrigins(allowedOrigins...).
- WithEnabled(utils.IsStringInSliceFold(oidc.EndpointPushedAuthorizationRequest, config.IdentityProviders.OIDC.CORS.Endpoints)).
- Build()
+ // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
+ r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS)
+ r.GET("/api/oidc/jwks", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, "jwks"), policyCORSPublicGET.Middleware(bridgeOIDC(handlers.JSONWebKeySetGET))))
- r.OPTIONS(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.HandleOnlyOPTIONS)
- r.POST(oidc.EndpointPathPushedAuthorizationRequest, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointPushedAuthorizationRequest), policyCORSPAR.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectPushedAuthorizationRequest)))))
+ policyCORSAuthorization := middlewares.NewCORSPolicyBuilder().
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost).
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.EndpointAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
- policyCORSToken := middlewares.NewCORSPolicyBuilder().
- WithAllowCredentials(true).
- WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
- WithAllowedOrigins(allowedOrigins...).
- WithEnabled(utils.IsStringInSlice(oidc.EndpointToken, config.IdentityProviders.OIDC.CORS.Endpoints)).
- Build()
+ authorization := middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointAuthorization), policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization))))
- r.OPTIONS(oidc.EndpointPathToken, policyCORSToken.HandleOPTIONS)
- r.POST(oidc.EndpointPathToken, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointToken), policyCORSToken.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST)))))
+ r.OPTIONS(oidc.EndpointPathAuthorization, policyCORSAuthorization.HandleOnlyOPTIONS)
+ r.GET(oidc.EndpointPathAuthorization, authorization)
+ r.POST(oidc.EndpointPathAuthorization, authorization)
- policyCORSUserinfo := middlewares.NewCORSPolicyBuilder().
- WithAllowCredentials(true).
- WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost).
- WithAllowedOrigins(allowedOrigins...).
- WithEnabled(utils.IsStringInSlice(oidc.EndpointUserinfo, config.IdentityProviders.OIDC.CORS.Endpoints)).
- Build()
+ // TODO (james-d-elliott): Remove in GA. This is a legacy endpoint.
+ r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS)
+ r.GET("/api/oidc/authorize", authorization)
+ r.POST("/api/oidc/authorize", authorization)
- r.OPTIONS(oidc.EndpointPathUserinfo, policyCORSUserinfo.HandleOPTIONS)
- r.GET(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))))
- r.POST(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))))
+ policyCORSDeviceAuthorization := middlewares.NewCORSPolicyBuilder().
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.EndpointDeviceAuthorization, config.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
- policyCORSIntrospection := middlewares.NewCORSPolicyBuilder().
- WithAllowCredentials(true).
- WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
- WithAllowedOrigins(allowedOrigins...).
- WithEnabled(utils.IsStringInSlice(oidc.EndpointIntrospection, config.IdentityProviders.OIDC.CORS.Endpoints)).
- Build()
+ r.OPTIONS(oidc.EndpointPathDeviceAuthorization, policyCORSDeviceAuthorization.HandleOnlyOPTIONS)
+ r.POST(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), policyCORSDeviceAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPOST)))))
+ r.PUT(oidc.EndpointPathDeviceAuthorization, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointDeviceAuthorization), bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthDeviceAuthorizationPUT))))
- r.OPTIONS(oidc.EndpointPathIntrospection, policyCORSIntrospection.HandleOPTIONS)
- r.POST(oidc.EndpointPathIntrospection, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))))
+ policyCORSPAR := middlewares.NewCORSPolicyBuilder().
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSliceFold(oidc.EndpointPushedAuthorizationRequest, config.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
- // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
- r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS)
- r.POST("/api/oidc/introspect", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))))
+ r.OPTIONS(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.HandleOnlyOPTIONS)
+ r.POST(oidc.EndpointPathPushedAuthorizationRequest, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointPushedAuthorizationRequest), policyCORSPAR.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectPushedAuthorizationRequest)))))
- policyCORSRevocation := middlewares.NewCORSPolicyBuilder().
- WithAllowCredentials(true).
- WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
- WithAllowedOrigins(allowedOrigins...).
- WithEnabled(utils.IsStringInSlice(oidc.EndpointRevocation, config.IdentityProviders.OIDC.CORS.Endpoints)).
- Build()
+ policyCORSToken := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.EndpointToken, config.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
- r.OPTIONS(oidc.EndpointPathRevocation, policyCORSRevocation.HandleOPTIONS)
- r.POST(oidc.EndpointPathRevocation, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))))
+ r.OPTIONS(oidc.EndpointPathToken, policyCORSToken.HandleOPTIONS)
+ r.POST(oidc.EndpointPathToken, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointToken), policyCORSToken.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST)))))
- // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
- r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS)
- r.POST("/api/oidc/revoke", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))))
- }
+ policyCORSUserinfo := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet, fasthttp.MethodPost).
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.EndpointUserinfo, config.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
- r.RedirectFixedPath = false
- r.HandleMethodNotAllowed = true
- r.MethodNotAllowed = handleMethodNotAllowed
- r.NotFound = handleNotFound(bridge(serveIndexHandler))
+ r.OPTIONS(oidc.EndpointPathUserinfo, policyCORSUserinfo.HandleOPTIONS)
+ r.GET(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))))
+ r.POST(oidc.EndpointPathUserinfo, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointUserinfo), policyCORSUserinfo.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))))
- handler := middlewares.LogRequest(r.Handler)
- if config.Server.Address.RouterPath() != "/" {
- handler = middlewares.StripPath(config.Server.Address.RouterPath())(handler)
- }
+ policyCORSIntrospection := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.EndpointIntrospection, config.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
- handler = middlewares.MultiWrap(handler, middlewares.RecoverPanic, middlewares.NewMetricsRequest(providers.Metrics))
+ r.OPTIONS(oidc.EndpointPathIntrospection, policyCORSIntrospection.HandleOPTIONS)
+ r.POST(oidc.EndpointPathIntrospection, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))))
- return handler
+ // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
+ r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS)
+ r.POST("/api/oidc/introspect", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointIntrospection), policyCORSIntrospection.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))))
+
+ policyCORSRevocation := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.EndpointRevocation, config.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
+
+ r.OPTIONS(oidc.EndpointPathRevocation, policyCORSRevocation.HandleOPTIONS)
+ r.POST(oidc.EndpointPathRevocation, middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))))
+
+ // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
+ r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS)
+ r.POST("/api/oidc/revoke", middlewares.Wrap(middlewares.NewMetricsRequestOpenIDConnect(providers.Metrics, oidc.EndpointRevocation), policyCORSRevocation.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))))
}
func handleMetrics(path string) fasthttp.RequestHandler {
diff --git a/internal/server/server.go b/internal/server/server.go
index 564460155..c9836f863 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -15,10 +15,10 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares"
)
-// CreateDefaultServer Create Authelia's internal web server with the given configuration and providers.
-func CreateDefaultServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) {
+// New Create Authelia's internal web server with the given configuration and providers.
+func New(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) {
if err = providers.Templates.LoadTemplatedAssets(assets); err != nil {
- return nil, nil, nil, false, fmt.Errorf("failed to load templated assets: %w", err)
+ return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server: error occurred loading templated assets: %w", err)
}
server = &fasthttp.Server{
@@ -38,14 +38,14 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro
)
if listener, err = config.Server.Address.Listener(); err != nil {
- return nil, nil, nil, false, fmt.Errorf("error occurred while attempting to initialize main server listener for address '%s': %w", config.Server.Address.String(), err)
+ return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server listener for address '%s': %w", config.Server.Address.String(), err)
}
if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" {
isTLS, connectionScheme = true, schemeHTTPS
if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil {
- return nil, nil, nil, false, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err)
+ return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server tls parameters: failed to load certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err)
}
if len(config.Server.TLS.ClientCertificates) > 0 {
@@ -55,7 +55,7 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro
for _, path := range config.Server.TLS.ClientCertificates {
if cert, err = os.ReadFile(path); err != nil {
- return nil, nil, nil, false, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err)
+ return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server tls parameters: failed to load client certificate '%s': %w", path, err)
}
caCertPool.AppendCertsFromPEM(cert)
@@ -72,7 +72,7 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro
if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Address.Hostname(),
config.Server.Address.RouterPath(), config.Server.Address.Port()); err != nil {
- return nil, nil, nil, false, fmt.Errorf("unable to configure healthcheck: %w", err)
+ return nil, nil, nil, false, fmt.Errorf("error occurred initializing main server healthcheck metadata: %w", err)
}
paths = []string{"/"}
@@ -87,8 +87,8 @@ func CreateDefaultServer(config *schema.Configuration, providers middlewares.Pro
return server, listener, paths, isTLS, nil
}
-// CreateMetricsServer creates a metrics server.
-func CreateMetricsServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) {
+// NewMetrics creates a metrics server.
+func NewMetrics(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) {
if providers.Metrics == nil {
return
}
@@ -106,7 +106,7 @@ func CreateMetricsServer(config *schema.Configuration, providers middlewares.Pro
}
if listener, err = config.Telemetry.Metrics.Address.Listener(); err != nil {
- return nil, nil, nil, false, fmt.Errorf("error occurred while attempting to initialize metrics telemetry server listener for address '%s': %w", config.Telemetry.Metrics.Address.String(), err)
+ return nil, nil, nil, false, fmt.Errorf("error occurred initializing metrics telemetry server listener for address '%s': %w", config.Telemetry.Metrics.Address.String(), err)
}
return server, listener, []string{config.Telemetry.Metrics.Address.RouterPath()}, false, nil
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 356285fc6..914906238 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -148,7 +148,7 @@ func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLS
return nil, err
}
- s, listener, _, _, err := CreateDefaultServer(&configuration, providers)
+ s, listener, _, _, err := New(&configuration, providers)
if err != nil {
return nil, err
diff --git a/internal/service/const.go b/internal/service/const.go
new file mode 100644
index 000000000..ef60ade9b
--- /dev/null
+++ b/internal/service/const.go
@@ -0,0 +1,15 @@
+package service
+
+const (
+ fmtLogServerListening = "Listening for %s connections on '%s' path '%s'"
+)
+
+const (
+ logFieldService = "service"
+ logFieldFile = "file"
+ logFieldOP = "op"
+
+ serviceTypeServer = "server"
+ serviceTypeWatcher = "watcher"
+ serviceTypeSignal = "signal"
+)
diff --git a/internal/service/file_watcher.go b/internal/service/file_watcher.go
new file mode 100644
index 000000000..55202b806
--- /dev/null
+++ b/internal/service/file_watcher.go
@@ -0,0 +1,174 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/fsnotify/fsnotify"
+ "github.com/sirupsen/logrus"
+
+ "github.com/authelia/authelia/v4/internal/authentication"
+)
+
+func ProvisionUsersFileWatcher(ctx Context) (service Provider, err error) {
+ config := ctx.GetConfiguration()
+ providers := ctx.GetProviders()
+
+ if config.AuthenticationBackend.File != nil && config.AuthenticationBackend.File.Watch {
+ provider, ok := providers.UserProvider.(*authentication.FileUserProvider)
+
+ if !ok {
+ return nil, errors.New("error occurred asserting user provider")
+ }
+
+ if service, err = NewFileWatcher("users", config.AuthenticationBackend.File.Path, provider, ctx.GetLogger()); err != nil {
+ return nil, err
+ }
+ }
+
+ return service, nil
+}
+
+// NewFileWatcher creates a new FileWatcher with the appropriate logger etc.
+func NewFileWatcher(name, path string, reload ReloadableProvider, log *logrus.Entry) (service *FileWatcher, err error) {
+ if path == "" {
+ return nil, fmt.Errorf("error initializing file watcher: path must be specified")
+ }
+
+ if path, err = filepath.Abs(path); err != nil {
+ return nil, fmt.Errorf("error initializing file watcher: could not determine the absolute path of file '%s': %w", path, err)
+ }
+
+ var info os.FileInfo
+
+ if info, err = os.Stat(path); err != nil {
+ switch {
+ case os.IsNotExist(err):
+ return nil, fmt.Errorf("error initializing file watcher: error stating file '%s': file does not exist", path)
+ case os.IsPermission(err):
+ return nil, fmt.Errorf("error initializing file watcher: error stating file '%s': permission denied trying to read the file", path)
+ default:
+ return nil, fmt.Errorf("error initializing file watcher: error stating file '%s': %w", path, err)
+ }
+ }
+
+ var watcher *fsnotify.Watcher
+
+ if watcher, err = fsnotify.NewWatcher(); err != nil {
+ return nil, err
+ }
+
+ entry := log.WithFields(map[string]any{logFieldService: serviceTypeWatcher, serviceTypeWatcher: name})
+
+ if info.IsDir() {
+ service = &FileWatcher{
+ name: name,
+ watcher: watcher,
+ reload: reload,
+ log: entry,
+ directory: filepath.Clean(path),
+ }
+ } else {
+ service = &FileWatcher{
+ name: name,
+ watcher: watcher,
+ reload: reload,
+ log: entry,
+ directory: filepath.Dir(path),
+ file: filepath.Base(path),
+ }
+ }
+
+ if err = service.watcher.Add(service.directory); err != nil {
+ return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err)
+ }
+
+ return service, nil
+}
+
+// FileWatcher is a Provider that watches files for changes.
+type FileWatcher struct {
+ name string
+
+ watcher *fsnotify.Watcher
+ reload ReloadableProvider
+
+ log *logrus.Entry
+ file string
+ directory string
+}
+
+// ServiceType returns the service type for this service, which is always 'watcher'.
+func (service *FileWatcher) ServiceType() string {
+ return serviceTypeWatcher
+}
+
+// ServiceName returns the individual name for this service.
+func (service *FileWatcher) ServiceName() string {
+ return service.name
+}
+
+// Run the FileWatcher.
+func (service *FileWatcher) Run() (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
+ }
+ }()
+
+ service.log.WithField(logFieldFile, filepath.Join(service.directory, service.file)).Info("Watching file for changes")
+
+ for {
+ select {
+ case event, ok := <-service.watcher.Events:
+ if !ok {
+ return nil
+ }
+
+ log := service.log.WithFields(map[string]any{logFieldFile: event.Name, logFieldOP: event.Op})
+
+ if service.file != "" && service.file != filepath.Base(event.Name) {
+ log.Trace("File modification detected to irrelevant file")
+ break
+ }
+
+ switch {
+ case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
+ log.Debug("File modification was detected")
+
+ var reloaded bool
+
+ switch reloaded, err = service.reload.Reload(); {
+ case err != nil:
+ log.WithError(err).Error("Error occurred during reload")
+ case reloaded:
+ log.Info("Reloaded successfully")
+ default:
+ log.Debug("Reload was triggered but it was skipped")
+ }
+ case event.Op&fsnotify.Remove == fsnotify.Remove:
+ log.Debug("File remove was detected")
+ }
+ case err, ok := <-service.watcher.Errors:
+ if !ok {
+ return nil
+ }
+
+ service.log.WithError(err).Error("Error while watching file for changes")
+ }
+ }
+}
+
+// Shutdown the FileWatcher.
+func (service *FileWatcher) Shutdown() {
+ if err := service.watcher.Close(); err != nil {
+ service.log.WithError(err).Error("Error occurred during shutdown")
+ }
+}
+
+// Log returns the *logrus.Entry of the FileWatcher.
+func (service *FileWatcher) Log() *logrus.Entry {
+ return service.log
+}
diff --git a/internal/service/file_watcher_test.go b/internal/service/file_watcher_test.go
new file mode 100644
index 000000000..f4dce095f
--- /dev/null
+++ b/internal/service/file_watcher_test.go
@@ -0,0 +1,211 @@
+package service
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "regexp"
+ "testing"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/authelia/authelia/v4/internal/authentication"
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/logging"
+ "github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/templates"
+)
+
+func TestProvisionUsersFileWatcher(t *testing.T) {
+ dir := t.TempDir()
+
+ f, err := os.Create(filepath.Join(dir, "users.yml"))
+ require.NoError(t, err)
+ require.NoError(t, f.Close())
+
+ tx, err := templates.New(templates.Config{})
+ require.NoError(t, err)
+
+ address, err := schema.NewAddress("tcp://:9091")
+ require.NoError(t, err)
+
+ config := &schema.Configuration{
+ Server: schema.Server{
+ Address: &schema.AddressTCP{Address: *address},
+ },
+ }
+
+ provision := ProvisionUsersFileWatcher
+
+ ctx := &testCtx{
+ Context: context.Background(),
+ Configuration: config,
+ Providers: middlewares.Providers{
+ Templates: tx,
+ },
+ Logger: logrus.NewEntry(logging.Logger()),
+ }
+
+ watcher, err := provision(ctx)
+ assert.NoError(t, err)
+ assert.Nil(t, watcher)
+
+ watcher, err = provision(ctx)
+ assert.NoError(t, err)
+ assert.Nil(t, watcher)
+
+ config.AuthenticationBackend.File = &schema.AuthenticationBackendFile{
+ Path: filepath.Join(dir, "users.yml"),
+ Watch: true,
+ }
+
+ watcher, err = provision(ctx)
+ assert.EqualError(t, err, "error occurred asserting user provider")
+ assert.Nil(t, watcher)
+
+ ctx.Providers.UserProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File)
+
+ config.AuthenticationBackend.File = &schema.AuthenticationBackendFile{
+ Watch: true,
+ }
+
+ watcher, err = provision(ctx)
+ assert.EqualError(t, err, "error initializing file watcher: path must be specified")
+ assert.Nil(t, watcher)
+
+ config.AuthenticationBackend.File = &schema.AuthenticationBackendFile{
+ Path: filepath.Join(dir, "users.yml"),
+ Watch: true,
+ }
+
+ watcher, err = provision(ctx)
+ assert.NoError(t, err)
+ assert.NotNil(t, watcher)
+ assert.NotNil(t, watcher.Log())
+ assert.Equal(t, "users", watcher.ServiceName())
+ assert.Equal(t, "watcher", watcher.ServiceType())
+
+ watcher.Shutdown()
+}
+
+func TestNewFileWatcher(t *testing.T) {
+ dir := t.TempDir()
+
+ reloader := &testReloader{reload: true}
+
+ f, err := os.Create(filepath.Join(dir, "test.log"))
+ require.NoError(t, err)
+
+ service, err := NewFileWatcher("example", filepath.Join(dir, "test.log"), reloader, logrus.NewEntry(logging.Logger()))
+
+ assert.NoError(t, err)
+
+ go func() {
+ require.NoError(t, service.Run())
+ }()
+
+ // Give the service a moment to start.
+ time.Sleep(100 * time.Millisecond)
+
+ _, err = f.Write([]byte("test"))
+ require.NoError(t, err)
+
+ require.NoError(t, f.Close())
+
+ time.Sleep(time.Second)
+
+ assert.Equal(t, 1, reloader.count)
+
+ assert.NoError(t, os.WriteFile(filepath.Join(dir, "test2.log"), []byte("test"), 0600))
+
+ assert.Equal(t, 1, reloader.count)
+
+ assert.NoError(t, os.Remove(filepath.Join(dir, "test2.log")))
+
+ assert.Equal(t, 1, reloader.count)
+
+ service.Shutdown()
+}
+
+func TestNewFileWatcherDirectory(t *testing.T) {
+ dir := t.TempDir()
+
+ reloader := &testReloader{reload: true}
+
+ service, err := NewFileWatcher("example", dir, reloader, logrus.NewEntry(logging.Logger()))
+
+ assert.NoError(t, err)
+
+ go func() {
+ require.NoError(t, service.Run())
+ }()
+
+ // Give the service a moment to start.
+ time.Sleep(100 * time.Millisecond)
+
+ f, err := os.Create(filepath.Join(dir, "test.log"))
+ require.NoError(t, err)
+
+ _, err = f.Write([]byte("test"))
+ require.NoError(t, err)
+
+ require.NoError(t, f.Close())
+
+ time.Sleep(time.Second)
+
+ assert.Equal(t, 2, reloader.count)
+
+ service.Shutdown()
+}
+
+func TestNewFileWatcherBadPath(t *testing.T) {
+ dir := t.TempDir()
+
+ reloader := &testReloader{reload: true}
+
+ service, err := NewFileWatcher("example", filepath.Join(dir, "test.log"), reloader, logrus.NewEntry(logging.Logger()))
+
+ require.Error(t, err)
+ assert.Regexp(t, regexp.MustCompile(`^error initializing file watcher: error stating file '/tmp/[^/]+/\d+/test.log': file does not exist$`), err.Error())
+
+ assert.Nil(t, service)
+}
+
+func TestNewFileWatcherBadPermission(t *testing.T) {
+ dir := t.TempDir()
+
+ reloader := &testReloader{reload: true}
+
+ require.NoError(t, os.Mkdir(filepath.Join(dir, "tmp"), 0700))
+
+ f, err := os.Create(filepath.Join(dir, "tmp", "test.log"))
+
+ require.NoError(t, err)
+ require.NoError(t, f.Close())
+
+ require.NoError(t, os.Chmod(filepath.Join(dir, "tmp"), 0o000))
+
+ service, err := NewFileWatcher("example", filepath.Join(dir, "tmp", "test.log"), reloader, logrus.NewEntry(logging.Logger()))
+
+ require.Error(t, err)
+ assert.Regexp(t, regexp.MustCompile(`^error initializing file watcher: error stating file '/tmp/[^/]+/\d+/tmp/test.log': permission denied trying to read the file$`), err.Error())
+
+ require.NoError(t, os.Chmod(filepath.Join(dir, "tmp"), 0o700))
+
+ assert.Nil(t, service)
+}
+
+type testReloader struct {
+ count int
+ reload bool
+ err error
+}
+
+func (r *testReloader) Reload() (bool, error) {
+ r.count++
+
+ return r.reload, r.err
+}
diff --git a/internal/service/provider.go b/internal/service/provider.go
new file mode 100644
index 000000000..182ddba67
--- /dev/null
+++ b/internal/service/provider.go
@@ -0,0 +1,52 @@
+package service
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/middlewares"
+)
+
+// Provider represents the required methods to support handling a service.
+type Provider interface {
+ // ServiceType returns the type name for the Provider.
+ ServiceType() string
+
+ // ServiceName returns the individual name for the Provider.
+ ServiceName() string
+
+ // Run performs the running operations for the Provider.
+ Run() (err error)
+
+ // Shutdown perform the shutdown cleanup and termination operations for the Provider.
+ Shutdown()
+
+ // Log returns the logger configured for the service.
+ Log() *logrus.Entry
+}
+
+// ReloadableProvider represents the required methods to support reloading a provider.
+type ReloadableProvider interface {
+ Reload() (reloaded bool, err error)
+}
+
+type Provisioner func(ctx Context) (provider Provider, err error)
+
+func GetProvisioners() []Provisioner {
+ return []Provisioner{
+ ProvisionServer,
+ ProvisionServerMetrics,
+ ProvisionUsersFileWatcher,
+ ProvisionLoggingSignal,
+ }
+}
+
+type Context interface {
+ GetLogger() *logrus.Entry
+ GetProviders() middlewares.Providers
+ GetConfiguration() *schema.Configuration
+
+ context.Context
+}
diff --git a/internal/service/provider_test.go b/internal/service/provider_test.go
new file mode 100644
index 000000000..58a653557
--- /dev/null
+++ b/internal/service/provider_test.go
@@ -0,0 +1,13 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestGetProvisioners(t *testing.T) {
+ provisioners := GetProvisioners()
+
+ assert.Len(t, provisioners, 4)
+}
diff --git a/internal/service/server.go b/internal/service/server.go
new file mode 100644
index 000000000..15921af01
--- /dev/null
+++ b/internal/service/server.go
@@ -0,0 +1,120 @@
+package service
+
+import (
+ "context"
+ "net"
+ "strings"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/valyala/fasthttp"
+
+ "github.com/authelia/authelia/v4/internal/server"
+)
+
+func ProvisionServer(ctx Context) (service Provider, err error) {
+ var (
+ s *fasthttp.Server
+ listener net.Listener
+ paths []string
+ isTLS bool
+ )
+
+ switch s, listener, paths, isTLS, err = server.New(ctx.GetConfiguration(), ctx.GetProviders()); {
+ case err != nil:
+ return nil, err
+ case s != nil && listener != nil:
+ service = NewBaseServer("main", s, listener, paths, isTLS, ctx.GetLogger())
+ default:
+ return nil, nil
+ }
+
+ return service, nil
+}
+
+func ProvisionServerMetrics(ctx Context) (service Provider, err error) {
+ var (
+ s *fasthttp.Server
+ listener net.Listener
+ paths []string
+ isTLS bool
+ )
+
+ switch s, listener, paths, isTLS, err = server.NewMetrics(ctx.GetConfiguration(), ctx.GetProviders()); {
+ case err != nil:
+ return nil, err
+ case s != nil && listener != nil:
+ service = NewBaseServer("metrics", s, listener, paths, isTLS, ctx.GetLogger())
+ default:
+ return nil, nil
+ }
+
+ return service, nil
+}
+
+// NewBaseServer creates a new Server with the appropriate logger etc.
+func NewBaseServer(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Entry) (service *Server) {
+ return &Server{
+ name: name,
+ server: server,
+ listener: listener,
+ paths: paths,
+ isTLS: isTLS,
+ log: log.WithFields(map[string]any{logFieldService: serviceTypeServer, serviceTypeServer: name}),
+ }
+}
+
+// Server is a Provider which runs a web server.
+type Server struct {
+ name string
+ server *fasthttp.Server
+ paths []string
+ isTLS bool
+ listener net.Listener
+ log *logrus.Entry
+}
+
+// ServiceType returns the service type for this service, which is always 'server'.
+func (service *Server) ServiceType() string {
+ return serviceTypeServer
+}
+
+// ServiceName returns the individual name for this service.
+func (service *Server) ServiceName() string {
+ return service.name
+}
+
+// Run the Server.
+func (service *Server) Run() (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
+ }
+ }()
+
+ service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '"))
+
+ if err = service.server.Serve(service.listener); err != nil {
+ service.log.WithError(err).Error("Error returned attempting to serve requests")
+
+ return err
+ }
+
+ return nil
+}
+
+// Shutdown the Server.
+func (service *Server) Shutdown() {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+
+ defer cancel()
+
+ if err := service.server.ShutdownWithContext(ctx); err != nil {
+ service.log.WithError(err).Error("Error occurred during shutdown")
+ }
+}
+
+// Log returns the *logrus.Entry of the Server.
+func (service *Server) Log() *logrus.Entry {
+ return service.log
+}
diff --git a/internal/service/sever_test.go b/internal/service/sever_test.go
new file mode 100644
index 000000000..6effa88e9
--- /dev/null
+++ b/internal/service/sever_test.go
@@ -0,0 +1,121 @@
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/logging"
+ "github.com/authelia/authelia/v4/internal/metrics"
+ "github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/templates"
+)
+
+func TestNewMainServer(t *testing.T) {
+ tx, err := templates.New(templates.Config{})
+ require.NoError(t, err)
+
+ address, err := schema.NewAddress("tcp://:9091")
+ require.NoError(t, err)
+
+ config := &schema.Configuration{
+ Server: schema.Server{
+ Address: &schema.AddressTCP{Address: *address},
+ },
+ }
+
+ ctx := &testCtx{
+ Context: context.Background(),
+ Configuration: config,
+ Providers: middlewares.Providers{
+ Templates: tx,
+ },
+ Logger: logrus.NewEntry(logging.Logger()),
+ }
+
+ server, err := ProvisionServer(ctx)
+ assert.NoError(t, err)
+ assert.NotNil(t, server)
+
+ go func() {
+ require.NoError(t, server.Run())
+ }()
+
+ // Give the service a moment to start.
+ time.Sleep(100 * time.Millisecond)
+
+ assert.Equal(t, "main", server.ServiceName())
+ assert.Equal(t, "server", server.ServiceType())
+ assert.NotNil(t, server.Log())
+
+ server.Shutdown()
+}
+
+func TestNewMetricsServer(t *testing.T) {
+ tx, err := templates.New(templates.Config{})
+ require.NoError(t, err)
+
+ address, err := schema.NewAddress("tcp://:9891/metrics")
+ require.NoError(t, err)
+
+ config := &schema.Configuration{
+ Telemetry: schema.Telemetry{
+ Metrics: schema.TelemetryMetrics{
+ Enabled: true,
+ Address: &schema.AddressTCP{Address: *address},
+ },
+ },
+ }
+
+ ctx := &testCtx{
+ Context: context.Background(),
+ Configuration: config,
+ Providers: middlewares.Providers{
+ Templates: tx,
+ Metrics: metrics.NewPrometheus(),
+ },
+ Logger: logrus.NewEntry(logging.Logger()),
+ }
+
+ server, err := ProvisionServerMetrics(ctx)
+ assert.NoError(t, err)
+ assert.NotNil(t, server)
+
+ go func() {
+ require.NoError(t, server.Run())
+ }()
+
+ // Give the service a moment to start.
+ time.Sleep(100 * time.Millisecond)
+
+ assert.Equal(t, "metrics", server.ServiceName())
+ assert.Equal(t, "server", server.ServiceType())
+ assert.NotNil(t, server.Log())
+
+ server.Shutdown()
+}
+
+type testCtx struct {
+ Configuration *schema.Configuration
+ Providers middlewares.Providers
+ Logger *logrus.Entry
+
+ context.Context
+}
+
+func (c *testCtx) GetConfiguration() *schema.Configuration {
+ return c.Configuration
+}
+
+func (c *testCtx) GetProviders() middlewares.Providers {
+ return c.Providers
+}
+
+func (c *testCtx) GetLogger() *logrus.Entry {
+ return c.Logger
+}
diff --git a/internal/service/signal.go b/internal/service/signal.go
new file mode 100644
index 000000000..b6445a792
--- /dev/null
+++ b/internal/service/signal.go
@@ -0,0 +1,81 @@
+package service
+
+import (
+ "os"
+ "os/signal"
+ "syscall"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/authelia/authelia/v4/internal/logging"
+)
+
+func ProvisionLoggingSignal(ctx Context) (service Provider, err error) {
+ config := ctx.GetConfiguration()
+
+ if config == nil || len(config.Log.FilePath) == 0 {
+ return nil, nil
+ }
+
+ return &Signal{
+ name: "log-reload",
+ signals: []os.Signal{syscall.SIGHUP},
+ action: logging.Reopen,
+ log: ctx.GetLogger().WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: "log-reload"}),
+ }, nil
+}
+
+// Signal is a Service which performs actions on signals.
+type Signal struct {
+ name string
+ signals []os.Signal
+ action func() (err error)
+ log *logrus.Entry
+
+ notify chan os.Signal
+ quit chan struct{}
+}
+
+// ServiceType returns the service type for this service, which is always 'server'.
+func (service *Signal) ServiceType() string {
+ return serviceTypeSignal
+}
+
+// ServiceName returns the individual name for this service.
+func (service *Signal) ServiceName() string {
+ return service.name
+}
+
+// Run the ServerService.
+func (service *Signal) Run() (err error) {
+ service.quit = make(chan struct{})
+
+ service.notify = make(chan os.Signal, 1)
+
+ signal.Notify(service.notify, service.signals...)
+
+ for {
+ select {
+ case s := <-service.notify:
+ if err = service.action(); err != nil {
+ service.log.WithError(err).Error("Error occurred executing service action.")
+ } else {
+ service.log.WithFields(map[string]any{"signal-received": s.String()}).Debug("Successfully executed service action.")
+ }
+ case <-service.quit:
+ return
+ }
+ }
+}
+
+// Shutdown the ServerService.
+func (service *Signal) Shutdown() {
+ signal.Stop(service.notify)
+
+ service.quit <- struct{}{}
+}
+
+// Log returns the *logrus.Entry of the ServerService.
+func (service *Signal) Log() *logrus.Entry {
+ return service.log
+}
diff --git a/internal/commands/services_test.go b/internal/service/signal_test.go
index 3a9364e16..a43b7db03 100644
--- a/internal/commands/services_test.go
+++ b/internal/service/signal_test.go
@@ -1,4 +1,4 @@
-package commands
+package service
import (
"context"
@@ -13,9 +13,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/authelia/authelia/v4/internal/logging"
-
"github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/middlewares"
)
@@ -23,11 +22,11 @@ import (
type mockServiceCtx struct {
ctx context.Context
config *schema.Configuration
- logger *logrus.Logger
+ logger *logrus.Entry
providers middlewares.Providers
}
-func (m *mockServiceCtx) GetLogger() *logrus.Logger {
+func (m *mockServiceCtx) GetLogger() *logrus.Entry {
return m.logger
}
@@ -64,7 +63,7 @@ func newMockServiceCtx() *mockServiceCtx {
return &mockServiceCtx{
ctx: context.Background(),
config: config,
- logger: logger,
+ logger: logrus.NewEntry(logger),
providers: middlewares.Providers{},
}
}
@@ -98,7 +97,12 @@ func TestSignalService_Run(t *testing.T) {
return tc.actionError
}
- service := NewSignalService("test", action, logger, tc.signal)
+ service := &Signal{
+ name: "log-reload",
+ signals: []os.Signal{syscall.SIGHUP},
+ action: action,
+ log: logger.WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: "log-reload"}),
+ }
errChan := make(chan error, 1)
done := make(chan struct{})
@@ -110,6 +114,7 @@ func TestSignalService_Run(t *testing.T) {
close(done)
}()
+ // Give the service a moment to start.
time.Sleep(100 * time.Millisecond)
p, err := os.FindProcess(os.Getpid())
@@ -163,7 +168,7 @@ func TestSvcSignalLogReOpenFunc(t *testing.T) {
mockCtx := newMockServiceCtx()
mockCtx.config.Log.FilePath = tc.logFilePath
- service := svcSignalLogReOpenFunc(mockCtx)
+ service, _ := ProvisionLoggingSignal(mockCtx)
if tc.expectService {
require.NotNil(t, service)
@@ -197,11 +202,11 @@ func TestLogReopenFiles(t *testing.T) {
ctx := &mockServiceCtx{
ctx: context.Background(),
config: config,
- logger: logging.Logger(),
+ logger: logrus.NewEntry(logging.Logger()),
providers: middlewares.Providers{},
}
- service := svcSignalLogReOpenFunc(ctx)
+ service, _ := ProvisionLoggingSignal(ctx)
require.NotNil(t, service)
errChan := make(chan error, 1)
@@ -214,6 +219,7 @@ func TestLogReopenFiles(t *testing.T) {
close(done)
}()
+ // Give the service a moment to start.
time.Sleep(100 * time.Millisecond)
p, err := os.FindProcess(os.Getpid())
@@ -237,7 +243,12 @@ func TestLogReopenFiles(t *testing.T) {
func TestSignalService_Shutdown(t *testing.T) {
logger := logrus.New()
action := func() error { return nil }
- service := NewSignalService("test", action, logger, syscall.SIGHUP)
+ service := &Signal{
+ name: "test",
+ signals: []os.Signal{syscall.SIGHUP},
+ action: action,
+ log: logger.WithFields(map[string]any{logFieldService: serviceTypeSignal, serviceTypeSignal: "test"}),
+ }
done := make(chan struct{})
go func() {
diff --git a/internal/service/util.go b/internal/service/util.go
new file mode 100644
index 000000000..3e4a7f35c
--- /dev/null
+++ b/internal/service/util.go
@@ -0,0 +1,120 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/signal"
+ "sync"
+ "syscall"
+
+ "golang.org/x/sync/errgroup"
+)
+
+func RunAll(ctx Context) (err error) {
+ provisioners := GetProvisioners()
+
+ return Run(ctx, provisioners...)
+}
+
+func Run(ctx Context, provisioners ...Provisioner) (err error) {
+ cctx, cancel := context.WithCancel(ctx)
+
+ group, cctx := errgroup.WithContext(cctx)
+
+ defer cancel()
+
+ quit := make(chan os.Signal, 1)
+
+ signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
+
+ defer signal.Stop(quit)
+
+ var (
+ services []Provider
+ )
+
+ log := ctx.GetLogger()
+
+ for _, provisioner := range provisioners {
+ if service, err := provisioner(ctx); err != nil {
+ return fmt.Errorf("error occurred provisioning services: %w", err)
+ } else if service != nil {
+ services = append(services, service)
+ }
+ }
+
+ for _, service := range services {
+ group.Go(service.Run)
+ }
+
+ log.Info("Startup complete")
+
+ select {
+ case s := <-quit:
+ log.WithField("signal", s.String()).Debug("Shutdown initiated due to process signal")
+ case <-cctx.Done():
+ log.Debug("Shutdown initiated due to context completion")
+ }
+
+ cancel()
+
+ log.Info("Shutdown initiated")
+
+ wgShutdown := &sync.WaitGroup{}
+
+ log.Tracef("Shutdown of %d services is required", len(services))
+
+ for _, service := range services {
+ wgShutdown.Add(1)
+
+ go func(service Provider) {
+ service.Log().Trace("Shutdown of service initiated")
+
+ service.Shutdown()
+
+ wgShutdown.Done()
+
+ service.Log().Trace("Shutdown of service complete")
+ }(service)
+ }
+
+ wgShutdown.Wait()
+
+ if err = ctx.GetProviders().UserProvider.Close(); err != nil {
+ ctx.GetLogger().WithError(err).Error("Error occurred closing authentication connections")
+ }
+
+ if err = ctx.GetProviders().StorageProvider.Close(); err != nil {
+ log.WithError(err).Error("Error occurred closing database connections")
+ }
+
+ if err = group.Wait(); err != nil {
+ log.WithError(err).Error("Error occurred waiting for shutdown")
+ }
+
+ log.Info("Shutdown complete")
+
+ return nil
+}
+
+func connectionType(isTLS bool) string {
+ if isTLS {
+ return "TLS"
+ }
+
+ return "non-TLS"
+}
+
+func recoverErr(i any) error {
+ switch v := i.(type) {
+ case nil:
+ return nil
+ case string:
+ return fmt.Errorf("recovered panic: %s", v)
+ case error:
+ return fmt.Errorf("recovered panic: %w", v)
+ default:
+ return fmt.Errorf("recovered panic with unknown type: %v", v)
+ }
+}
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index d738cd80c..3ba918b87 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -3,6 +3,7 @@ package storage
import (
"context"
"crypto/sha256"
+ "crypto/x509"
"database/sql"
"errors"
"fmt"
@@ -19,6 +20,20 @@ import (
"github.com/authelia/authelia/v4/internal/model"
)
+// NewProvider dynamically initializes a storage.Provider given a *schema.Configuration and *x509.CertPool.
+func NewProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider Provider) {
+ switch {
+ case config.Storage.PostgreSQL != nil:
+ return NewPostgreSQLProvider(config, caCertPool)
+ case config.Storage.MySQL != nil:
+ return NewMySQLProvider(config, caCertPool)
+ case config.Storage.Local != nil:
+ return NewSQLiteProvider(config)
+ default:
+ return nil
+ }
+}
+
// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's.
func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceName string) (provider SQLProvider) {
db, err := sqlx.Open(driverName, dataSourceName)
diff --git a/web/.commitlintrc.cjs b/web/.commitlintrc.cjs
index 82516cc34..ea897e475 100644
--- a/web/.commitlintrc.cjs
+++ b/web/.commitlintrc.cjs
@@ -47,6 +47,7 @@ module.exports = {
"renovate",
"reviewdog",
"server",
+ "service",
"session",
"storage",
"suites",