diff --git a/io_uring/net.c b/io_uring/net.c index caa6a803cb72..8c7226b5bf41 100644 --- a/io_uring/net.c +++ b/io_uring/net.c @@ -46,6 +46,7 @@ struct io_connect { struct file *file; struct sockaddr __user *addr; int addr_len; + bool in_progress; }; struct io_sr_msg { @@ -1386,6 +1387,7 @@ int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr)); conn->addr_len = READ_ONCE(sqe->addr2); + conn->in_progress = false; return 0; } @@ -1397,6 +1399,16 @@ int io_connect(struct io_kiocb *req, unsigned int issue_flags) int ret; bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK; + if (connect->in_progress) { + struct socket *socket; + + ret = -ENOTSOCK; + socket = sock_from_file(req->file); + if (socket) + ret = sock_error(socket->sk); + goto out; + } + if (req_has_async_data(req)) { io = req->async_data; } else { @@ -1413,13 +1425,17 @@ int io_connect(struct io_kiocb *req, unsigned int issue_flags) ret = __sys_connect_file(req->file, &io->address, connect->addr_len, file_flags); if ((ret == -EAGAIN || ret == -EINPROGRESS) && force_nonblock) { - if (req_has_async_data(req)) - return -EAGAIN; - if (io_alloc_async_data(req)) { - ret = -ENOMEM; - goto out; + if (ret == -EINPROGRESS) { + connect->in_progress = true; + } else { + if (req_has_async_data(req)) + return -EAGAIN; + if (io_alloc_async_data(req)) { + ret = -ENOMEM; + goto out; + } + memcpy(req->async_data, &__io, sizeof(__io)); } - memcpy(req->async_data, &__io, sizeof(__io)); return -EAGAIN; } if (ret == -ERESTARTSYS)