summaryrefslogtreecommitdiff
path: root/internal/configuration/validator/access_control.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/configuration/validator/access_control.go')
-rw-r--r--internal/configuration/validator/access_control.go125
1 files changed, 57 insertions, 68 deletions
diff --git a/internal/configuration/validator/access_control.go b/internal/configuration/validator/access_control.go
index aa392351e..86aa37996 100644
--- a/internal/configuration/validator/access_control.go
+++ b/internal/configuration/validator/access_control.go
@@ -2,7 +2,6 @@ package validator
import (
"fmt"
- "net"
"regexp"
"strings"
@@ -11,51 +10,6 @@ import (
"github.com/authelia/authelia/v4/internal/utils"
)
-// IsPolicyValid check if policy is valid.
-func IsPolicyValid(policy string) (isValid bool) {
- return utils.IsStringInSlice(policy, validACLRulePolicies)
-}
-
-// IsSubjectValid check if a subject is valid.
-func IsSubjectValid(subject string) (isValid bool) {
- return subject == "" || IsSubjectValidStrict(subject) || strings.HasPrefix(subject, "oauth2:client:")
-}
-
-func IsSubjectValidStrict(subject string) (isValid bool) {
- return strings.HasPrefix(subject, "user:") || strings.HasPrefix(subject, "group:")
-}
-
-// IsNetworkGroupValid check if a network group is valid.
-func IsNetworkGroupValid(config schema.AccessControl, network string) bool {
- for _, networks := range config.Networks {
- if network != networks.Name {
- continue
- } else {
- return true
- }
- }
-
- return false
-}
-
-// IsNetworkValid checks if a network is valid.
-func IsNetworkValid(network string) (isValid bool) {
- if net.ParseIP(network) == nil {
- _, _, err := net.ParseCIDR(network)
- return err == nil
- }
-
- return true
-}
-
-func ruleDescriptor(position int, rule schema.AccessControlRule) string {
- if len(rule.Domains) == 0 {
- return fmt.Sprintf("#%d", position)
- }
-
- return fmt.Sprintf("#%d (domain '%s')", position, strings.Join(rule.Domains, ","))
-}
-
// ValidateAccessControl validates access control configuration.
func ValidateAccessControl(config *schema.Configuration, validator *schema.StructValidator) {
if config.AccessControl.DefaultPolicy == "" {
@@ -65,14 +19,6 @@ func ValidateAccessControl(config *schema.Configuration, validator *schema.Struc
if !IsPolicyValid(config.AccessControl.DefaultPolicy) {
validator.Push(fmt.Errorf(errFmtAccessControlDefaultPolicyValue, utils.StringJoinOr(validACLRulePolicies), config.AccessControl.DefaultPolicy))
}
-
- for _, n := range config.AccessControl.Networks {
- for _, networks := range n.Networks {
- if !IsNetworkValid(networks) {
- validator.Push(fmt.Errorf(errFmtAccessControlNetworkGroupIPCIDRInvalid, n.Name, networks))
- }
- }
- }
}
// ValidateRules validates an ACL Rule configuration.
@@ -103,9 +49,7 @@ func ValidateRules(config *schema.Configuration, validator *schema.StructValidat
}
}
- validateNetworks(rulePosition, rule, config.AccessControl, validator)
-
- validateSubjects(rulePosition, rule, validator)
+ validateSubjects(rulePosition, rule, config, validator)
validateMethods(rulePosition, rule, validator)
@@ -142,21 +86,22 @@ func validateDomains(rulePosition int, rule schema.AccessControlRule, validator
}
}
-func validateNetworks(rulePosition int, rule schema.AccessControlRule, config schema.AccessControl, validator *schema.StructValidator) {
- for _, network := range rule.Networks {
- if !IsNetworkValid(network) {
- if !IsNetworkGroupValid(config, network) {
- validator.Push(fmt.Errorf(errFmtAccessControlRuleNetworksInvalid, ruleDescriptor(rulePosition, rule), network))
- }
- }
- }
-}
+func validateSubjects(rulePosition int, rule schema.AccessControlRule, config *schema.Configuration, validator *schema.StructValidator) {
+ var (
+ id string
+ isValid bool
+ )
-func validateSubjects(rulePosition int, rule schema.AccessControlRule, validator *schema.StructValidator) {
for _, subjectRule := range rule.Subjects {
for _, subject := range subjectRule {
- if !IsSubjectValid(subject) {
+ if id, isValid = IsSubjectValid(subject); !isValid {
validator.Push(fmt.Errorf(errFmtAccessControlRuleSubjectInvalid, ruleDescriptor(rulePosition, rule), subject))
+
+ continue
+ }
+
+ if len(id) != 0 && !IsSubjectValidOAuth20(config, id) {
+ validator.Push(fmt.Errorf(errFmtAccessControlRuleOAuth2ClientSubjectInvalid, ruleDescriptor(rulePosition, rule), subject, id))
}
}
}
@@ -230,3 +175,47 @@ func validateQuery(i int, rule schema.AccessControlRule, config *schema.Configur
}
}
}
+
+// IsPolicyValid check if policy is valid.
+func IsPolicyValid(policy string) (isValid bool) {
+ return utils.IsStringInSlice(policy, validACLRulePolicies)
+}
+
+// IsSubjectValid validates if a subject has a valid prefix and returns the client id if applicable.
+func IsSubjectValid(subject string) (id string, isValid bool) {
+ if IsSubjectValidBasic(subject) {
+ return "", true
+ }
+
+ if strings.HasPrefix(subject, "oauth2:client:") {
+ return strings.TrimPrefix(subject, "oauth2:client:"), true
+ }
+
+ return "", false
+}
+
+func IsSubjectValidBasic(subject string) (isValid bool) {
+ return strings.HasPrefix(subject, "user:") || strings.HasPrefix(subject, "group:")
+}
+
+func IsSubjectValidOAuth20(config *schema.Configuration, id string) (isValid bool) {
+ if config.IdentityProviders.OIDC == nil || len(config.IdentityProviders.OIDC.Clients) == 0 {
+ return false
+ }
+
+ for _, client := range config.IdentityProviders.OIDC.Clients {
+ if client.ID == id {
+ return true
+ }
+ }
+
+ return false
+}
+
+func ruleDescriptor(position int, rule schema.AccessControlRule) string {
+ if len(rule.Domains) == 0 {
+ return fmt.Sprintf("#%d", position)
+ }
+
+ return fmt.Sprintf("#%d (domain '%s')", position, strings.Join(rule.Domains, ","))
+}