Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/ebiederm...
[~shefty/rdma-dev.git] / net / sunrpc / auth_gss / auth_gss.c
index 911ef00..6ea29f4 100644 (file)
@@ -255,7 +255,7 @@ err:
 
 struct gss_upcall_msg {
        atomic_t count;
-       uid_t   uid;
+       kuid_t  uid;
        struct rpc_pipe_msg msg;
        struct list_head list;
        struct gss_auth *auth;
@@ -302,11 +302,11 @@ gss_release_msg(struct gss_upcall_msg *gss_msg)
 }
 
 static struct gss_upcall_msg *
-__gss_find_upcall(struct rpc_pipe *pipe, uid_t uid)
+__gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid)
 {
        struct gss_upcall_msg *pos;
        list_for_each_entry(pos, &pipe->in_downcall, list) {
-               if (pos->uid != uid)
+               if (!uid_eq(pos->uid, uid))
                        continue;
                atomic_inc(&pos->count);
                dprintk("RPC:       %s found msg %p\n", __func__, pos);
@@ -394,8 +394,11 @@ gss_upcall_callback(struct rpc_task *task)
 
 static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
 {
-       gss_msg->msg.data = &gss_msg->uid;
-       gss_msg->msg.len = sizeof(gss_msg->uid);
+       uid_t uid = from_kuid(&init_user_ns, gss_msg->uid);
+       memcpy(gss_msg->databuf, &uid, sizeof(uid));
+       gss_msg->msg.data = gss_msg->databuf;
+       gss_msg->msg.len = sizeof(uid);
+       BUG_ON(sizeof(uid) > UPCALL_BUF_LEN);
 }
 
 static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
@@ -408,7 +411,7 @@ static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
 
        gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
                                   mech->gm_name,
-                                  gss_msg->uid);
+                                  from_kuid(&init_user_ns, gss_msg->uid));
        p += gss_msg->msg.len;
        if (clnt->cl_principal) {
                len = sprintf(p, "target=%s ", clnt->cl_principal);
@@ -444,7 +447,7 @@ static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
 
 static struct gss_upcall_msg *
 gss_alloc_msg(struct gss_auth *gss_auth, struct rpc_clnt *clnt,
-               uid_t uid, const char *service_name)
+               kuid_t uid, const char *service_name)
 {
        struct gss_upcall_msg *gss_msg;
        int vers;
@@ -474,7 +477,7 @@ gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cr
        struct gss_cred *gss_cred = container_of(cred,
                        struct gss_cred, gc_base);
        struct gss_upcall_msg *gss_new, *gss_msg;
-       uid_t uid = cred->cr_uid;
+       kuid_t uid = cred->cr_uid;
 
        gss_new = gss_alloc_msg(gss_auth, clnt, uid, gss_cred->gc_principal);
        if (IS_ERR(gss_new))
@@ -516,7 +519,7 @@ gss_refresh_upcall(struct rpc_task *task)
        int err = 0;
 
        dprintk("RPC: %5u %s for uid %u\n",
-               task->tk_pid, __func__, cred->cr_uid);
+               task->tk_pid, __func__, from_kuid(&init_user_ns, cred->cr_uid));
        gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
        if (PTR_ERR(gss_msg) == -EAGAIN) {
                /* XXX: warning on the first, under the assumption we
@@ -548,7 +551,8 @@ gss_refresh_upcall(struct rpc_task *task)
        gss_release_msg(gss_msg);
 out:
        dprintk("RPC: %5u %s for uid %u result %d\n",
-               task->tk_pid, __func__, cred->cr_uid, err);
+               task->tk_pid, __func__,
+               from_kuid(&init_user_ns, cred->cr_uid), err);
        return err;
 }
 
@@ -561,7 +565,8 @@ gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
        DEFINE_WAIT(wait);
        int err = 0;
 
-       dprintk("RPC:       %s for uid %u\n", __func__, cred->cr_uid);
+       dprintk("RPC:       %s for uid %u\n",
+               __func__, from_kuid(&init_user_ns, cred->cr_uid));
 retry:
        gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
        if (PTR_ERR(gss_msg) == -EAGAIN) {
@@ -603,7 +608,7 @@ out_intr:
        gss_release_msg(gss_msg);
 out:
        dprintk("RPC:       %s for uid %u result %d\n",
-               __func__, cred->cr_uid, err);
+               __func__, from_kuid(&init_user_ns, cred->cr_uid), err);
        return err;
 }
 
@@ -617,7 +622,8 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
        struct gss_upcall_msg *gss_msg;
        struct rpc_pipe *pipe = RPC_I(filp->f_dentry->d_inode)->pipe;
        struct gss_cl_ctx *ctx;
-       uid_t uid;
+       uid_t id;
+       kuid_t uid;
        ssize_t err = -EFBIG;
 
        if (mlen > MSG_BUF_MAXSIZE)
@@ -632,12 +638,18 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
                goto err;
 
        end = (const void *)((char *)buf + mlen);
-       p = simple_get_bytes(buf, end, &uid, sizeof(uid));
+       p = simple_get_bytes(buf, end, &id, sizeof(id));
        if (IS_ERR(p)) {
                err = PTR_ERR(p);
                goto err;
        }
 
+       uid = make_kuid(&init_user_ns, id);
+       if (!uid_valid(uid)) {
+               err = -EINVAL;
+               goto err;
+       }
+
        err = -ENOMEM;
        ctx = gss_alloc_context();
        if (ctx == NULL)
@@ -1058,7 +1070,8 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
        int err = -ENOMEM;
 
        dprintk("RPC:       %s for uid %d, flavor %d\n",
-               __func__, acred->uid, auth->au_flavor);
+               __func__, from_kuid(&init_user_ns, acred->uid),
+               auth->au_flavor);
 
        if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
                goto out_err;
@@ -1114,7 +1127,7 @@ out:
        }
        if (gss_cred->gc_principal != NULL)
                return 0;
-       return rc->cr_uid == acred->uid;
+       return uid_eq(rc->cr_uid, acred->uid);
 }
 
 /*