aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp')
-rw-r--r--src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp79
1 files changed, 78 insertions, 1 deletions
diff --git a/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp b/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp
index 31ff026887..0b717bc0fd 100644
--- a/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp
+++ b/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019-2023 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -150,6 +150,18 @@ struct ParserFlatbuffersSerializeFixture
const std::map<std::string, std::vector<InputDataType>>& inputData,
const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData);
+ template<std::size_t NumOutputDimensions,
+ armnn::DataType ArmnnInputType0,
+ armnn::DataType ArmnnInputType1,
+ armnn::DataType ArmnnOutputType,
+ typename InputDataType0 = armnn::ResolveType<ArmnnInputType0>,
+ typename InputDataType1 = armnn::ResolveType<ArmnnInputType1>,
+ typename OutputDataType = armnn::ResolveType<ArmnnOutputType>>
+ void RunTest(unsigned int layersId,
+ const std::map<std::string, std::vector<InputDataType0>>& inputData0,
+ const std::map<std::string, std::vector<InputDataType1>>& inputData1,
+ const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData);
+
void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector<int32_t>& shape,
armnnSerializer::TensorInfo tensorType, const std::string& name,
const float scale, const int64_t zeroPoint)
@@ -246,3 +258,68 @@ void ParserFlatbuffersSerializeFixture::RunTest(
CHECK_MESSAGE(result.m_Result, result.m_Message.str());
}
}
+
+template<std::size_t NumOutputDimensions,
+ armnn::DataType ArmnnInputType0,
+ armnn::DataType ArmnnInputType1,
+ armnn::DataType ArmnnOutputType,
+ typename InputDataType0,
+ typename InputDataType1,
+ typename OutputDataType>
+void ParserFlatbuffersSerializeFixture::RunTest(
+ unsigned int layersId,
+ const std::map<std::string, std::vector<InputDataType0>>& inputData0,
+ const std::map<std::string, std::vector<InputDataType1>>& inputData1,
+ const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData)
+{
+ auto ConvertBindingInfo = [](const armnnDeserializer::BindingPointInfo& bindingInfo)
+ {
+ return std::make_pair(bindingInfo.m_BindingId, bindingInfo.m_TensorInfo);
+ };
+
+ // Setup the armnn input tensors from the given vectors.
+ armnn::InputTensors inputTensors;
+ for (auto&& it : inputData0)
+ {
+ armnn::BindingPointInfo bindingInfo = ConvertBindingInfo(
+ m_Parser->GetNetworkInputBindingInfo(layersId, it.first));
+ bindingInfo.second.SetConstant(true);
+ armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnInputType0);
+ inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
+ }
+
+ for (auto&& it : inputData1)
+ {
+ armnn::BindingPointInfo bindingInfo = ConvertBindingInfo(
+ m_Parser->GetNetworkInputBindingInfo(layersId, it.first));
+ bindingInfo.second.SetConstant(true);
+ armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnInputType1);
+ inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
+ }
+
+ // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
+ std::map<std::string, std::vector<OutputDataType>> outputStorage;
+ armnn::OutputTensors outputTensors;
+ for (auto&& it : expectedOutputData)
+ {
+ armnn::BindingPointInfo bindingInfo = ConvertBindingInfo(
+ m_Parser->GetNetworkOutputBindingInfo(layersId, it.first));
+ armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnOutputType);
+ outputStorage.emplace(it.first, std::vector<OutputDataType>(bindingInfo.second.GetNumElements()));
+ outputTensors.push_back(
+ { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
+ }
+
+ m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
+
+ // Compare each output tensor to the expected values
+ for (auto&& it : expectedOutputData)
+ {
+ armnn::BindingPointInfo bindingInfo = ConvertBindingInfo(
+ m_Parser->GetNetworkOutputBindingInfo(layersId, it.first));
+ auto outputExpected = it.second;
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ }
+}