diff options
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index fae0b30..33a9b94 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -595,6 +595,22 @@ int SubgraphTraverser::allocateTensor(std::string name) } } break; + case DType_FP8E4M3: + case DType_FP8E5M2: { + std::vector<float> fp32_data; + TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + // Ensure valid fp8 stored in each float + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector<double> f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + } + break; case DType_FP32: { std::vector<float> fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); |