aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp')
-rw-r--r--src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp122
1 files changed, 54 insertions, 68 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
index 6247fc3153..676dc7120d 100644
--- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
+++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
@@ -28,22 +28,17 @@ using TensorRawPtr = const tflite::TensorT *;
struct ParserFlatbuffersFixture
{
- ParserFlatbuffersFixture()
- : m_Parser(ITfLiteParser::Create()), m_NetworkIdentifier(-1)
+ ParserFlatbuffersFixture() :
+ m_Parser(ITfLiteParser::Create()),
+ m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions())),
+ m_NetworkIdentifier(-1)
{
- armnn::IRuntime::CreationOptions options;
-
- const armnn::BackendIdSet availableBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
- for (auto& backendId : availableBackendIds)
- {
- m_Runtimes.push_back(std::make_pair(armnn::IRuntime::Create(options), backendId));
- }
}
std::vector<uint8_t> m_GraphBinary;
std::string m_JsonString;
std::unique_ptr<ITfLiteParser, void (*)(ITfLiteParser *parser)> m_Parser;
- std::vector<std::pair<armnn::IRuntimePtr, armnn::BackendId>> m_Runtimes;
+ armnn::IRuntimePtr m_Runtime;
armnn::NetworkId m_NetworkIdentifier;
/// If the single-input-single-output overload of Setup() is called, these will store the input and output name
@@ -58,35 +53,29 @@ struct ParserFlatbuffersFixture
throw armnn::Exception("LoadNetwork failed while reading binary input");
}
- for (auto&& runtime : m_Runtimes)
+ armnn::INetworkPtr network =
+ m_Parser->CreateNetworkFromBinary(m_GraphBinary);
+
+ if (!network) {
+ throw armnn::Exception("The parser failed to create an ArmNN network");
+ }
+
+ auto optimized = Optimize(*network, { armnn::Compute::CpuRef },
+ m_Runtime->GetDeviceSpec());
+ std::string errorMessage;
+
+ armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
+
+ if (ret != armnn::Status::Success)
{
- armnn::INetworkPtr network =
- m_Parser->CreateNetworkFromBinary(m_GraphBinary);
-
- if (!network) {
- throw armnn::Exception("The parser failed to create an ArmNN network");
- }
-
- auto optimized = Optimize(*network,
- { runtime.second, armnn::Compute::CpuRef },
- runtime.first->GetDeviceSpec());
- std::string errorMessage;
-
- armnn::Status ret = runtime.first->LoadNetwork(m_NetworkIdentifier,
- move(optimized),
- errorMessage);
-
- if (ret != armnn::Status::Success)
- {
- throw armnn::Exception(
- boost::str(
- boost::format("The runtime failed to load the network. "
- "Error was: %1%. in %2% [%3%:%4%]") %
- errorMessage %
- __func__ %
- __FILE__ %
- __LINE__));
- }
+ throw armnn::Exception(
+ boost::str(
+ boost::format("The runtime failed to load the network. "
+ "Error was: %1%. in %2% [%3%:%4%]") %
+ errorMessage %
+ __func__ %
+ __FILE__ %
+ __LINE__));
}
}
@@ -190,39 +179,36 @@ ParserFlatbuffersFixture::RunTest(size_t subgraphId,
const std::map<std::string, std::vector<DataType>>& inputData,
const std::map<std::string, std::vector<DataType>>& expectedOutputData)
{
- for (auto&& runtime : m_Runtimes)
- {
- using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
+ using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
- // Setup the armnn input tensors from the given vectors.
- armnn::InputTensors inputTensors;
- for (auto&& it : inputData)
- {
- BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(subgraphId, it.first);
- armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second);
- inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
- }
+ // Setup the armnn input tensors from the given vectors.
+ armnn::InputTensors inputTensors;
+ for (auto&& it : inputData)
+ {
+ BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(subgraphId, it.first);
+ armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second);
+ inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
+ }
- // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
- std::map<std::string, boost::multi_array<DataType, NumOutputDimensions>> outputStorage;
- armnn::OutputTensors outputTensors;
- for (auto&& it : expectedOutputData)
- {
- BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
- armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second);
- outputStorage.emplace(it.first, MakeTensor<DataType, NumOutputDimensions>(bindingInfo.second));
- outputTensors.push_back(
- { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
- }
+ // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
+ std::map<std::string, boost::multi_array<DataType, NumOutputDimensions>> outputStorage;
+ armnn::OutputTensors outputTensors;
+ for (auto&& it : expectedOutputData)
+ {
+ BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
+ armnn::VerifyTensorInfoDataType<DataType>(bindingInfo.second);
+ outputStorage.emplace(it.first, MakeTensor<DataType, NumOutputDimensions>(bindingInfo.second));
+ outputTensors.push_back(
+ { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
+ }
- runtime.first->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
+ m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
- // Compare each output tensor to the expected values
- for (auto&& it : expectedOutputData)
- {
- BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
- auto outputExpected = MakeTensor<DataType, NumOutputDimensions>(bindingInfo.second, it.second);
- BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
- }
+ // Compare each output tensor to the expected values
+ for (auto&& it : expectedOutputData)
+ {
+ BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
+ auto outputExpected = MakeTensor<DataType, NumOutputDimensions>(bindingInfo.second, it.second);
+ BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
}
}