diff options
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 52b1806..6aa0a45 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -580,7 +580,7 @@ int SubgraphTraverser::allocateTensor(std::string name) break; case DType_BF16: { std::vector<float> fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + TosaSerializationHandler::ConvertU8toBF16(ts->GetData(), tensor->getElementCount(), fp32_data); // Ensure valid bfloat16 stored in each float for (auto f : fp32_data) ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); @@ -595,11 +595,23 @@ int SubgraphTraverser::allocateTensor(std::string name) } } break; - case DType_FP8E4M3: + case DType_FP8E4M3: { + std::vector<float> fp32_data; + TosaSerializationHandler::ConvertU8toFP8E4M3(ts->GetData(), tensor->getElementCount(), fp32_data); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector<double> f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + } + break; case DType_FP8E5M2: { std::vector<float> fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); - // Ensure valid fp8 stored in each float + TosaSerializationHandler::ConvertU8toFP8E5M2(ts->GetData(), tensor->getElementCount(), fp32_data); if (tensor->getDtype() == TOSA_REF_TYPE_FP64) { std::vector<double> f64_data(fp32_data.begin(), fp32_data.end()); |