diff options
Diffstat (limited to 'internal/configuration/validator/access_control.go')
| -rw-r--r-- | internal/configuration/validator/access_control.go | 125 |
1 files changed, 57 insertions, 68 deletions
diff --git a/internal/configuration/validator/access_control.go b/internal/configuration/validator/access_control.go index aa392351e..86aa37996 100644 --- a/internal/configuration/validator/access_control.go +++ b/internal/configuration/validator/access_control.go @@ -2,7 +2,6 @@ package validator import ( "fmt" - "net" "regexp" "strings" @@ -11,51 +10,6 @@ import ( "github.com/authelia/authelia/v4/internal/utils" ) -// IsPolicyValid check if policy is valid. -func IsPolicyValid(policy string) (isValid bool) { - return utils.IsStringInSlice(policy, validACLRulePolicies) -} - -// IsSubjectValid check if a subject is valid. -func IsSubjectValid(subject string) (isValid bool) { - return subject == "" || IsSubjectValidStrict(subject) || strings.HasPrefix(subject, "oauth2:client:") -} - -func IsSubjectValidStrict(subject string) (isValid bool) { - return strings.HasPrefix(subject, "user:") || strings.HasPrefix(subject, "group:") -} - -// IsNetworkGroupValid check if a network group is valid. -func IsNetworkGroupValid(config schema.AccessControl, network string) bool { - for _, networks := range config.Networks { - if network != networks.Name { - continue - } else { - return true - } - } - - return false -} - -// IsNetworkValid checks if a network is valid. -func IsNetworkValid(network string) (isValid bool) { - if net.ParseIP(network) == nil { - _, _, err := net.ParseCIDR(network) - return err == nil - } - - return true -} - -func ruleDescriptor(position int, rule schema.AccessControlRule) string { - if len(rule.Domains) == 0 { - return fmt.Sprintf("#%d", position) - } - - return fmt.Sprintf("#%d (domain '%s')", position, strings.Join(rule.Domains, ",")) -} - // ValidateAccessControl validates access control configuration. func ValidateAccessControl(config *schema.Configuration, validator *schema.StructValidator) { if config.AccessControl.DefaultPolicy == "" { @@ -65,14 +19,6 @@ func ValidateAccessControl(config *schema.Configuration, validator *schema.Struc if !IsPolicyValid(config.AccessControl.DefaultPolicy) { validator.Push(fmt.Errorf(errFmtAccessControlDefaultPolicyValue, utils.StringJoinOr(validACLRulePolicies), config.AccessControl.DefaultPolicy)) } - - for _, n := range config.AccessControl.Networks { - for _, networks := range n.Networks { - if !IsNetworkValid(networks) { - validator.Push(fmt.Errorf(errFmtAccessControlNetworkGroupIPCIDRInvalid, n.Name, networks)) - } - } - } } // ValidateRules validates an ACL Rule configuration. @@ -103,9 +49,7 @@ func ValidateRules(config *schema.Configuration, validator *schema.StructValidat } } - validateNetworks(rulePosition, rule, config.AccessControl, validator) - - validateSubjects(rulePosition, rule, validator) + validateSubjects(rulePosition, rule, config, validator) validateMethods(rulePosition, rule, validator) @@ -142,21 +86,22 @@ func validateDomains(rulePosition int, rule schema.AccessControlRule, validator } } -func validateNetworks(rulePosition int, rule schema.AccessControlRule, config schema.AccessControl, validator *schema.StructValidator) { - for _, network := range rule.Networks { - if !IsNetworkValid(network) { - if !IsNetworkGroupValid(config, network) { - validator.Push(fmt.Errorf(errFmtAccessControlRuleNetworksInvalid, ruleDescriptor(rulePosition, rule), network)) - } - } - } -} +func validateSubjects(rulePosition int, rule schema.AccessControlRule, config *schema.Configuration, validator *schema.StructValidator) { + var ( + id string + isValid bool + ) -func validateSubjects(rulePosition int, rule schema.AccessControlRule, validator *schema.StructValidator) { for _, subjectRule := range rule.Subjects { for _, subject := range subjectRule { - if !IsSubjectValid(subject) { + if id, isValid = IsSubjectValid(subject); !isValid { validator.Push(fmt.Errorf(errFmtAccessControlRuleSubjectInvalid, ruleDescriptor(rulePosition, rule), subject)) + + continue + } + + if len(id) != 0 && !IsSubjectValidOAuth20(config, id) { + validator.Push(fmt.Errorf(errFmtAccessControlRuleOAuth2ClientSubjectInvalid, ruleDescriptor(rulePosition, rule), subject, id)) } } } @@ -230,3 +175,47 @@ func validateQuery(i int, rule schema.AccessControlRule, config *schema.Configur } } } + +// IsPolicyValid check if policy is valid. +func IsPolicyValid(policy string) (isValid bool) { + return utils.IsStringInSlice(policy, validACLRulePolicies) +} + +// IsSubjectValid validates if a subject has a valid prefix and returns the client id if applicable. +func IsSubjectValid(subject string) (id string, isValid bool) { + if IsSubjectValidBasic(subject) { + return "", true + } + + if strings.HasPrefix(subject, "oauth2:client:") { + return strings.TrimPrefix(subject, "oauth2:client:"), true + } + + return "", false +} + +func IsSubjectValidBasic(subject string) (isValid bool) { + return strings.HasPrefix(subject, "user:") || strings.HasPrefix(subject, "group:") +} + +func IsSubjectValidOAuth20(config *schema.Configuration, id string) (isValid bool) { + if config.IdentityProviders.OIDC == nil || len(config.IdentityProviders.OIDC.Clients) == 0 { + return false + } + + for _, client := range config.IdentityProviders.OIDC.Clients { + if client.ID == id { + return true + } + } + + return false +} + +func ruleDescriptor(position int, rule schema.AccessControlRule) string { + if len(rule.Domains) == 0 { + return fmt.Sprintf("#%d", position) + } + + return fmt.Sprintf("#%d (domain '%s')", position, strings.Join(rule.Domains, ",")) +} |
