diff options
Diffstat (limited to 'src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp')
-rw-r--r-- | src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp | 38 |
1 files changed, 23 insertions, 15 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp index b372a604f3..8d0ee01aa9 100644 --- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp +++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp @@ -14,6 +14,7 @@ #include <armnn/TypesUtils.hpp> #include "test/TensorHelpers.hpp" +#include "TypeUtils.hpp" #include "armnnTfLiteParser/ITfLiteParser.hpp" #include <backendsCommon/BackendRegistry.hpp> @@ -116,14 +117,18 @@ struct ParserFlatbuffersFixture /// Executes the network with the given input tensor and checks the result against the given output tensor. /// This overload assumes the network has a single input and a single output. - template <std::size_t NumOutputDimensions, typename DataType> + template <std::size_t NumOutputDimensions, + armnn::DataType ArmnnType, + typename DataType = armnn::ResolveType<ArmnnType>> void RunTest(size_t subgraphId, - const std::vector<DataType>& inputData, - const std::vector<DataType>& expectedOutputData); + const std::vector<DataType>& inputData, + const std::vector<DataType>& expectedOutputData); /// Executes the network with the given input tensors and checks the results against the given output tensors. /// This overload supports multiple inputs and multiple outputs, identified by name. - template <std::size_t NumOutputDimensions, typename DataType> + template <std::size_t NumOutputDimensions, + armnn::DataType ArmnnType, + typename DataType = armnn::ResolveType<ArmnnType>> void RunTest(size_t subgraphId, const std::map<std::string, std::vector<DataType>>& inputData, const std::map<std::string, std::vector<DataType>>& expectedOutputData); @@ -152,21 +157,24 @@ struct ParserFlatbuffersFixture } }; -template <std::size_t NumOutputDimensions, typename DataType> +template <std::size_t NumOutputDimensions, + armnn::DataType ArmnnType, + typename DataType> void ParserFlatbuffersFixture::RunTest(size_t subgraphId, const std::vector<DataType>& inputData, const std::vector<DataType>& expectedOutputData) { - RunTest<NumOutputDimensions, DataType>(subgraphId, - { { m_SingleInputName, inputData } }, - { { m_SingleOutputName, expectedOutputData } }); + RunTest<NumOutputDimensions, ArmnnType>(subgraphId, + { { m_SingleInputName, inputData } }, + { { m_SingleOutputName, expectedOutputData } }); } -template <std::size_t NumOutputDimensions, typename DataType> -void -ParserFlatbuffersFixture::RunTest(size_t subgraphId, - const std::map<std::string, std::vector<DataType>>& inputData, - const std::map<std::string, std::vector<DataType>>& expectedOutputData) +template <std::size_t NumOutputDimensions, + armnn::DataType ArmnnType, + typename DataType> +void ParserFlatbuffersFixture::RunTest(size_t subgraphId, + const std::map<std::string, std::vector<DataType>>& inputData, + const std::map<std::string, std::vector<DataType>>& expectedOutputData) { using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>; @@ -175,7 +183,7 @@ ParserFlatbuffersFixture::RunTest(size_t subgraphId, for (auto&& it : inputData) { BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(subgraphId, it.first); - armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second); + armnn::VerifyTensorInfoDataType<ArmnnType>(bindingInfo.second); inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) }); } @@ -185,7 +193,7 @@ ParserFlatbuffersFixture::RunTest(size_t subgraphId, for (auto&& it : expectedOutputData) { BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first); - armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second); + armnn::VerifyTensorInfoDataType<ArmnnType>(bindingInfo.second); outputStorage.emplace(it.first, MakeTensor<DataType, NumOutputDimensions>(bindingInfo.second)); outputTensors.push_back( { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) }); |