aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/numpy_utils.h17
-rw-r--r--include/tosa_serialization_handler.h12
-rw-r--r--python/serializer/tosa_serializer.py42
-rw-r--r--src/numpy_utils.cpp29
-rw-r--r--src/tosa_serialization_handler.cpp54
5 files changed, 87 insertions, 67 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 60cf77e..ade2f2d 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -24,8 +24,13 @@
#include <cstring>
#include <vector>
+#include "cfloat.h"
#include "half.hpp"
+using bf16 = ct::cfloat<int16_t, 8, true, true, true>;
+using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
+using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
+
class NumpyUtilities
{
public:
@@ -85,6 +90,18 @@ public:
{
return "'<f2'";
}
+ if (std::is_same<T, bf16>::value)
+ {
+ return "'<V2'";
+ }
+ if (std::is_same<T, fp8e4m3>::value)
+ {
+ return "'<V1'";
+ }
+ if (std::is_same<T, fp8e5m2>::value)
+ {
+ return "'<f1'";
+ }
assert(false && "unsupported Dtype");
};
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 139a476..c09a47d 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -412,9 +412,9 @@ public:
tosa_err_t LoadFileSchema(const char* schema_filename);
// data format conversion. little-endian.
- static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
@@ -425,9 +425,9 @@ public:
static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out);
+ static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out);
+ static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out);
static tosa_err_t
ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index c328aaf..7122216 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
-import struct
+from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2
from enum import IntEnum, unique
from tosa import (
TosaGraph,
@@ -392,13 +392,14 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if (
- dtype == DType.FP32
- or dtype == DType.BF16
- or dtype == DType.FP8E4M3
- or dtype == DType.FP8E5M2
- ):
+ if dtype == DType.FP32:
fntype = np.float32
+ elif dtype == DType.BF16:
+ fntype = bfloat16
+ elif dtype == DType.FP8E4M3:
+ fntype = float8_e4m3fn
+ elif dtype == DType.FP8E5M2:
+ fntype = float8_e5m2
elif dtype == DType.FP16:
fntype = np.float16
else:
@@ -943,35 +944,20 @@ class TosaSerializer:
np_arr = np.array(data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP32:
- # for val in data:
- # b = struct.pack("!f", val)
- # u8_data.extend([b[3], b[2], b[1], b[0]])
np_arr = np.array(data, dtype=np.float32)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.BF16:
for val in data:
- # convert val to little endian byte arrays b
- b = struct.pack("<f", val)
- # val => [ b[3], b[2], b[1], b[0] ]
- # keep only most significant 2 bytes for bf16
- # in little endian ordering
- u8_data.extend([b[2], b[3]])
+ np_arr = np.array(data, dtype=bfloat16)
+ u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP8E4M3:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == DType.FP8E5M2:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == TosaDType.DType:
# Serialize DType enum data as uint8 bytes
for val in data:
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index e4171d7..7cf5f94 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3
while (isspace(*ptr))
ptr++;
+ // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but
+ // default NumPy does not understand this notation, which causes trouble
+ // when other code tries to open this file.
+ // To avoid this, '|u1' notation is used when the file is written, and the uint8
+ // data is viewed as float8_e5m2 later when the file is read.
+ if (!strcmp(dtype_str, "'<f1'"))
+ dtype_str = "'|u1'";
+
if (strcmp(ptr, dtype_str))
{
return FILE_TYPE_MISMATCH;
@@ -430,6 +438,13 @@ NumpyUtilities::NPError
memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
headerPos += sizeof(NUMPY_HEADER_STR) - 1;
+ // NumPy does not understand float8_e5m2, so change it to uint8 type, so that
+ // Python can read .npy files.
+ if (!strcmp(dtype_str, "'<f1'"))
+ {
+ dtype_str = "'|u1'";
+ }
+
// Output the format dictionary
// Hard-coded for I32 for now
headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
@@ -438,7 +453,19 @@ NumpyUtilities::NPError
// Add shape contents (if any - as this will be empty for rank 0)
for (i = 0; i < shape.size(); i++)
{
- headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
+ // Output NumPy file from tosa_refmodel_sut_run generates the shape information
+ // without a trailing comma when the rank is greater than 1.
+ if (i == 0)
+ {
+ if (shape.size() == 1)
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]);
+ else
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]);
+ }
+ else
+ {
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]);
+ }
}
// Close off the dictionary
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 0ce6211..76b2198 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -19,9 +19,6 @@
#include <iostream>
using namespace tosa;
-using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
-using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
-
TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
@@ -750,45 +747,41 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
}
}
-tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->bf16 by ignoring the least significant 16 bits
out.clear();
for (auto val : in)
{
- uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val);
- uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF;
- uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF;
- // little endian: byte2 followed by byte3
- out.push_back(f32_byte2);
- out.push_back(f32_byte3);
+ uint8_t bf16_byte0 = val.bits() & 0xFF;
+ uint8_t bf16_byte1 = (val.bits() >> 8) & 0xFF;
+ out.push_back(bf16_byte0);
+ out.push_back(bf16_byte1);
}
ForceAlignTensorData(out);
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->FP8E4M3 before converting to unint8_t
out.clear();
for (auto val : in)
{
- auto f8 = static_cast<fp8e4m3>(val);
- uint8_t b8 = f8.bits();
+ uint8_t b8 = val.bits();
out.push_back(b8);
}
ForceAlignTensorData(out);
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->FP8E5M2 before converting to uint8_t
out.clear();
for (auto val : in)
{
- auto f8 = static_cast<fp8e5m2>(val);
- uint8_t b8 = f8.bits();
+ uint8_t b8 = val.bits();
out.push_back(b8);
}
ForceAlignTensorData(out);
@@ -944,9 +937,8 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in,
- uint32_t out_size,
- std::vector<float>& out)
+tosa_err_t
+ TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out)
{
// Note: bf16 values returned in fp32 type
out.clear();
@@ -964,17 +956,17 @@ tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>&
uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24);
// Reinterpret u32 bytes as fp32
- float val_f32 = *(float*)&val_u32;
- out.push_back(val_f32);
+ float val_f32 = *(float*)&val_u32;
+ float val_bf16 = static_cast<bf16>(val_f32);
+ out.push_back(val_bf16);
}
return TOSA_OK;
}
tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in,
uint32_t out_size,
- std::vector<float>& out)
+ std::vector<fp8e4m3>& out)
{
- // Note: FP8E4M3 values returned in fp32 type
out.clear();
if (in.size() < out_size * sizeof(int8_t))
{
@@ -985,17 +977,16 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_
for (uint32_t i = 0; i < out_size; i++)
{
- int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
- auto f8 = fp8e4m3::from_bits(bits);
- float val_f32 = static_cast<float>(f8);
- out.push_back(val_f32);
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e4m3::from_bits(bits);
+ out.push_back(f8);
}
return TOSA_OK;
}
tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in,
uint32_t out_size,
- std::vector<float>& out)
+ std::vector<fp8e5m2>& out)
{
// Note: FP8E5M2 values returned in fp32 type
out.clear();
@@ -1008,10 +999,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_
for (uint32_t i = 0; i < out_size; i++)
{
- int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
- auto f8 = fp8e5m2::from_bits(bits);
- float val_f32 = static_cast<float>(f8);
- out.push_back(val_f32);
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e5m2::from_bits(bits);
+ out.push_back(f8);
}
return TOSA_OK;
}