aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-22 16:10:44 +0000
committerNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-22 16:10:44 +0000
commit649dd9515ddf4bd00a0bff64d51dfd835a6c7b39 (patch)
treec938bc8eb11dd24223c0cb00a57d4372a907b943 /src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
parent382e21ce95c04479a6900afca81a57949b369f1e (diff)
downloadarmnn-649dd9515ddf4bd00a0bff64d51dfd835a6c7b39.tar.gz
IVGCVSW-2467 Remove GetDataType<T> function
Change-Id: I7359617a307b9abb4c30b3d5f2364dc6d0f828f0
Diffstat (limited to 'src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp')
-rw-r--r--src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp38
1 files changed, 23 insertions, 15 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
index b372a604f3..8d0ee01aa9 100644
--- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
+++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
@@ -14,6 +14,7 @@
#include <armnn/TypesUtils.hpp>
#include "test/TensorHelpers.hpp"
+#include "TypeUtils.hpp"
#include "armnnTfLiteParser/ITfLiteParser.hpp"
#include <backendsCommon/BackendRegistry.hpp>
@@ -116,14 +117,18 @@ struct ParserFlatbuffersFixture
/// Executes the network with the given input tensor and checks the result against the given output tensor.
/// This overload assumes the network has a single input and a single output.
- template <std::size_t NumOutputDimensions, typename DataType>
+ template <std::size_t NumOutputDimensions,
+ armnn::DataType ArmnnType,
+ typename DataType = armnn::ResolveType<ArmnnType>>
void RunTest(size_t subgraphId,
- const std::vector<DataType>& inputData,
- const std::vector<DataType>& expectedOutputData);
+ const std::vector<DataType>& inputData,
+ const std::vector<DataType>& expectedOutputData);
/// Executes the network with the given input tensors and checks the results against the given output tensors.
/// This overload supports multiple inputs and multiple outputs, identified by name.
- template <std::size_t NumOutputDimensions, typename DataType>
+ template <std::size_t NumOutputDimensions,
+ armnn::DataType ArmnnType,
+ typename DataType = armnn::ResolveType<ArmnnType>>
void RunTest(size_t subgraphId,
const std::map<std::string, std::vector<DataType>>& inputData,
const std::map<std::string, std::vector<DataType>>& expectedOutputData);
@@ -152,21 +157,24 @@ struct ParserFlatbuffersFixture
}
};
-template <std::size_t NumOutputDimensions, typename DataType>
+template <std::size_t NumOutputDimensions,
+ armnn::DataType ArmnnType,
+ typename DataType>
void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
const std::vector<DataType>& inputData,
const std::vector<DataType>& expectedOutputData)
{
- RunTest<NumOutputDimensions, DataType>(subgraphId,
- { { m_SingleInputName, inputData } },
- { { m_SingleOutputName, expectedOutputData } });
+ RunTest<NumOutputDimensions, ArmnnType>(subgraphId,
+ { { m_SingleInputName, inputData } },
+ { { m_SingleOutputName, expectedOutputData } });
}
-template <std::size_t NumOutputDimensions, typename DataType>
-void
-ParserFlatbuffersFixture::RunTest(size_t subgraphId,
- const std::map<std::string, std::vector<DataType>>& inputData,
- const std::map<std::string, std::vector<DataType>>& expectedOutputData)
+template <std::size_t NumOutputDimensions,
+ armnn::DataType ArmnnType,
+ typename DataType>
+void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
+ const std::map<std::string, std::vector<DataType>>& inputData,
+ const std::map<std::string, std::vector<DataType>>& expectedOutputData)
{
using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
@@ -175,7 +183,7 @@ ParserFlatbuffersFixture::RunTest(size_t subgraphId,
for (auto&& it : inputData)
{
BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(subgraphId, it.first);
- armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second);
+ armnn::VerifyTensorInfoDataType<ArmnnType>(bindingInfo.second);
inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
}
@@ -185,7 +193,7 @@ ParserFlatbuffersFixture::RunTest(size_t subgraphId,
for (auto&& it : expectedOutputData)
{
BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
- armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second);
+ armnn::VerifyTensorInfoDataType<ArmnnType>(bindingInfo.second);
outputStorage.emplace(it.first, MakeTensor<DataType, NumOutputDimensions>(bindingInfo.second));
outputTensors.push_back(
{ bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });