aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-07-07 15:43:06 +0100
committerMike Kelly <mike.kelly@arm.com>2023-07-14 00:00:53 +0100
commit4cc341cf8b5a6e6bb0543504cbbfde6fa11a2cdb (patch)
tree7cac128e9ec6f2fd27f1afdb55f44b870f39e0b3 /src/backends
parent6963b33221c23af4a8eff19ff4a5773230b0befd (diff)
downloadarmnn-4cc341cf8b5a6e6bb0543504cbbfde6fa11a2cdb.tar.gz
IVGCVSW-7830 Add backend optimizations to remove Reshapes where possible
* Added optimization to remove reshapes for Neon and Ref Backends by using overridden TensorInfos * Added ability to delete Subgraphs during Optimization * Fixed naming error in NeonEndToEndTests and CLEndToEndTests * Added LayerNameAndTypeCheck for testing. * Fixed error where layers were not marked as altered when removed in CLBackend Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I1ac25cd4ec9821470d961831ae2c8d24882276cc
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/backendsCommon/OptimizationViews.cpp8
-rw-r--r--src/backends/backendsCommon/SubgraphUtils.hpp186
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp6
-rw-r--r--src/backends/backendsCommon/test/CMakeLists.txt1
-rw-r--r--src/backends/backendsCommon/test/SubgraphUtilsTest.hpp399
-rw-r--r--src/backends/cl/CMakeLists.txt1
-rw-r--r--src/backends/cl/ClBackend.cpp6
-rw-r--r--src/backends/cl/ClTensorHandle.cpp82
-rw-r--r--src/backends/cl/ClTensorHandle.hpp184
-rw-r--r--src/backends/cl/backend.mk1
-rw-r--r--src/backends/cl/test/ClEndToEndTests.cpp7
-rw-r--r--src/backends/neon/CMakeLists.txt1
-rw-r--r--src/backends/neon/NeonBackend.cpp32
-rw-r--r--src/backends/neon/NeonTensorHandle.cpp47
-rw-r--r--src/backends/neon/NeonTensorHandle.hpp168
-rw-r--r--src/backends/neon/backend.mk1
-rw-r--r--src/backends/neon/test/NeonEndToEndTests.cpp22
-rw-r--r--src/backends/reference/RefBackend.cpp13
-rw-r--r--src/backends/reference/RefTensorHandle.cpp69
-rw-r--r--src/backends/reference/RefTensorHandle.hpp88
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp48
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp17
22 files changed, 1349 insertions, 38 deletions
diff --git a/src/backends/backendsCommon/OptimizationViews.cpp b/src/backends/backendsCommon/OptimizationViews.cpp
index d887cbc93c..a358f88520 100644
--- a/src/backends/backendsCommon/OptimizationViews.cpp
+++ b/src/backends/backendsCommon/OptimizationViews.cpp
@@ -40,6 +40,12 @@ bool OptimizationViews::Validate(const armnn::SubgraphView& originalSubgraph) co
successful.m_SubstitutableSubgraph.GetIConnectableLayers().begin(),
successful.m_SubstitutableSubgraph.GetIConnectableLayers().end());
}
+ for (auto& successful : m_DeletedSubgraphs)
+ {
+ countedLayers.insert(countedLayers.end(),
+ successful.GetIConnectableLayers().begin(),
+ successful.GetIConnectableLayers().end());
+ }
countedLayers.sort();
// Compare the two lists to make sure they match
@@ -58,7 +64,7 @@ bool OptimizationViews::Validate(const armnn::SubgraphView& originalSubgraph) co
for (auto& substitution : m_SuccesfulOptimizations)
{
bool validSubstitution = true;
- const SubgraphView& replacement = substitution.m_ReplacementSubgraph;
+ const SubgraphView &replacement = substitution.m_ReplacementSubgraph;
const SubgraphView& old = substitution.m_SubstitutableSubgraph;
validSubstitution &= replacement.GetIInputSlots().size() == old.GetIInputSlots().size();
validSubstitution &= replacement.GetIOutputSlots().size() == old.GetIOutputSlots().size();
diff --git a/src/backends/backendsCommon/SubgraphUtils.hpp b/src/backends/backendsCommon/SubgraphUtils.hpp
index bd3d698a98..ade4b63976 100644
--- a/src/backends/backendsCommon/SubgraphUtils.hpp
+++ b/src/backends/backendsCommon/SubgraphUtils.hpp
@@ -1,10 +1,12 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
+#include <armnn/StrategyBase.hpp>
+#include <armnn/Descriptors.hpp>
#include <optimizations/FoldPadIntoLayer2d.hpp>
namespace armnn
@@ -13,6 +15,118 @@ namespace armnn
namespace
{
+/// Checks if a Layer has a DataLayout that is either NCHW or NCDHW.
+class CheckForNCHW : public StrategyBase<NoThrowStrategy>
+{
+public:
+ CheckForNCHW()
+ {}
+
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
+ {
+ armnn::IgnoreUnused(layer, constants, id, name);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::BatchMatMul:
+ {
+ auto desc = static_cast<const armnn::BatchMatMulDescriptor &>(descriptor);
+ m_Result = desc.m_DataLayoutX == DataLayout::NCHW || desc.m_DataLayoutY == DataLayout::NCHW;
+ break;
+ }
+ case armnn::LayerType::BatchNormalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::BatchToSpaceNd:
+ {
+ CheckDescForNCHW(static_cast<const armnn::BatchToSpaceNdDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Convolution2d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Convolution3d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Convolution3dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::DepthwiseConvolution2d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::InstanceNormalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::InstanceNormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::L2Normalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::L2NormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Normalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::NormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Pooling2d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Pooling2dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Pooling3d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Pooling3dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::SpaceToBatchNd:
+ {
+ CheckDescForNCHW(static_cast<const armnn::SpaceToBatchNdDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::SpaceToDepth:
+ {
+ CheckDescForNCHW(static_cast<const armnn::SpaceToDepthDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::StridedSlice:
+ {
+ CheckDescForNCHW(static_cast<const armnn::StridedSliceDescriptor&>(descriptor));
+ break;
+ }
+ default:
+ {
+ m_Result = false;
+ }
+ }
+ }
+
+ /// Returns true if the Layer had a DataLayout and it was NCHW or NCDHW.
+ /// Returns false if the Layer either doesn't have a DataLayout or if it
+ /// had a DataLayout that was neither NCHW nor NCDHW.
+ bool Result()
+ {
+ return m_Result;
+ }
+
+private:
+ template<typename Descriptor>
+ void CheckDescForNCHW(const Descriptor& descriptor)
+ {
+ m_Result = (descriptor.m_DataLayout == DataLayout::NCHW) || (descriptor.m_DataLayout == DataLayout::NCDHW);
+ }
+
+ bool m_Result = false;
+};
+
//
// this helper only works if all layers where the inputs connect to are not selected
//
@@ -49,6 +163,13 @@ SubgraphView::IOutputSlots CreateIOutputsFrom(const std::vector<armnn::IConnecta
}
+inline bool IsNCHW(armnn::Layer& layer)
+{
+ CheckForNCHW check;
+ layer.ExecuteStrategy(check);
+ return check.Result();
+}
+
inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
{
std::vector<Layer*> untouchedVector;
@@ -78,22 +199,69 @@ LayerType* FoldPadLayer(OptimizationViews& optimizationViews,
return replacementLayer;
}
+inline void RemoveReshapeLayer(ReshapeLayer* baseLayer,
+ std::map<LayerGuid, Layer*>& untouched,
+ OptimizationViews& optimizationViews)
+{
+ if (baseLayer == nullptr)
+ {
+ return;
+ }
+ ReshapeDescriptor reshapeDescriptor = baseLayer->GetParameters();
+ Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
+
+ // Cannot currently remove the Reshape if it's connected to an Input, Constant or Splitter
+ if (parentLayer.GetType() == LayerType::Input || parentLayer.GetType() == LayerType::Constant)
+ {
+ return;
+ }
+
+ // Cannot currently remove the Reshape if it's connected to an OutputSlot or Concat
+ for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
+ {
+ Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
+
+ if (nextLayer.GetType() == LayerType::Output)
+ {
+ return;
+ }
+ }
+ auto it = untouched.find(baseLayer->GetGuid());
+ if (it == untouched.end())
+ {
+ // Already removed from map
+ return;
+ }
+ untouched.erase(it);
+
+ // Override the InputSlot TensorInfos for all the layers connected to the Reshape's OutputSlot
+ for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
+ {
+ Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
+ auto inputIndex = baseLayer->GetOutputSlot(0).GetConnection(i)->GetSlotIndex();
+ TensorInfo reshapeInfo(baseLayer->GetOutputSlot(0).GetTensorInfo());
+ reshapeInfo.SetShape(reshapeDescriptor.m_TargetShape);
+ nextLayer.GetInputSlot(inputIndex).SetTensorInfo(reshapeInfo);
+ }
+ optimizationViews.AddDeletedSubgraph(baseLayer);
+}
+
template<typename LayerType>
LayerType* FoldPadIntoAveragePool2d(OptimizationViews& optimizationViews,
Pooling2dLayer* baseLayer,
Pooling2dDescriptor& poolDescriptor,
PadLayer* padLayer)
{
- IConnectableLayer* replacement =
- optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
- LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
+ IConnectableLayer* replacement =
+ optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
+ LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
- FoldPadLayer(optimizationViews,
- baseLayer,
- replacementLayer,
- padLayer);
+ FoldPadLayer(optimizationViews,
+ baseLayer,
+ replacementLayer,
+ padLayer);
- return replacementLayer;
+ return replacementLayer;
}
} // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 37f9382d6e..ac4bcc90f6 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -1122,7 +1122,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
case LayerType::ReverseV2:
{
auto cLayer = PolymorphicDowncast<const ReverseV2Layer*>(&layer);
- const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& input = layer.GetInputSlot(0).GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
result = layerSupportObject.IsReverseV2Supported(OverrideDataType(input, dataType),
OverrideDataType(output, dataType),
@@ -1413,9 +1413,9 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
// All inputs.
const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetTensorInfo(),
dataType);
- const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
+ const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetTensorInfo(),
dataType);
- const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
+ const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetTensorInfo(),
dataType);
// Outputs
const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt
index 0139044432..8d6891a68d 100644
--- a/src/backends/backendsCommon/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/test/CMakeLists.txt
@@ -59,6 +59,7 @@ list(APPEND armnnBackendsCommonUnitTests_sources
SpaceToDepthEndToEndTestImpl.hpp
SplitterEndToEndTestImpl.hpp
StridedSliceAsyncEndToEndTest.hpp
+ SubgraphUtilsTest.hpp
SubtractionEndToEndTestImpl.hpp
TransposeEndToEndTestImpl.hpp
TensorCopyUtils.hpp
diff --git a/src/backends/backendsCommon/test/SubgraphUtilsTest.hpp b/src/backends/backendsCommon/test/SubgraphUtilsTest.hpp
new file mode 100644
index 0000000000..957726797b
--- /dev/null
+++ b/src/backends/backendsCommon/test/SubgraphUtilsTest.hpp
@@ -0,0 +1,399 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <ResolveType.hpp>
+
+#include <armnn/INetwork.hpp>
+#include <armnn/utility/NumericCast.hpp>
+#include <GraphUtils.hpp>
+#include <CommonTestUtils.hpp>
+#include <armnnTestUtils/DataLayoutUtils.hpp>
+
+#include <doctest/doctest.h>
+
+#include <vector>
+#include "backendsCommon/SubgraphUtils.hpp"
+
+namespace armnn
+{
+
+template<DataType ArmnnIType, DataType ArmnnOType,
+ typename TInput = ResolveType<ArmnnIType>, typename TOutput = ResolveType<ArmnnOType>>
+void EndToEndLayerTest(IRuntimePtr runtime,
+ IOptimizedNetworkPtr optNet,
+ const std::map<int, std::vector<TInput>>& inputTensorData,
+ const std::map<int, std::vector<TOutput>>& expectedOutputData,
+ float tolerance = 0.000001f)
+{
+ // Loads it into the runtime.
+ NetworkId netId;
+ std::string errorMessage;
+ armnn::Status loadingStatus = runtime->LoadNetwork(netId, std::move(optNet), errorMessage);
+ CHECK_MESSAGE(loadingStatus == Status::Success, errorMessage);
+
+ InputTensors inputTensors;
+ inputTensors.reserve(inputTensorData.size());
+ for (auto&& it : inputTensorData)
+ {
+ inputTensors.push_back({it.first,
+ ConstTensor(runtime->GetInputTensorInfo(netId, it.first), it.second.data())});
+ }
+ OutputTensors outputTensors;
+ outputTensors.reserve(expectedOutputData.size());
+ std::map<int, std::vector<TOutput>> outputStorage;
+ for (auto&& it : expectedOutputData)
+ {
+ std::vector<TOutput> out(it.second.size());
+ outputStorage.emplace(it.first, out);
+ outputTensors.push_back({it.first,
+ Tensor(runtime->GetOutputTensorInfo(netId, it.first),
+ outputStorage.at(it.first).data())});
+ }
+
+ // Does the inference.
+ runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
+
+ // Checks the results.
+ for (auto&& it : expectedOutputData)
+ {
+ std::vector<TOutput> out = outputStorage.at(it.first);
+ for (unsigned int i = 0; i < out.size(); ++i)
+ {
+ CHECK_MESSAGE(Compare<ArmnnOType>(it.second[i], out[i], tolerance) == true,
+ "Actual output: " << out[i] << ". Expected output:" << it.second[i]);
+
+ }
+ }
+}
+
+template<armnn::DataType ArmnnType, typename T = ResolveType<ArmnnType>>
+armnn::INetworkPtr CreateReshapeInOutNetwork(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& outputShape,
+ ReshapeDescriptor& descriptor,
+ const float qScale = 1.0f,
+ const int32_t qOffset = 0)
+{
+ armnn::INetworkPtr network(armnn::INetwork::Create());
+
+ TensorInfo inputTensorInfo(inputShape, ArmnnType, qScale, qOffset, true);
+ TensorInfo outputTensorInfo(outputShape, ArmnnType, qScale, qOffset);
+
+ IConnectableLayer* activation0 = network->AddActivationLayer(ActivationFunction::ReLu, "act0");
+ IConnectableLayer* activation1 = network->AddActivationLayer(ActivationFunction::ReLu, "act1");
+ IConnectableLayer* activation2 = network->AddActivationLayer(ActivationFunction::ReLu, "act2");
+ IConnectableLayer* activation3 = network->AddActivationLayer(ActivationFunction::ReLu, "act3");
+ IConnectableLayer* reshape = network->AddReshapeLayer(descriptor, "reshape");
+
+ IConnectableLayer* input = network->AddInputLayer(0, "input");
+ IConnectableLayer* output1 = network->AddOutputLayer(0, "output1");
+ IConnectableLayer* output2 = network->AddOutputLayer(1, "output2");
+ IConnectableLayer* output3 = network->AddOutputLayer(2, "output3");
+
+ Connect(input, activation0, inputTensorInfo, 0, 0);
+ Connect(activation0, reshape, inputTensorInfo, 0, 0);
+
+ Connect(reshape, activation1, outputTensorInfo, 0, 0);
+ Connect(reshape, activation2, outputTensorInfo, 0, 0);
+ Connect(reshape, activation3, outputTensorInfo, 0, 0);
+ Connect(activation1, output1, outputTensorInfo, 0, 0);
+ Connect(activation2, output2, outputTensorInfo, 0, 0);
+ Connect(activation3, output3, outputTensorInfo, 0, 0);
+
+ return network;
+}
+
+template<armnn::DataType ArmnnType, typename T = ResolveType<ArmnnType>>
+armnn::INetworkPtr CreateReshapeConv2dInOutNetwork(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& weightsShape,
+ const armnn::TensorShape& convOutputShape,
+ const armnn::TensorShape& outputShape,
+ std::vector<float>& weightsData,
+ ReshapeDescriptor& descriptor,
+ Convolution2dDescriptor& convolution2DDescriptor,
+ bool convFirst,
+ const float qScale = 1.0f,
+ const int32_t qOffset = 0)
+{
+ armnn::INetworkPtr network(armnn::INetwork::Create());
+ TensorInfo weightsTensorInfo(weightsShape, ArmnnType, qScale, qOffset, true);
+ ConstTensor weights(weightsTensorInfo, weightsData);
+
+ IConnectableLayer* convolution1 = network->AddConvolution2dLayer(convolution2DDescriptor, "conv2d");
+ IConnectableLayer* weightsLayer = network->AddConstantLayer(weights, "weights");
+
+ IConnectableLayer* activation1 = network->AddActivationLayer(ActivationFunction::ReLu, "act");
+ IConnectableLayer* reshape = network->AddReshapeLayer(descriptor, "reshape");
+
+ IConnectableLayer* input = network->AddInputLayer(0, "input");
+ IConnectableLayer* output = network->AddOutputLayer(0, "output");
+
+ TensorInfo inputTensorInfo(inputShape, ArmnnType, qScale, qOffset, true);
+ TensorInfo convTensorInfo(convOutputShape, ArmnnType, qScale, qOffset, true);
+ TensorInfo outputTensorInfo(outputShape, ArmnnType, qScale, qOffset);
+ TensorInfo reshapeTensorInfo(descriptor.m_TargetShape, ArmnnType, qScale, qOffset, true);
+
+ if (convFirst)
+ {
+ Connect(input, convolution1, inputTensorInfo, 0, 0);
+ Connect(weightsLayer, convolution1, weightsTensorInfo, 0, 1);
+ Connect(convolution1, reshape, convTensorInfo, 0, 0);
+ Connect(reshape, activation1, reshapeTensorInfo, 0, 0);
+ Connect(activation1, output, outputTensorInfo, 0, 0);
+ }
+ else
+ {
+ Connect(input, activation1, inputTensorInfo, 0, 0);
+ Connect(activation1, reshape, inputTensorInfo, 0, 0);
+ Connect(reshape, convolution1, reshapeTensorInfo, 0, 0);
+ Connect(weightsLayer, convolution1, weightsTensorInfo, 0, 1);
+ Connect(convolution1, output, outputTensorInfo, 0, 0);
+ }
+ return network;
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+void ReshapeRemovalEndToEnd(const std::vector<armnn::BackendId>& backends)
+{
+ using namespace armnn;
+
+ const TensorShape& inputShape = { 2, 3 };
+ const TensorShape& outputShape = { 6 };
+
+ ReshapeDescriptor descriptor;
+ descriptor.m_TargetShape = outputShape;
+
+ INetworkPtr network = CreateReshapeInOutNetwork<ArmnnType>(inputShape, outputShape, descriptor);
+
+ CHECK(network);
+
+ std::vector<T> data{ 1, 2, 3,
+ 4, 5, 6 };
+
+ std::map<int, std::vector<float>> inputTensorData = { { 0, data } };
+ std::map<int, std::vector<float>> expectedOutputData = { { 0, data }, { 1, data }, { 2, data } };
+
+ // Create runtime in which test will run
+ IRuntime::CreationOptions options;
+ IRuntimePtr runtime(IRuntime::Create(options));
+
+ // optimize the network
+ IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec());
+
+ Graph& graph = GetGraphForTesting(optNet.get());
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ LayerNameAndTypeCheck(LayerType::Input, "input"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act0"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act1"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act2"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act3"),
+ LayerNameAndTypeCheck(LayerType::Output, "output1"),
+ LayerNameAndTypeCheck(LayerType::Output, "output2"),
+ LayerNameAndTypeCheck(LayerType::Output, "output3")));
+
+ EndToEndLayerTest<ArmnnType, ArmnnType>(std::move(runtime), std::move(optNet), inputTensorData, expectedOutputData);
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+void ReshapeRemovalNCHWEndToEnd(const std::vector<armnn::BackendId>& backends, bool shouldBeRemoved, bool convFirst)
+{
+ using namespace armnn;
+
+ // shapes are different if convFirst or not
+ //these are convFirst
+ TensorShape inputShape;
+ TensorShape convOutputShape;
+ TensorShape weightsShape;
+ TensorShape reshapeShape;
+ TensorShape outputShape;
+
+ if (convFirst)
+ {
+ inputShape = { 1, 1, 5, 5 };
+ convOutputShape = { 1, 1, 3, 3 };
+ weightsShape = { 1, 1, 3, 3 };
+ reshapeShape = { 9 };
+ outputShape = { 9 };
+ }
+ else
+ {
+ inputShape = { 5, 5 };
+ reshapeShape = { 1, 1, 5, 5 };
+ convOutputShape = { 1, 1, 3, 3 };
+ weightsShape = { 1, 1, 3, 3 };
+ outputShape = { 1, 1, 3, 3 };
+ }
+
+ ReshapeDescriptor descriptor;
+ descriptor.m_TargetShape = reshapeShape;
+
+ Convolution2dDescriptor convolution2DDescriptor;
+ convolution2DDescriptor.m_PadLeft = 0;
+ convolution2DDescriptor.m_PadRight = 0;
+ convolution2DDescriptor.m_PadTop = 0;
+ convolution2DDescriptor.m_PadBottom = 0;
+ convolution2DDescriptor.m_StrideX = 1;
+ convolution2DDescriptor.m_StrideY = 1;
+ convolution2DDescriptor.m_DataLayout = DataLayout::NCHW;
+ convolution2DDescriptor.m_BiasEnabled = false;
+
+ TensorInfo inputInfo(inputShape, DataType::Float32, true);
+ TensorInfo outputInfo(convOutputShape, DataType::Float32);
+ TensorInfo weightsInfo(weightsShape, DataType::Float32, true);
+
+ std::vector<float> inputData =
+ {
+ 1.0f, 8.0f, 3.0f, 4.0f, 6.0f,
+ 5.0f, 7.0f, 3.0f, 1.0f, 8.0f,
+ 2.0f, 3.0f, 9.0f, 8.0f, 1.0f,
+ 3.0f, 6.0f, 1.0f, 1.0f, 9.0f,
+ 5.0f, 3.0f, 9.0f, 3.0f, 2.0f
+ };
+
+ std::vector<float> weightsData =
+ {
+ 4.0f, 0.0f, 3.0f,
+ 5.0f, 0.0f, 2.0f,
+ 6.0f, 0.0f, 1.0f
+ };
+
+ std::vector<float> outputData =
+ {
+ 65.0f, 107.0f, 116.0f,
+ 76.0f, 99.0f, 98.0f,
+ 91.0f, 89.0f, 118.0f
+ };
+
+ INetworkPtr network = CreateReshapeConv2dInOutNetwork<DataType::Float32>(inputShape,
+ weightsShape,
+ convOutputShape,
+ outputShape,
+ weightsData,
+ descriptor,
+ convolution2DDescriptor,
+ convFirst);
+ CHECK(network);
+
+ std::map<int, std::vector<float>> inputTensorData = { { 0, inputData } };
+ std::map<int, std::vector<float>> expectedOutputData = { { 0, outputData } };
+
+ // Create runtime in which test will run
+ IRuntime::CreationOptions options;
+ IRuntimePtr runtime(IRuntime::Create(options));
+
+ // optimize the network
+ IOptimizedNetworkPtr optNet = Optimize(*network, backends, runtime->GetDeviceSpec());
+
+ Graph& graph = GetGraphForTesting(optNet.get());
+
+ if (shouldBeRemoved)
+ {
+ if (convFirst)
+ {
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ LayerNameAndTypeCheck(LayerType::Input, "input"),
+ LayerNameAndTypeCheck(LayerType::Constant, "weights"),
+ LayerNameAndTypeCheck(LayerType::Convolution2d, "conv2d"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act"),
+ LayerNameAndTypeCheck(LayerType::Output, "output")));
+ }
+ else
+ {
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ LayerNameAndTypeCheck(LayerType::Input, "input"),
+ LayerNameAndTypeCheck(LayerType::Constant, "weights"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act"),
+ LayerNameAndTypeCheck(LayerType::Convolution2d, "conv2d"),
+ LayerNameAndTypeCheck(LayerType::Output, "output")));
+ }
+ }
+ else
+ {
+ if (convFirst)
+ {
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ LayerNameAndTypeCheck(LayerType::Input, "input"),
+ LayerNameAndTypeCheck(LayerType::Constant, "weights"),
+ LayerNameAndTypeCheck(LayerType::Convolution2d, "conv2d"),
+ LayerNameAndTypeCheck(LayerType::Reshape, "reshape"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act"),
+ LayerNameAndTypeCheck(LayerType::Output, "output")));
+ }
+ else
+ {
+ CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+ LayerNameAndTypeCheck(LayerType::Input, "input"),
+ LayerNameAndTypeCheck(LayerType::Constant, "weights"),
+ LayerNameAndTypeCheck(LayerType::Activation, "act"),
+ LayerNameAndTypeCheck(LayerType::Reshape, "reshape"),
+ LayerNameAndTypeCheck(LayerType::Convolution2d, "conv2d"),
+ LayerNameAndTypeCheck(LayerType::Output, "output")));
+ }
+ }
+
+ EndToEndLayerTest<ArmnnType, ArmnnType>(std::move(runtime), std::move(optNet), inputTensorData, expectedOutputData);
+}
+
+template<typename Descriptor, typename LayerType>
+void CheckIsNCHW()
+{
+ Graph graph;
+ Descriptor nchwDesc;
+ nchwDesc.m_DataLayout = DataLayout::NCHW;
+ auto nchwLayer = graph.AddLayer<LayerType>(nchwDesc, "");
+ CHECK(IsNCHW(*nchwLayer));
+
+ Descriptor nhwcDesc;
+ nhwcDesc.m_DataLayout = DataLayout::NHWC;
+ auto nhwcLayer = graph.AddLayer<LayerType>(nhwcDesc, "");
+ CHECK_FALSE(IsNCHW(*nhwcLayer));
+}
+
+TEST_CASE("CheckIsNCHW")
+{
+ Graph graph;
+ BatchMatMulDescriptor descriptor1;
+ descriptor1.m_DataLayoutX = DataLayout::NHWC;
+ descriptor1.m_DataLayoutY = DataLayout::NHWC;
+ auto batchMatMulLayer1 = graph.AddLayer<BatchMatMulLayer>(descriptor1, "");
+ CHECK_FALSE(IsNCHW(*batchMatMulLayer1));
+
+ BatchMatMulDescriptor descriptor2;
+ descriptor2.m_DataLayoutX = DataLayout::NCHW;
+ descriptor2.m_DataLayoutY = DataLayout::NHWC;
+ auto batchMatMulLayer2 = graph.AddLayer<BatchMatMulLayer>(descriptor2, "");
+ CHECK(IsNCHW(*batchMatMulLayer2));
+
+ BatchMatMulDescriptor descriptor3;
+ descriptor3.m_DataLayoutX = DataLayout::NHWC;
+ descriptor3.m_DataLayoutY = DataLayout::NCHW;
+ auto batchMatMulLayer3 = graph.AddLayer<BatchMatMulLayer>(descriptor3, "");
+ CHECK(IsNCHW(*batchMatMulLayer3));
+
+ BatchMatMulDescriptor descriptor4;
+ descriptor4.m_DataLayoutX = DataLayout::NCHW;
+ descriptor4.m_DataLayoutY = DataLayout::NCHW;
+ auto batchMatMulLayer4 = graph.AddLayer<BatchMatMulLayer>(descriptor4, "");
+ CHECK(IsNCHW(*batchMatMulLayer4));
+
+ CheckIsNCHW<BatchToSpaceNdDescriptor, BatchToSpaceNdLayer>();
+ CheckIsNCHW<Convolution2dDescriptor, Convolution2dLayer>();
+ CheckIsNCHW<Convolution3dDescriptor, Convolution3dLayer>();
+ CheckIsNCHW<DepthwiseConvolution2dDescriptor, DepthwiseConvolution2dLayer>();
+ CheckIsNCHW<InstanceNormalizationDescriptor, InstanceNormalizationLayer>();
+ CheckIsNCHW<L2NormalizationDescriptor, L2NormalizationLayer>();
+ CheckIsNCHW<NormalizationDescriptor, NormalizationLayer>();
+ CheckIsNCHW<Pooling2dDescriptor, Pooling2dLayer>();
+ CheckIsNCHW<Pooling3dDescriptor, Pooling3dLayer>();
+ CheckIsNCHW<SpaceToBatchNdDescriptor, SpaceToBatchNdLayer>();
+ CheckIsNCHW<SpaceToDepthDescriptor, SpaceToDepthLayer>();
+ CheckIsNCHW<StridedSliceDescriptor, StridedSliceLayer>();
+
+ // Check Default
+ auto elementwiseLayer = graph.AddLayer<ElementwiseBinaryLayer>(BinaryOperation::Add, "");
+ CHECK_FALSE(IsNCHW(*elementwiseLayer));
+}
+
+
+} // Namespace
diff --git a/src/backends/cl/CMakeLists.txt b/src/backends/cl/CMakeLists.txt
index 20c42061fc..cc71069910 100644
--- a/src/backends/cl/CMakeLists.txt
+++ b/src/backends/cl/CMakeLists.txt
@@ -39,6 +39,7 @@ if(ARMCOMPUTECL)
ClLayerSupport.cpp
ClLayerSupport.hpp
ClRegistryInitializer.cpp
+ ClTensorHandle.cpp
ClTensorHandle.hpp
ClTensorHandleFactory.cpp
ClTensorHandleFactory.hpp
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index a10b6fbb43..b018654288 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -455,6 +455,7 @@ OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
replacementLayer->m_Gamma = std::move(baseLayer->m_Gamma);
replacementLayer->m_Mean = std::move(baseLayer->m_Mean);
replacementLayer->m_Variance = std::move(baseLayer->m_Variance);
+
untouched.erase(baseLayer->GetGuid());
untouched.erase(activationLayer->GetGuid());
}
@@ -476,6 +477,7 @@ OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
activationLayer,
activationDesc,
name);
+
untouched.erase(baseLayer->GetGuid());
untouched.erase(activationLayer->GetGuid());
}
@@ -623,6 +625,8 @@ OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
activationDesc,
BinaryOperation::Sub,
name);
+ untouched.erase(baseLayer->GetGuid());
+ untouched.erase(activationLayer->GetGuid());
}
}
// No fusion available for other BinaryOperations
@@ -678,7 +682,7 @@ OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
}
}
- if (optimizationViews.GetSubstitutions().empty())
+ if (optimizationViews.GetSubstitutions().empty() && optimizationViews.GetDeletedSubgraphs().empty())
{
optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
}
diff --git a/src/backends/cl/ClTensorHandle.cpp b/src/backends/cl/ClTensorHandle.cpp
new file mode 100644
index 0000000000..ccc8f6effc
--- /dev/null
+++ b/src/backends/cl/ClTensorHandle.cpp
@@ -0,0 +1,82 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClTensorHandle.hpp"
+
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+
+namespace armnn
+{
+ std::shared_ptr<ITensorHandle> ClTensorHandle::DecorateTensorHandle(const TensorInfo& tensorInfo)
+ {
+ auto* parent = const_cast<ClTensorHandle*>(this);
+ auto decorated = std::make_shared<ClTensorHandleDecorator>(parent, tensorInfo);
+ m_Decorated.emplace_back(decorated);
+ return decorated;
+ }
+
+ ClTensorDecorator::ClTensorDecorator()
+ : m_Original(nullptr), m_TensorInfo()
+ {
+ }
+
+ ClTensorDecorator::ClTensorDecorator(arm_compute::ICLTensor* original, const TensorInfo& tensorInfo)
+ : m_Original(nullptr), m_TensorInfo()
+ {
+ m_TensorInfo = armcomputetensorutils::BuildArmComputeTensorInfo(tensorInfo);
+ m_Original = original;
+ }
+
+ arm_compute::ITensorInfo* ClTensorDecorator::info() const
+ {
+ return &m_TensorInfo;
+ }
+
+ arm_compute::ITensorInfo* ClTensorDecorator::info()
+ {
+ return &m_TensorInfo;
+ }
+
+ const cl::Buffer& ClTensorDecorator::cl_buffer() const
+ {
+ ARM_COMPUTE_ERROR_ON(m_Original == nullptr);
+ return m_Original->cl_buffer();
+ }
+
+ arm_compute::ICLTensor* ClTensorDecorator::parent()
+ {
+ return nullptr;
+ }
+
+ arm_compute::CLQuantization ClTensorDecorator::quantization() const
+ {
+ return m_Original->quantization();
+ }
+
+ void ClTensorDecorator::map(bool blocking)
+ {
+ arm_compute::ICLTensor::map(arm_compute::CLScheduler::get().queue(), blocking);
+ }
+
+ void ClTensorDecorator::unmap()
+ {
+ arm_compute::ICLTensor::unmap(arm_compute::CLScheduler::get().queue());
+ }
+
+ uint8_t* ClTensorDecorator::do_map(cl::CommandQueue& q, bool blocking)
+ {
+ if(m_Original->buffer() == nullptr)
+ {
+ m_Original->map(q, blocking);
+ }
+ return m_Original->buffer();
+ }
+
+ void ClTensorDecorator::do_unmap(cl::CommandQueue& q)
+ {
+ m_Original->unmap(q);
+ }
+
+} \ No newline at end of file
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index 3d750f9059..42657341fd 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -1,7 +1,8 @@
//
-// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#pragma once
#include <aclCommon/ArmComputeTensorHandle.hpp>
@@ -22,6 +23,7 @@
namespace armnn
{
+class ClTensorHandleDecorator;
class ClTensorHandle : public IClTensorHandle
{
@@ -122,7 +124,7 @@ public:
virtual bool Import(void* memory, MemorySource source) override
{
armnn::IgnoreUnused(memory);
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
{
throw MemoryImportException("ClTensorHandle::Incorrect import flag");
}
@@ -137,6 +139,8 @@ public:
return false;
}
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -227,6 +231,7 @@ private:
MemorySourceFlags m_ImportFlags;
bool m_Imported;
bool m_IsImportEnabled;
+ std::vector<std::shared_ptr<ClTensorHandleDecorator>> m_Decorated;
};
class ClSubTensorHandle : public IClTensorHandle
@@ -361,4 +366,179 @@ private:
ITensorHandle* parentHandle = nullptr;
};
+/** ClTensorDecorator wraps an existing CL tensor allowing us to override the TensorInfo for it */
+class ClTensorDecorator : public arm_compute::ICLTensor
+{
+public:
+ ClTensorDecorator();
+
+ ClTensorDecorator(arm_compute::ICLTensor* original, const TensorInfo& info);
+
+ ~ClTensorDecorator() = default;
+
+ ClTensorDecorator(const ClTensorDecorator&) = delete;
+
+ ClTensorDecorator& operator=(const ClTensorDecorator&) = delete;
+
+ ClTensorDecorator(ClTensorDecorator&&) = default;
+
+ ClTensorDecorator& operator=(ClTensorDecorator&&) = default;
+
+ arm_compute::ICLTensor* parent();
+
+ void map(bool blocking = true);
+ using arm_compute::ICLTensor::map;
+
+ void unmap();
+ using arm_compute::ICLTensor::unmap;
+
+ virtual arm_compute::ITensorInfo* info() const override;
+ virtual arm_compute::ITensorInfo* info() override;
+ const cl::Buffer& cl_buffer() const override;
+ arm_compute::CLQuantization quantization() const override;
+
+protected:
+ // Inherited methods overridden:
+ uint8_t* do_map(cl::CommandQueue& q, bool blocking) override;
+ void do_unmap(cl::CommandQueue& q) override;
+
+private:
+ arm_compute::ICLTensor* m_Original;
+ mutable arm_compute::TensorInfo m_TensorInfo;
+};
+
+class ClTensorHandleDecorator : public IClTensorHandle
+{
+public:
+ ClTensorHandleDecorator(IClTensorHandle* parent, const TensorInfo& info)
+ : m_Tensor(&parent->GetTensor(), info)
+ {
+ m_OriginalHandle = parent;
+ }
+
+ arm_compute::ICLTensor& GetTensor() override { return m_Tensor; }
+ arm_compute::ICLTensor const& GetTensor() const override { return m_Tensor; }
+
+ virtual void Allocate() override {}
+ virtual void Manage() override {}
+
+ virtual const void* Map(bool blocking = true) const override
+ {
+ m_Tensor.map(blocking);
+ return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+ }
+
+ virtual void Unmap() const override
+ {
+ m_Tensor.unmap();
+ }
+
+ virtual ITensorHandle* GetParent() const override { return nullptr; }
+
+ virtual arm_compute::DataType GetDataType() const override
+ {
+ return m_Tensor.info()->data_type();
+ }
+
+ virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
+
+ TensorShape GetStrides() const override
+ {
+ return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+ }
+
+ TensorShape GetShape() const override
+ {
+ return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+ }
+
+private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ const_cast<ClTensorHandleDecorator*>(this)->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<armnn::Half*>(memory));
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QSYMM8_PER_CHANNEL:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int8_t*>(memory));
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int16_t*>(memory));
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int32_t*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ const_cast<ClTensorHandleDecorator*>(this)->Unmap();
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ this->Map(true);
+ switch(this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::F16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QSYMM8_PER_CHANNEL:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ this->Unmap();
+ }
+
+ mutable ClTensorDecorator m_Tensor;
+ IClTensorHandle* m_OriginalHandle = nullptr;
+};
+
} // namespace armnn
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index 03f1a9540d..f4b9fac740 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -23,6 +23,7 @@ BACKEND_SOURCES := \
ClImportTensorHandleFactory.cpp \
ClLayerSupport.cpp \
ClRegistryInitializer.cpp \
+ ClTensorHandle.cpp \
ClTensorHandleFactory.cpp \
ClWorkloadFactory.cpp \
OpenClTimer.cpp \
diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp
index 091526fd2b..6342cbc48a 100644
--- a/src/backends/cl/test/ClEndToEndTests.cpp
+++ b/src/backends/cl/test/ClEndToEndTests.cpp
@@ -24,6 +24,7 @@
#include <backendsCommon/test/ReshapeEndToEndTestImpl.hpp>
#include <backendsCommon/test/SpaceToDepthEndToEndTestImpl.hpp>
#include <backendsCommon/test/SplitterEndToEndTestImpl.hpp>
+#include <backendsCommon/test/SubgraphUtilsTest.hpp>
#include <backendsCommon/test/TransposeConvolution2dEndToEndTestImpl.hpp>
#include <backendsCommon/test/TransposeEndToEndTestImpl.hpp>
@@ -60,18 +61,18 @@ TEST_CASE("ClAdditionEndToEndUint8Test")
}
// Power
-TEST_CASE("RefPowerEndToEndTestFloat32")
+TEST_CASE("ClPowerEndToEndTestFloat32")
{
ElementwiseBinarySimpleEndToEnd<armnn::DataType::Float32>(clDefaultBackends, BinaryOperation::Power);
}
// SqDiff
-TEST_CASE("RefSquaredDifferenceEndToEndTestFloat32")
+TEST_CASE("ClSquaredDifferenceEndToEndTestFloat32")
{
ElementwiseBinarySimpleEndToEnd<armnn::DataType::Float32>(clDefaultBackends, BinaryOperation::SqDiff);
}
-TEST_CASE("RefSquaredDifferenceEndToEndTestUint8")
+TEST_CASE("ClSquaredDifferenceEndToEndTestUint8")
{
ElementwiseBinarySimpleEndToEnd<armnn::DataType::QAsymmU8>(clDefaultBackends, BinaryOperation::SqDiff);
}
diff --git a/src/backends/neon/CMakeLists.txt b/src/backends/neon/CMakeLists.txt
index 16164de3fb..5934221ec1 100644
--- a/src/backends/neon/CMakeLists.txt
+++ b/src/backends/neon/CMakeLists.txt
@@ -16,6 +16,7 @@ if(ARMCOMPUTENEON)
NeonLayerSupport.hpp
NeonRegistryInitializer.cpp
NeonTensorHandle.hpp
+ NeonTensorHandle.cpp
NeonTensorHandleFactory.cpp
NeonTensorHandleFactory.hpp
NeonTimer.hpp
diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp
index cea2aa3eba..098b1ff109 100644
--- a/src/backends/neon/NeonBackend.cpp
+++ b/src/backends/neon/NeonBackend.cpp
@@ -505,9 +505,39 @@ OptimizationViews NeonBackend::OptimizeSubgraphView(const SubgraphView& subgraph
untouched.erase(baseLayer->GetGuid());
}
}
+
+ // Remove Reshape where possible
+ if (base.GetType() == LayerType::Reshape)
+ {
+ ReshapeLayer* baseLayer = PolymorphicDowncast<ReshapeLayer*>(&base);
+ Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
+
+ // Cannot currently remove the Reshape if it's connected to any layer that has an NCHW layout
+ if (IsNCHW(parentLayer))
+ {
+ continue;
+ }
+ bool isNCHW = false;
+
+ for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
+ {
+ Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
+
+ if (IsNCHW(nextLayer))
+ {
+ isNCHW = true;
+ break;
+ }
+ }
+ if (isNCHW)
+ {
+ continue;
+ }
+ RemoveReshapeLayer(baseLayer, untouched, optimizationViews);
+ }
}
- if (optimizationViews.GetSubstitutions().empty())
+ if (optimizationViews.GetSubstitutions().empty() && optimizationViews.GetDeletedSubgraphs().empty())
{
optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
}
diff --git a/src/backends/neon/NeonTensorHandle.cpp b/src/backends/neon/NeonTensorHandle.cpp
new file mode 100644
index 0000000000..819805aa59
--- /dev/null
+++ b/src/backends/neon/NeonTensorHandle.cpp
@@ -0,0 +1,47 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "NeonTensorHandle.hpp"
+
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+
+namespace armnn
+{
+std::shared_ptr<ITensorHandle> NeonTensorHandle::DecorateTensorHandle(const TensorInfo &tensorInfo)
+{
+ auto* parent = const_cast<NeonTensorHandle*>(this);
+ auto decorated = std::make_shared<NeonTensorHandleDecorator>(parent, tensorInfo);
+ m_Decorated.emplace_back(decorated);
+ return decorated;
+}
+
+NeonTensorDecorator::NeonTensorDecorator()
+ : m_Original(nullptr), m_TensorInfo()
+{
+}
+
+NeonTensorDecorator::NeonTensorDecorator(arm_compute::ITensor *parent, const TensorInfo& tensorInfo)
+ : m_Original(nullptr), m_TensorInfo()
+{
+ m_TensorInfo = armcomputetensorutils::BuildArmComputeTensorInfo(tensorInfo);
+ m_Original = parent;
+}
+
+arm_compute::ITensorInfo *NeonTensorDecorator::info() const
+{
+ return &m_TensorInfo;
+}
+
+arm_compute::ITensorInfo *NeonTensorDecorator::info()
+{
+ return &m_TensorInfo;
+}
+
+uint8_t *NeonTensorDecorator::buffer() const
+{
+ return m_Original->buffer();
+}
+
+} \ No newline at end of file
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp
index fcae77cdaa..e5f210773d 100644
--- a/src/backends/neon/NeonTensorHandle.hpp
+++ b/src/backends/neon/NeonTensorHandle.hpp
@@ -1,7 +1,8 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#pragma once
#include <BFloat16.hpp>
@@ -19,9 +20,11 @@
#include <arm_compute/runtime/SubTensor.h>
#include <arm_compute/core/TensorShape.h>
#include <arm_compute/core/Coordinates.h>
+#include "armnn/TypesUtils.hpp"
namespace armnn
{
+class NeonTensorHandleDecorator;
class NeonTensorHandle : public IAclTensorHandle
{
@@ -125,7 +128,7 @@ public:
virtual bool Import(void* memory, MemorySource source) override
{
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
{
if (source == MemorySource::Malloc && m_IsImportEnabled)
{
@@ -181,6 +184,8 @@ public:
return false;
}
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -275,6 +280,7 @@ private:
bool m_Imported;
bool m_IsImportEnabled;
const uintptr_t m_TypeAlignment;
+ std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated;
};
class NeonSubTensorHandle : public IAclTensorHandle
@@ -283,7 +289,7 @@ public:
NeonSubTensorHandle(IAclTensorHandle* parent,
const arm_compute::TensorShape& shape,
const arm_compute::Coordinates& coords)
- : m_Tensor(&parent->GetTensor(), shape, coords)
+ : m_Tensor(&parent->GetTensor(), shape, coords, true)
{
parentHandle = parent;
}
@@ -319,6 +325,11 @@ public:
return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
}
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
+ {
+ return nullptr;
+ };
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -394,4 +405,155 @@ private:
ITensorHandle* parentHandle = nullptr;
};
+/// NeonTensorDecorator wraps an existing Neon tensor allowing us to override the TensorInfo for it
+class NeonTensorDecorator : public arm_compute::ITensor
+{
+public:
+ NeonTensorDecorator();
+
+ NeonTensorDecorator(arm_compute::ITensor* original, const TensorInfo& info);
+
+ ~NeonTensorDecorator() = default;
+
+ NeonTensorDecorator(const NeonTensorDecorator&) = delete;
+
+ NeonTensorDecorator& operator=(const NeonTensorDecorator&) = delete;
+
+ NeonTensorDecorator(NeonTensorDecorator&&) = default;
+
+ NeonTensorDecorator& operator=(NeonTensorDecorator&&) = default;
+
+ // Inherited methods overridden:
+ arm_compute::ITensorInfo* info() const override;
+
+ arm_compute::ITensorInfo* info() override;
+
+ uint8_t* buffer() const override;
+
+private:
+ arm_compute::ITensor* m_Original;
+ mutable arm_compute::TensorInfo m_TensorInfo;
+};
+
+class NeonTensorHandleDecorator : public IAclTensorHandle
+{
+public:
+ NeonTensorHandleDecorator(IAclTensorHandle* parent, const TensorInfo& info)
+ : m_Tensor(&parent->GetTensor(), info)
+ {
+ parentHandle = parent;
+ }
+
+ arm_compute::ITensor& GetTensor() override { return m_Tensor; }
+ arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
+
+ virtual void Allocate() override {}
+ virtual void Manage() override {}
+
+ virtual ITensorHandle* GetParent() const override { return nullptr; }
+
+ virtual arm_compute::DataType GetDataType() const override
+ {
+ return m_Tensor.info()->data_type();
+ }
+
+ virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
+
+ virtual const void* Map(bool /* blocking = true */) const override
+ {
+ return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+ }
+ virtual void Unmap() const override {}
+
+ TensorShape GetStrides() const override
+ {
+ return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+ }
+
+ TensorShape GetShape() const override
+ {
+ return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+ }
+
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
+ {
+ return nullptr;
+ };
+
+private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int8_t*>(memory));
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int16_t*>(memory));
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int32_t*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
+ NeonTensorDecorator m_Tensor;
+ ITensorHandle* parentHandle = nullptr;
+};
+
+
} // namespace armnn
diff --git a/src/backends/neon/backend.mk b/src/backends/neon/backend.mk
index e2439eecb7..4150845f58 100644
--- a/src/backends/neon/backend.mk
+++ b/src/backends/neon/backend.mk
@@ -19,6 +19,7 @@ BACKEND_SOURCES := \
NeonInterceptorScheduler.cpp \
NeonLayerSupport.cpp \
NeonRegistryInitializer.cpp \
+ NeonTensorHandle.cpp \
NeonTensorHandleFactory.cpp \
NeonTimer.cpp \
NeonWorkloadFactory.cpp \
diff --git a/src/backends/neon/test/NeonEndToEndTests.cpp b/src/backends/neon/test/NeonEndToEndTests.cpp
index 071ee415de..5672f8b993 100644
--- a/src/backends/neon/test/NeonEndToEndTests.cpp
+++ b/src/backends/neon/test/NeonEndToEndTests.cpp
@@ -25,6 +25,7 @@
#include <backendsCommon/test/ReshapeEndToEndTestImpl.hpp>
#include <backendsCommon/test/SpaceToDepthEndToEndTestImpl.hpp>
#include <backendsCommon/test/SplitterEndToEndTestImpl.hpp>
+#include <backendsCommon/test/SubgraphUtilsTest.hpp>
#include <backendsCommon/test/TransposeConvolution2dEndToEndTestImpl.hpp>
#include <backendsCommon/test/TransposeEndToEndTestImpl.hpp>
@@ -147,18 +148,18 @@ TEST_CASE("NeonAdditionEndToEndUint8Test")
}
// Power
-TEST_CASE("RefPowerEndToEndTestFloat32")
+TEST_CASE("NeonPowerEndToEndTestFloat32")
{
ElementwiseBinarySimpleEndToEnd<armnn::DataType::Float32>(neonDefaultBackends, BinaryOperation::Power);
}
// SqDiff
-TEST_CASE("RefSquaredDifferenceEndToEndTestFloat32")
+TEST_CASE("NeonSquaredDifferenceEndToEndTestFloat32")
{
ElementwiseBinarySimpleEndToEnd<armnn::DataType::Float32>(neonDefaultBackends, BinaryOperation::SqDiff);
}
-TEST_CASE("RefSquaredDifferenceEndToEndTestUint8")
+TEST_CASE("NeonSquaredDifferenceEndToEndTestUint8")
{
ElementwiseBinarySimpleEndToEnd<armnn::DataType::QAsymmU8>(neonDefaultBackends, BinaryOperation::SqDiff);
}
@@ -850,4 +851,19 @@ TEST_CASE("NeonQLstmEndToEndTest")
QLstmEndToEnd(neonDefaultBackends);
}
+TEST_CASE("NeonReshapeRemovalSimpleCaseEndToEnd")
+{
+ ReshapeRemovalEndToEnd<armnn::DataType::Float32>(neonDefaultBackends);
+}
+
+TEST_CASE("NeonReshapeRemovalNCHWFirstEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(neonDefaultBackends, false, true);
+}
+
+TEST_CASE("NeonReshapeRemovalNCHWSecondEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(neonDefaultBackends, false, false);
+}
+
}
diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp
index 8c8879c8be..02749af1f9 100644
--- a/src/backends/reference/RefBackend.cpp
+++ b/src/backends/reference/RefBackend.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -16,8 +16,6 @@
#include <backendsCommon/DefaultAllocator.hpp>
#include <backendsCommon/SubgraphUtils.hpp>
-#include <Optimizer.hpp>
-
namespace armnn
{
@@ -116,9 +114,16 @@ OptimizationViews RefBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
}
}
}
+
+ // Remove Reshape where possible
+ if (base.GetType() == LayerType::Reshape)
+ {
+ ReshapeLayer* baseLayer = PolymorphicDowncast<ReshapeLayer*>(&base);
+ RemoveReshapeLayer(baseLayer, untouched, optimizationViews);
+ }
}
- if (optimizationViews.GetSubstitutions().empty())
+ if (optimizationViews.GetSubstitutions().empty() && optimizationViews.GetDeletedSubgraphs().empty())
{
optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
}
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index dbfa374945..cce992c947 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -1,29 +1,40 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019-2023 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#include "RefTensorHandle.hpp"
namespace armnn
{
-RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager):
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager>& memoryManager):
m_TensorInfo(tensorInfo),
m_MemoryManager(memoryManager),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_ImportedMemory(nullptr)
+ m_ImportedMemory(nullptr),
+ m_Decorated()
{
-
}
RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo)
: m_TensorInfo(tensorInfo),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_ImportedMemory(nullptr)
+ m_ImportedMemory(nullptr),
+ m_Decorated()
{
+}
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, const RefTensorHandle& parent)
+ : m_TensorInfo(tensorInfo),
+ m_MemoryManager(parent.m_MemoryManager),
+ m_Pool(parent.m_Pool),
+ m_UnmanagedMemory(parent.m_UnmanagedMemory),
+ m_ImportedMemory(parent.m_ImportedMemory),
+ m_Decorated()
+{
}
RefTensorHandle::~RefTensorHandle()
@@ -139,4 +150,52 @@ bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
return false;
}
+std::shared_ptr<ITensorHandle> RefTensorHandle::DecorateTensorHandle(const TensorInfo& tensorInfo)
+{
+ auto decorated = std::make_shared<RefTensorHandleDecorator>(tensorInfo, *this);
+ m_Decorated.emplace_back(decorated);
+ return decorated;
+}
+
+RefTensorHandleDecorator::RefTensorHandleDecorator(const TensorInfo& tensorInfo, const RefTensorHandle& parent)
+: RefTensorHandle(tensorInfo)
+, m_TensorInfo(tensorInfo)
+, m_Parent(parent)
+{
+}
+
+void RefTensorHandleDecorator::Manage()
+{
+}
+
+void RefTensorHandleDecorator::Allocate()
+{
+}
+
+const void* RefTensorHandleDecorator::Map(bool unused) const
+{
+ return m_Parent.Map(unused);
+}
+
+MemorySourceFlags RefTensorHandleDecorator::GetImportFlags() const
+{
+ return static_cast<MemorySourceFlags>(MemorySource::Malloc);
+}
+
+bool RefTensorHandleDecorator::Import(void*, MemorySource )
+{
+ return false;
+}
+
+bool RefTensorHandleDecorator::CanBeImported(void* , MemorySource)
+{
+ return false;
+}
+
+std::shared_ptr<ITensorHandle> RefTensorHandleDecorator::DecorateTensorHandle(const TensorInfo&)
+{
+ return nullptr;
+}
+
+
}
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp
index b4dedd5e77..128f623cd3 100644
--- a/src/backends/reference/RefTensorHandle.hpp
+++ b/src/backends/reference/RefTensorHandle.hpp
@@ -1,7 +1,8 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019-2023 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#pragma once
#include <armnn/backends/TensorHandle.hpp>
@@ -11,14 +12,17 @@
namespace armnn
{
+class RefTensorHandleDecorator;
// An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
class RefTensorHandle : public ITensorHandle
{
public:
- RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager);
+ RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager>& memoryManager);
RefTensorHandle(const TensorInfo& tensorInfo);
+ RefTensorHandle(const TensorInfo& tensorInfo, const RefTensorHandle& parent);
+
~RefTensorHandle();
virtual void Manage() override;
@@ -56,6 +60,8 @@ public:
virtual bool Import(void* memory, MemorySource source) override;
virtual bool CanBeImported(void* memory, MemorySource source) override;
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
+
private:
// Only used for testing
void CopyOutTo(void*) const override;
@@ -68,10 +74,86 @@ private:
TensorInfo m_TensorInfo;
- std::shared_ptr<RefMemoryManager> m_MemoryManager;
+ mutable std::shared_ptr<RefMemoryManager> m_MemoryManager;
RefMemoryManager::Pool* m_Pool;
mutable void* m_UnmanagedMemory;
void* m_ImportedMemory;
+ std::vector<std::shared_ptr<RefTensorHandleDecorator>> m_Decorated;
+};
+
+class RefTensorHandleDecorator : public RefTensorHandle
+{
+public:
+ RefTensorHandleDecorator(const TensorInfo& tensorInfo, const RefTensorHandle& parent);
+
+ ~RefTensorHandleDecorator() = default;
+
+ virtual void Manage() override;
+
+ virtual void Allocate() override;
+
+ virtual ITensorHandle* GetParent() const override
+ {
+ return nullptr;
+ }
+
+ virtual const void* Map(bool /* blocking = true */) const override;
+ using ITensorHandle::Map;
+
+ virtual void Unmap() const override
+ {}
+
+ TensorShape GetStrides() const override
+ {
+ return GetUnpaddedTensorStrides(m_TensorInfo);
+ }
+
+ TensorShape GetShape() const override
+ {
+ return m_TensorInfo.GetShape();
+ }
+
+ const TensorInfo& GetTensorInfo() const
+ {
+ return m_TensorInfo;
+ }
+
+ virtual MemorySourceFlags GetImportFlags() const override;
+
+ virtual bool Import(void* memory, MemorySource source) override;
+ virtual bool CanBeImported(void* memory, MemorySource source) override;
+
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
+
+ /// Map the tensor data for access. Must be paired with call to Unmap().
+ /// \param blocking hint to block the calling thread until all other accesses are complete. (backend dependent)
+ /// \return pointer to the first element of the mapped data.
+ void* Map(bool blocking=true)
+ {
+ return const_cast<void*>(static_cast<const ITensorHandle*>(this)->Map(blocking));
+ }
+
+ /// Unmap the tensor data that was previously mapped with call to Map().
+ void Unmap()
+ {
+ return static_cast<const ITensorHandle*>(this)->Unmap();
+ }
+
+ /// Testing support to be able to verify and set tensor data content
+ void CopyOutTo(void* /* memory */) const override
+ {};
+
+ void CopyInFrom(const void* /* memory */) override
+ {};
+
+ /// Unimport externally allocated memory
+ void Unimport() override
+ {};
+
+private:
+ TensorInfo m_TensorInfo;
+ const RefTensorHandle& m_Parent;
};
}
+
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 894dd75ef2..13ac7fc233 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -1314,4 +1314,52 @@ TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload")
RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>();
}
+bool TestRefTensorHandleInfo(armnn::RefTensorHandle* handle, const armnn::TensorInfo& expectedInfo)
+{
+ const TensorInfo handleInfo = handle->GetTensorInfo();
+ const TensorInfo expectedAclInfo = expectedInfo;
+
+ if (handleInfo.GetDataType() != expectedAclInfo.GetDataType())
+ {
+ return false;
+ }
+
+ if (handleInfo.GetNumDimensions() != expectedAclInfo.GetNumDimensions())
+ {
+ return false;
+ }
+
+ for (unsigned int d = 0; d < expectedAclInfo.GetNumDimensions(); ++d)
+ {
+ if (handleInfo.GetShape()[d] != expectedAclInfo.GetShape()[d])
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+TEST_CASE("RefCreateSplitterWorkload")
+{
+ Graph graph;
+ RefWorkloadFactory factory = GetFactory();
+
+ auto workload = CreateSplitterWorkloadTest<RefSplitterWorkload, DataType::Float32>(factory, graph);
+
+ // Checks that outputs are as we expect them (see definition of CreateSplitterWorkloadTest).
+ SplitterQueueDescriptor queueDescriptor = workload->GetData();
+ auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
+ CHECK(TestRefTensorHandleInfo(inputHandle, TensorInfo({5, 7, 7}, DataType::Float32)));
+
+ auto outputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
+ CHECK(TestRefTensorHandleInfo(outputHandle0, TensorInfo({1, 7, 7}, DataType::Float32)));
+
+ auto outputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
+ CHECK(TestRefTensorHandleInfo(outputHandle1, TensorInfo({2, 7, 7}, DataType::Float32)));
+
+ auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
+ CHECK(TestRefTensorHandleInfo(outputHandle2, TensorInfo({2, 7, 7}, DataType::Float32)));
+}
+
}
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 4bb3f2947a..eb2aabcd1e 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -35,6 +35,7 @@
#include <backendsCommon/test/SpaceToDepthEndToEndTestImpl.hpp>
#include <backendsCommon/test/SplitterEndToEndTestImpl.hpp>
#include <backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp>
+#include <backendsCommon/test/SubgraphUtilsTest.hpp>
#include <backendsCommon/test/TransposeConvolution2dEndToEndTestImpl.hpp>
#include <backendsCommon/test/TransposeEndToEndTestImpl.hpp>
@@ -1618,6 +1619,22 @@ TEST_CASE("RefSquaredDifferenceEndToEndTestUint8")
{
ElementwiseBinarySimpleEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends, BinaryOperation::SqDiff);
}
+
#endif
+// Backend Optimization Tests
+TEST_CASE("RefReshapeRemovalSimpleCaseEndToEnd")
+{
+ ReshapeRemovalEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefReshapeRemovalNCHWFirstEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(defaultBackends, true, true);
+}
+
+TEST_CASE("RefReshapeRemovalNCHWSecondEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(defaultBackends, true, false);
+}
}