Let tests pass more than one FD

Signed-off-by: Keith Packard <keithp@keithp.com>
This commit is contained in:
Keith Packard 2019-05-09 10:36:36 -07:00
parent d4d31cd5f1
commit f2355a5633
7 changed files with 134 additions and 78 deletions

View File

@ -17,19 +17,26 @@
#include "fdpass.h"
#define MAX_FDS 128
struct fd_pass {
struct cmsghdr cmsghdr;
int fd[MAX_FDS];
};
ssize_t
sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd)
sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd, int *nfdp)
{
ssize_t size;
printf ("sock_fd_read\n");
if (fd) {
struct msghdr msg;
struct iovec iov;
union {
struct cmsghdr cmsghdr;
char control[CMSG_SPACE(sizeof (int))];
} cmsgu;
struct cmsghdr *cmsg;
struct fd_pass pass;
int nfd_passed, nfd;
int i;
int *fd_passed;
iov.iov_base = buf;
iov.iov_len = bufsize;
@ -38,35 +45,47 @@ sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd)
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = cmsgu.control;
msg.msg_controllen = sizeof(cmsgu.control);
msg.msg_control = &pass;
msg.msg_controllen = sizeof pass;
size = recvmsg (sock, &msg, 0);
if (size < 0) {
perror ("recvmsg");
exit(1);
}
if (size > 0 && pass.cmsghdr.cmsg_len > sizeof (struct cmsghdr)) {
if ((msg.msg_flags & MSG_TRUNC) ||
(msg.msg_flags & MSG_CTRUNC)) {
fprintf (stderr, "control message truncated");
exit(1);
}
cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
if (cmsg->cmsg_level != SOL_SOCKET) {
if (pass.cmsghdr.cmsg_level != SOL_SOCKET) {
fprintf (stderr, "invalid cmsg_level %d\n",
cmsg->cmsg_level);
pass.cmsghdr.cmsg_level);
exit(1);
}
if (cmsg->cmsg_type != SCM_RIGHTS) {
if (pass.cmsghdr.cmsg_type != SCM_RIGHTS) {
fprintf (stderr, "invalid cmsg_type %d\n",
cmsg->cmsg_type);
pass.cmsghdr.cmsg_type);
exit(1);
}
*fd = *((int *) CMSG_DATA(cmsg));
printf ("received fd %d\n", *fd);
nfd_passed = (pass.cmsghdr.cmsg_len - sizeof (struct cmsghdr)) / sizeof (int);
fd_passed = (int *) CMSG_DATA(&pass.cmsghdr);
nfd = *nfdp;
if (nfd > nfd_passed)
nfd = nfd_passed;
memcpy(fd, fd_passed, nfd * sizeof (int));
for (i = 0; i < nfd; i++)
printf ("received fd %d\n", fd[i]);
for (i = nfd; i < nfd_passed; i++) {
printf ("dropping fd %d\n", fd_passed[i]);
close(fd_passed[i]);
}
*nfdp = nfd;
} else
*fd = -1;
*nfdp = 0;
} else {
size = read (sock, buf, bufsize);
if (size < 0) {
@ -78,36 +97,43 @@ sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd)
}
ssize_t
sock_fd_write(int sock, void *buf, ssize_t buflen, int fd)
sock_fd_write(int sock, void *buf, ssize_t buflen, int *fd, int nfd)
{
ssize_t size;
struct msghdr msg;
struct iovec iov;
union {
struct cmsghdr cmsghdr;
char control[CMSG_SPACE(sizeof (int))];
} cmsgu;
struct cmsghdr *cmsg;
struct fd_pass pass;
int i;
iov.iov_base = buf;
iov.iov_len = buflen;
msg.msg_name = NULL;
msg.msg_namelen = 0;
if (buflen) {
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
} else {
msg.msg_iov = NULL;
msg.msg_iovlen = 0;
}
if (fd != -1) {
msg.msg_control = cmsgu.control;
msg.msg_controllen = sizeof(cmsgu.control);
if (nfd) {
if (nfd > MAX_FDS) {
printf ("passing too many fds\n");
exit(1);
}
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(sizeof (int));
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
msg.msg_control = &pass;
msg.msg_controllen = sizeof (struct cmsghdr) + nfd * sizeof (int);
printf ("passing fd %d\n", fd);
*((int *) CMSG_DATA(cmsg)) = fd;
pass.cmsghdr.cmsg_len = msg.msg_controllen;
pass.cmsghdr.cmsg_level = SOL_SOCKET;
pass.cmsghdr.cmsg_type = SCM_RIGHTS;
memcpy(&pass.fd, fd, nfd * sizeof (int));
for (i = 0; i < nfd; i++)
printf ("passing fd %d\n", pass.fd[i]);
} else {
msg.msg_control = NULL;
msg.msg_controllen = 0;

View File

@ -20,16 +20,18 @@
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <signal.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <fcntl.h>
ssize_t
sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd);
sock_fd_read(int sock, void *buf, ssize_t bufsize, int *fd, int *nfd);
ssize_t
sock_fd_write(int sock, void *buf, ssize_t buflen, int fd);
sock_fd_write(int sock, void *buf, ssize_t buflen, int *fd, int nfd);
#endif

View File

@ -17,22 +17,29 @@
#include "fdpass.h"
#define NFD 128
void
child(int sock)
{
int fd;
int fd[NFD];
int nfd;
char buf[16];
ssize_t size;
int i;
sleep(1);
for (;;) {
size = sock_fd_read(sock, buf, sizeof(buf), &fd);
nfd = NFD;
size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd);
if (size <= 0)
break;
printf ("read %d\n", size);
if (fd != -1) {
write(fd, "hello, world\n", 13);
close(fd);
if (nfd) {
for (i = 0; i < nfd; i++) {
write(fd[i], "hello, world\n", 13);
close(fd[i]);
}
}
}
}
@ -42,10 +49,11 @@ parent(int sock)
{
ssize_t size;
int i;
int fd;
int fd[NFD];
fd = 1;
size = sock_fd_write(sock, "1", 1, 1);
for (i = 0; i < NFD; i++)
fd[i] = 1 + (i & 1);
size = sock_fd_write(sock, "1", 1, fd, NFD);
printf ("wrote %d\n", size);
}

View File

@ -17,25 +17,32 @@
#include "fdpass.h"
#define NFD 16
void
child(int sock)
{
int fd;
int fd[NFD];
int nfd;
char buf[16];
ssize_t size;
int i;
sleep(1);
size = sock_fd_read(sock, buf, sizeof(buf), NULL);
size = sock_fd_read(sock, buf, sizeof(buf), NULL, NULL);
if (size <= 0)
return;
printf ("read %d\n", size);
size = sock_fd_read(sock, buf, sizeof(buf), &fd);
nfd = NFD;
size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd);
if (size <= 0)
return;
printf ("read %d\n", size);
if (fd != -1) {
write(fd, "hello, world\n", 13);
close(fd);
if (nfd) {
for (i = 0; i < nfd; i++) {
write(fd[i], "hello, world\n", 13);
close(fd[i]);
}
}
}
@ -47,9 +54,10 @@ parent(int sock)
int fd;
fd = 1;
size = sock_fd_write(sock, "1", 1, 1);
size = sock_fd_write(sock, "1", 1, &fd, 1);
printf ("wrote %d without fd\n", size);
size = sock_fd_write(sock, "1", 1, 2);
fd = 2;
size = sock_fd_write(sock, "1", 1, &fd, 1);
printf ("wrote %d with fd\n", size);
}

View File

@ -17,22 +17,29 @@
#include "fdpass.h"
#define NFD 32
void
child(int sock)
{
int fd;
int fd[NFD];
char buf[16];
ssize_t size;
int nfd;
int i;
sleep(1);
for (;;) {
size = sock_fd_read(sock, buf, sizeof(buf), &fd);
nfd = NFD;
size = sock_fd_read(sock, buf, sizeof(buf), fd, &nfd);
if (size <= 0)
break;
printf ("read %d\n", size);
if (fd != -1) {
write(fd, "hello, world\n", 13);
close(fd);
if (nfd) {
for (i = 0; i < nfd; i++) {
write(fd[i], "hello, world\n", 13);
close(fd[i]);
}
}
}
}
@ -45,11 +52,11 @@ parent(int sock)
int fd;
fd = 1;
size = sock_fd_write(sock, "1", 1, -1);
size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size);
size = sock_fd_write(sock, "1", 1, 1);
size = sock_fd_write(sock, "1", 1, &fd, 1);
printf ("wrote %d with fd\n", size);
size = sock_fd_write(sock, "1", 1, -1);
size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size);
}

14
xreq.c
View File

@ -27,11 +27,13 @@ child(int sock)
int i, reqlen;
ssize_t size, fdsize;
int fd = -1, *fdp;
int nfd;
int j;
sleep (1);
for (j = 0;; j++) {
size = sock_fd_read(sock, xreq, sizeof (xreq), NULL);
nfd = 0;
size = sock_fd_read(sock, xreq, sizeof (xreq), NULL, &nfd);
printf ("got %d\n", size);
if (size == 0)
break;
@ -48,7 +50,8 @@ child(int sock)
fprintf (stderr, "Got fd req, but not at end of input %d < %d\n",
i, size);
}
fdsize = sock_fd_read(sock, xnop, sizeof (xnop), &fd);
nfd = 1;
fdsize = sock_fd_read(sock, xnop, sizeof (xnop), &fd, &nfd);
if (fd == -1) {
fprintf (stderr, "no fd received\n");
} else {
@ -89,19 +92,20 @@ parent(int sock)
for (i = 0; i < 8; i++) {
xreq[0] = 0;
xreq[1] = sizeof (xreq);
sock_fd_write(sock, xreq, sizeof (xreq), -1);
sock_fd_write(sock, xreq, sizeof (xreq), NULL, 0);
}
/* Write our 'pass an fd' request with a 'useless' FD to block the receiver */
xreq[0] = 1;
xreq[1] = sizeof(xreq);
sock_fd_write(sock, xreq, sizeof (xreq), 1);
fd = 1;
sock_fd_write(sock, xreq, sizeof (xreq), &fd, 1);
/* Pass an fd */
xnop[0] = 2;
xnop[1] = sizeof (xnop);
fd = tmp_file(j);
sock_fd_write(sock, xnop, sizeof (xnop), fd);
sock_fd_write(sock, xnop, sizeof (xnop), &fd, 1);
close(fd);
}
}

View File

@ -20,17 +20,18 @@
void
child(int sock)
{
int fd;
int fd, nfd;
char buf[16];
ssize_t size;
sleep(1);
for (;;) {
size = sock_fd_read(sock, buf, sizeof(buf), &fd);
nfd = 1;
size = sock_fd_read(sock, buf, sizeof(buf), &fd, &nfd);
if (size <= 0)
break;
printf ("read %d\n", size);
if (fd != -1) {
printf ("read %d nfd %d\n", size, nfd);
if (nfd == 1) {
write(fd, "hello, world\n", 13);
close(fd);
}
@ -45,11 +46,11 @@ parent(int sock)
int fd;
fd = 1;
size = sock_fd_write(sock, "1", 1, -1);
size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size);
size = sock_fd_write(sock, NULL, 0, 1);
size = sock_fd_write(sock, NULL, 0, &fd, 1);
printf ("wrote %d with fd\n", size);
size = sock_fd_write(sock, "1", 1, -1);
size = sock_fd_write(sock, "1", 1, NULL, 0);
printf ("wrote %d without fd\n", size);
}