Add dtype::normalized_num and dtype::num_of (#5429)

* Add dtype::normalized_num and dtype::num_of

* Fix compiler warning and improve NumPy 1.x compatibility

* Fix clang-tidy warning

* Fix another clang-tidy warning

* Add extra comment
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 228e02c..ab224e1 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -212,6 +212,7 @@
 }
 
 struct npy_api {
+    // If you change this code, please review `normalized_dtype_num` below.
     enum constants {
         NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
         NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
@@ -384,6 +385,74 @@
     }
 };
 
+// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
+// This is needed to correctly handle situations where multiple typenums map to the same type,
+// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
+// typenum. The normalized typenum should always match the values used in npy_format_descriptor.
+// If you change this code, please review `enum constants` above.
+static constexpr int normalized_dtype_num[npy_api::NPY_VOID_ + 1] = {
+    // NPY_BOOL_ =>
+    npy_api::NPY_BOOL_,
+    // NPY_BYTE_ =>
+    npy_api::NPY_BYTE_,
+    // NPY_UBYTE_ =>
+    npy_api::NPY_UBYTE_,
+    // NPY_SHORT_ =>
+    npy_api::NPY_INT16_,
+    // NPY_USHORT_ =>
+    npy_api::NPY_UINT16_,
+    // NPY_INT_ =>
+    sizeof(int) == sizeof(std::int16_t)   ? npy_api::NPY_INT16_
+    : sizeof(int) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
+    : sizeof(int) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
+                                          : npy_api::NPY_INT_,
+    // NPY_UINT_ =>
+    sizeof(unsigned int) == sizeof(std::uint16_t)   ? npy_api::NPY_UINT16_
+    : sizeof(unsigned int) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
+    : sizeof(unsigned int) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
+                                                    : npy_api::NPY_UINT_,
+    // NPY_LONG_ =>
+    sizeof(long) == sizeof(std::int16_t)   ? npy_api::NPY_INT16_
+    : sizeof(long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
+    : sizeof(long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
+                                           : npy_api::NPY_LONG_,
+    // NPY_ULONG_ =>
+    sizeof(unsigned long) == sizeof(std::uint16_t)   ? npy_api::NPY_UINT16_
+    : sizeof(unsigned long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
+    : sizeof(unsigned long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
+                                                     : npy_api::NPY_ULONG_,
+    // NPY_LONGLONG_ =>
+    sizeof(long long) == sizeof(std::int16_t)   ? npy_api::NPY_INT16_
+    : sizeof(long long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
+    : sizeof(long long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
+                                                : npy_api::NPY_LONGLONG_,
+    // NPY_ULONGLONG_ =>
+    sizeof(unsigned long long) == sizeof(std::uint16_t)   ? npy_api::NPY_UINT16_
+    : sizeof(unsigned long long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
+    : sizeof(unsigned long long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
+                                                          : npy_api::NPY_ULONGLONG_,
+    // NPY_FLOAT_ =>
+    npy_api::NPY_FLOAT_,
+    // NPY_DOUBLE_ =>
+    npy_api::NPY_DOUBLE_,
+    // NPY_LONGDOUBLE_ =>
+    npy_api::NPY_LONGDOUBLE_,
+    // NPY_CFLOAT_ =>
+    npy_api::NPY_CFLOAT_,
+    // NPY_CDOUBLE_ =>
+    npy_api::NPY_CDOUBLE_,
+    // NPY_CLONGDOUBLE_ =>
+    npy_api::NPY_CLONGDOUBLE_,
+    // NPY_OBJECT_ =>
+    npy_api::NPY_OBJECT_,
+    // NPY_STRING_ =>
+    npy_api::NPY_STRING_,
+    // NPY_UNICODE_ =>
+    npy_api::NPY_UNICODE_,
+    // NPY_VOID_ =>
+    npy_api::NPY_VOID_,
+};
+
 inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast<PyArray_Proxy *>(ptr); }
 
 inline const PyArray_Proxy *array_proxy(const void *ptr) {
@@ -684,6 +753,13 @@
         return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
     }
 
+    /// Return the type number associated with a C++ type.
+    /// This is the constexpr equivalent of `dtype::of<T>().num()`.
+    template <typename T>
+    static constexpr int num_of() {
+        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::value;
+    }
+
     /// Size of the data type in bytes.
 #ifdef PYBIND11_NUMPY_1_ONLY
     ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
@@ -725,7 +801,9 @@
         return detail::array_descriptor_proxy(m_ptr)->type;
     }
 
-    /// type number of dtype.
+    /// Type number of dtype. Note that different values may be returned for equivalent types,
+    /// e.g. even though ``long`` may be equivalent to ``int`` or ``long long``, they still have
+    /// different type numbers. Consider using `normalized_num` to avoid this.
     int num() const {
         // Note: The signature, `dtype::num` follows the naming of NumPy's public
         // Python API (i.e., ``dtype.num``), rather than its internal
@@ -733,6 +811,17 @@
         return detail::array_descriptor_proxy(m_ptr)->type_num;
     }
 
+    /// Type number of dtype, normalized to match the return value of `num_of` for equivalent
+    /// types. This function can be used to write switch statements that correctly handle
+    /// equivalent types with different type numbers.
+    int normalized_num() const {
+        int value = num();
+        if (value >= 0 && value <= detail::npy_api::NPY_VOID_) {
+            return detail::normalized_dtype_num[value];
+        }
+        return value;
+    }
+
     /// Single character for byteorder
     char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
 
diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp
index 596d902..b6db439 100644
--- a/tests/test_numpy_dtypes.cpp
+++ b/tests/test_numpy_dtypes.cpp
@@ -11,6 +11,9 @@
 
 #include "pybind11_tests.h"
 
+#include <cstdint>
+#include <stdexcept>
+
 #ifdef __GNUC__
 #    define PYBIND11_PACKED(cls) cls __attribute__((__packed__))
 #else
@@ -297,6 +300,15 @@
     return list;
 }
 
+template <typename T>
+py::array_t<T> dispatch_array_increment(py::array_t<T> arr) {
+    py::array_t<T> res(arr.shape(0));
+    for (py::ssize_t i = 0; i < arr.shape(0); ++i) {
+        res.mutable_at(i) = T(arr.at(i) + 1);
+    }
+    return res;
+}
+
 struct A {};
 struct B {};
 
@@ -496,6 +508,98 @@
         }
         return list;
     });
+    m.def("test_dtype_num_of", []() -> py::list {
+        py::list res;
+#define TEST_DTYPE(T) res.append(py::make_tuple(py::dtype::of<T>().num(), py::dtype::num_of<T>()));
+        TEST_DTYPE(bool)
+        TEST_DTYPE(char)
+        TEST_DTYPE(unsigned char)
+        TEST_DTYPE(short)
+        TEST_DTYPE(unsigned short)
+        TEST_DTYPE(int)
+        TEST_DTYPE(unsigned int)
+        TEST_DTYPE(long)
+        TEST_DTYPE(unsigned long)
+        TEST_DTYPE(long long)
+        TEST_DTYPE(unsigned long long)
+        TEST_DTYPE(float)
+        TEST_DTYPE(double)
+        TEST_DTYPE(long double)
+        TEST_DTYPE(std::complex<float>)
+        TEST_DTYPE(std::complex<double>)
+        TEST_DTYPE(std::complex<long double>)
+        TEST_DTYPE(int8_t)
+        TEST_DTYPE(uint8_t)
+        TEST_DTYPE(int16_t)
+        TEST_DTYPE(uint16_t)
+        TEST_DTYPE(int32_t)
+        TEST_DTYPE(uint32_t)
+        TEST_DTYPE(int64_t)
+        TEST_DTYPE(uint64_t)
+#undef TEST_DTYPE
+        return res;
+    });
+    m.def("test_dtype_normalized_num", []() -> py::list {
+        py::list res;
+#define TEST_DTYPE(NT, T)                                                                         \
+    res.append(py::make_tuple(py::dtype(py::detail::npy_api::NT).normalized_num(),                \
+                              py::dtype::num_of<T>()));
+        TEST_DTYPE(NPY_BOOL_, bool)
+        TEST_DTYPE(NPY_BYTE_, char);
+        TEST_DTYPE(NPY_UBYTE_, unsigned char);
+        TEST_DTYPE(NPY_SHORT_, short);
+        TEST_DTYPE(NPY_USHORT_, unsigned short);
+        TEST_DTYPE(NPY_INT_, int);
+        TEST_DTYPE(NPY_UINT_, unsigned int);
+        TEST_DTYPE(NPY_LONG_, long);
+        TEST_DTYPE(NPY_ULONG_, unsigned long);
+        TEST_DTYPE(NPY_LONGLONG_, long long);
+        TEST_DTYPE(NPY_ULONGLONG_, unsigned long long);
+        TEST_DTYPE(NPY_FLOAT_, float);
+        TEST_DTYPE(NPY_DOUBLE_, double);
+        TEST_DTYPE(NPY_LONGDOUBLE_, long double);
+        TEST_DTYPE(NPY_CFLOAT_, std::complex<float>);
+        TEST_DTYPE(NPY_CDOUBLE_, std::complex<double>);
+        TEST_DTYPE(NPY_CLONGDOUBLE_, std::complex<long double>);
+        TEST_DTYPE(NPY_INT8_, int8_t);
+        TEST_DTYPE(NPY_UINT8_, uint8_t);
+        TEST_DTYPE(NPY_INT16_, int16_t);
+        TEST_DTYPE(NPY_UINT16_, uint16_t);
+        TEST_DTYPE(NPY_INT32_, int32_t);
+        TEST_DTYPE(NPY_UINT32_, uint32_t);
+        TEST_DTYPE(NPY_INT64_, int64_t);
+        TEST_DTYPE(NPY_UINT64_, uint64_t);
+#undef TEST_DTYPE
+        return res;
+    });
+    m.def("test_dtype_switch", [](const py::array &arr) -> py::array {
+        switch (arr.dtype().normalized_num()) {
+            case py::dtype::num_of<int8_t>():
+                return dispatch_array_increment<int8_t>(arr);
+            case py::dtype::num_of<uint8_t>():
+                return dispatch_array_increment<uint8_t>(arr);
+            case py::dtype::num_of<int16_t>():
+                return dispatch_array_increment<int16_t>(arr);
+            case py::dtype::num_of<uint16_t>():
+                return dispatch_array_increment<uint16_t>(arr);
+            case py::dtype::num_of<int32_t>():
+                return dispatch_array_increment<int32_t>(arr);
+            case py::dtype::num_of<uint32_t>():
+                return dispatch_array_increment<uint32_t>(arr);
+            case py::dtype::num_of<int64_t>():
+                return dispatch_array_increment<int64_t>(arr);
+            case py::dtype::num_of<uint64_t>():
+                return dispatch_array_increment<uint64_t>(arr);
+            case py::dtype::num_of<float>():
+                return dispatch_array_increment<float>(arr);
+            case py::dtype::num_of<double>():
+                return dispatch_array_increment<double>(arr);
+            case py::dtype::num_of<long double>():
+                return dispatch_array_increment<long double>(arr);
+            default:
+                throw std::runtime_error("Unsupported dtype");
+        }
+    });
     m.def("test_dtype_methods", []() {
         py::list list;
         auto dt1 = py::dtype::of<int32_t>();
diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py
index 8ae239e..5d83993 100644
--- a/tests/test_numpy_dtypes.py
+++ b/tests/test_numpy_dtypes.py
@@ -188,6 +188,28 @@
             chr(np.dtype(ch).flags) for ch in expected_chars
         ]
 
+    for a, b in m.test_dtype_num_of():
+        assert a == b
+
+    for a, b in m.test_dtype_normalized_num():
+        assert a == b
+
+    arr = np.array([4, 84, 21, 36])
+    # Note: "ulong" does not work in NumPy 1.x, so we use "L"
+    assert (m.test_dtype_switch(arr.astype("byte")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("ubyte")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("short")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("ushort")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("intc")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("uintc")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("long")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("L")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("longlong")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("ulonglong")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("single")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("double")) == arr + 1).all()
+    assert (m.test_dtype_switch(arr.astype("longdouble")) == arr + 1).all()
+
 
 def test_recarray(simple_dtype, packed_dtype):
     elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]