diff --git a/drivers/virtio/virtio-net.c b/drivers/virtio/virtio-net.c index 914b76616b..e6c0552a47 100644 --- a/drivers/virtio/virtio-net.c +++ b/drivers/virtio/virtio-net.c @@ -110,6 +110,8 @@ struct virtio_net_priv_s struct netdev_lowerhalf_s lower; /* The netdev lowerhalf */ #endif + spinlock_t lock[VIRTIO_NET_NUM]; + /* Virtio device information */ FAR struct virtio_device *vdev; /* Virtio device pointer */ @@ -264,11 +266,13 @@ static int virtio_net_addbuffer(FAR struct netdev_lowerhalf_s *dev, vrtinfo("Fill vq=%u, hdr=%p, count=%d\n", vq_id, hdr, iov_cnt); if (vq_id == VIRTIO_NET_RX) { - return virtqueue_add_buffer(vq, vb, 0, iov_cnt, hdr); + return virtqueue_add_buffer_lock(vq, vb, 0, iov_cnt, hdr, + &priv->lock[vq_id]); } else { - return virtqueue_add_buffer(vq, vb, iov_cnt, 0, hdr); + return virtqueue_add_buffer_lock(vq, vb, iov_cnt, 0, hdr, + &priv->lock[vq_id]); } } @@ -311,7 +315,7 @@ static void virtio_net_rxfill(FAR struct netdev_lowerhalf_s *dev) if (i > 0) { - virtqueue_kick(vq); + virtqueue_kick_lock(vq, &priv->lock[VIRTIO_NET_RX]); } } @@ -329,7 +333,8 @@ static void virtio_net_txfree(FAR struct netdev_lowerhalf_s *dev) { /* Get buffer from tx virtqueue */ - hdr = virtqueue_get_buffer(vq, NULL, NULL); + hdr = virtqueue_get_buffer_lock(vq, NULL, NULL, + &priv->lock[VIRTIO_NET_TX]); if (hdr == NULL) { break; @@ -363,7 +368,8 @@ static int virtio_net_ifup(FAR struct netdev_lowerhalf_s *dev) /* Prepare interrupt and packets for receiving */ - virtqueue_enable_cb(priv->vdev->vrings_info[VIRTIO_NET_RX].vq); + virtqueue_enable_cb_lock(priv->vdev->vrings_info[VIRTIO_NET_RX].vq, + &priv->lock[VIRTIO_NET_RX]); virtio_net_rxfill(dev); #ifdef CONFIG_DRIVERS_WIFI_SIM @@ -389,7 +395,8 @@ static int virtio_net_ifdown(FAR struct netdev_lowerhalf_s *dev) for (i = 0; i < VIRTIO_NET_NUM; i++) { - virtqueue_disable_cb(priv->vdev->vrings_info[i].vq); + virtqueue_disable_cb_lock(priv->vdev->vrings_info[i].vq, + &priv->lock[i]); } #ifdef CONFIG_DRIVERS_WIFI_SIM @@ -426,7 +433,7 @@ static int virtio_net_send(FAR struct netdev_lowerhalf_s *dev, /* Add buffer to vq and notify the other side */ virtio_net_addbuffer(dev, vq, pkt, VIRTIO_NET_TX); - virtqueue_kick(vq); + virtqueue_kick_lock(vq, &priv->lock[VIRTIO_NET_TX]); /* Try return Netpkt TX buffer to upper-half. */ @@ -436,7 +443,7 @@ static int virtio_net_send(FAR struct netdev_lowerhalf_s *dev, if (netdev_lower_quota_load(dev, NETPKT_TX) <= 0) { - virtqueue_enable_cb(vq); + virtqueue_enable_cb_lock(vq, &priv->lock[VIRTIO_NET_TX]); } return OK; @@ -451,6 +458,7 @@ static netpkt_t *virtio_net_recv(FAR struct netdev_lowerhalf_s *dev) FAR struct virtio_net_priv_s *priv = (FAR struct virtio_net_priv_s *)dev; FAR struct virtqueue *vq = priv->vdev->vrings_info[VIRTIO_NET_RX].vq; FAR struct virtio_net_llhdr_s *hdr; + irqstate_t flags; uint32_t len; /* Fill the free Netpkt RX buffer to the RX virtqueue */ @@ -459,16 +467,22 @@ static netpkt_t *virtio_net_recv(FAR struct netdev_lowerhalf_s *dev) /* Get received buffer form RX virtqueue */ + flags = spin_lock_irqsave(&priv->lock[VIRTIO_NET_RX]); hdr = virtqueue_get_buffer(vq, &len, NULL); if (hdr == NULL) { /* If we have no buffer left, enable RX callback. */ virtqueue_enable_cb(vq); + spin_unlock_irqrestore(&priv->lock[VIRTIO_NET_RX], flags); vrtinfo("get NULL buffer\n"); return NULL; } + else + { + spin_unlock_irqrestore(&priv->lock[VIRTIO_NET_RX], flags); + } /* Set the received pkt length */ @@ -519,7 +533,7 @@ static void virtio_net_rxready(FAR struct virtqueue *vq) { FAR struct virtio_net_priv_s *priv = vq->vq_dev->priv; - virtqueue_disable_cb(vq); + virtqueue_disable_cb_lock(vq, &priv->lock[VIRTIO_NET_RX]); netdev_lower_rxready((FAR struct netdev_lowerhalf_s *)priv); } @@ -531,7 +545,7 @@ static void virtio_net_txdone(FAR struct virtqueue *vq) { FAR struct virtio_net_priv_s *priv = vq->vq_dev->priv; - virtqueue_disable_cb(vq); + virtqueue_disable_cb_lock(vq, &priv->lock[VIRTIO_NET_TX]); netdev_lower_txdone((FAR struct netdev_lowerhalf_s *)priv); } @@ -546,6 +560,8 @@ static int virtio_net_init(FAR struct virtio_net_priv_s *priv, vq_callback callbacks[VIRTIO_NET_NUM]; int ret; + spin_lock_init(&priv->lock[VIRTIO_NET_RX]); + spin_lock_init(&priv->lock[VIRTIO_NET_TX]); priv->vdev = vdev; vdev->priv = priv;