aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp')
-rw-r--r--src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp99
1 files changed, 87 insertions, 12 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
index c4c75594a3..b4653cd8db 100644
--- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
+++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -13,7 +13,7 @@
#include <armnn/BackendRegistry.hpp>
#include <armnn/utility/Assert.hpp>
-#include <armnnTfLiteParser/ITfLiteParser.hpp>
+#include "../TfLiteParser.hpp"
#include <ResolveType.hpp>
@@ -28,7 +28,6 @@
#include <schema_generated.h>
-#include <iostream>
using armnnTfLiteParser::ITfLiteParser;
using armnnTfLiteParser::ITfLiteParserPtr;
@@ -37,37 +36,97 @@ using TensorRawPtr = const tflite::TensorT *;
struct ParserFlatbuffersFixture
{
ParserFlatbuffersFixture() :
- m_Parser(nullptr, &ITfLiteParser::Destroy),
- m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions())),
- m_NetworkIdentifier(-1)
+ m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions())),
+ m_NetworkIdentifier(0),
+ m_DynamicNetworkIdentifier(1)
{
ITfLiteParser::TfLiteParserOptions options;
options.m_StandInLayerForUnsupported = true;
options.m_InferAndValidate = true;
- m_Parser.reset(ITfLiteParser::CreateRaw(armnn::Optional<ITfLiteParser::TfLiteParserOptions>(options)));
+ m_Parser = std::make_unique<armnnTfLiteParser::TfLiteParserImpl>(
+ armnn::Optional<ITfLiteParser::TfLiteParserOptions>(options));
}
std::vector<uint8_t> m_GraphBinary;
std::string m_JsonString;
- ITfLiteParserPtr m_Parser;
armnn::IRuntimePtr m_Runtime;
armnn::NetworkId m_NetworkIdentifier;
+ armnn::NetworkId m_DynamicNetworkIdentifier;
+ bool m_TestDynamic;
+ std::unique_ptr<armnnTfLiteParser::TfLiteParserImpl> m_Parser;
/// If the single-input-single-output overload of Setup() is called, these will store the input and output name
/// so they don't need to be passed to the single-input-single-output overload of RunTest().
std::string m_SingleInputName;
std::string m_SingleOutputName;
- void Setup()
+ void Setup(bool testDynamic = true)
+ {
+ m_TestDynamic = testDynamic;
+ loadNetwork(m_NetworkIdentifier, false);
+
+ if (m_TestDynamic)
+ {
+ loadNetwork(m_DynamicNetworkIdentifier, true);
+ }
+ }
+
+ std::unique_ptr<tflite::ModelT> MakeModelDynamic(std::vector<uint8_t> graphBinary)
+ {
+ const uint8_t* binaryContent = graphBinary.data();
+ const size_t len = graphBinary.size();
+ if (binaryContent == nullptr)
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
+ CHECK_LOCATION().AsString()));
+ }
+ flatbuffers::Verifier verifier(binaryContent, len);
+ if (verifier.VerifyBuffer<tflite::Model>() == false)
+ {
+ throw armnn::ParseException(fmt::format("Buffer doesn't conform to the expected Tensorflow Lite "
+ "flatbuffers format. size:{} {}",
+ len,
+ CHECK_LOCATION().AsString()));
+ }
+ auto model = tflite::UnPackModel(binaryContent);
+
+ for (auto const& subgraph : model->subgraphs)
+ {
+ std::vector<int32_t> inputIds = subgraph->inputs;
+ for (unsigned int tensorIndex = 0; tensorIndex < subgraph->tensors.size(); ++tensorIndex)
+ {
+ if (std::find(inputIds.begin(), inputIds.end(), tensorIndex) != inputIds.end())
+ {
+ continue;
+ }
+ for (auto const& tensor : subgraph->tensors)
+ {
+ if (tensor->shape_signature.size() != 0)
+ {
+ continue;
+ }
+
+ for (unsigned int i = 0; i < tensor->shape.size(); ++i)
+ {
+ tensor->shape_signature.push_back(-1);
+ }
+ }
+ }
+ }
+
+ return model;
+ }
+
+ void loadNetwork(armnn::NetworkId networkId, bool loadDynamic)
{
bool ok = ReadStringToBinary();
if (!ok) {
throw armnn::Exception("LoadNetwork failed while reading binary input");
}
- armnn::INetworkPtr network =
- m_Parser->CreateNetworkFromBinary(m_GraphBinary);
+ armnn::INetworkPtr network = loadDynamic ? m_Parser->LoadModel(MakeModelDynamic(m_GraphBinary))
+ : m_Parser->CreateNetworkFromBinary(m_GraphBinary);
if (!network) {
throw armnn::Exception("The parser failed to create an ArmNN network");
@@ -77,7 +136,7 @@ struct ParserFlatbuffersFixture
m_Runtime->GetDeviceSpec());
std::string errorMessage;
- armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
+ armnn::Status ret = m_Runtime->LoadNetwork(networkId, move(optimized), errorMessage);
if (ret != armnn::Status::Success)
{
@@ -337,6 +396,22 @@ void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
CHECK_MESSAGE(result.m_Result, result.m_Message.str());
}
}
+
+ if (isDynamic)
+ {
+ m_Runtime->EnqueueWorkload(m_DynamicNetworkIdentifier, inputTensors, outputTensors);
+
+ // Compare each output tensor to the expected values
+ for (auto&& it : expectedOutputData)
+ {
+ armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
+ auto outputExpected = it.second;
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
+ false, isDynamic);
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ }
+ }
}
/// Multiple Inputs, Multiple Outputs w/ Variable Datatypes and different dimension sizes.