aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/NetworkTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/NetworkTests.cpp')
-rw-r--r--src/armnn/test/NetworkTests.cpp118
1 files changed, 77 insertions, 41 deletions
diff --git a/src/armnn/test/NetworkTests.cpp b/src/armnn/test/NetworkTests.cpp
index 9acb60df4a..25dab596fd 100644
--- a/src/armnn/test/NetworkTests.cpp
+++ b/src/armnn/test/NetworkTests.cpp
@@ -398,26 +398,44 @@ TEST_CASE("NetworkModification_SplitterMultiplication")
TEST_CASE("Network_AddQuantize")
{
- struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+ struct Test : public armnn::IStrategy
{
- void VisitQuantizeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- m_Visited = true;
-
- CHECK(layer);
-
- std::string expectedName = std::string("quantize");
- CHECK(std::string(layer->GetName()) == expectedName);
- CHECK(std::string(name) == expectedName);
-
- CHECK(layer->GetNumInputSlots() == 1);
- CHECK(layer->GetNumOutputSlots() == 1);
-
- const armnn::TensorInfo& infoIn = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
- CHECK((infoIn.GetDataType() == armnn::DataType::Float32));
-
- const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
- CHECK((infoOut.GetDataType() == armnn::DataType::QAsymmU8));
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input: break;
+ case armnn::LayerType::Output: break;
+ case armnn::LayerType::Quantize:
+ {
+ m_Visited = true;
+
+ CHECK(layer);
+
+ std::string expectedName = std::string("quantize");
+ CHECK(std::string(layer->GetName()) == expectedName);
+ CHECK(std::string(name) == expectedName);
+
+ CHECK(layer->GetNumInputSlots() == 1);
+ CHECK(layer->GetNumOutputSlots() == 1);
+
+ const armnn::TensorInfo& infoIn = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ CHECK((infoIn.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
+ CHECK((infoOut.GetDataType() == armnn::DataType::QAsymmU8));
+ break;
+ }
+ default:
+ {
+ // nothing
+ }
+ }
}
bool m_Visited = false;
@@ -440,7 +458,7 @@ TEST_CASE("Network_AddQuantize")
quantize->GetOutputSlot(0).SetTensorInfo(infoOut);
Test testQuantize;
- graph->Accept(testQuantize);
+ graph->ExecuteStrategy(testQuantize);
CHECK(testQuantize.m_Visited == true);
@@ -448,29 +466,47 @@ TEST_CASE("Network_AddQuantize")
TEST_CASE("Network_AddMerge")
{
- struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+ struct Test : public armnn::IStrategy
{
- void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- m_Visited = true;
-
- CHECK(layer);
-
- std::string expectedName = std::string("merge");
- CHECK(std::string(layer->GetName()) == expectedName);
- CHECK(std::string(name) == expectedName);
-
- CHECK(layer->GetNumInputSlots() == 2);
- CHECK(layer->GetNumOutputSlots() == 1);
-
- const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
- CHECK((infoIn0.GetDataType() == armnn::DataType::Float32));
-
- const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
- CHECK((infoIn1.GetDataType() == armnn::DataType::Float32));
-
- const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
- CHECK((infoOut.GetDataType() == armnn::DataType::Float32));
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input: break;
+ case armnn::LayerType::Output: break;
+ case armnn::LayerType::Merge:
+ {
+ m_Visited = true;
+
+ CHECK(layer);
+
+ std::string expectedName = std::string("merge");
+ CHECK(std::string(layer->GetName()) == expectedName);
+ CHECK(std::string(name) == expectedName);
+
+ CHECK(layer->GetNumInputSlots() == 2);
+ CHECK(layer->GetNumOutputSlots() == 1);
+
+ const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ CHECK((infoIn0.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
+ CHECK((infoIn1.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
+ CHECK((infoOut.GetDataType() == armnn::DataType::Float32));
+ break;
+ }
+ default:
+ {
+ // nothing
+ }
+ }
}
bool m_Visited = false;
@@ -493,7 +529,7 @@ TEST_CASE("Network_AddMerge")
merge->GetOutputSlot(0).SetTensorInfo(info);
Test testMerge;
- network->Accept(testMerge);
+ network->ExecuteStrategy(testMerge);
CHECK(testMerge.m_Visited == true);
}