aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/Unsupported.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/Unsupported.cpp')
-rw-r--r--src/armnnTfLiteParser/test/Unsupported.cpp76
1 files changed, 43 insertions, 33 deletions
diff --git a/src/armnnTfLiteParser/test/Unsupported.cpp b/src/armnnTfLiteParser/test/Unsupported.cpp
index b405e1958c..8426246414 100644
--- a/src/armnnTfLiteParser/test/Unsupported.cpp
+++ b/src/armnnTfLiteParser/test/Unsupported.cpp
@@ -5,7 +5,7 @@
#include "ParserFlatbuffersFixture.hpp"
-#include <armnn/LayerVisitorBase.hpp>
+#include <armnn/StrategyBase.hpp>
#include <armnn/utility/Assert.hpp>
#include <armnn/utility/NumericCast.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
@@ -19,45 +19,55 @@ TEST_SUITE("TensorflowLiteParser_Unsupported")
{
using namespace armnn;
-class StandInLayerVerifier : public LayerVisitorBase<VisitorThrowingPolicy>
+class StandInLayerVerifier : public StrategyBase<NoThrowStrategy>
{
public:
StandInLayerVerifier(const std::vector<TensorInfo>& inputInfos,
const std::vector<TensorInfo>& outputInfos)
- : LayerVisitorBase<VisitorThrowingPolicy>()
- , m_InputInfos(inputInfos)
+ : m_InputInfos(inputInfos)
, m_OutputInfos(outputInfos) {}
- void VisitInputLayer(const IConnectableLayer*, LayerBindingId, const char*) override {}
-
- void VisitOutputLayer(const IConnectableLayer*, LayerBindingId, const char*) override {}
-
- void VisitStandInLayer(const IConnectableLayer* layer,
- const StandInDescriptor& descriptor,
- const char*) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- unsigned int numInputs = armnn::numeric_cast<unsigned int>(m_InputInfos.size());
- CHECK(descriptor.m_NumInputs == numInputs);
- CHECK(layer->GetNumInputSlots() == numInputs);
-
- unsigned int numOutputs = armnn::numeric_cast<unsigned int>(m_OutputInfos.size());
- CHECK(descriptor.m_NumOutputs == numOutputs);
- CHECK(layer->GetNumOutputSlots() == numOutputs);
-
- const StandInLayer* standInLayer = PolymorphicDowncast<const StandInLayer*>(layer);
- for (unsigned int i = 0u; i < numInputs; ++i)
- {
- const OutputSlot* connectedSlot = standInLayer->GetInputSlot(i).GetConnectedOutputSlot();
- CHECK(connectedSlot != nullptr);
-
- const TensorInfo& inputInfo = connectedSlot->GetTensorInfo();
- CHECK(inputInfo == m_InputInfos[i]);
- }
-
- for (unsigned int i = 0u; i < numOutputs; ++i)
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
{
- const TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
- CHECK(outputInfo == m_OutputInfos[i]);
+ case armnn::LayerType::StandIn:
+ {
+ auto standInDescriptor = static_cast<const armnn::StandInDescriptor&>(descriptor);
+ unsigned int numInputs = armnn::numeric_cast<unsigned int>(m_InputInfos.size());
+ CHECK(standInDescriptor.m_NumInputs == numInputs);
+ CHECK(layer->GetNumInputSlots() == numInputs);
+
+ unsigned int numOutputs = armnn::numeric_cast<unsigned int>(m_OutputInfos.size());
+ CHECK(standInDescriptor.m_NumOutputs == numOutputs);
+ CHECK(layer->GetNumOutputSlots() == numOutputs);
+
+ const StandInLayer* standInLayer = PolymorphicDowncast<const StandInLayer*>(layer);
+ for (unsigned int i = 0u; i < numInputs; ++i)
+ {
+ const OutputSlot* connectedSlot = standInLayer->GetInputSlot(i).GetConnectedOutputSlot();
+ CHECK(connectedSlot != nullptr);
+
+ const TensorInfo& inputInfo = connectedSlot->GetTensorInfo();
+ CHECK(inputInfo == m_InputInfos[i]);
+ }
+
+ for (unsigned int i = 0u; i < numOutputs; ++i)
+ {
+ const TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
+ CHECK(outputInfo == m_OutputInfos[i]);
+ }
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
}
}
@@ -164,7 +174,7 @@ public:
void RunTest()
{
INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary);
- network->Accept(m_StandInLayerVerifier);
+ network->ExecuteStrategy(m_StandInLayerVerifier);
}
private: