Let tests pass more than one FD

Signed-off-by: Keith Packard <keithp@keithp.com>
This commit is contained in:
Keith Packard 2019-05-09 10:36:36 -07:00
parent d4d31cd5f1
commit f2355a5633
7 changed files with 134 additions and 78 deletions

106
fdpass.c
View File

@ -17,19 +17,26 @@
#include "fdpass.h" #include "fdpass.h"
#define MAX_FDS 128
struct fd_pass {
struct cmsghdr cmsghdr;
int fd[MAX_FDS];
};
ssize_t ssize_t
sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd) sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd, int *nfdp)
{ {
ssize_t size; ssize_t size;
printf ("sock_fd_read\n");
if (fd) { if (fd) {
struct msghdr msg; struct msghdr msg;
struct iovec iov; struct iovec iov;
union { struct fd_pass pass;
struct cmsghdr cmsghdr; int nfd_passed, nfd;
char control[CMSG_SPACE(sizeof (int))]; int i;
} cmsgu; int *fd_passed;
struct cmsghdr *cmsg;
iov.iov_base = buf; iov.iov_base = buf;
iov.iov_len = bufsize; iov.iov_len = bufsize;
@ -38,35 +45,47 @@ sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd)
msg.msg_namelen = 0; msg.msg_namelen = 0;
msg.msg_iov = &iov; msg.msg_iov = &iov;
msg.msg_iovlen = 1; msg.msg_iovlen = 1;
msg.msg_control = cmsgu.control; msg.msg_control = &pass;
msg.msg_controllen = sizeof(cmsgu.control); msg.msg_controllen = sizeof pass;
size = recvmsg (sock, &msg, 0); size = recvmsg (sock, &msg, 0);
if (size < 0) { if (size < 0) {
perror ("recvmsg"); perror ("recvmsg");
exit(1); exit(1);
} }
if ((msg.msg_flags & MSG_TRUNC) || if (size > 0 && pass.cmsghdr.cmsg_len > sizeof (struct cmsghdr)) {
(msg.msg_flags & MSG_CTRUNC)) { if ((msg.msg_flags & MSG_TRUNC) ||
fprintf (stderr, "control message truncated"); (msg.msg_flags & MSG_CTRUNC)) {
exit(1); fprintf (stderr, "control message truncated");
}
cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
if (cmsg->cmsg_level != SOL_SOCKET) {
fprintf (stderr, "invalid cmsg_level %d\n",
cmsg->cmsg_level);
exit(1); exit(1);
} }
if (cmsg->cmsg_type != SCM_RIGHTS) { if (pass.cmsghdr.cmsg_level != SOL_SOCKET) {
fprintf (stderr, "invalid cmsg_level %d\n",
pass.cmsghdr.cmsg_level);
exit(1);
}
if (pass.cmsghdr.cmsg_type != SCM_RIGHTS) {
fprintf (stderr, "invalid cmsg_type %d\n", fprintf (stderr, "invalid cmsg_type %d\n",
cmsg->cmsg_type); pass.cmsghdr.cmsg_type);
exit(1); exit(1);
} }
*fd = *((int *) CMSG_DATA(cmsg)); nfd_passed = (pass.cmsghdr.cmsg_len - sizeof (struct cmsghdr)) / sizeof (int);
printf ("received fd %d\n", *fd); fd_passed = (int *) CMSG_DATA(&pass.cmsghdr);
nfd = *nfdp;
if (nfd > nfd_passed)
nfd = nfd_passed;
memcpy(fd, fd_passed, nfd * sizeof (int));
for (i = 0; i < nfd; i++)
printf ("received fd %d\n", fd[i]);
for (i = nfd; i < nfd_passed; i++) {
printf ("dropping fd %d\n", fd_passed[i]);
close(fd_passed[i]);
}
*nfdp = nfd;
} else } else
*fd = -1; *nfdp = 0;
} else { } else {
size = read (sock, buf, bufsize); size = read (sock, buf, bufsize);
if (size < 0) { if (size < 0) {
@ -78,36 +97,43 @@ sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd)
} }
ssize_t ssize_t
sock_fd_write(int sock, void *buf, ssize_t buflen, int fd) sock_fd_write(int sock, void *buf, ssize_t buflen, int *fd, int nfd)
{ {
ssize_t size; ssize_t size;
struct msghdr msg; struct msghdr msg;
struct iovec iov; struct iovec iov;
union { struct fd_pass pass;
struct cmsghdr cmsghdr; int i;
char control[CMSG_SPACE(sizeof (int))];
} cmsgu;
struct cmsghdr *cmsg;
iov.iov_base = buf; iov.iov_base = buf;
iov.iov_len = buflen; iov.iov_len = buflen;
msg.msg_name = NULL; msg.msg_name = NULL;
msg.msg_namelen = 0; msg.msg_namelen = 0;
msg.msg_iov = &iov; if (buflen) {
msg.msg_iovlen = 1; msg.msg_iov = &iov;
msg.msg_iovlen = 1;
} else {
msg.msg_iov = NULL;
msg.msg_iovlen = 0;
}
if (fd != -1) { if (nfd) {
msg.msg_control = cmsgu.control; if (nfd > MAX_FDS) {
msg.msg_controllen = sizeof(cmsgu.control); printf ("passing too many fds\n");
exit(1);
}
cmsg = CMSG_FIRSTHDR(&msg); msg.msg_control = &pass;
cmsg->cmsg_len = CMSG_LEN(sizeof (int)); msg.msg_controllen = sizeof (struct cmsghdr) + nfd * sizeof (int);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
printf ("passing fd %d\n", fd); pass.cmsghdr.cmsg_len = msg.msg_controllen;
*((int *) CMSG_DATA(cmsg)) = fd; pass.cmsghdr.cmsg_level = SOL_SOCKET;
pass.cmsghdr.cmsg_type = SCM_RIGHTS;
memcpy(&pass.fd, fd, nfd * sizeof (int));
for (i = 0; i < nfd; i++)
printf ("passing fd %d\n", pass.fd[i]);
} else { } else {
msg.msg_control = NULL; msg.msg_control = NULL;
msg.msg_controllen = 0; msg.msg_controllen = 0;

View File

@ -20,16 +20,18 @@
#include <sys/types.h> #include <sys/types.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/stat.h>
#include <signal.h> #include <signal.h>
#include <unistd.h> #include <unistd.h>
#include <stdlib.h> #include <stdlib.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <fcntl.h>
ssize_t ssize_t
sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd); sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd, int *nfd);
ssize_t ssize_t
sock_fd_write(int sock, void *buf, ssize_t buflen, int fd); sock_fd_write(int sock, void *buf, ssize_t buflen, int *fd, int nfd);
#endif #endif

View File

@ -17,22 +17,29 @@
#include "fdpass.h" #include "fdpass.h"
#define NFD 128
void void
child(int sock) child(int sock)
{ {
int fd; int fd[NFD];
int nfd;
char buf[16]; char buf[16];
ssize_t size; ssize_t size;
int i;
sleep(1); sleep(1);
for (;;) { for (;;) {
size = sock_fd_read(sock, buf, sizeof(buf), &fd); nfd = NFD;
size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd);
if (size <= 0) if (size <= 0)
break; break;
printf ("read %d\n", size); printf ("read %d\n", size);
if (fd != -1) { if (nfd) {
write(fd, "hello, world\n", 13); for (i = 0; i < nfd; i++) {
close(fd); write(fd[i], "hello, world\n", 13);
close(fd[i]);
}
} }
} }
} }
@ -42,10 +49,11 @@ parent(int sock)
{ {
ssize_t size; ssize_t size;
int i; int i;
int fd; int fd[NFD];
fd = 1; for (i = 0; i < NFD; i++)
size = sock_fd_write(sock, "1", 1, 1); fd[i] = 1 + (i & 1);
size = sock_fd_write(sock, "1", 1, fd, NFD);
printf ("wrote %d\n", size); printf ("wrote %d\n", size);
} }

View File

@ -17,25 +17,32 @@
#include "fdpass.h" #include "fdpass.h"
#define NFD 16
void void
child(int sock) child(int sock)
{ {
int fd; int fd[NFD];
int nfd;
char buf[16]; char buf[16];
ssize_t size; ssize_t size;
int i;
sleep(1); sleep(1);
size = sock_fd_read(sock, buf, sizeof(buf), NULL); size = sock_fd_read(sock, buf, sizeof(buf), NULL, NULL);
if (size <= 0) if (size <= 0)
return; return;
printf ("read %d\n", size); printf ("read %d\n", size);
size = sock_fd_read(sock, buf, sizeof(buf), &fd); nfd = NFD;
size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd);
if (size <= 0) if (size <= 0)
return; return;
printf ("read %d\n", size); printf ("read %d\n", size);
if (fd != -1) { if (nfd) {
write(fd, "hello, world\n", 13); for (i = 0; i < nfd; i++) {
close(fd); write(fd[i], "hello, world\n", 13);
close(fd[i]);
}
} }
} }
@ -47,9 +54,10 @@ parent(int sock)
int fd; int fd;
fd = 1; fd = 1;
size = sock_fd_write(sock, "1", 1, 1); size = sock_fd_write(sock, "1", 1, &fd, 1);
printf ("wrote %d without fd\n", size); printf ("wrote %d without fd\n", size);
size = sock_fd_write(sock, "1", 1, 2); fd = 2;
size = sock_fd_write(sock, "1", 1, &fd, 1);
printf ("wrote %d with fd\n", size); printf ("wrote %d with fd\n", size);
} }

View File

@ -17,22 +17,29 @@
#include "fdpass.h" #include "fdpass.h"
#define NFD 32
void void
child(int sock) child(int sock)
{ {
int fd; int fd[NFD];
char buf[16]; char buf[16];
ssize_t size; ssize_t size;
int nfd;
int i;
sleep(1); sleep(1);
for (;;) { for (;;) {
size = sock_fd_read(sock, buf, sizeof(buf), &fd); nfd = NFD;
size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd);
if (size <= 0) if (size <= 0)
break; break;
printf ("read %d\n", size); printf ("read %d\n", size);
if (fd != -1) { if (nfd) {
write(fd, "hello, world\n", 13); for (i = 0; i < nfd; i++) {
close(fd); write(fd[i], "hello, world\n", 13);
close(fd[i]);
}
} }
} }
} }
@ -45,11 +52,11 @@ parent(int sock)
int fd; int fd;
fd = 1; fd = 1;
size = sock_fd_write(sock, "1", 1, -1); size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size); printf ("wrote %d without fd\n", size);
size = sock_fd_write(sock, "1", 1, 1); size = sock_fd_write(sock, "1", 1, &fd, 1);
printf ("wrote %d with fd\n", size); printf ("wrote %d with fd\n", size);
size = sock_fd_write(sock, "1", 1, -1); size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size); printf ("wrote %d without fd\n", size);
} }

14
xreq.c
View File

@ -27,11 +27,13 @@ child(int sock)
int i, reqlen; int i, reqlen;
ssize_t size, fdsize; ssize_t size, fdsize;
int fd = -1, *fdp; int fd = -1, *fdp;
int nfd;
int j; int j;
sleep (1); sleep (1);
for (j = 0;; j++) { for (j = 0;; j++) {
size = sock_fd_read(sock, xreq, sizeof (xreq), NULL); nfd = 0;
size = sock_fd_read(sock, xreq, sizeof (xreq), NULL, &nfd);
printf ("got %d\n", size); printf ("got %d\n", size);
if (size == 0) if (size == 0)
break; break;
@ -48,7 +50,8 @@ child(int sock)
fprintf (stderr, "Got fd req, but not at end of input %d < %d\n", fprintf (stderr, "Got fd req, but not at end of input %d < %d\n",
i, size); i, size);
} }
fdsize = sock_fd_read(sock, xnop, sizeof (xnop), &fd); nfd = 1;
fdsize = sock_fd_read(sock, xnop, sizeof (xnop), &fd, &nfd);
if (fd == -1) { if (fd == -1) {
fprintf (stderr, "no fd received\n"); fprintf (stderr, "no fd received\n");
} else { } else {
@ -89,19 +92,20 @@ parent(int sock)
for (i = 0; i < 8; i++) { for (i = 0; i < 8; i++) {
xreq[0] = 0; xreq[0] = 0;
xreq[1] = sizeof (xreq); xreq[1] = sizeof (xreq);
sock_fd_write(sock, xreq, sizeof (xreq), -1); sock_fd_write(sock, xreq, sizeof (xreq), NULL, 0);
} }
/* Write our 'pass an fd' request with a 'useless' FD to block the receiver */ /* Write our 'pass an fd' request with a 'useless' FD to block the receiver */
xreq[0] = 1; xreq[0] = 1;
xreq[1] = sizeof(xreq); xreq[1] = sizeof(xreq);
sock_fd_write(sock, xreq, sizeof (xreq), 1); fd = 1;
sock_fd_write(sock, xreq, sizeof (xreq), &fd, 1);
/* Pass an fd */ /* Pass an fd */
xnop[0] = 2; xnop[0] = 2;
xnop[1] = sizeof (xnop); xnop[1] = sizeof (xnop);
fd = tmp_file(j); fd = tmp_file(j);
sock_fd_write(sock, xnop, sizeof (xnop), fd); sock_fd_write(sock, xnop, sizeof (xnop), &fd, 1);
close(fd); close(fd);
} }
} }

View File

@ -20,17 +20,18 @@
void void
child(int sock) child(int sock)
{ {
int fd; int fd, nfd;
char buf[16]; char buf[16];
ssize_t size; ssize_t size;
sleep(1); sleep(1);
for (;;) { for (;;) {
size = sock_fd_read(sock, buf, sizeof(buf), &fd); nfd = 1;
size = sock_fd_read(sock, buf, sizeof(buf), &fd, &nfd);
if (size <= 0) if (size <= 0)
break; break;
printf ("read %d\n", size); printf ("read %d nfd %d\n", size, nfd);
if (fd != -1) { if (nfd == 1) {
write(fd, "hello, world\n", 13); write(fd, "hello, world\n", 13);
close(fd); close(fd);
} }
@ -45,11 +46,11 @@ parent(int sock)
int fd; int fd;
fd = 1; fd = 1;
size = sock_fd_write(sock, "1", 1, -1); size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size); printf ("wrote %d without fd\n", size);
size = sock_fd_write(sock, NULL, 0, 1); size = sock_fd_write(sock, NULL, 0, &fd, 1);
printf ("wrote %d with fd\n", size); printf ("wrote %d with fd\n", size);
size = sock_fd_write(sock, "1", 1, -1); size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size); printf ("wrote %d without fd\n", size);
} }