diff --git a/src/rsocket.c b/src/rsocket.c index abdd392..76fbb85 100644 --- a/src/rsocket.c +++ b/src/rsocket.c @@ -206,6 +206,7 @@ enum rs_state { rs_connect_error = 0x0800, rs_disconnected = 0x1000, rs_error = 0x2000, + rs_shutdown = 0x4000, }; #define RS_OPT_SWAP_SGL (1 << 0) @@ -1786,9 +1787,15 @@ static int rs_poll_cq(struct rsocket *rs) case RS_OP_CTRL: if (rs_msg_data(msg) == RS_CTRL_DISCONNECT) { rs->state = rs_disconnected; + rshutdown(rs->index, SHUT_RDWR); return 0; } else if (rs_msg_data(msg) == RS_CTRL_SHUTDOWN) { - rs->state &= ~rs_readable; + if (rs->state & rs_writable) { + rs->state &= ~rs_readable; + } else { + rs->state = rs_disconnected; + return 0; + } } break; case RS_OP_WRITE: @@ -2914,10 +2921,12 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds) rs = idm_lookup(&idm, fds[i].fd); if (rs) { + fastlock_acquire(&rs->cq_wait_lock); if (rs->type == SOCK_STREAM) rs_get_cq_event(rs); else ds_get_cq_event(rs); + fastlock_release(&rs->cq_wait_lock); fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all); } else { fds[i].revents = rfds[i].revents; @@ -3064,7 +3073,8 @@ int rselect(int nfds, fd_set *readfds, fd_set *writefds, /* * For graceful disconnect, notify the remote side that we're - * disconnecting and wait until all outstanding sends complete. + * disconnecting and wait until all outstanding sends complete, provided + * that the remote side has not sent a disconnect message. */ int rshutdown(int socket, int how) { @@ -3072,11 +3082,6 @@ int rshutdown(int socket, int how) int ctrl, ret = 0; rs = idm_at(&idm, socket); - if (how == SHUT_RD) { - rs->state &= ~rs_readable; - return 0; - } - if (rs->fd_flags & O_NONBLOCK) rs_set_nonblocking(rs, 0); @@ -3084,15 +3089,20 @@ int rshutdown(int socket, int how) if (how == SHUT_RDWR) { ctrl = RS_CTRL_DISCONNECT; rs->state &= ~(rs_readable | rs_writable); - } else { + } else if (how == SHUT_WR) { rs->state &= ~rs_writable; ctrl = (rs->state & rs_readable) ? RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT; + } else { + rs->state &= ~rs_readable; + if (rs->state & rs_writable) + goto out; + ctrl = RS_CTRL_DISCONNECT; } if (!rs->ctrl_avail) { ret = rs_process_cq(rs, 0, rs_conn_can_send_ctrl); if (ret) - return ret; + goto out; } if ((rs->state & rs_connected) && rs->ctrl_avail) { @@ -3104,10 +3114,19 @@ int rshutdown(int socket, int how) if (rs->state & rs_connected) rs_process_cq(rs, 0, rs_conn_all_sends_done); +out: if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected)) rs_set_nonblocking(rs, rs->fd_flags); - return 0; + if (rs->state & rs_disconnected) { + /* Generate event by flushing receives to unblock rpoll */ + ibv_req_notify_cq(rs->cm_id->recv_cq, 0); + rdma_disconnect(rs->cm_id); + } + + rs->state = rs_shutdown; + + return ret; } static void ds_shutdown(struct rsocket *rs)