summaryrefslogtreecommitdiff
path: root/internal/commands/helpers.go
blob: 84ff9074c2731f3f45181a8736f957241c80a69c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package commands

import (
	"encoding/base32"
	"errors"
	"fmt"

	"github.com/spf13/pflag"

	"github.com/authelia/authelia/v4/internal/configuration/schema"
	"github.com/authelia/authelia/v4/internal/model"
	"github.com/authelia/authelia/v4/internal/storage"
)

func getStorageProvider(ctx *CmdCtx) (provider storage.Provider) {
	switch {
	case ctx.config.Storage.PostgreSQL != nil:
		return storage.NewPostgreSQLProvider(ctx.config, ctx.trusted)
	case ctx.config.Storage.MySQL != nil:
		return storage.NewMySQLProvider(ctx.config, ctx.trusted)
	case ctx.config.Storage.Local != nil:
		return storage.NewSQLiteProvider(ctx.config)
	default:
		return nil
	}
}

func containsIdentifier(identifier model.UserOpaqueIdentifier, identifiers []model.UserOpaqueIdentifier) bool {
	for i := 0; i < len(identifiers); i++ {
		if identifier.Service == identifiers[i].Service && identifier.SectorID == identifiers[i].SectorID && identifier.Username == identifiers[i].Username {
			return true
		}
	}

	return false
}

func storageWrapCheckSchemaErr(err error) error {
	switch {
	case errors.Is(err, errStorageSchemaIncompatible):
		return fmt.Errorf("command requires the use of a compatibe schema version: %w", err)
	case errors.Is(err, errStorageSchemaOutdated):
		return fmt.Errorf("command requires the use of a up to date schema version: %w", err)
	default:
		return err
	}
}

func storageTOTPGenerateRunEOptsFromFlags(flags *pflag.FlagSet) (force bool, filename, secret string, err error) {
	if force, err = flags.GetBool("force"); err != nil {
		return force, filename, secret, err
	}

	if filename, err = flags.GetString("path"); err != nil {
		return force, filename, secret, err
	}

	if secret, err = flags.GetString("secret"); err != nil {
		return force, filename, secret, err
	}

	secretLength := base32.StdEncoding.WithPadding(base32.NoPadding).DecodedLen(len(secret))
	if secret != "" && secretLength < schema.TOTPSecretSizeMinimum {
		return force, filename, secret, fmt.Errorf("decoded length of the base32 secret must have "+
			"a length of more than %d but '%s' has a decoded length of %d", schema.TOTPSecretSizeMinimum, secret, secretLength)
	}

	return force, filename, secret, nil
}

func storageWebAuthnDeleteRunEOptsFromFlags(flags *pflag.FlagSet, args []string) (all, byKID bool, description, kid, user string, err error) {
	if len(args) != 0 {
		user = args[0]
	}

	f := 0

	if flags.Changed(cmdFlagNameAll) {
		if all, err = flags.GetBool(cmdFlagNameAll); err != nil {
			return
		}

		f++
	}

	if flags.Changed(cmdFlagNameDescription) {
		if description, err = flags.GetString(cmdFlagNameDescription); err != nil {
			return
		}

		f++
	}

	if byKID = flags.Changed(cmdFlagNameKeyID); byKID {
		if kid, err = flags.GetString(cmdFlagNameKeyID); err != nil {
			return
		}

		f++
	}

	if f > 1 {
		err = fmt.Errorf("must only supply one of the flags --all, --description, and --kid but %d were specified", f)

		return
	}

	if f == 0 {
		err = fmt.Errorf("must supply one of the flags --all, --description, or --kid")

		return
	}

	if !byKID && len(user) == 0 {
		err = fmt.Errorf("must supply the username or the --kid flag")

		return
	}

	return
}