diff --git a/net/dsa/dsa.c b/net/dsa/dsa.c
index 76e3800765f881c554e93385c85f965cc96f611c..bf4ba15fb780e96ed77858c68a5d724a065962d6 100644
--- a/net/dsa/dsa.c
+++ b/net/dsa/dsa.c
@@ -634,6 +634,10 @@ static void dsa_of_free_platform_data(struct dsa_platform_data *pd)
 			port_index++;
 		}
 		kfree(pd->chip[i].rtable);
+
+		/* Drop our reference to the MDIO bus device */
+		if (pd->chip[i].host_dev)
+			put_device(pd->chip[i].host_dev);
 	}
 	kfree(pd->chip);
 }
@@ -661,16 +665,22 @@ static int dsa_of_probe(struct device *dev)
 		return -EPROBE_DEFER;
 
 	ethernet = of_parse_phandle(np, "dsa,ethernet", 0);
-	if (!ethernet)
-		return -EINVAL;
+	if (!ethernet) {
+		ret = -EINVAL;
+		goto out_put_mdio;
+	}
 
 	ethernet_dev = of_find_net_device_by_node(ethernet);
-	if (!ethernet_dev)
-		return -EPROBE_DEFER;
+	if (!ethernet_dev) {
+		ret = -EPROBE_DEFER;
+		goto out_put_mdio;
+	}
 
 	pd = kzalloc(sizeof(*pd), GFP_KERNEL);
-	if (!pd)
-		return -ENOMEM;
+	if (!pd) {
+		ret = -ENOMEM;
+		goto out_put_mdio;
+	}
 
 	dev->platform_data = pd;
 	pd->of_netdev = ethernet_dev;
@@ -691,7 +701,9 @@ static int dsa_of_probe(struct device *dev)
 		cd = &pd->chip[chip_index];
 
 		cd->of_node = child;
-		cd->host_dev = &mdio_bus->dev;
+
+		/* When assigning the host device, increment its refcount */
+		cd->host_dev = get_device(&mdio_bus->dev);
 
 		sw_addr = of_get_property(child, "reg", NULL);
 		if (!sw_addr)
@@ -711,6 +723,12 @@ static int dsa_of_probe(struct device *dev)
 				ret = -EPROBE_DEFER;
 				goto out_free_chip;
 			}
+
+			/* Drop the mdio_bus device ref, replacing the host
+			 * device with the mdio_bus_switch device, keeping
+			 * the refcount from of_mdio_find_bus() above.
+			 */
+			put_device(cd->host_dev);
 			cd->host_dev = &mdio_bus_switch->dev;
 		}
 
@@ -744,6 +762,10 @@ static int dsa_of_probe(struct device *dev)
 		}
 	}
 
+	/* The individual chips hold their own refcount on the mdio bus,
+	 * so drop ours */
+	put_device(&mdio_bus->dev);
+
 	return 0;
 
 out_free_chip:
@@ -751,6 +773,8 @@ static int dsa_of_probe(struct device *dev)
 out_free:
 	kfree(pd);
 	dev->platform_data = NULL;
+out_put_mdio:
+	put_device(&mdio_bus->dev);
 	return ret;
 }