usb: disable previously enabled endpoint before re-enabling

When switching between alternate settings of an interface, it is
currently possible to call set_endpoint() multiple times on an endpoint
without first calling reset_endpoint(). For these situations, it is
beneficial to track endpoints for which set_endpoint() has previously
been called, and then reset them to properly terminate any transfers,
and to return the HAL to the correct state

Signed-off-by: Milind Paranjpe <mparanjpe@yahoo.com>
diff --git a/subsys/usb/device/usb_device.c b/subsys/usb/device/usb_device.c
index 0191af5..51311ef 100644
--- a/subsys/usb/device/usb_device.c
+++ b/subsys/usb/device/usb_device.c
@@ -135,6 +135,8 @@
 	uint8_t alt_setting[CONFIG_USB_MAX_ALT_SETTING];
 	/** Remote wakeup feature status */
 	bool remote_wakeup;
+	/** Tracks whether set_endpoint() had been called on an EP */
+	uint32_t ep_bm;
 } usb_dev;
 
 /* Setup packet definition used to read raw data from USB line */
@@ -146,6 +148,8 @@
 	uint16_t wLength;
 } __packed;
 
+static bool reset_endpoint(const struct usb_ep_descriptor *ep_desc);
+
 /*
  * @brief print the contents of a setup packet
  *
@@ -509,6 +513,37 @@
 }
 
 /*
+ * @brief Get 32-bit endpoint bitmask from index
+ *
+ * In the returned 32-bit word, the bit positions in the lower 16 bits
+ * indicate OUT endpoints, while the upper 16 bits indicate IN
+ * endpoints
+ *
+ * @param [in]  ep Endpoint of interest
+ *
+ * @return 32-bit bitmask
+ */
+static uint32_t get_ep_bm_from_addr(uint8_t ep)
+{
+	uint32_t ep_bm = 0;
+	uint8_t ep_idx;
+
+	ep_idx = ep & (~USB_EP_DIR_IN);
+	if (ep_idx > 15) {
+		LOG_ERR("Endpoint 0x%02x is invalid", ep);
+		goto done;
+	}
+
+	if (ep & USB_EP_DIR_IN) {
+		ep_bm = BIT(ep_idx + 16);
+	} else {
+		ep_bm = BIT(ep_idx);
+	}
+done:
+	return ep_bm;
+}
+
+/*
  * @brief configure and enable endpoint
  *
  * This function sets endpoint configuration according to one specified in USB
@@ -521,6 +556,7 @@
 static bool set_endpoint(const struct usb_ep_descriptor *ep_desc)
 {
 	struct usb_dc_ep_cfg_data ep_cfg;
+	uint32_t ep_bm;
 	int ret;
 
 	ep_cfg.ep_addr = ep_desc->bEndpointAddress;
@@ -530,6 +566,14 @@
 	LOG_DBG("Set endpoint 0x%x type %u MPS %u",
 		ep_cfg.ep_addr, ep_cfg.ep_type, ep_cfg.ep_mps);
 
+	/* if endpoint is has been set() previously, reset() it first */
+	ep_bm = get_ep_bm_from_addr(ep_desc->bEndpointAddress);
+	if (ep_bm & usb_dev.ep_bm) {
+		reset_endpoint(ep_desc);
+		/* allow any canceled transfers to terminate */
+		k_usleep(150);
+	}
+
 	ret = usb_dc_ep_configure(&ep_cfg);
 	if (ret == -EALREADY) {
 		LOG_WRN("Endpoint 0x%02x already configured", ep_cfg.ep_addr);
@@ -551,6 +595,7 @@
 	}
 
 	usb_dev.configured = true;
+	usb_dev.ep_bm |= ep_bm;
 
 	return true;
 }
@@ -568,6 +613,7 @@
 static bool reset_endpoint(const struct usb_ep_descriptor *ep_desc)
 {
 	struct usb_dc_ep_cfg_data ep_cfg;
+	uint32_t ep_bm;
 	int ret;
 
 	ep_cfg.ep_addr = ep_desc->bEndpointAddress;
@@ -588,6 +634,10 @@
 		;
 	}
 
+	/* clear endpoint mask */
+	ep_bm = get_ep_bm_from_addr(ep_desc->bEndpointAddress);
+	usb_dev.ep_bm &= ~ep_bm;
+
 	return true;
 }
 
@@ -1175,7 +1225,16 @@
 
 static int disable_interface_ep(const struct usb_ep_cfg_data *ep_data)
 {
-	return usb_dc_ep_disable(ep_data->ep_addr);
+	uint32_t ep_bm;
+	int ret;
+
+	ret = usb_dc_ep_disable(ep_data->ep_addr);
+
+	/* clear endpoint mask */
+	ep_bm = get_ep_bm_from_addr(ep_data->ep_addr);
+	usb_dev.ep_bm &= ~ep_bm;
+
+	return ret;
 }
 
 static void forward_status_cb(enum usb_dc_status_code status, const uint8_t *param)