diff options
Diffstat (limited to 'internal/configuration/decode_hooks.go')
| -rw-r--r-- | internal/configuration/decode_hooks.go | 52 |
1 files changed, 51 insertions, 1 deletions
diff --git a/internal/configuration/decode_hooks.go b/internal/configuration/decode_hooks.go index c92834d2f..bb1ee7fdd 100644 --- a/internal/configuration/decode_hooks.go +++ b/internal/configuration/decode_hooks.go @@ -18,6 +18,7 @@ import ( "github.com/go-crypt/crypt/algorithm/plaintext" "github.com/go-viper/mapstructure/v2" + "github.com/google/uuid" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/utils" @@ -559,12 +560,13 @@ func StringToTLSVersionHookFunc() mapstructure.DecodeHookFuncType { // StringToCryptoPrivateKeyHookFunc decodes strings to schema.CryptographicPrivateKey's. func StringToCryptoPrivateKeyHookFunc() mapstructure.DecodeHookFuncType { + field, _ := reflect.TypeOf(schema.TLS{}).FieldByName("PrivateKey") + return func(f reflect.Type, t reflect.Type, data any) (value any, err error) { if f.Kind() != reflect.String { return data, nil } - field, _ := reflect.TypeOf(schema.TLS{}).FieldByName("PrivateKey") expectedType := field.Type if t != expectedType { @@ -823,3 +825,51 @@ func StringToIPNetworksHookFunc(definitions map[string][]*net.IPNet) mapstructur return networks, nil } } + +// StringToUUIDHookFunc decodes a string into a uuid.UUID. +func StringToUUIDHookFunc() mapstructure.DecodeHookFuncType { + return func(f reflect.Type, t reflect.Type, data any) (value any, err error) { + var ptr bool + + if f.Kind() != reflect.String { + return data, nil + } + + prefixType := "" + + if t.Kind() == reflect.Ptr { + ptr = true + prefixType = "*" + } + + expectedType := reflect.TypeOf(uuid.UUID{}) + + if ptr && t.Elem() != expectedType { + return data, nil + } else if !ptr && t != expectedType { + return data, nil + } + + dataStr := data.(string) + + var result uuid.UUID + + if dataStr == "" { + if ptr { + return (*uuid.UUID)(nil), nil + } else { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParseEmptyValue, prefixType, expectedType.String(), errDecodeNonPtrMustHaveValue) + } + } + + if result, err = uuid.Parse(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, prefixType, expectedType.String(), err) + } + + if ptr { + return &result, nil + } + + return result, nil + } +} |
