// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include #include #include using namespace armnn; TEST_SUITE("Optimizer") { using namespace armnn::optimizations; TEST_CASE("SquashEqualSiblingsTest") { armnn::Graph graph; armnn::LayerBindingId outputId = 0; const armnn::TensorInfo info({ 1, 2, 3, 5 }, armnn::DataType::Float32); const armnn::TensorInfo permuted({ 1, 5, 2, 3 }, armnn::DataType::Float32); auto input = graph.AddLayer(0, "input"); input->GetOutputSlot().SetTensorInfo(info); // Inserts equal permutes, equal reshapes and something else. const armnn::PermuteDescriptor permDesc({ 0, 2, 3, 1 }); const armnn::ReshapeDescriptor reshapeDesc{ { 1, 3, 1, 5 } }; armnn::Layer* layer; layer = graph.AddLayer(permDesc, ""); layer->GetOutputSlot().SetTensorInfo(permuted); layer->GetOutputSlot().Connect(graph.AddLayer(outputId++, "")->GetInputSlot(0)); input->GetOutputSlot().Connect(layer->GetInputSlot(0)); layer = graph.AddLayer(reshapeDesc, ""); layer->GetOutputSlot().Connect(graph.AddLayer(outputId++, "")->GetInputSlot(0)); input->GetOutputSlot().Connect(layer->GetInputSlot(0)); layer = graph.AddLayer(""); layer->GetOutputSlot().Connect(graph.AddLayer(outputId++, "")->GetInputSlot(0)); input->GetOutputSlot().Connect(layer->GetInputSlot(0)); layer = graph.AddLayer(reshapeDesc, ""); layer->GetOutputSlot().Connect(graph.AddLayer(outputId++, "")->GetInputSlot(0)); input->GetOutputSlot().Connect(layer->GetInputSlot(0)); layer = graph.AddLayer(permDesc, ""); layer->GetOutputSlot().SetTensorInfo(permuted); layer->GetOutputSlot().Connect(graph.AddLayer(outputId++, "")->GetInputSlot(0)); input->GetOutputSlot().Connect(layer->GetInputSlot(0)); CHECK(CheckSequence( graph.cbegin(), graph.cend(), &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType)); armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(SquashEqualPermuteSiblings(), SquashEqualReshapeSiblings())); // The permutes and reshapes are squashed. CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType, &IsLayerOfType)); } }