diff options
Diffstat (limited to 'src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp')
-rw-r--r-- | src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp | 79 |
1 files changed, 78 insertions, 1 deletions
diff --git a/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp b/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp index 31ff026887..0b717bc0fd 100644 --- a/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp +++ b/src/armnnDeserializer/test/ParserFlatbuffersSerializeFixture.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2019-2023 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -150,6 +150,18 @@ struct ParserFlatbuffersSerializeFixture const std::map<std::string, std::vector<InputDataType>>& inputData, const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData); + template<std::size_t NumOutputDimensions, + armnn::DataType ArmnnInputType0, + armnn::DataType ArmnnInputType1, + armnn::DataType ArmnnOutputType, + typename InputDataType0 = armnn::ResolveType<ArmnnInputType0>, + typename InputDataType1 = armnn::ResolveType<ArmnnInputType1>, + typename OutputDataType = armnn::ResolveType<ArmnnOutputType>> + void RunTest(unsigned int layersId, + const std::map<std::string, std::vector<InputDataType0>>& inputData0, + const std::map<std::string, std::vector<InputDataType1>>& inputData1, + const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData); + void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector<int32_t>& shape, armnnSerializer::TensorInfo tensorType, const std::string& name, const float scale, const int64_t zeroPoint) @@ -246,3 +258,68 @@ void ParserFlatbuffersSerializeFixture::RunTest( CHECK_MESSAGE(result.m_Result, result.m_Message.str()); } } + +template<std::size_t NumOutputDimensions, + armnn::DataType ArmnnInputType0, + armnn::DataType ArmnnInputType1, + armnn::DataType ArmnnOutputType, + typename InputDataType0, + typename InputDataType1, + typename OutputDataType> +void ParserFlatbuffersSerializeFixture::RunTest( + unsigned int layersId, + const std::map<std::string, std::vector<InputDataType0>>& inputData0, + const std::map<std::string, std::vector<InputDataType1>>& inputData1, + const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData) +{ + auto ConvertBindingInfo = [](const armnnDeserializer::BindingPointInfo& bindingInfo) + { + return std::make_pair(bindingInfo.m_BindingId, bindingInfo.m_TensorInfo); + }; + + // Setup the armnn input tensors from the given vectors. + armnn::InputTensors inputTensors; + for (auto&& it : inputData0) + { + armnn::BindingPointInfo bindingInfo = ConvertBindingInfo( + m_Parser->GetNetworkInputBindingInfo(layersId, it.first)); + bindingInfo.second.SetConstant(true); + armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnInputType0); + inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) }); + } + + for (auto&& it : inputData1) + { + armnn::BindingPointInfo bindingInfo = ConvertBindingInfo( + m_Parser->GetNetworkInputBindingInfo(layersId, it.first)); + bindingInfo.second.SetConstant(true); + armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnInputType1); + inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) }); + } + + // Allocate storage for the output tensors to be written to and setup the armnn output tensors. + std::map<std::string, std::vector<OutputDataType>> outputStorage; + armnn::OutputTensors outputTensors; + for (auto&& it : expectedOutputData) + { + armnn::BindingPointInfo bindingInfo = ConvertBindingInfo( + m_Parser->GetNetworkOutputBindingInfo(layersId, it.first)); + armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnOutputType); + outputStorage.emplace(it.first, std::vector<OutputDataType>(bindingInfo.second.GetNumElements())); + outputTensors.push_back( + { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) }); + } + + m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors); + + // Compare each output tensor to the expected values + for (auto&& it : expectedOutputData) + { + armnn::BindingPointInfo bindingInfo = ConvertBindingInfo( + m_Parser->GetNetworkOutputBindingInfo(layersId, it.first)); + auto outputExpected = it.second; + auto result = CompareTensors(outputExpected, outputStorage[it.first], + bindingInfo.second.GetShape(), bindingInfo.second.GetShape()); + CHECK_MESSAGE(result.m_Result, result.m_Message.str()); + } +} |