diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2022-04-03 22:44:52 +1000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-04-03 22:44:52 +1000 |
| commit | 7230db7ceac7854dccb4d69d8b2d632ceeca2f0b (patch) | |
| tree | 08dbc7f208e7c787d2c93f3dfff3be9fbed9d1f5 /internal/configuration/decode_hooks.go | |
| parent | bfd5d66ed8fd60a12b88ecf56144c9f74085bb73 (diff) | |
refactor(configuration): decode_hooks blackbox and better testing (#3097)
Diffstat (limited to 'internal/configuration/decode_hooks.go')
| -rw-r--r-- | internal/configuration/decode_hooks.go | 137 |
1 files changed, 88 insertions, 49 deletions
diff --git a/internal/configuration/decode_hooks.go b/internal/configuration/decode_hooks.go index 4258513a4..e7fa14988 100644 --- a/internal/configuration/decode_hooks.go +++ b/internal/configuration/decode_hooks.go @@ -13,32 +13,53 @@ import ( "github.com/authelia/authelia/v4/internal/utils" ) -// StringToMailAddressHookFunc decodes a string into a mail.Address. +// StringToMailAddressHookFunc decodes a string into a mail.Address or *mail.Address. func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { - if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) { + var ptr bool + + if f.Kind() != reflect.String { + return data, nil + } + + kindStr := "mail.Address (RFC5322)" + + if t.Kind() == reflect.Ptr { + ptr = true + kindStr = "*" + kindStr + } + + expectedType := reflect.TypeOf(mail.Address{}) + + if ptr && t.Elem() != expectedType { + return data, nil + } else if !ptr && t != expectedType { return data, nil } dataStr := data.(string) - if dataStr == "" { - return mail.Address{}, nil + var result *mail.Address + + if dataStr != "" { + if result, err = mail.ParseAddress(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) + } } - var ( - parsedAddress *mail.Address - ) + if ptr { + return result, nil + } - if parsedAddress, err = mail.ParseAddress(dataStr); err != nil { - return nil, fmt.Errorf("could not parse '%s' as a RFC5322 address: %w", dataStr, err) + if result == nil { + return mail.Address{}, nil } - return *parsedAddress, nil + return *result, nil } } -// StringToURLHookFunc converts string types into a url.URL. +// StringToURLHookFunc converts string types into a url.URL or *url.URL. func StringToURLHookFunc() mapstructure.DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { var ptr bool @@ -47,37 +68,40 @@ func StringToURLHookFunc() mapstructure.DecodeHookFuncType { return data, nil } - ptr = t.Kind() == reflect.Ptr + kindStr := "url.URL" - typeURL := reflect.TypeOf(url.URL{}) + if t.Kind() == reflect.Ptr { + ptr = true + kindStr = "*" + kindStr + } - if ptr && t.Elem() != typeURL { + expectedType := reflect.TypeOf(url.URL{}) + + if ptr && t.Elem() != expectedType { return data, nil - } else if !ptr && t != typeURL { + } else if !ptr && t != expectedType { return data, nil } dataStr := data.(string) - var parsedURL *url.URL + var result *url.URL - // Return an empty URL if there is an empty string. if dataStr != "" { - if parsedURL, err = url.Parse(dataStr); err != nil { - return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err) + if result, err = url.Parse(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) } } if ptr { - return parsedURL, nil + return result, nil } - // Return an empty URL if there is an empty string. - if parsedURL == nil { + if result == nil { return url.URL{}, nil } - return *parsedURL, nil + return *result, nil } } @@ -94,48 +118,51 @@ func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType { return data, nil } - typeTimeDuration := reflect.TypeOf(time.Hour) + kindStr := "time.Duration" if t.Kind() == reflect.Ptr { - if t.Elem() != typeTimeDuration { - return data, nil - } - ptr = true - } else if t != typeTimeDuration { + kindStr = "*" + kindStr + } + + expectedType := reflect.TypeOf(time.Duration(0)) + + if ptr && t.Elem() != expectedType { + return data, nil + } else if !ptr && t != expectedType { return data, nil } - var duration time.Duration + var result time.Duration switch { case f.Kind() == reflect.String: dataStr := data.(string) - if duration, err = utils.ParseDurationString(dataStr); err != nil { - return nil, err + if result, err = utils.ParseDurationString(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) } case f.Kind() == reflect.Int: seconds := data.(int) - duration = time.Second * time.Duration(seconds) + result = time.Second * time.Duration(seconds) case f.Kind() == reflect.Int32: seconds := data.(int32) - duration = time.Second * time.Duration(seconds) - case f == typeTimeDuration: - duration = data.(time.Duration) + result = time.Second * time.Duration(seconds) + case f == expectedType: + result = data.(time.Duration) case f.Kind() == reflect.Int64: seconds := data.(int64) - duration = time.Second * time.Duration(seconds) + result = time.Second * time.Duration(seconds) } if ptr { - return &duration, nil + return &result, nil } - return duration, nil + return result, nil } } @@ -148,27 +175,39 @@ func StringToRegexpFunc() mapstructure.DecodeHookFuncType { return data, nil } - ptr = t.Kind() == reflect.Ptr + kindStr := "regexp.Regexp" - typeRegexp := reflect.TypeOf(regexp.Regexp{}) + if t.Kind() == reflect.Ptr { + ptr = true + kindStr = "*" + kindStr + } - if ptr && t.Elem() != typeRegexp { + expectedType := reflect.TypeOf(regexp.Regexp{}) + + if ptr && t.Elem() != expectedType { return data, nil - } else if !ptr && t != typeRegexp { + } else if !ptr && t != expectedType { return data, nil } - regexStr := data.(string) + dataStr := data.(string) + + var result *regexp.Regexp - pattern, err := regexp.Compile(regexStr) - if err != nil { - return nil, fmt.Errorf("could not parse '%s' as regexp: %w", regexStr, err) + if dataStr != "" { + if result, err = regexp.Compile(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) + } } if ptr { - return pattern, nil + return result, nil + } + + if result == nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParseEmptyValue, kindStr, errDecodeNonPtrMustHaveValue) } - return *pattern, nil + return *result, nil } } |
