diff options
Diffstat (limited to 'internal/commands/root.go')
| -rw-r--r-- | internal/commands/root.go | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/internal/commands/root.go b/internal/commands/root.go new file mode 100644 index 000000000..369e92a80 --- /dev/null +++ b/internal/commands/root.go @@ -0,0 +1,155 @@ +package commands + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/authelia/authelia/internal/authentication" + "github.com/authelia/authelia/internal/authorization" + "github.com/authelia/authelia/internal/configuration/schema" + "github.com/authelia/authelia/internal/logging" + "github.com/authelia/authelia/internal/middlewares" + "github.com/authelia/authelia/internal/notification" + "github.com/authelia/authelia/internal/oidc" + "github.com/authelia/authelia/internal/regulation" + "github.com/authelia/authelia/internal/server" + "github.com/authelia/authelia/internal/session" + "github.com/authelia/authelia/internal/storage" + "github.com/authelia/authelia/internal/utils" +) + +// NewRootCmd returns a new Root Cmd. +func NewRootCmd() (cmd *cobra.Command) { + version := utils.Version() + + cmd = &cobra.Command{ + Use: "authelia", + Example: cmdAutheliaExample, + Short: fmt.Sprintf("authelia %s", version), + Long: fmt.Sprintf(fmtAutheliaLong, version), + Version: version, + Args: cobra.NoArgs, + PreRun: newCmdWithConfigPreRun(true, true, true), + Run: cmdRootRun, + } + + cmdWithConfigFlags(cmd) + + cmd.AddCommand( + newBuildInfoCmd(), + NewCertificatesCmd(), + newCompletionCmd(), + NewHashPasswordCmd(), + NewRSACmd(), + newValidateConfigCmd(), + ) + + return cmd +} + +func cmdRootRun(_ *cobra.Command, _ []string) { + logger := logging.Logger() + + logger.Infof("Authelia %s is starting", utils.Version()) + + if os.Getenv("ENVIRONMENT") == "dev" { + logger.Info("===> Authelia is running in development mode. <===") + } + + if err := logging.InitializeLogger(config.Log, true); err != nil { + logger.Fatalf("Cannot initialize logger: %v", err) + } + + providers, warnings, errors := getProviders(config) + if len(warnings) != 0 { + for _, err := range warnings { + logger.Warn(err) + } + } + + if len(errors) != 0 { + for _, err := range errors { + logger.Error(err) + } + + logger.Fatalf("Errors occurred provisioning providers.") + } + + server.Start(*config, providers) +} + +func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) { + autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory) + if len(warnings) != 0 || len(errors) != 0 { + return providers, warnings, errors + } + + var storageProvider storage.Provider + + switch { + case config.Storage.PostgreSQL != nil: + storageProvider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL) + case config.Storage.MySQL != nil: + storageProvider = storage.NewMySQLProvider(*config.Storage.MySQL) + case config.Storage.Local != nil: + storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path) + default: + errors = append(errors, fmt.Errorf("unrecognized storage provider")) + } + + var ( + userProvider authentication.UserProvider + err error + ) + + switch { + case config.AuthenticationBackend.File != nil: + userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File) + case config.AuthenticationBackend.LDAP != nil: + userProvider, err = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool) + if err != nil { + errors = append(errors, fmt.Errorf("failed to check LDAP authentication backend: %w", err)) + } + default: + errors = append(errors, fmt.Errorf("unrecognized user provider")) + } + + var notifier notification.Notifier + + switch { + case config.Notifier.SMTP != nil: + notifier = notification.NewSMTPNotifier(*config.Notifier.SMTP, autheliaCertPool) + case config.Notifier.FileSystem != nil: + notifier = notification.NewFileNotifier(*config.Notifier.FileSystem) + default: + errors = append(errors, fmt.Errorf("unrecognized notifier provider")) + } + + if notifier != nil { + if _, err := notifier.StartupCheck(); err != nil { + errors = append(errors, fmt.Errorf("failed to check notification provider: %w", err)) + } + } + + clock := utils.RealClock{} + authorizer := authorization.NewAuthorizer(config) + sessionProvider := session.NewProvider(config.Session, autheliaCertPool) + regulator := regulation.NewRegulator(config.Regulation, storageProvider, clock) + + oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC) + if err != nil { + errors = append(errors, err) + } + + return middlewares.Providers{ + Authorizer: authorizer, + UserProvider: userProvider, + Regulator: regulator, + OpenIDConnect: oidcProvider, + StorageProvider: storageProvider, + Notifier: notifier, + SessionProvider: sessionProvider, + }, warnings, errors +} |
