summaryrefslogtreecommitdiff
path: root/lib/privs.c
diff options
context:
space:
mode:
Diffstat (limited to 'lib/privs.c')
-rw-r--r--lib/privs.c48
1 files changed, 34 insertions, 14 deletions
diff --git a/lib/privs.c b/lib/privs.c
index 838ff8fc92..59f24afe4a 100644
--- a/lib/privs.c
+++ b/lib/privs.c
@@ -706,13 +706,21 @@ struct zebra_privs_t *_zprivs_raise(struct zebra_privs_t *privs,
if (!privs)
return NULL;
- errno = 0;
- if (privs->change(ZPRIVS_RAISE)) {
- zlog_err("%s: Failed to raise privileges (%s)",
- funcname, safe_strerror(errno));
+ /* If we're already elevated, just return */
+ pthread_mutex_lock(&(privs->mutex));
+ {
+ if (++(privs->refcount) == 1) {
+ errno = 0;
+ if (privs->change(ZPRIVS_RAISE)) {
+ zlog_err("%s: Failed to raise privileges (%s)",
+ funcname, safe_strerror(errno));
+ }
+ errno = save_errno;
+ privs->raised_in_funcname = funcname;
+ }
}
- errno = save_errno;
- privs->raised_in_funcname = funcname;
+ pthread_mutex_unlock(&(privs->mutex));
+
return privs;
}
@@ -723,13 +731,22 @@ void _zprivs_lower(struct zebra_privs_t **privs)
if (!*privs)
return;
- errno = 0;
- if ((*privs)->change(ZPRIVS_LOWER)) {
- zlog_err("%s: Failed to lower privileges (%s)",
- (*privs)->raised_in_funcname, safe_strerror(errno));
+ /* Don't lower privs if there's another caller */
+ pthread_mutex_lock(&(*privs)->mutex);
+ {
+ if (--((*privs)->refcount) == 0) {
+ errno = 0;
+ if ((*privs)->change(ZPRIVS_LOWER)) {
+ zlog_err("%s: Failed to lower privileges (%s)",
+ (*privs)->raised_in_funcname,
+ safe_strerror(errno));
+ }
+ errno = save_errno;
+ (*privs)->raised_in_funcname = NULL;
+ }
}
- errno = save_errno;
- (*privs)->raised_in_funcname = NULL;
+ pthread_mutex_unlock(&(*privs)->mutex);
+
*privs = NULL;
}
@@ -743,6 +760,9 @@ void zprivs_preinit(struct zebra_privs_t *zprivs)
exit(1);
}
+ pthread_mutex_init(&(zprivs->mutex), NULL);
+ zprivs->refcount = 0;
+
if (zprivs->vty_group) {
/* in a "NULL" setup, this is allowed to fail too, but still
* try. */
@@ -789,7 +809,7 @@ void zprivs_preinit(struct zebra_privs_t *zprivs)
void zprivs_init(struct zebra_privs_t *zprivs)
{
- gid_t groups[NGROUPS_MAX];
+ gid_t groups[NGROUPS_MAX] = {};
int i, ngroups = 0;
int found = 0;
@@ -799,7 +819,7 @@ void zprivs_init(struct zebra_privs_t *zprivs)
return;
if (zprivs->user) {
- ngroups = sizeof(groups);
+ ngroups = array_size(groups);
if (getgrouplist(zprivs->user, zprivs_state.zgid, groups,
&ngroups)
< 0) {