diff options
author | Kevin May <kevin.may@arm.com> | 2023-12-12 11:18:46 +0000 |
---|---|---|
committer | Kevin May <kevin.may@arm.com> | 2023-12-14 10:05:56 +0000 |
commit | 1bea6beb042635c7716ae43220ee19eedb2de9ff (patch) | |
tree | 638a60e9128e28efb613bee96e2de62f2c08b3fc /src/backends/tosaCommon/test/OneToManyMappingTests.cpp | |
parent | ce65588484ed1e553bdebf24123a30b5575f1bce (diff) | |
download | armnn-1bea6beb042635c7716ae43220ee19eedb2de9ff.tar.gz |
Add Split support to TOSA Reference Backend
* Resolves IVGCVSW-7918
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: Ic2afaa55f7ee88ce4c9b8ea696eef5f28663f8c6
Diffstat (limited to 'src/backends/tosaCommon/test/OneToManyMappingTests.cpp')
-rw-r--r-- | src/backends/tosaCommon/test/OneToManyMappingTests.cpp | 68 |
1 files changed, 67 insertions, 1 deletions
diff --git a/src/backends/tosaCommon/test/OneToManyMappingTests.cpp b/src/backends/tosaCommon/test/OneToManyMappingTests.cpp index b8d28f0405..94dd537a30 100644 --- a/src/backends/tosaCommon/test/OneToManyMappingTests.cpp +++ b/src/backends/tosaCommon/test/OneToManyMappingTests.cpp @@ -1,9 +1,10 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "AvgPool2DIgnoreValueChecker.hpp" +#include "SplitChecker.hpp" #include <armnn/IRuntime.hpp> using namespace armnn; @@ -82,4 +83,69 @@ TEST_CASE("GetTosaMappingFromLayer_AvgPool2DIgnoreValueLayer") intermediateShape, descriptor); } + +TEST_CASE("GetTosaMapping_SplitLayer") +{ + const unsigned int numViews = 3; + const unsigned int numDimensions = 4; + armnn::ViewsDescriptor descriptor(numViews, numDimensions); + descriptor.SetAxis(static_cast<int32_t>(1)); + + std::vector<std::vector<int32_t>> inShape = {{ 1, 18, 4, 4 }}; + std::vector<std::vector<int32_t>> outShape = {{ 1, 6, 4, 4 },{ 1, 6, 4, 4 },{ 1, 6, 4, 4 }}; + + armnn::TensorInfo inputTensorInfo({1, 18, 4, 4}, DataType::Float32); + armnn::TensorInfo outputTensorInfo({1, 6, 4, 4}, DataType::Float32); + + TosaSerializationBasicBlock* basicBlock = + GetTosaMapping(nullptr, LayerType::Splitter, {&inputTensorInfo}, {&outputTensorInfo}, descriptor); + + VerifySplit(basicBlock, + inShape, + outShape, + descriptor); +} + +TEST_CASE("GetTosaMappingFromLayer_SplitLayer") +{ + IRuntime::CreationOptions options; + IRuntimePtr runtime(IRuntime::Create(options)); + + // Builds up the structure of the network. + INetworkPtr net(INetwork::Create()); + + const unsigned int numViews = 3; + const unsigned int numDimensions = 4; + armnn::ViewsDescriptor descriptor(numViews, numDimensions); + descriptor.SetAxis(static_cast<int32_t>(1)); + + std::vector<std::vector<int32_t>> inShape = {{ 1, 18, 4, 4 }}; + std::vector<std::vector<int32_t>> outShape = {{ 1, 6, 4, 4 },{ 1, 6, 4, 4 },{ 1, 6, 4, 4 }}; + + IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); + IConnectableLayer* split = net->AddSplitterLayer(descriptor, "split"); + IConnectableLayer* output0 = net->AddOutputLayer(0, "output0"); + IConnectableLayer* output1 = net->AddOutputLayer(1, "output1"); + IConnectableLayer* output2 = net->AddOutputLayer(2, "output2"); + + input0->GetOutputSlot(0).Connect(split->GetInputSlot(0)); + split->GetOutputSlot(0).Connect(output0->GetInputSlot(0)); + split->GetOutputSlot(1).Connect(output1->GetInputSlot(0)); + split->GetOutputSlot(2).Connect(output2->GetInputSlot(0)); + + armnn::TensorInfo inputTensorInfo({1, 18, 4, 4}, DataType::Float32); + armnn::TensorInfo outputTensorInfo({1, 6, 4, 4}, DataType::Float32); + + input0->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + split->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + split->GetOutputSlot(1).SetTensorInfo(outputTensorInfo); + split->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); + + TosaSerializationBasicBlock* basicBlock = GetTosaMappingFromLayer(PolymorphicDowncast<Layer*>(split)); + + VerifySplit(basicBlock, + inShape, + outShape, + descriptor); +} }
\ No newline at end of file |