aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp')
-rw-r--r--src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp b/src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp
new file mode 100644
index 0000000000..1fcba0e581
--- /dev/null
+++ b/src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "LayersFwd.hpp"
+#include <Network.hpp>
+#include <doctest/doctest.h>
+#include <Optimizer.hpp>
+#include <TestUtils.hpp>
+
+TEST_SUITE("Optimizer")
+{
+using namespace armnn;
+using namespace armnn::optimizations;
+
+TEST_CASE("ConvertConstPermuteToConst")
+{
+ Graph graph;
+ const unsigned int shape[] = {1, 2, 2, 3};
+
+ const TensorInfo constTensorInfo(4, shape, DataType::Float32, 1.0, 0, true);
+
+ ConstantLayer* constant = graph.AddLayer<ConstantLayer>("constant");
+ std::vector<float> constantValues(constTensorInfo.GetNumElements(), 4.5f);
+ ConstTensor constTensor(constTensorInfo, constantValues.data());
+ constant->m_LayerOutput = std::make_shared<ScopedTensorHandle>(constTensor);
+ constant->GetOutputSlot().SetTensorInfo(constTensorInfo);
+
+ PermuteDescriptor desc({ 0, 2, 3, 1 });
+ PermuteLayer* permuteLayer = graph.AddLayer<PermuteLayer>(desc, "permute");
+ TensorInfo infoPermuted = armnnUtils::Permuted(constTensorInfo, { 0, 2, 3, 1 });
+ permuteLayer->GetOutputSlot().SetTensorInfo(infoPermuted);
+
+ OutputLayer* output = graph.AddLayer<OutputLayer>(0, "output");
+
+ // Connect up constant -> permute -> output
+ constant->GetOutputSlot().Connect(permuteLayer->GetInputSlot(0));
+ permuteLayer->GetOutputSlot().Connect(output->GetInputSlot(0));
+
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ &IsLayerOfType<ConstantLayer>,
+ &IsLayerOfType<PermuteLayer>,
+ &IsLayerOfType<OutputLayer>));
+
+ armnn::Optimizer::Pass(graph, MakeOptimizations(FusePermuteIntoConstLayer()));
+
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ &IsLayerOfType<ConstantLayer>,
+ &IsLayerOfType<OutputLayer>));
+
+ TensorShape tensorShape = constant->GetOutputSlot(0).GetTensorInfo().GetShape();
+ CHECK(tensorShape[0] == shape[0]);
+ CHECK(tensorShape[1] == shape[3]);
+ CHECK(tensorShape[2] == shape[1]);
+ CHECK(tensorShape[3] == shape[2]);
+
+}
+
+}