diff options
Diffstat (limited to 'src/armnn/test/NetworkTests.cpp')
-rw-r--r-- | src/armnn/test/NetworkTests.cpp | 118 |
1 files changed, 77 insertions, 41 deletions
diff --git a/src/armnn/test/NetworkTests.cpp b/src/armnn/test/NetworkTests.cpp index 9acb60df4a..25dab596fd 100644 --- a/src/armnn/test/NetworkTests.cpp +++ b/src/armnn/test/NetworkTests.cpp @@ -398,26 +398,44 @@ TEST_CASE("NetworkModification_SplitterMultiplication") TEST_CASE("Network_AddQuantize") { - struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + struct Test : public armnn::IStrategy { - void VisitQuantizeLayer(const armnn::IConnectableLayer* layer, const char* name) override + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override { - m_Visited = true; - - CHECK(layer); - - std::string expectedName = std::string("quantize"); - CHECK(std::string(layer->GetName()) == expectedName); - CHECK(std::string(name) == expectedName); - - CHECK(layer->GetNumInputSlots() == 1); - CHECK(layer->GetNumOutputSlots() == 1); - - const armnn::TensorInfo& infoIn = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); - CHECK((infoIn.GetDataType() == armnn::DataType::Float32)); - - const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo(); - CHECK((infoOut.GetDataType() == armnn::DataType::QAsymmU8)); + armnn::IgnoreUnused(descriptor, constants, id); + switch (layer->GetType()) + { + case armnn::LayerType::Input: break; + case armnn::LayerType::Output: break; + case armnn::LayerType::Quantize: + { + m_Visited = true; + + CHECK(layer); + + std::string expectedName = std::string("quantize"); + CHECK(std::string(layer->GetName()) == expectedName); + CHECK(std::string(name) == expectedName); + + CHECK(layer->GetNumInputSlots() == 1); + CHECK(layer->GetNumOutputSlots() == 1); + + const armnn::TensorInfo& infoIn = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); + CHECK((infoIn.GetDataType() == armnn::DataType::Float32)); + + const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo(); + CHECK((infoOut.GetDataType() == armnn::DataType::QAsymmU8)); + break; + } + default: + { + // nothing + } + } } bool m_Visited = false; @@ -440,7 +458,7 @@ TEST_CASE("Network_AddQuantize") quantize->GetOutputSlot(0).SetTensorInfo(infoOut); Test testQuantize; - graph->Accept(testQuantize); + graph->ExecuteStrategy(testQuantize); CHECK(testQuantize.m_Visited == true); @@ -448,29 +466,47 @@ TEST_CASE("Network_AddQuantize") TEST_CASE("Network_AddMerge") { - struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + struct Test : public armnn::IStrategy { - void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override { - m_Visited = true; - - CHECK(layer); - - std::string expectedName = std::string("merge"); - CHECK(std::string(layer->GetName()) == expectedName); - CHECK(std::string(name) == expectedName); - - CHECK(layer->GetNumInputSlots() == 2); - CHECK(layer->GetNumOutputSlots() == 1); - - const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); - CHECK((infoIn0.GetDataType() == armnn::DataType::Float32)); - - const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo(); - CHECK((infoIn1.GetDataType() == armnn::DataType::Float32)); - - const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo(); - CHECK((infoOut.GetDataType() == armnn::DataType::Float32)); + armnn::IgnoreUnused(descriptor, constants, id); + switch (layer->GetType()) + { + case armnn::LayerType::Input: break; + case armnn::LayerType::Output: break; + case armnn::LayerType::Merge: + { + m_Visited = true; + + CHECK(layer); + + std::string expectedName = std::string("merge"); + CHECK(std::string(layer->GetName()) == expectedName); + CHECK(std::string(name) == expectedName); + + CHECK(layer->GetNumInputSlots() == 2); + CHECK(layer->GetNumOutputSlots() == 1); + + const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); + CHECK((infoIn0.GetDataType() == armnn::DataType::Float32)); + + const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo(); + CHECK((infoIn1.GetDataType() == armnn::DataType::Float32)); + + const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo(); + CHECK((infoOut.GetDataType() == armnn::DataType::Float32)); + break; + } + default: + { + // nothing + } + } } bool m_Visited = false; @@ -493,7 +529,7 @@ TEST_CASE("Network_AddMerge") merge->GetOutputSlot(0).SetTensorInfo(info); Test testMerge; - network->Accept(testMerge); + network->ExecuteStrategy(testMerge); CHECK(testMerge.m_Visited == true); } |