diff options
author | Tai Ly <tai.ly@arm.com> | 2024-04-05 01:19:31 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-04-15 14:28:29 +0000 |
commit | 5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd (patch) | |
tree | d9dddba756207cee68b948d434502801be93d6c4 /reference_model/src/subgraph_traverser.cc | |
parent | 6dc755bf141726a7582ad1a844f97cb3f50c9b21 (diff) | |
download | reference_model-5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd.tar.gz |
[ref model] fix const/pad/clamp attribute serialization
This changes to use native type serialization and deserialization
for pad_const, clamp min_val/max_val and const data attribute values
whereby fp16 values are stored as 2 bytes each, fp8 values are stored
in 1 byte each, etc.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: Ia95d320fe8c546ce1d1ccc035d6e9bcaadcc9ca3
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 52b1806..6aa0a45 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -580,7 +580,7 @@ int SubgraphTraverser::allocateTensor(std::string name) break; case DType_BF16: { std::vector<float> fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + TosaSerializationHandler::ConvertU8toBF16(ts->GetData(), tensor->getElementCount(), fp32_data); // Ensure valid bfloat16 stored in each float for (auto f : fp32_data) ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); @@ -595,11 +595,23 @@ int SubgraphTraverser::allocateTensor(std::string name) } } break; - case DType_FP8E4M3: + case DType_FP8E4M3: { + std::vector<float> fp32_data; + TosaSerializationHandler::ConvertU8toFP8E4M3(ts->GetData(), tensor->getElementCount(), fp32_data); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector<double> f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + } + break; case DType_FP8E5M2: { std::vector<float> fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); - // Ensure valid fp8 stored in each float + TosaSerializationHandler::ConvertU8toFP8E5M2(ts->GetData(), tensor->getElementCount(), fp32_data); if (tensor->getDtype() == TOSA_REF_TYPE_FP64) { std::vector<double> f64_data(fp32_data.begin(), fp32_data.end()); |