aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/test/OneToManyMappingTests.cpp
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2023-12-12 11:18:46 +0000
committerKevin May <kevin.may@arm.com>2023-12-14 10:05:56 +0000
commit1bea6beb042635c7716ae43220ee19eedb2de9ff (patch)
tree638a60e9128e28efb613bee96e2de62f2c08b3fc /src/backends/tosaCommon/test/OneToManyMappingTests.cpp
parentce65588484ed1e553bdebf24123a30b5575f1bce (diff)
downloadarmnn-1bea6beb042635c7716ae43220ee19eedb2de9ff.tar.gz
Add Split support to TOSA Reference Backend
* Resolves IVGCVSW-7918 Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: Ic2afaa55f7ee88ce4c9b8ea696eef5f28663f8c6
Diffstat (limited to 'src/backends/tosaCommon/test/OneToManyMappingTests.cpp')
-rw-r--r--src/backends/tosaCommon/test/OneToManyMappingTests.cpp68
1 files changed, 67 insertions, 1 deletions
diff --git a/src/backends/tosaCommon/test/OneToManyMappingTests.cpp b/src/backends/tosaCommon/test/OneToManyMappingTests.cpp
index b8d28f0405..94dd537a30 100644
--- a/src/backends/tosaCommon/test/OneToManyMappingTests.cpp
+++ b/src/backends/tosaCommon/test/OneToManyMappingTests.cpp
@@ -1,9 +1,10 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "AvgPool2DIgnoreValueChecker.hpp"
+#include "SplitChecker.hpp"
#include <armnn/IRuntime.hpp>
using namespace armnn;
@@ -82,4 +83,69 @@ TEST_CASE("GetTosaMappingFromLayer_AvgPool2DIgnoreValueLayer")
intermediateShape,
descriptor);
}
+
+TEST_CASE("GetTosaMapping_SplitLayer")
+{
+ const unsigned int numViews = 3;
+ const unsigned int numDimensions = 4;
+ armnn::ViewsDescriptor descriptor(numViews, numDimensions);
+ descriptor.SetAxis(static_cast<int32_t>(1));
+
+ std::vector<std::vector<int32_t>> inShape = {{ 1, 18, 4, 4 }};
+ std::vector<std::vector<int32_t>> outShape = {{ 1, 6, 4, 4 },{ 1, 6, 4, 4 },{ 1, 6, 4, 4 }};
+
+ armnn::TensorInfo inputTensorInfo({1, 18, 4, 4}, DataType::Float32);
+ armnn::TensorInfo outputTensorInfo({1, 6, 4, 4}, DataType::Float32);
+
+ TosaSerializationBasicBlock* basicBlock =
+ GetTosaMapping(nullptr, LayerType::Splitter, {&inputTensorInfo}, {&outputTensorInfo}, descriptor);
+
+ VerifySplit(basicBlock,
+ inShape,
+ outShape,
+ descriptor);
+}
+
+TEST_CASE("GetTosaMappingFromLayer_SplitLayer")
+{
+ IRuntime::CreationOptions options;
+ IRuntimePtr runtime(IRuntime::Create(options));
+
+ // Builds up the structure of the network.
+ INetworkPtr net(INetwork::Create());
+
+ const unsigned int numViews = 3;
+ const unsigned int numDimensions = 4;
+ armnn::ViewsDescriptor descriptor(numViews, numDimensions);
+ descriptor.SetAxis(static_cast<int32_t>(1));
+
+ std::vector<std::vector<int32_t>> inShape = {{ 1, 18, 4, 4 }};
+ std::vector<std::vector<int32_t>> outShape = {{ 1, 6, 4, 4 },{ 1, 6, 4, 4 },{ 1, 6, 4, 4 }};
+
+ IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
+ IConnectableLayer* split = net->AddSplitterLayer(descriptor, "split");
+ IConnectableLayer* output0 = net->AddOutputLayer(0, "output0");
+ IConnectableLayer* output1 = net->AddOutputLayer(1, "output1");
+ IConnectableLayer* output2 = net->AddOutputLayer(2, "output2");
+
+ input0->GetOutputSlot(0).Connect(split->GetInputSlot(0));
+ split->GetOutputSlot(0).Connect(output0->GetInputSlot(0));
+ split->GetOutputSlot(1).Connect(output1->GetInputSlot(0));
+ split->GetOutputSlot(2).Connect(output2->GetInputSlot(0));
+
+ armnn::TensorInfo inputTensorInfo({1, 18, 4, 4}, DataType::Float32);
+ armnn::TensorInfo outputTensorInfo({1, 6, 4, 4}, DataType::Float32);
+
+ input0->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+ split->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ split->GetOutputSlot(1).SetTensorInfo(outputTensorInfo);
+ split->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
+
+ TosaSerializationBasicBlock* basicBlock = GetTosaMappingFromLayer(PolymorphicDowncast<Layer*>(split));
+
+ VerifySplit(basicBlock,
+ inShape,
+ outShape,
+ descriptor);
+}
} \ No newline at end of file