diff options
Diffstat (limited to 'include/numpy_utils.h')
-rw-r--r-- | include/numpy_utils.h | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h index 60cf77e..ade2f2d 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -24,8 +24,13 @@ #include <cstring> #include <vector> +#include "cfloat.h" #include "half.hpp" +using bf16 = ct::cfloat<int16_t, 8, true, true, true>; +using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; +using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; + class NumpyUtilities { public: @@ -85,6 +90,18 @@ public: { return "'<f2'"; } + if (std::is_same<T, bf16>::value) + { + return "'<V2'"; + } + if (std::is_same<T, fp8e4m3>::value) + { + return "'<V1'"; + } + if (std::is_same<T, fp8e5m2>::value) + { + return "'<f1'"; + } assert(false && "unsupported Dtype"); }; |