aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/vhost/net.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/net.c')
-rw-r--r--drivers/vhost/net.c112
1 files changed, 110 insertions, 2 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index f274826..7ef84c1 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -12,6 +12,7 @@
#include <linux/virtio_net.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
+#include <linux/moduleparam.h>
#include <linux/mutex.h>
#include <linux/workqueue.h>
#include <linux/rcupdate.h>
@@ -28,10 +29,18 @@
#include "vhost.h"
+static int experimental_zcopytx;
+module_param(experimental_zcopytx, int, 0444);
+MODULE_PARM_DESC(experimental_zcopytx, "Enable Experimental Zero Copy TX");
+
/* Max number of bytes transferred before requeueing the job.
* Using this limit prevents one virtqueue from starving others. */
#define VHOST_NET_WEIGHT 0x80000
+/* MAX number of TX used buffers for outstanding zerocopy */
+#define VHOST_MAX_PEND 128
+#define VHOST_GOODCOPY_LEN 256
+
enum {
VHOST_NET_VQ_RX = 0,
VHOST_NET_VQ_TX = 1,
@@ -54,6 +63,12 @@ struct vhost_net {
enum vhost_net_poll_state tx_poll_state;
};
+static bool vhost_sock_zcopy(struct socket *sock)
+{
+ return unlikely(experimental_zcopytx) &&
+ sock_flag(sock->sk, SOCK_ZEROCOPY);
+}
+
/* Pop first len bytes from iovec. Return number of segments used. */
static int move_iovec_hdr(struct iovec *from, struct iovec *to,
size_t len, int iov_count)
@@ -129,6 +144,8 @@ static void handle_tx(struct vhost_net *net)
int err, wmem;
size_t hdr_size;
struct socket *sock;
+ struct vhost_ubuf_ref *uninitialized_var(ubufs);
+ bool zcopy;
/* TODO: check that we are running from vhost_worker? */
sock = rcu_dereference_check(vq->private_data, 1);
@@ -149,8 +166,13 @@ static void handle_tx(struct vhost_net *net)
if (wmem < sock->sk->sk_sndbuf / 2)
tx_poll_stop(net);
hdr_size = vq->vhost_hlen;
+ zcopy = vhost_sock_zcopy(sock);
for (;;) {
+ /* Release DMAs done buffers first */
+ if (zcopy)
+ vhost_zerocopy_signal_used(vq);
+
head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
ARRAY_SIZE(vq->iov),
&out, &in,
@@ -160,12 +182,25 @@ static void handle_tx(struct vhost_net *net)
break;
/* Nothing new? Wait for eventfd to tell us they refilled. */
if (head == vq->num) {
+ int num_pends;
+
wmem = atomic_read(&sock->sk->sk_wmem_alloc);
if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
tx_poll_start(net, sock);
set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
break;
}
+ /* If more outstanding DMAs, queue the work.
+ * Handle upend_idx wrap around
+ */
+ num_pends = likely(vq->upend_idx >= vq->done_idx) ?
+ (vq->upend_idx - vq->done_idx) :
+ (vq->upend_idx + UIO_MAXIOV - vq->done_idx);
+ if (unlikely(num_pends > VHOST_MAX_PEND)) {
+ tx_poll_start(net, sock);
+ set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
+ break;
+ }
if (unlikely(vhost_enable_notify(&net->dev, vq))) {
vhost_disable_notify(&net->dev, vq);
continue;
@@ -188,9 +223,40 @@ static void handle_tx(struct vhost_net *net)
iov_length(vq->hdr, s), hdr_size);
break;
}
+ /* use msg_control to pass vhost zerocopy ubuf info to skb */
+ if (zcopy) {
+ vq->heads[vq->upend_idx].id = head;
+ if (len < VHOST_GOODCOPY_LEN) {
+ /* copy don't need to wait for DMA done */
+ vq->heads[vq->upend_idx].len =
+ VHOST_DMA_DONE_LEN;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ ubufs = NULL;
+ } else {
+ struct ubuf_info *ubuf;
+ ubuf = vq->ubuf_info + vq->upend_idx;
+
+ vq->heads[vq->upend_idx].len = len;
+ ubuf->callback = vhost_zerocopy_callback;
+ ubuf->arg = vq->ubufs;
+ ubuf->desc = vq->upend_idx;
+ msg.msg_control = ubuf;
+ msg.msg_controllen = sizeof(ubuf);
+ ubufs = vq->ubufs;
+ kref_get(&ubufs->kref);
+ }
+ vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV;
+ }
/* TODO: Check specific error and bomb out unless ENOBUFS? */
err = sock->ops->sendmsg(NULL, sock, &msg, len);
if (unlikely(err < 0)) {
+ if (zcopy) {
+ if (ubufs)
+ vhost_ubuf_put(ubufs);
+ vq->upend_idx = ((unsigned)vq->upend_idx - 1) %
+ UIO_MAXIOV;
+ }
vhost_discard_vq_desc(vq, 1);
tx_poll_start(net, sock);
break;
@@ -198,7 +264,8 @@ static void handle_tx(struct vhost_net *net)
if (err != len)
pr_debug("Truncated TX packet: "
" len %d != %zd\n", err, len);
- vhost_add_used_and_signal(&net->dev, vq, head, 0);
+ if (!zcopy)
+ vhost_add_used_and_signal(&net->dev, vq, head, 0);
total_len += len;
if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
vhost_poll_queue(&vq->poll);
@@ -252,9 +319,13 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
r = -ENOBUFS;
goto err;
}
- d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
+ r = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
ARRAY_SIZE(vq->iov) - seg, &out,
&in, log, log_num);
+ if (unlikely(r < 0))
+ goto err;
+
+ d = r;
if (d == vq->num) {
r = 0;
goto err;
@@ -279,6 +350,12 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
*iovcount = seg;
if (unlikely(log))
*log_num = nlogs;
+
+ /* Detect overrun */
+ if (unlikely(datalen > 0)) {
+ r = UIO_MAXIOV + 1;
+ goto err;
+ }
return headcount;
err:
vhost_discard_vq_desc(vq, headcount);
@@ -333,6 +410,14 @@ static void handle_rx(struct vhost_net *net)
/* On error, stop handling until the next kick. */
if (unlikely(headcount < 0))
break;
+ /* On overrun, truncate and discard */
+ if (unlikely(headcount > UIO_MAXIOV)) {
+ msg.msg_iovlen = 1;
+ err = sock->ops->recvmsg(NULL, sock, &msg,
+ 1, MSG_DONTWAIT | MSG_TRUNC);
+ pr_debug("Discarded rx packet: len %zd\n", sock_len);
+ continue;
+ }
/* OK, now we need to know about added descriptors. */
if (!headcount) {
if (unlikely(vhost_enable_notify(&net->dev, vq))) {
@@ -604,6 +689,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
{
struct socket *sock, *oldsock;
struct vhost_virtqueue *vq;
+ struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
int r;
mutex_lock(&n->dev.mutex);
@@ -633,13 +719,31 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
oldsock = rcu_dereference_protected(vq->private_data,
lockdep_is_held(&vq->mutex));
if (sock != oldsock) {
+ ubufs = vhost_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock));
+ if (IS_ERR(ubufs)) {
+ r = PTR_ERR(ubufs);
+ goto err_ubufs;
+ }
+ oldubufs = vq->ubufs;
+ vq->ubufs = ubufs;
vhost_net_disable_vq(n, vq);
rcu_assign_pointer(vq->private_data, sock);
vhost_net_enable_vq(n, vq);
+
+ r = vhost_init_used(vq);
+ if (r)
+ goto err_vq;
}
mutex_unlock(&vq->mutex);
+ if (oldubufs) {
+ vhost_ubuf_put_and_wait(oldubufs);
+ mutex_lock(&vq->mutex);
+ vhost_zerocopy_signal_used(vq);
+ mutex_unlock(&vq->mutex);
+ }
+
if (oldsock) {
vhost_net_flush_vq(n, index);
fput(oldsock->file);
@@ -648,6 +752,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
mutex_unlock(&n->dev.mutex);
return 0;
+err_ubufs:
+ fput(sock->file);
err_vq:
mutex_unlock(&vq->mutex);
err:
@@ -777,6 +883,8 @@ static struct miscdevice vhost_net_misc = {
static int vhost_net_init(void)
{
+ if (experimental_zcopytx)
+ vhost_enable_zcopy(VHOST_NET_VQ_TX);
return misc_register(&vhost_net_misc);
}
module_init(vhost_net_init);