aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-04-05 01:19:31 +0000
committerTai Ly <tai.ly@arm.com>2024-04-15 14:28:29 +0000
commit5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd (patch)
treed9dddba756207cee68b948d434502801be93d6c4 /reference_model/src/subgraph_traverser.cc
parent6dc755bf141726a7582ad1a844f97cb3f50c9b21 (diff)
downloadreference_model-5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd.tar.gz
[ref model] fix const/pad/clamp attribute serialization
This changes to use native type serialization and deserialization for pad_const, clamp min_val/max_val and const data attribute values whereby fp16 values are stored as 2 bytes each, fp8 values are stored in 1 byte each, etc. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ia95d320fe8c546ce1d1ccc035d6e9bcaadcc9ca3
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc20
1 files changed, 16 insertions, 4 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 52b1806..6aa0a45 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -580,7 +580,7 @@ int SubgraphTraverser::allocateTensor(std::string name)
break;
case DType_BF16: {
std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ TosaSerializationHandler::ConvertU8toBF16(ts->GetData(), tensor->getElementCount(), fp32_data);
// Ensure valid bfloat16 stored in each float
for (auto f : fp32_data)
ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
@@ -595,11 +595,23 @@ int SubgraphTraverser::allocateTensor(std::string name)
}
}
break;
- case DType_FP8E4M3:
+ case DType_FP8E4M3: {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toFP8E4M3(ts->GetData(), tensor->getElementCount(), fp32_data);
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
+ }
+ break;
case DType_FP8E5M2: {
std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
- // Ensure valid fp8 stored in each float
+ TosaSerializationHandler::ConvertU8toFP8E5M2(ts->GetData(), tensor->getElementCount(), fp32_data);
if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
{
std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());