summaryrefslogtreecommitdiff
path: root/internal/middlewares/startup.go
blob: 0664be3429813bfbf13257745d5c233fb99016cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package middlewares

import (
	"fmt"
	"strings"

	"github.com/authelia/authelia/v4/internal/model"
	"github.com/authelia/authelia/v4/internal/utils"
)

func (p *Providers) StartupChecks(ctx Context, log bool) (err error) {
	e := &ErrProviderStartupCheck{errors: map[string]error{}}

	var (
		disable  bool
		provider model.StartupCheck
	)

	provider, disable = ctx.GetProviders().StorageProvider, false
	doStartupCheck(ctx, ProviderNameStorage, provider, disable, log, e.errors)

	provider, disable = ctx.GetProviders().UserProvider, false
	doStartupCheck(ctx, ProviderNameUser, provider, disable, log, e.errors)

	provider, disable = ctx.GetProviders().Notifier, false
	doStartupCheck(ctx, ProviderNameNotification, provider, disable, log, e.errors)

	provider, disable = ctx.GetProviders().NTP, ctx.GetConfiguration().NTP.DisableStartupCheck
	doStartupCheck(ctx, ProviderNameNTP, provider, disable, log, e.errors)

	provider, disable = ctx.GetProviders().UserAttributeResolver, false
	doStartupCheck(ctx, ProviderNameExpressions, provider, disable, log, e.errors)

	provider = ctx.GetProviders().MetaDataService
	disable = !ctx.GetConfiguration().WebAuthn.Metadata.Enabled || ctx.GetProviders().MetaDataService == nil
	doStartupCheck(ctx, ProviderNameWebAuthnMetaData, provider, disable, log, e.errors)

	var filters []string

	if ctx.GetConfiguration().NTP.DisableFailure {
		filters = append(filters, ProviderNameNTP)
	}

	return e.FilterError(filters...)
}

func doStartupCheck(ctx Context, name string, provider model.StartupCheck, disabled, log bool, errors map[string]error) {
	if log {
		ctx.GetLogger().WithFields(map[string]any{LogFieldProvider: name}).Trace(LogMessageStartupCheckPerforming)
	}

	if disabled {
		if log {
			ctx.GetLogger().Debugf("%s provider: startup check skipped as it is disabled", name)
		}

		return
	}

	if provider == nil {
		errors[name] = fmt.Errorf("unrecognized provider or it is not configured properly")

		return
	}

	var err error

	if err = provider.StartupCheck(); err != nil {
		if log {
			ctx.GetLogger().WithError(err).WithField(LogFieldProvider, name).Error(LogMessageStartupCheckError)
		}

		errors[name] = err

		return
	}

	if log {
		ctx.GetLogger().WithFields(map[string]any{LogFieldProvider: name}).Trace("Startup Check Completed Successfully")
	}
}

type ErrProviderStartupCheck struct {
	errors map[string]error
}

func (e *ErrProviderStartupCheck) Error() string {
	keys := make([]string, 0, len(e.errors))
	for k := range e.errors {
		keys = append(keys, k)
	}

	return fmt.Sprintf("errors occurred performing checks on the '%s' providers", strings.Join(keys, ", "))
}

func (e *ErrProviderStartupCheck) Failed() (failed []string) {
	for key := range e.errors {
		failed = append(failed, key)
	}

	return failed
}

func (e *ErrProviderStartupCheck) FilterError(providers ...string) error {
	filtered := map[string]error{}

	for provider, err := range e.errors {
		if utils.IsStringInSlice(provider, providers) {
			continue
		}

		filtered[provider] = err
	}

	if len(filtered) == 0 {
		return nil
	}

	return &ErrProviderStartupCheck{errors: filtered}
}

func (e *ErrProviderStartupCheck) ErrorMap() map[string]error {
	return e.errors
}