summaryrefslogtreecommitdiff
path: root/internal/configuration/decode_hooks.go
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2022-04-03 22:44:52 +1000
committerGitHub <noreply@github.com>2022-04-03 22:44:52 +1000
commit7230db7ceac7854dccb4d69d8b2d632ceeca2f0b (patch)
tree08dbc7f208e7c787d2c93f3dfff3be9fbed9d1f5 /internal/configuration/decode_hooks.go
parentbfd5d66ed8fd60a12b88ecf56144c9f74085bb73 (diff)
refactor(configuration): decode_hooks blackbox and better testing (#3097)
Diffstat (limited to 'internal/configuration/decode_hooks.go')
-rw-r--r--internal/configuration/decode_hooks.go137
1 files changed, 88 insertions, 49 deletions
diff --git a/internal/configuration/decode_hooks.go b/internal/configuration/decode_hooks.go
index 4258513a4..e7fa14988 100644
--- a/internal/configuration/decode_hooks.go
+++ b/internal/configuration/decode_hooks.go
@@ -13,32 +13,53 @@ import (
"github.com/authelia/authelia/v4/internal/utils"
)
-// StringToMailAddressHookFunc decodes a string into a mail.Address.
+// StringToMailAddressHookFunc decodes a string into a mail.Address or *mail.Address.
func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
- if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) {
+ var ptr bool
+
+ if f.Kind() != reflect.String {
+ return data, nil
+ }
+
+ kindStr := "mail.Address (RFC5322)"
+
+ if t.Kind() == reflect.Ptr {
+ ptr = true
+ kindStr = "*" + kindStr
+ }
+
+ expectedType := reflect.TypeOf(mail.Address{})
+
+ if ptr && t.Elem() != expectedType {
+ return data, nil
+ } else if !ptr && t != expectedType {
return data, nil
}
dataStr := data.(string)
- if dataStr == "" {
- return mail.Address{}, nil
+ var result *mail.Address
+
+ if dataStr != "" {
+ if result, err = mail.ParseAddress(dataStr); err != nil {
+ return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
+ }
}
- var (
- parsedAddress *mail.Address
- )
+ if ptr {
+ return result, nil
+ }
- if parsedAddress, err = mail.ParseAddress(dataStr); err != nil {
- return nil, fmt.Errorf("could not parse '%s' as a RFC5322 address: %w", dataStr, err)
+ if result == nil {
+ return mail.Address{}, nil
}
- return *parsedAddress, nil
+ return *result, nil
}
}
-// StringToURLHookFunc converts string types into a url.URL.
+// StringToURLHookFunc converts string types into a url.URL or *url.URL.
func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var ptr bool
@@ -47,37 +68,40 @@ func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
return data, nil
}
- ptr = t.Kind() == reflect.Ptr
+ kindStr := "url.URL"
- typeURL := reflect.TypeOf(url.URL{})
+ if t.Kind() == reflect.Ptr {
+ ptr = true
+ kindStr = "*" + kindStr
+ }
- if ptr && t.Elem() != typeURL {
+ expectedType := reflect.TypeOf(url.URL{})
+
+ if ptr && t.Elem() != expectedType {
return data, nil
- } else if !ptr && t != typeURL {
+ } else if !ptr && t != expectedType {
return data, nil
}
dataStr := data.(string)
- var parsedURL *url.URL
+ var result *url.URL
- // Return an empty URL if there is an empty string.
if dataStr != "" {
- if parsedURL, err = url.Parse(dataStr); err != nil {
- return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err)
+ if result, err = url.Parse(dataStr); err != nil {
+ return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
}
}
if ptr {
- return parsedURL, nil
+ return result, nil
}
- // Return an empty URL if there is an empty string.
- if parsedURL == nil {
+ if result == nil {
return url.URL{}, nil
}
- return *parsedURL, nil
+ return *result, nil
}
}
@@ -94,48 +118,51 @@ func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType {
return data, nil
}
- typeTimeDuration := reflect.TypeOf(time.Hour)
+ kindStr := "time.Duration"
if t.Kind() == reflect.Ptr {
- if t.Elem() != typeTimeDuration {
- return data, nil
- }
-
ptr = true
- } else if t != typeTimeDuration {
+ kindStr = "*" + kindStr
+ }
+
+ expectedType := reflect.TypeOf(time.Duration(0))
+
+ if ptr && t.Elem() != expectedType {
+ return data, nil
+ } else if !ptr && t != expectedType {
return data, nil
}
- var duration time.Duration
+ var result time.Duration
switch {
case f.Kind() == reflect.String:
dataStr := data.(string)
- if duration, err = utils.ParseDurationString(dataStr); err != nil {
- return nil, err
+ if result, err = utils.ParseDurationString(dataStr); err != nil {
+ return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
}
case f.Kind() == reflect.Int:
seconds := data.(int)
- duration = time.Second * time.Duration(seconds)
+ result = time.Second * time.Duration(seconds)
case f.Kind() == reflect.Int32:
seconds := data.(int32)
- duration = time.Second * time.Duration(seconds)
- case f == typeTimeDuration:
- duration = data.(time.Duration)
+ result = time.Second * time.Duration(seconds)
+ case f == expectedType:
+ result = data.(time.Duration)
case f.Kind() == reflect.Int64:
seconds := data.(int64)
- duration = time.Second * time.Duration(seconds)
+ result = time.Second * time.Duration(seconds)
}
if ptr {
- return &duration, nil
+ return &result, nil
}
- return duration, nil
+ return result, nil
}
}
@@ -148,27 +175,39 @@ func StringToRegexpFunc() mapstructure.DecodeHookFuncType {
return data, nil
}
- ptr = t.Kind() == reflect.Ptr
+ kindStr := "regexp.Regexp"
- typeRegexp := reflect.TypeOf(regexp.Regexp{})
+ if t.Kind() == reflect.Ptr {
+ ptr = true
+ kindStr = "*" + kindStr
+ }
- if ptr && t.Elem() != typeRegexp {
+ expectedType := reflect.TypeOf(regexp.Regexp{})
+
+ if ptr && t.Elem() != expectedType {
return data, nil
- } else if !ptr && t != typeRegexp {
+ } else if !ptr && t != expectedType {
return data, nil
}
- regexStr := data.(string)
+ dataStr := data.(string)
+
+ var result *regexp.Regexp
- pattern, err := regexp.Compile(regexStr)
- if err != nil {
- return nil, fmt.Errorf("could not parse '%s' as regexp: %w", regexStr, err)
+ if dataStr != "" {
+ if result, err = regexp.Compile(dataStr); err != nil {
+ return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
+ }
}
if ptr {
- return pattern, nil
+ return result, nil
+ }
+
+ if result == nil {
+ return nil, fmt.Errorf(errFmtDecodeHookCouldNotParseEmptyValue, kindStr, errDecodeNonPtrMustHaveValue)
}
- return *pattern, nil
+ return *result, nil
}
}