summaryrefslogtreecommitdiff
path: root/internal/commands/root.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/commands/root.go')
-rw-r--r--internal/commands/root.go155
1 files changed, 155 insertions, 0 deletions
diff --git a/internal/commands/root.go b/internal/commands/root.go
new file mode 100644
index 000000000..369e92a80
--- /dev/null
+++ b/internal/commands/root.go
@@ -0,0 +1,155 @@
+package commands
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/spf13/cobra"
+
+ "github.com/authelia/authelia/internal/authentication"
+ "github.com/authelia/authelia/internal/authorization"
+ "github.com/authelia/authelia/internal/configuration/schema"
+ "github.com/authelia/authelia/internal/logging"
+ "github.com/authelia/authelia/internal/middlewares"
+ "github.com/authelia/authelia/internal/notification"
+ "github.com/authelia/authelia/internal/oidc"
+ "github.com/authelia/authelia/internal/regulation"
+ "github.com/authelia/authelia/internal/server"
+ "github.com/authelia/authelia/internal/session"
+ "github.com/authelia/authelia/internal/storage"
+ "github.com/authelia/authelia/internal/utils"
+)
+
+// NewRootCmd returns a new Root Cmd.
+func NewRootCmd() (cmd *cobra.Command) {
+ version := utils.Version()
+
+ cmd = &cobra.Command{
+ Use: "authelia",
+ Example: cmdAutheliaExample,
+ Short: fmt.Sprintf("authelia %s", version),
+ Long: fmt.Sprintf(fmtAutheliaLong, version),
+ Version: version,
+ Args: cobra.NoArgs,
+ PreRun: newCmdWithConfigPreRun(true, true, true),
+ Run: cmdRootRun,
+ }
+
+ cmdWithConfigFlags(cmd)
+
+ cmd.AddCommand(
+ newBuildInfoCmd(),
+ NewCertificatesCmd(),
+ newCompletionCmd(),
+ NewHashPasswordCmd(),
+ NewRSACmd(),
+ newValidateConfigCmd(),
+ )
+
+ return cmd
+}
+
+func cmdRootRun(_ *cobra.Command, _ []string) {
+ logger := logging.Logger()
+
+ logger.Infof("Authelia %s is starting", utils.Version())
+
+ if os.Getenv("ENVIRONMENT") == "dev" {
+ logger.Info("===> Authelia is running in development mode. <===")
+ }
+
+ if err := logging.InitializeLogger(config.Log, true); err != nil {
+ logger.Fatalf("Cannot initialize logger: %v", err)
+ }
+
+ providers, warnings, errors := getProviders(config)
+ if len(warnings) != 0 {
+ for _, err := range warnings {
+ logger.Warn(err)
+ }
+ }
+
+ if len(errors) != 0 {
+ for _, err := range errors {
+ logger.Error(err)
+ }
+
+ logger.Fatalf("Errors occurred provisioning providers.")
+ }
+
+ server.Start(*config, providers)
+}
+
+func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) {
+ autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory)
+ if len(warnings) != 0 || len(errors) != 0 {
+ return providers, warnings, errors
+ }
+
+ var storageProvider storage.Provider
+
+ switch {
+ case config.Storage.PostgreSQL != nil:
+ storageProvider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL)
+ case config.Storage.MySQL != nil:
+ storageProvider = storage.NewMySQLProvider(*config.Storage.MySQL)
+ case config.Storage.Local != nil:
+ storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
+ default:
+ errors = append(errors, fmt.Errorf("unrecognized storage provider"))
+ }
+
+ var (
+ userProvider authentication.UserProvider
+ err error
+ )
+
+ switch {
+ case config.AuthenticationBackend.File != nil:
+ userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File)
+ case config.AuthenticationBackend.LDAP != nil:
+ userProvider, err = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool)
+ if err != nil {
+ errors = append(errors, fmt.Errorf("failed to check LDAP authentication backend: %w", err))
+ }
+ default:
+ errors = append(errors, fmt.Errorf("unrecognized user provider"))
+ }
+
+ var notifier notification.Notifier
+
+ switch {
+ case config.Notifier.SMTP != nil:
+ notifier = notification.NewSMTPNotifier(*config.Notifier.SMTP, autheliaCertPool)
+ case config.Notifier.FileSystem != nil:
+ notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
+ default:
+ errors = append(errors, fmt.Errorf("unrecognized notifier provider"))
+ }
+
+ if notifier != nil {
+ if _, err := notifier.StartupCheck(); err != nil {
+ errors = append(errors, fmt.Errorf("failed to check notification provider: %w", err))
+ }
+ }
+
+ clock := utils.RealClock{}
+ authorizer := authorization.NewAuthorizer(config)
+ sessionProvider := session.NewProvider(config.Session, autheliaCertPool)
+ regulator := regulation.NewRegulator(config.Regulation, storageProvider, clock)
+
+ oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC)
+ if err != nil {
+ errors = append(errors, err)
+ }
+
+ return middlewares.Providers{
+ Authorizer: authorizer,
+ UserProvider: userProvider,
+ Regulator: regulator,
+ OpenIDConnect: oidcProvider,
+ StorageProvider: storageProvider,
+ Notifier: notifier,
+ SessionProvider: sessionProvider,
+ }, warnings, errors
+}