diff --git a/include/linux/virtio_net.h b/include/linux/virtio_net.h
index 9397628a196714dc2177552465fe91fd18b9627d..cb462f9ab7dd592bc1d613c86aefeb787cdd9321 100644
--- a/include/linux/virtio_net.h
+++ b/include/linux/virtio_net.h
@@ -5,6 +5,24 @@
 #include <linux/if_vlan.h>
 #include <uapi/linux/virtio_net.h>
 
+static inline int virtio_net_hdr_set_proto(struct sk_buff *skb,
+					   const struct virtio_net_hdr *hdr)
+{
+	switch (hdr->gso_type & ~VIRTIO_NET_HDR_GSO_ECN) {
+	case VIRTIO_NET_HDR_GSO_TCPV4:
+	case VIRTIO_NET_HDR_GSO_UDP:
+		skb->protocol = cpu_to_be16(ETH_P_IP);
+		break;
+	case VIRTIO_NET_HDR_GSO_TCPV6:
+		skb->protocol = cpu_to_be16(ETH_P_IPV6);
+		break;
+	default:
+		return -EINVAL;
+	}
+
+	return 0;
+}
+
 static inline int virtio_net_hdr_to_skb(struct sk_buff *skb,
 					const struct virtio_net_hdr *hdr,
 					bool little_endian)
diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c
index 75c92a87e7b2481141161c8945f5e7eef8e0abf8..d6e94dc7e2900bdf91baf98ec0319586a9e06660 100644
--- a/net/packet/af_packet.c
+++ b/net/packet/af_packet.c
@@ -2715,10 +2715,12 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 			}
 		}
 
-		if (po->has_vnet_hdr && virtio_net_hdr_to_skb(skb, vnet_hdr,
-							      vio_le())) {
-			tp_len = -EINVAL;
-			goto tpacket_error;
+		if (po->has_vnet_hdr) {
+			if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) {
+				tp_len = -EINVAL;
+				goto tpacket_error;
+			}
+			virtio_net_hdr_set_proto(skb, vnet_hdr);
 		}
 
 		skb->destructor = tpacket_destruct_skb;
@@ -2915,6 +2917,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 		if (err)
 			goto out_free;
 		len += sizeof(vnet_hdr);
+		virtio_net_hdr_set_proto(skb, &vnet_hdr);
 	}
 
 	skb_probe_transport_header(skb, reserve);