Fix compatibility with clone J-Link devices (#2732)

diff --git a/changelog/fixed-jlink-clone.md b/changelog/fixed-jlink-clone.md
new file mode 100644
index 0000000..f097995
--- /dev/null
+++ b/changelog/fixed-jlink-clone.md
@@ -0,0 +1 @@
+Improved compatibility with certain J-Link clones
\ No newline at end of file
diff --git a/probe-rs/src/probe/jlink/config.rs b/probe-rs/src/probe/jlink/config.rs
new file mode 100644
index 0000000..95b5a60
--- /dev/null
+++ b/probe-rs/src/probe/jlink/config.rs
@@ -0,0 +1,54 @@
+#[derive(Default, Clone, Copy, Debug)]
+#[allow(dead_code)]
+pub struct JlinkConfig {
+    pub usb_address: Option<u8>,
+    pub kickstart_power: Option<bool>,
+    pub ip_address: Option<[u8; 4]>,
+    pub subnet_mask: Option<[u8; 4]>,
+    pub mac_address: Option<[u8; 6]>,
+}
+
+impl JlinkConfig {
+    pub fn parse(data: [u8; 256]) -> Result<Self, String> {
+        let usb_address = match data[0] {
+            0 => Some(0),
+            1 => Some(1),
+            2 => Some(2),
+            0xFF => None,
+            other => return Err(format!("Unexpected USB address configured: {other}")),
+        };
+
+        let kickstart_power = match u32::from_le_bytes([data[4], data[5], data[6], data[7]]) {
+            0 => Some(false),
+            1 => Some(true),
+            u32::MAX => None,
+            other => return Err(format!("Unexpected kickstart power value: {other:#010x}")),
+        };
+
+        let ip_address = match data[32..36] {
+            [0xFF, 0xFF, 0xFF, 0xFF] => None,
+            [a, b, c, d] => Some([a, b, c, d]),
+            _ => unreachable!(),
+        };
+
+        let subnet_mask = match data[36..40] {
+            [0xFF, 0xFF, 0xFF, 0xFF] => None,
+            [a, b, c, d] => Some([a, b, c, d]),
+            _ => unreachable!(),
+        };
+
+        let mac_address = match data[48..54] {
+            [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF] => None,
+            [a, b, c, d, e, f] => Some([a, b, c, d, e, f]),
+            _ => unreachable!(),
+        };
+
+        Ok(Self {
+            usb_address,
+            kickstart_power,
+            ip_address,
+            subnet_mask,
+            mac_address,
+        })
+    }
+}
diff --git a/probe-rs/src/probe/jlink/connection.rs b/probe-rs/src/probe/jlink/connection.rs
new file mode 100644
index 0000000..29b99a8
--- /dev/null
+++ b/probe-rs/src/probe/jlink/connection.rs
@@ -0,0 +1,42 @@
+#[derive(Clone, Copy, Debug)]
+pub struct JlinkConnection {
+    /// Handle
+    pub handle: u16,
+    /// Process ID
+    pub pid: u32,
+    /// Host ID
+    pub hid: [u8; 4],
+    /// IID - unknown
+    pub iid: u8,
+    /// CID - unknown
+    pub cid: u8,
+}
+
+impl JlinkConnection {
+    pub fn usb(handle: u16) -> Self {
+        Self {
+            handle,
+            pid: 0,
+            hid: [0; 4],
+            iid: 0,
+            cid: 0,
+        }
+    }
+
+    pub(crate) fn into_bytes(self) -> [u8; 12] {
+        [
+            self.pid as u8,
+            (self.pid >> 8) as u8,
+            (self.pid >> 16) as u8,
+            (self.pid >> 24) as u8,
+            self.hid[0],
+            self.hid[1],
+            self.hid[2],
+            self.hid[3],
+            self.iid,
+            self.cid,
+            self.handle as u8,
+            (self.handle >> 8) as u8,
+        ]
+    }
+}
diff --git a/probe-rs/src/probe/jlink/mod.rs b/probe-rs/src/probe/jlink/mod.rs
index d9046f0..8df2b7a 100644
--- a/probe-rs/src/probe/jlink/mod.rs
+++ b/probe-rs/src/probe/jlink/mod.rs
@@ -2,6 +2,8 @@
 
 mod bits;
 pub mod capabilities;
+mod config;
+mod connection;
 mod error;
 mod interface;
 mod speed;
@@ -17,7 +19,6 @@
 use nusb::transfer::{Direction, EndpointType};
 use nusb::DeviceInfo;
 use probe_rs_target::ScanChainElement;
-use tracing::{debug, trace, warn};
 
 use self::bits::BitIter;
 use self::capabilities::{Capabilities, Capability};
@@ -31,6 +32,8 @@
 };
 use crate::probe::common::{JtagDriverState, RawJtagIo};
 use crate::probe::jlink::bits::IteratorExt;
+use crate::probe::jlink::config::JlinkConfig;
+use crate::probe::jlink::connection::JlinkConnection;
 use crate::probe::usb_util::InterfaceExt;
 use crate::probe::JTAGAccess;
 use crate::{
@@ -99,24 +102,24 @@
         let configs: Vec<_> = handle.configurations().collect();
 
         if configs.len() != 1 {
-            warn!("device has {} configurations, expected 1", configs.len());
+            tracing::warn!("device has {} configurations, expected 1", configs.len());
         }
 
         let conf = &configs[0];
-        debug!("scanning {} interfaces", conf.interfaces().count());
-        trace!("active configuration descriptor: {:#x?}", conf);
+        tracing::debug!("scanning {} interfaces", conf.interfaces().count());
+        tracing::trace!("active configuration descriptor: {:#x?}", conf);
 
         let mut jlink_intf = None;
         for intf in conf.interfaces() {
-            trace!("interface #{} descriptors:", intf.interface_number());
+            tracing::trace!("interface #{} descriptors:", intf.interface_number());
 
             for descr in intf.alt_settings() {
-                trace!("{:#x?}", descr);
+                tracing::trace!("{:#x?}", descr);
 
                 // We detect the proprietary J-Link interface using the vendor-specific class codes
                 // and the endpoint properties
                 if descr.class() == 0xff && descr.subclass() == 0xff && descr.protocol() == 0xff {
-                    if let Some((intf, _, _)) = jlink_intf {
+                    if let Some((intf, _, _, _)) = jlink_intf {
                         Err(JlinkError::Other(format!(
                             "found multiple matching USB interfaces ({} and {})",
                             intf,
@@ -125,9 +128,9 @@
                     }
 
                     let endpoints: Vec<_> = descr.endpoints().collect();
-                    trace!("endpoint descriptors: {:#x?}", endpoints);
+                    tracing::trace!("endpoint descriptors: {:#x?}", endpoints);
                     if endpoints.len() != 2 {
-                        warn!("vendor-specific interface with {} endpoints, expected 2 (skipping interface)", endpoints.len());
+                        tracing::warn!("vendor-specific interface with {} endpoints, expected 2 (skipping interface)", endpoints.len());
                         continue;
                     }
 
@@ -135,7 +138,7 @@
                         .iter()
                         .all(|ep| ep.transfer_type() == EndpointType::Bulk)
                     {
-                        warn!(
+                        tracing::warn!(
                             "encountered non-bulk endpoints, skipping interface: {:#x?}",
                             endpoints
                         );
@@ -143,18 +146,23 @@
                     }
 
                     let (read_ep, write_ep) = if endpoints[0].direction() == Direction::In {
-                        (endpoints[0].address(), endpoints[1].address())
+                        (&endpoints[0], &endpoints[1])
                     } else {
-                        (endpoints[1].address(), endpoints[0].address())
+                        (&endpoints[1], &endpoints[0])
                     };
 
-                    jlink_intf = Some((descr.interface_number(), read_ep, write_ep));
-                    debug!("J-Link interface is #{}", descr.interface_number());
+                    jlink_intf = Some((
+                        descr.interface_number(),
+                        read_ep.address(),
+                        write_ep.address(),
+                        read_ep.max_packet_size(),
+                    ));
+                    tracing::debug!("J-Link interface is #{}", descr.interface_number());
                 }
             }
         }
 
-        let Some((intf, read_ep, write_ep)) = jlink_intf else {
+        let Some((intf, read_ep, write_ep, max_read_ep_packet)) = jlink_intf else {
             Err(JlinkError::Other(
                 "device is not a J-Link device".to_string(),
             ))?
@@ -167,6 +175,7 @@
         let mut this = JLink {
             read_ep,
             write_ep,
+            max_read_ep_packet,
             caps: Capabilities::from_raw_legacy(0), // dummy value
             interface: Interface::Spi,              // dummy value, must not be JTAG
             interfaces: Interfaces::from_bits_warn(0), // dummy value
@@ -174,6 +183,7 @@
 
             supported_protocols: vec![],  // dummy value
             protocol: WireProtocol::Jtag, // dummy value
+            connection_handle: None,
 
             swo_config: None,
             speed_khz: 0, // default is unknown
@@ -188,32 +198,28 @@
 
             max_mem_block_size: 0, // dummy value
             jtag_chunk_size: 0,    // dummy value
+
+            config: JlinkConfig::default(),
         };
         this.fill_capabilities()?;
         this.fill_interfaces()?;
 
         this.supported_protocols = if this.caps.contains(Capability::SelectIf) {
-            let protocols: Vec<_> = this
-                .interfaces
+            this.interfaces
                 .into_iter()
-                .map(WireProtocol::try_from)
-                .collect();
-
-            protocols
-                .iter()
-                .filter(|p| p.is_err())
-                .for_each(|protocol| {
-                    if let Err(JlinkError::UnknownInterface(interface)) = protocol {
+                .filter_map(|p| match WireProtocol::try_from(p) {
+                    Ok(protocol) => Some(protocol),
+                    Err(JlinkError::UnknownInterface(interface)) => {
+                        // We ignore unknown protocols.
                         tracing::debug!(
                             "J-Link returned interface {:?}, which is not supported by probe-rs.",
                             interface
                         );
+                        None
                     }
-                });
-
-            // We ignore unknown protocols, the chance that this happens is pretty low,
-            // and we can just work with the ones we know and support.
-            protocols.into_iter().filter_map(Result::ok).collect()
+                    Err(_) => None,
+                })
+                .collect::<Vec<_>>()
         } else {
             // The J-Link cannot report which interfaces it supports, and cannot
             // switch interfaces. We assume it just supports JTAG.
@@ -255,6 +261,12 @@
             // Assume the lowest value is a safe default
             _ => 504,
         };
+        this.config = this.read_device_config()?;
+        this.connection_handle = match selector.product_id {
+            // 0x1051: J-Link OB-K22-SiFive: reports "hardware fault or protocol violation"
+            0x1051 => None,
+            _ => Some(this.register_connection()?),
+        };
 
         Ok(Box::new(this))
     }
@@ -264,10 +276,17 @@
     }
 }
 
+impl Drop for JLink {
+    fn drop(&mut self) {
+        self.unregister_connection().ok();
+    }
+}
+
 #[repr(u8)]
 #[allow(dead_code)]
 enum Command {
     Version = 0x01,
+    Register = 0x09,
     GetSpeeds = 0xC0,
     GetMaxMemBlock = 0xD4,
     GetCaps = 0xE8,
@@ -320,6 +339,7 @@
 
     read_ep: u8,
     write_ep: u8,
+    max_read_ep_packet: usize,
 
     /// The capabilities reported by the device. They're fetched once, when the device is opened.
     caps: Capabilities,
@@ -331,6 +351,10 @@
     /// when performing target I/O operations.
     interface: Interface,
 
+    /// Device configuration, fetched once when the device is opened.
+    config: JlinkConfig,
+    connection_handle: Option<u16>,
+
     swo_config: Option<SwoConfig>,
 
     /// Protocols supported by the probe.
@@ -376,7 +400,7 @@
 
         let caps = self.read_u32().map(Capabilities::from_raw_legacy)?;
 
-        debug!("legacy caps: {:?}", caps);
+        tracing::debug!("legacy caps: {:?}", caps);
 
         // If the `GET_CAPS_EX` capability is set, use the extended capability command to fetch
         // all the capabilities.
@@ -390,10 +414,10 @@
                     caps, real_caps
                 )));
             }
-            debug!("extended caps: {:?}", real_caps);
+            tracing::debug!("extended caps: {:?}", real_caps);
             self.caps = real_caps;
         } else {
-            debug!("extended caps not supported");
+            tracing::debug!("extended caps not supported");
             self.caps = caps;
         }
 
@@ -417,7 +441,7 @@
     }
 
     fn write_cmd(&self, cmd: &[u8]) -> Result<(), JlinkError> {
-        trace!("write {} bytes: {:x?}", cmd.len(), cmd);
+        tracing::trace!("write {} bytes: {:x?}", cmd.len(), cmd);
 
         let n = self
             .handle
@@ -434,16 +458,34 @@
     }
 
     fn read(&self, buf: &mut [u8]) -> Result<(), JlinkError> {
-        let mut total = 0;
+        let needs_workaround = buf.len() % self.max_read_ep_packet == 0;
+        let len = buf.len();
 
-        while total < buf.len() {
+        let mut tmp_buffer;
+        let dst = if needs_workaround {
+            // For some unknown reason, reading 256 bytes of config data leaves the interface in
+            // an unusable state. Force-reading one more byte works around this issue.
+            tmp_buffer = vec![0; len + 1];
+            &mut tmp_buffer
+        } else {
+            tmp_buffer = vec![];
+            &mut buf[..]
+        };
+
+        let mut total = 0;
+        while total < len {
             let n = self
                 .handle
-                .read_bulk(self.read_ep, &mut buf[total..], TIMEOUT_DEFAULT)?;
+                .read_bulk(self.read_ep, &mut dst[total..], TIMEOUT_DEFAULT)?;
+
             total += n;
         }
 
-        trace!("read {} bytes: {:x?}", buf.len(), buf);
+        if needs_workaround {
+            buf.copy_from_slice(&tmp_buffer[..len]);
+        }
+
+        tracing::trace!("read {total} bytes: {buf:x?}");
 
         Ok(())
     }
@@ -502,6 +544,28 @@
         self.read_u32()
     }
 
+    fn read_device_config(&self) -> Result<JlinkConfig, JlinkError> {
+        if self.caps.contains(Capability::ReadConfig) {
+            self.write_cmd(&[Command::ReadConfig as u8])?;
+            let bytes = self.read_n::<256>()?;
+
+            let config = match JlinkConfig::parse(bytes) {
+                Ok(config) => {
+                    tracing::debug!("J-Link config: {:?}", config);
+                    config
+                }
+                Err(error) => {
+                    tracing::warn!("Failed to parse J-Link config: {error}");
+                    JlinkConfig::default()
+                }
+            };
+
+            Ok(config)
+        } else {
+            Ok(JlinkConfig::default())
+        }
+    }
+
     /// Reads the firmware version string from the device.
     fn read_firmware_version(&self) -> Result<String, JlinkError> {
         self.write_cmd(&[Command::Version as u8])?;
@@ -664,7 +728,7 @@
 
         // Round bit count up to multple of 8 to get the number of response bytes.
         let num_resp_bytes = tms_bit_count.div_ceil(8);
-        trace!(
+        tracing::trace!(
             "{} TMS/TDI bits sent; reading {} response bytes",
             tms_bit_count,
             num_resp_bytes
@@ -780,6 +844,74 @@
         self.require_capability(Capability::SetKsPower)?;
         self.write_cmd(&[Command::SetKsPower as u8, if enable { 1 } else { 0 }])
     }
+
+    fn register_connection(&mut self) -> Result<u16, JlinkError> {
+        if !self.caps.contains(Capability::Register) {
+            return Ok(0);
+        }
+
+        // Undocumented, taken from OpenOCD/libjaylink
+        let mut buf = vec![Command::Register as u8, 0x64];
+        buf.extend(JlinkConnection::usb(0).into_bytes());
+        self.write_cmd(&buf)?;
+
+        let handle = self.read_registration_response()?;
+
+        if handle == 0 {
+            return Err(JlinkError::Other("Invalid registration handle".to_string()));
+        }
+
+        Ok(handle)
+    }
+
+    fn unregister_connection(&mut self) -> Result<(), JlinkError> {
+        if !self.caps.contains(Capability::Register) {
+            return Ok(());
+        }
+
+        if let Some(handle) = self.connection_handle.take() {
+            let mut buf = vec![Command::Register as u8, 0x65];
+            buf.extend(JlinkConnection::usb(handle).into_bytes());
+            self.write_cmd(&buf)?;
+            self.read_registration_response()?;
+        }
+
+        Ok(())
+    }
+
+    fn read_registration_response(&mut self) -> Result<u16, JlinkError> {
+        const REG_HEADER_SIZE: usize = 8;
+        const REG_MIN_SIZE: usize = 76;
+        const REG_MAX_SIZE: usize = 512;
+
+        let mut response = [0; REG_MAX_SIZE];
+        self.read(&mut response[..REG_MIN_SIZE])?;
+
+        let handle = u16::from_le_bytes([response[0], response[1]]);
+        let num = u16::from_le_bytes([response[2], response[3]]) as usize;
+        let entry_size = u16::from_le_bytes([response[4], response[5]]) as usize;
+        let info_size = u16::from_le_bytes([response[6], response[7]]) as usize;
+
+        let table_size = num * entry_size;
+        let size = REG_HEADER_SIZE + table_size + info_size;
+
+        tracing::debug!("Registration response size: {size}");
+
+        if size > REG_MAX_SIZE {
+            return Err(JlinkError::Other(format!(
+                "Maximum registration size exceeded: {size} bytes",
+            )));
+        }
+
+        if size > REG_MIN_SIZE {
+            // Read the rest of the response.
+            self.read(&mut response[REG_MIN_SIZE..size])?;
+        }
+
+        // TODO: we should process the response, and return the list of connections.
+
+        Ok(handle)
+    }
 }
 
 impl DebugProbe for JLink {
@@ -898,6 +1030,14 @@
             }
         }
 
+        self.write_cmd(&[Command::HwReset1 as u8])?;
+        self.write_cmd(&[Command::HwTrst1 as u8])?;
+
+        // Set a default speed if not already set
+        if self.speed_khz == 0 {
+            self.set_speed(400)?;
+        }
+
         tracing::debug!("Attached succesfully");
 
         Ok(())
@@ -1159,7 +1299,7 @@
     }
 }
 
-#[tracing::instrument(skip_all)]
+#[tracing::instrument]
 fn list_jlink_devices() -> Vec<DebugProbeInfo> {
     let Ok(devices) = nusb::list_devices() else {
         return vec![];
diff --git a/probe-rs/src/probe/jlink/speed.rs b/probe-rs/src/probe/jlink/speed.rs
index 2ca98f5..021d286 100644
--- a/probe-rs/src/probe/jlink/speed.rs
+++ b/probe-rs/src/probe/jlink/speed.rs
@@ -26,7 +26,8 @@
     #[allow(unused)]
     pub(crate) fn max_speed_config(&self) -> SpeedConfig {
         let khz = cmp::min(self.max_speed_hz() / 1000, 0xFFFE);
-        SpeedConfig::khz(khz.try_into().unwrap()).unwrap()
+        // khz is guaranteed to be in the range 1..=0xFFFE, so let's skip the constructor
+        SpeedConfig { raw: khz as u16 }
     }
 }
 
@@ -34,7 +35,7 @@
 ///
 /// This determines the clock frequency of the target communication. Supported speeds for the
 /// currently selected target interface can be fetched via [`JLink::read_interface_speeds()`].
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, PartialEq)]
 pub struct SpeedConfig {
     raw: u16,
 }
@@ -50,7 +51,7 @@
     /// Returns `None` if the value is the invalid value `0xFFFF`. Note that this doesn't mean that
     /// every other value will be accepted by the device.
     pub(crate) fn khz(khz: u16) -> Option<Self> {
-        if khz == 0xFFFF {
+        if khz == SpeedConfig::ADAPTIVE.raw {
             None
         } else {
             Some(Self { raw: khz })
@@ -60,7 +61,7 @@
 
 impl fmt::Display for SpeedConfig {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        if self.raw == Self::ADAPTIVE.raw {
+        if *self == Self::ADAPTIVE {
             f.write_str("adaptive")
         } else {
             write!(f, "{} kHz", self.raw)
@@ -101,10 +102,12 @@
     /// any API method that automatically selects an interface), the communication speed is reset to
     /// some unspecified default value.
     pub(super) fn set_interface_clock_speed(&mut self, speed: SpeedConfig) -> Result<()> {
-        if speed.raw == SpeedConfig::ADAPTIVE.raw {
+        if speed == SpeedConfig::ADAPTIVE {
             self.require_capability(Capability::AdaptiveClocking)?;
         }
 
+        tracing::info!("Selecting speed: {} Hz", speed.raw);
+
         let [low, high] = speed.raw.to_le_bytes();
         self.write_cmd(&[Command::SetSpeed as u8, low, high])?;
 
diff --git a/probe-rs/src/probe/jlink/swo.rs b/probe-rs/src/probe/jlink/swo.rs
index c1fc89c..eb58ce1 100644
--- a/probe-rs/src/probe/jlink/swo.rs
+++ b/probe-rs/src/probe/jlink/swo.rs
@@ -8,7 +8,6 @@
 use super::interface::Interface;
 
 use std::{cmp, ops::Deref};
-use tracing::warn;
 
 type Result<T> = std::result::Result<T, JlinkError>;
 
@@ -56,7 +55,7 @@
     fn new(bits: u32) -> Self {
         let flags = bits & Self::ALL_MASK;
         if flags != bits {
-            warn!("Unknown SWO status flag bits: {:#010x}", bits);
+            tracing::warn!("Unknown SWO status flag bits: {:#010x}", bits);
         }
         Self(flags)
     }
@@ -251,7 +250,7 @@
         };
 
         if status.contains(SwoStatus::OVERRUN) {
-            warn!("SWO probe buffer overrun");
+            tracing::warn!("SWO probe buffer overrun");
         }
 
         let buf = &mut data[..length];