From f2355a563354ea56330c427a1cb4f5f38ee8b7b9 Mon Sep 17 00:00:00 2001 From: Keith Packard Date: Thu, 9 May 2019 10:36:36 -0700 Subject: [PATCH] Let tests pass more than one FD Signed-off-by: Keith Packard --- fdpass.c | 106 ++++++++++++++++++++++++++++++++------------------- fdpass.h | 6 ++- fdpassing.c | 24 ++++++++---- lostfd.c | 24 ++++++++---- multiwrite.c | 23 +++++++---- xreq.c | 14 ++++--- zerowrite.c | 15 ++++---- 7 files changed, 134 insertions(+), 78 deletions(-) diff --git a/fdpass.c b/fdpass.c index 6e2e0c7..74ac148 100644 --- a/fdpass.c +++ b/fdpass.c @@ -17,19 +17,26 @@ #include "fdpass.h" +#define MAX_FDS 128 + +struct fd_pass { + struct cmsghdr cmsghdr; + int fd[MAX_FDS]; +}; + ssize_t -sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd) +sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd, int *nfdp) { ssize_t size; + printf ("sock_fd_read\n"); if (fd) { struct msghdr msg; struct iovec iov; - union { - struct cmsghdr cmsghdr; - char control[CMSG_SPACE(sizeof (int))]; - } cmsgu; - struct cmsghdr *cmsg; + struct fd_pass pass; + int nfd_passed, nfd; + int i; + int *fd_passed; iov.iov_base = buf; iov.iov_len = bufsize; @@ -38,35 +45,47 @@ sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd) msg.msg_namelen = 0; msg.msg_iov = &iov; msg.msg_iovlen = 1; - msg.msg_control = cmsgu.control; - msg.msg_controllen = sizeof(cmsgu.control); + msg.msg_control = &pass; + msg.msg_controllen = sizeof pass; size = recvmsg (sock, &msg, 0); if (size < 0) { perror ("recvmsg"); exit(1); } - if ((msg.msg_flags & MSG_TRUNC) || - (msg.msg_flags & MSG_CTRUNC)) { - fprintf (stderr, "control message truncated"); - exit(1); - } - cmsg = CMSG_FIRSTHDR(&msg); - if (cmsg && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) { - if (cmsg->cmsg_level != SOL_SOCKET) { - fprintf (stderr, "invalid cmsg_level %d\n", - cmsg->cmsg_level); + if (size > 0 && pass.cmsghdr.cmsg_len > sizeof (struct cmsghdr)) { + if ((msg.msg_flags & MSG_TRUNC) || + (msg.msg_flags & MSG_CTRUNC)) { + fprintf (stderr, "control message truncated"); exit(1); } - if (cmsg->cmsg_type != SCM_RIGHTS) { + if (pass.cmsghdr.cmsg_level != SOL_SOCKET) { + fprintf (stderr, "invalid cmsg_level %d\n", + pass.cmsghdr.cmsg_level); + exit(1); + } + if (pass.cmsghdr.cmsg_type != SCM_RIGHTS) { fprintf (stderr, "invalid cmsg_type %d\n", - cmsg->cmsg_type); + pass.cmsghdr.cmsg_type); exit(1); } - *fd = *((int *) CMSG_DATA(cmsg)); - printf ("received fd %d\n", *fd); + nfd_passed = (pass.cmsghdr.cmsg_len - sizeof (struct cmsghdr)) / sizeof (int); + fd_passed = (int *) CMSG_DATA(&pass.cmsghdr); + + nfd = *nfdp; + if (nfd > nfd_passed) + nfd = nfd_passed; + + memcpy(fd, fd_passed, nfd * sizeof (int)); + for (i = 0; i < nfd; i++) + printf ("received fd %d\n", fd[i]); + for (i = nfd; i < nfd_passed; i++) { + printf ("dropping fd %d\n", fd_passed[i]); + close(fd_passed[i]); + } + *nfdp = nfd; } else - *fd = -1; + *nfdp = 0; } else { size = read (sock, buf, bufsize); if (size < 0) { @@ -78,36 +97,43 @@ sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd) } ssize_t -sock_fd_write(int sock, void *buf, ssize_t buflen, int fd) +sock_fd_write(int sock, void *buf, ssize_t buflen, int *fd, int nfd) { ssize_t size; struct msghdr msg; struct iovec iov; - union { - struct cmsghdr cmsghdr; - char control[CMSG_SPACE(sizeof (int))]; - } cmsgu; - struct cmsghdr *cmsg; + struct fd_pass pass; + int i; iov.iov_base = buf; iov.iov_len = buflen; msg.msg_name = NULL; msg.msg_namelen = 0; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; + if (buflen) { + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + } else { + msg.msg_iov = NULL; + msg.msg_iovlen = 0; + } - if (fd != -1) { - msg.msg_control = cmsgu.control; - msg.msg_controllen = sizeof(cmsgu.control); + if (nfd) { + if (nfd > MAX_FDS) { + printf ("passing too many fds\n"); + exit(1); + } - cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = CMSG_LEN(sizeof (int)); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; + msg.msg_control = &pass; + msg.msg_controllen = sizeof (struct cmsghdr) + nfd * sizeof (int); - printf ("passing fd %d\n", fd); - *((int *) CMSG_DATA(cmsg)) = fd; + pass.cmsghdr.cmsg_len = msg.msg_controllen; + pass.cmsghdr.cmsg_level = SOL_SOCKET; + pass.cmsghdr.cmsg_type = SCM_RIGHTS; + + memcpy(&pass.fd, fd, nfd * sizeof (int)); + for (i = 0; i < nfd; i++) + printf ("passing fd %d\n", pass.fd[i]); } else { msg.msg_control = NULL; msg.msg_controllen = 0; diff --git a/fdpass.h b/fdpass.h index 236849a..4f6df4a 100644 --- a/fdpass.h +++ b/fdpass.h @@ -20,16 +20,18 @@ #include #include +#include #include #include #include #include #include +#include ssize_t -sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd); +sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd, int *nfd); ssize_t -sock_fd_write(int sock, void *buf, ssize_t buflen, int fd); +sock_fd_write(int sock, void *buf, ssize_t buflen, int *fd, int nfd); #endif diff --git a/fdpassing.c b/fdpassing.c index 871e3bf..4ee63c2 100644 --- a/fdpassing.c +++ b/fdpassing.c @@ -17,22 +17,29 @@ #include "fdpass.h" +#define NFD 128 + void child(int sock) { - int fd; + int fd[NFD]; + int nfd; char buf[16]; ssize_t size; + int i; sleep(1); for (;;) { - size = sock_fd_read(sock, buf, sizeof(buf), &fd); + nfd = NFD; + size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd); if (size <= 0) break; printf ("read %d\n", size); - if (fd != -1) { - write(fd, "hello, world\n", 13); - close(fd); + if (nfd) { + for (i = 0; i < nfd; i++) { + write(fd[i], "hello, world\n", 13); + close(fd[i]); + } } } } @@ -42,10 +49,11 @@ parent(int sock) { ssize_t size; int i; - int fd; + int fd[NFD]; - fd = 1; - size = sock_fd_write(sock, "1", 1, 1); + for (i = 0; i < NFD; i++) + fd[i] = 1 + (i & 1); + size = sock_fd_write(sock, "1", 1, fd, NFD); printf ("wrote %d\n", size); } diff --git a/lostfd.c b/lostfd.c index d3cf996..bdd8c29 100644 --- a/lostfd.c +++ b/lostfd.c @@ -17,25 +17,32 @@ #include "fdpass.h" +#define NFD 16 + void child(int sock) { - int fd; + int fd[NFD]; + int nfd; char buf[16]; ssize_t size; + int i; sleep(1); - size = sock_fd_read(sock, buf, sizeof(buf), NULL); + size = sock_fd_read(sock, buf, sizeof(buf), NULL, NULL); if (size <= 0) return; printf ("read %d\n", size); - size = sock_fd_read(sock, buf, sizeof(buf), &fd); + nfd = NFD; + size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd); if (size <= 0) return; printf ("read %d\n", size); - if (fd != -1) { - write(fd, "hello, world\n", 13); - close(fd); + if (nfd) { + for (i = 0; i < nfd; i++) { + write(fd[i], "hello, world\n", 13); + close(fd[i]); + } } } @@ -47,9 +54,10 @@ parent(int sock) int fd; fd = 1; - size = sock_fd_write(sock, "1", 1, 1); + size = sock_fd_write(sock, "1", 1, &fd, 1); printf ("wrote %d without fd\n", size); - size = sock_fd_write(sock, "1", 1, 2); + fd = 2; + size = sock_fd_write(sock, "1", 1, &fd, 1); printf ("wrote %d with fd\n", size); } diff --git a/multiwrite.c b/multiwrite.c index d2d73c1..d4159f9 100644 --- a/multiwrite.c +++ b/multiwrite.c @@ -17,22 +17,29 @@ #include "fdpass.h" +#define NFD 32 + void child(int sock) { - int fd; + int fd[NFD]; char buf[16]; ssize_t size; + int nfd; + int i; sleep(1); for (;;) { - size = sock_fd_read(sock, buf, sizeof(buf), &fd); + nfd = NFD; + size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd); if (size <= 0) break; printf ("read %d\n", size); - if (fd != -1) { - write(fd, "hello, world\n", 13); - close(fd); + if (nfd) { + for (i = 0; i < nfd; i++) { + write(fd[i], "hello, world\n", 13); + close(fd[i]); + } } } } @@ -45,11 +52,11 @@ parent(int sock) int fd; fd = 1; - size = sock_fd_write(sock, "1", 1, -1); + size = sock_fd_write(sock, "1", 1, NULL, 0); printf ("wrote %d without fd\n", size); - size = sock_fd_write(sock, "1", 1, 1); + size = sock_fd_write(sock, "1", 1, &fd, 1); printf ("wrote %d with fd\n", size); - size = sock_fd_write(sock, "1", 1, -1); + size = sock_fd_write(sock, "1", 1, NULL, 0); printf ("wrote %d without fd\n", size); } diff --git a/xreq.c b/xreq.c index d266db6..58ff0be 100644 --- a/xreq.c +++ b/xreq.c @@ -27,11 +27,13 @@ child(int sock) int i, reqlen; ssize_t size, fdsize; int fd = -1, *fdp; + int nfd; int j; sleep (1); for (j = 0;; j++) { - size = sock_fd_read(sock, xreq, sizeof (xreq), NULL); + nfd = 0; + size = sock_fd_read(sock, xreq, sizeof (xreq), NULL, &nfd); printf ("got %d\n", size); if (size == 0) break; @@ -48,7 +50,8 @@ child(int sock) fprintf (stderr, "Got fd req, but not at end of input %d < %d\n", i, size); } - fdsize = sock_fd_read(sock, xnop, sizeof (xnop), &fd); + nfd = 1; + fdsize = sock_fd_read(sock, xnop, sizeof (xnop), &fd, &nfd); if (fd == -1) { fprintf (stderr, "no fd received\n"); } else { @@ -89,19 +92,20 @@ parent(int sock) for (i = 0; i < 8; i++) { xreq[0] = 0; xreq[1] = sizeof (xreq); - sock_fd_write(sock, xreq, sizeof (xreq), -1); + sock_fd_write(sock, xreq, sizeof (xreq), NULL, 0); } /* Write our 'pass an fd' request with a 'useless' FD to block the receiver */ xreq[0] = 1; xreq[1] = sizeof(xreq); - sock_fd_write(sock, xreq, sizeof (xreq), 1); + fd = 1; + sock_fd_write(sock, xreq, sizeof (xreq), &fd, 1); /* Pass an fd */ xnop[0] = 2; xnop[1] = sizeof (xnop); fd = tmp_file(j); - sock_fd_write(sock, xnop, sizeof (xnop), fd); + sock_fd_write(sock, xnop, sizeof (xnop), &fd, 1); close(fd); } } diff --git a/zerowrite.c b/zerowrite.c index 334af44..b7775c6 100644 --- a/zerowrite.c +++ b/zerowrite.c @@ -20,17 +20,18 @@ void child(int sock) { - int fd; + int fd, nfd; char buf[16]; ssize_t size; sleep(1); for (;;) { - size = sock_fd_read(sock, buf, sizeof(buf), &fd); + nfd = 1; + size = sock_fd_read(sock, buf, sizeof(buf), &fd, &nfd); if (size <= 0) break; - printf ("read %d\n", size); - if (fd != -1) { + printf ("read %d nfd %d\n", size, nfd); + if (nfd == 1) { write(fd, "hello, world\n", 13); close(fd); } @@ -45,11 +46,11 @@ parent(int sock) int fd; fd = 1; - size = sock_fd_write(sock, "1", 1, -1); + size = sock_fd_write(sock, "1", 1, NULL, 0); printf ("wrote %d without fd\n", size); - size = sock_fd_write(sock, NULL, 0, 1); + size = sock_fd_write(sock, NULL, 0, &fd, 1); printf ("wrote %d with fd\n", size); - size = sock_fd_write(sock, "1", 1, -1); + size = sock_fd_write(sock, "1", 1, NULL, 0); printf ("wrote %d without fd\n", size); }