aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc20
1 files changed, 16 insertions, 4 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 52b1806..6aa0a45 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -580,7 +580,7 @@ int SubgraphTraverser::allocateTensor(std::string name)
break;
case DType_BF16: {
std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ TosaSerializationHandler::ConvertU8toBF16(ts->GetData(), tensor->getElementCount(), fp32_data);
// Ensure valid bfloat16 stored in each float
for (auto f : fp32_data)
ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
@@ -595,11 +595,23 @@ int SubgraphTraverser::allocateTensor(std::string name)
}
}
break;
- case DType_FP8E4M3:
+ case DType_FP8E4M3: {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toFP8E4M3(ts->GetData(), tensor->getElementCount(), fp32_data);
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
+ }
+ break;
case DType_FP8E5M2: {
std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
- // Ensure valid fp8 stored in each float
+ TosaSerializationHandler::ConvertU8toFP8E5M2(ts->GetData(), tensor->getElementCount(), fp32_data);
if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
{
std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());