aboutsummaryrefslogtreecommitdiff
path: root/include/numpy_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/numpy_utils.h')
-rw-r--r--include/numpy_utils.h17
1 files changed, 17 insertions, 0 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 60cf77e..ade2f2d 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -24,8 +24,13 @@
#include <cstring>
#include <vector>
+#include "cfloat.h"
#include "half.hpp"
+using bf16 = ct::cfloat<int16_t, 8, true, true, true>;
+using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
+using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
+
class NumpyUtilities
{
public:
@@ -85,6 +90,18 @@ public:
{
return "'<f2'";
}
+ if (std::is_same<T, bf16>::value)
+ {
+ return "'<V2'";
+ }
+ if (std::is_same<T, fp8e4m3>::value)
+ {
+ return "'<V1'";
+ }
+ if (std::is_same<T, fp8e5m2>::value)
+ {
+ return "'<f1'";
+ }
assert(false && "unsupported Dtype");
};