diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/numpy_utils.cpp | 29 | ||||
-rw-r--r-- | src/tosa_serialization_handler.cpp | 68 |
2 files changed, 56 insertions, 41 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index e4171d7..7cf5f94 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3 while (isspace(*ptr)) ptr++; + // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but + // default NumPy does not understand this notation, which causes trouble + // when other code tries to open this file. + // To avoid this, '|u1' notation is used when the file is written, and the uint8 + // data is viewed as float8_e5m2 later when the file is read. + if (!strcmp(dtype_str, "'<f1'")) + dtype_str = "'|u1'"; + if (strcmp(ptr, dtype_str)) { return FILE_TYPE_MISMATCH; @@ -430,6 +438,13 @@ NumpyUtilities::NPError memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1); headerPos += sizeof(NUMPY_HEADER_STR) - 1; + // NumPy does not understand float8_e5m2, so change it to uint8 type, so that + // Python can read .npy files. + if (!strcmp(dtype_str, "'<f1'")) + { + dtype_str = "'|u1'"; + } + // Output the format dictionary // Hard-coded for I32 for now headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, @@ -438,7 +453,19 @@ NumpyUtilities::NPError // Add shape contents (if any - as this will be empty for rank 0) for (i = 0; i < shape.size(); i++) { - headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); + // Output NumPy file from tosa_refmodel_sut_run generates the shape information + // without a trailing comma when the rank is greater than 1. + if (i == 0) + { + if (shape.size() == 1) + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]); + else + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]); + } + else + { + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]); + } } // Close off the dictionary diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 0ce6211..74f66d8 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -19,9 +19,6 @@ #include <iostream> using namespace tosa; -using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; -using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; - TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector<int32_t>* shape, DType dtype, @@ -750,45 +747,41 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf) } } -tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->bf16 by ignoring the least significant 16 bits out.clear(); for (auto val : in) { - uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val); - uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF; - uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF; - // little endian: byte2 followed by byte3 - out.push_back(f32_byte2); - out.push_back(f32_byte3); + uint8_t bf16_byte0 = val.bits() & 0xFF; + uint8_t bf16_byte1 = (val.bits() >> 8) & 0xFF; + out.push_back(bf16_byte0); + out.push_back(bf16_byte1); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->FP8E4M3 before converting to unint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast<fp8e4m3>(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->FP8E5M2 before converting to uint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast<fp8e5m2>(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); @@ -944,11 +937,9 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, - uint32_t out_size, - std::vector<float>& out) +tosa_err_t + TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out) { - // Note: bf16 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int16_t)) { @@ -959,22 +950,21 @@ tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& for (uint32_t i = 0; i < out_size; i++) { - uint32_t f32_byte2 = in[i * sizeof(int16_t)]; - uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; - uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + uint8_t bf16_byte0 = in[i * sizeof(int16_t)]; + uint8_t bf16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = (bf16_byte0) + (bf16_byte1 << 8); - // Reinterpret u32 bytes as fp32 - float val_f32 = *(float*)&val_u32; - out.push_back(val_f32); + // Reinterpret u16 bytes as bf16 + bf16 val_bf16 = *(bf16*)&val_u16; + out.push_back(val_bf16); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, - std::vector<float>& out) + std::vector<fp8e4m3>& out) { - // Note: FP8E4M3 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int8_t)) { @@ -985,17 +975,16 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_ for (uint32_t i = 0; i < out_size; i++) { - int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); - auto f8 = fp8e4m3::from_bits(bits); - float val_f32 = static_cast<float>(f8); - out.push_back(val_f32); + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e4m3::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, - std::vector<float>& out) + std::vector<fp8e5m2>& out) { // Note: FP8E5M2 values returned in fp32 type out.clear(); @@ -1008,10 +997,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_ for (uint32_t i = 0; i < out_size; i++) { - int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); - auto f8 = fp8e5m2::from_bits(bits); - float val_f32 = static_cast<float>(f8); - out.push_back(val_f32); + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e5m2::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } @@ -1031,9 +1019,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& for (uint32_t i = 0; i < out_size; i++) { - uint16_t f16_byte0 = in[i * sizeof(int16_t)]; - uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1]; - uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); + uint8_t f16_byte0 = in[i * sizeof(int16_t)]; + uint8_t f16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); // Reinterpret u16 byte as fp16 then convert to fp32 half_float::half val_f16 = *(half_float::half*)&val_u16; |