aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/tosa_serialization_handler.cpp36
1 files changed, 25 insertions, 11 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 76b2198..eb71b69 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -15,6 +15,7 @@
#include "tosa_serialization_handler.h"
#include "half.hpp"
+#include "tosa_schema.h"
#include <iostream>
using namespace tosa;
@@ -235,6 +236,21 @@ tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
return TOSA_OK;
}
+tosa_err_t TosaSerializationHandler::LoadTosaSchema()
+{
+ bool ok;
+ ok = _parser.Parse(TOSA_SCHEMA);
+
+ if (!ok)
+ {
+ printf("Error parsing ISA schema contents \n");
+ return TOSA_FILE_ERROR;
+ }
+
+ _schemaLoaded = true;
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename)
{
std::string jsonfile;
@@ -940,7 +956,6 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
tosa_err_t
TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out)
{
- // Note: bf16 values returned in fp32 type
out.clear();
if (in.size() < out_size * sizeof(int16_t))
{
@@ -951,13 +966,12 @@ tosa_err_t
for (uint32_t i = 0; i < out_size; i++)
{
- uint32_t f32_byte2 = in[i * sizeof(int16_t)];
- uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1];
- uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24);
+ uint8_t bf16_byte0 = in[i * sizeof(int16_t)];
+ uint8_t bf16_byte1 = in[i * sizeof(int16_t) + 1];
+ uint16_t val_u16 = (bf16_byte0) + (bf16_byte1 << 8);
- // Reinterpret u32 bytes as fp32
- float val_f32 = *(float*)&val_u32;
- float val_bf16 = static_cast<bf16>(val_f32);
+ // Reinterpret u16 bytes as bf16
+ bf16 val_bf16 = *(bf16*)&val_u16;
out.push_back(val_bf16);
}
return TOSA_OK;
@@ -1021,9 +1035,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>&
for (uint32_t i = 0; i < out_size; i++)
{
- uint16_t f16_byte0 = in[i * sizeof(int16_t)];
- uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1];
- uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8);
+ uint8_t f16_byte0 = in[i * sizeof(int16_t)];
+ uint8_t f16_byte1 = in[i * sizeof(int16_t) + 1];
+ uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8);
// Reinterpret u16 byte as fp16 then convert to fp32
half_float::half val_f16 = *(half_float::half*)&val_u16;
@@ -1191,7 +1205,7 @@ tosa_err_t
out_size, in.size());
return TOSA_USER_ERROR;
}
- for (size_t i = 0; i < in.size(); i++)
+ for (size_t i = 0; 2 * i < out_size; i++)
{
uint8_t val_u8 = in[i];
uint8_t val_0_u4 = val_u8 & 0xF;