aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp1
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp324
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp37
-rw-r--r--src/backends/reference/workloads/Broadcast.cpp24
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp4
-rw-r--r--src/backends/reference/workloads/Maximum.hpp22
-rw-r--r--src/backends/reference/workloads/Pad.cpp82
-rw-r--r--src/backends/reference/workloads/RefCastWorkload.cpp67
-rw-r--r--src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp20
-rw-r--r--src/backends/reference/workloads/Slice.cpp26
-rw-r--r--src/backends/reference/workloads/StridedSlice.cpp103
11 files changed, 643 insertions, 67 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 3e04a19df4..7b7d1563bc 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1547,6 +1547,7 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
DataType::Float16,
DataType::QAsymmS8,
DataType::QAsymmU8,
+ DataType::QSymmS8,
DataType::QSymmS16
};
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 6f57236dd5..67f805cf5e 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -28,6 +28,7 @@
#include <backendsCommon/test/GatherNdEndToEndTestImpl.hpp>
#include <backendsCommon/test/InstanceNormalizationEndToEndTestImpl.hpp>
#include <backendsCommon/test/LogSoftmaxEndToEndTestImpl.hpp>
+#include <backendsCommon/test/PadEndToEndTestImpl.hpp>
#include "backendsCommon/test/Pooling2dEndToEndTestImpl.hpp"
#include <backendsCommon/test/PreluEndToEndTestImpl.hpp>
#include <backendsCommon/test/QLstmEndToEndTestImpl.hpp>
@@ -93,6 +94,56 @@ TEST_CASE("RefRsqrtEndToEndTestInt16")
UnaryOperation::Rsqrt);
}
+// Exp
+TEST_CASE("RefExpEndToEndTestFloat32")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ UnaryOperation::Exp);
+}
+
+TEST_CASE("RefExpEndToEndTestUint8")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends,
+ UnaryOperation::Exp);
+}
+
+TEST_CASE("RefExpEndToEndTestInt8")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends,
+ UnaryOperation::Exp);
+}
+
+TEST_CASE("RefExpEndToEndTestInt16")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QSymmS16>(defaultBackends,
+ UnaryOperation::Exp);
+}
+
+// Log
+TEST_CASE("RefLogEndToEndTestFloat32")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ UnaryOperation::Log);
+}
+
+TEST_CASE("RefLogEndToEndTestUint8")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends,
+ UnaryOperation::Log);
+}
+
+TEST_CASE("RefLogEndToEndTestSint8")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends,
+ UnaryOperation::Log);
+}
+
+TEST_CASE("RefLogEndToEndTestInt16")
+{
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QSymmS16>(defaultBackends,
+ UnaryOperation::Log);
+}
+
// Addition
TEST_CASE("RefAdditionEndtoEndFloat32")
{
@@ -501,6 +552,46 @@ TEST_CASE("RefBatchMatMulEndToEndInt8Test")
BatchMatMulEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
}
+TEST_CASE("RefBatchMatMulNoTransposeEndToEndFloat32Test")
+{
+ BatchMatMulNoTransposeEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefBatchMatMulNoTransposeEndToEndInt8Test")
+{
+ BatchMatMulNoTransposeEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
+}
+
+TEST_CASE("RefBatchMatMulSimple4DEndToEndFloat32Test")
+{
+ BatchMatMulSimple4DEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefBatchMatMulSimple4DEndToEndInt8Test")
+{
+ BatchMatMulSimple4DEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
+}
+
+TEST_CASE("RefBatchMatMul4DEndToEndFloat32Test")
+{
+ BatchMatMul4DEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefBatchMatMul4DEndToEndInt8Test")
+{
+ BatchMatMul4DEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
+}
+
+TEST_CASE("RefBatchMatMulNotSquareEndToEndFloat32Test")
+{
+ BatchMatMulNotSquareEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefBatchMatMulNotSquareEndToEndInt8Test")
+{
+ BatchMatMulNotSquareEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
+}
+
TEST_CASE("RefBatchToSpaceNdEndToEndFloat32NHWCTest")
{
BatchToSpaceNdEndToEnd<armnn::DataType::Float32>(defaultBackends, armnn::DataLayout::NHWC);
@@ -669,6 +760,13 @@ TEST_CASE("RefDepthwiseConvolution2dEndtoEndFloat32Test")
armnn::DataLayout::NHWC);
}
+TEST_CASE("RefDepthwiseConvolution2dEndtoEndFloat32TestBiasDisabled")
+{
+ DepthwiseConvolution2dEndToEnd<armnn::DataType::Float32, armnn::DataType::Float32>(defaultBackends,
+ armnn::DataLayout::NHWC,
+ false);
+}
+
TEST_CASE("RefFillEndToEndTest")
{
FillEndToEnd<armnn::DataType::Float32>(defaultBackends);
@@ -684,8 +782,67 @@ TEST_CASE("RefFillEndToEndTestInt32")
FillEndToEnd<armnn::DataType::Signed32>(defaultBackends);
}
+// Fully Connected
TEST_CASE("RefFullyConnectedEndToEndTestFloat32")
{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::Float32>(defaultBackends, true);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestNoBiasFloat32")
+{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::Float32>(defaultBackends, false);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestInt8")
+{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::QAsymmS8,
+ armnn::DataType::QAsymmS8,
+ armnn::DataType::Signed32,
+ armnn::DataType::QAsymmS8>(defaultBackends, true);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestNoBiasInt8")
+{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::QAsymmS8,
+ armnn::DataType::QAsymmS8,
+ armnn::DataType::Signed32,
+ armnn::DataType::QAsymmS8>(defaultBackends, false);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestInt8Symm")
+{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::QSymmS8,
+ armnn::DataType::QSymmS8,
+ armnn::DataType::Signed32,
+ armnn::DataType::QSymmS8>(defaultBackends, true);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestNoBiasInt8Symm")
+{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::QSymmS8,
+ armnn::DataType::QSymmS8,
+ armnn::DataType::Signed32,
+ armnn::DataType::QSymmS8>(defaultBackends, false);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestUint8")
+{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::QAsymmU8,
+ armnn::DataType::QAsymmU8,
+ armnn::DataType::Signed32,
+ armnn::DataType::QAsymmU8>(defaultBackends, true);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestNoBiasUint8")
+{
+ FullyConnectedConstantWeightsAndBiasEndToEnd<armnn::DataType::QAsymmU8,
+ armnn::DataType::QAsymmU8,
+ armnn::DataType::Signed32,
+ armnn::DataType::QAsymmU8>(defaultBackends, false);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestNoBiasOtherFloat32")
+{
FullyConnectedWithDynamicWeightsEndToEnd<armnn::DataType::Float32>(defaultBackends);
}
@@ -724,6 +881,7 @@ TEST_CASE("RefFullyConnectedEndToEndTestBiasDisabledConnectBias")
FullyConnectedErrorChecking<armnn::DataType::Float32>(defaultBackends, true, false, false, true, true);
}
+// Gather
TEST_CASE("RefGatherFloatTest")
{
GatherEndToEnd<armnn::DataType::Float32>(defaultBackends);
@@ -1095,6 +1253,32 @@ TEST_CASE("RefReLuEndToEndTestQSymmS16")
ActivationEndToEndTest<armnn::DataType::QSymmS16>(defaultBackends, ActivationFunction::ReLu);
}
+// GeLu
+TEST_CASE("RefGeluEndToEndTestFloat32")
+{
+ ActivationEndToEndTest<armnn::DataType::Float32>(defaultBackends, ActivationFunction::Gelu);
+}
+
+TEST_CASE("RefGeluEndToEndTestFloat16")
+{
+ ActivationEndToEndTest<armnn::DataType::Float16>(defaultBackends, ActivationFunction::Gelu);
+}
+
+TEST_CASE("RefGeluEndToEndTestQAsymmS8")
+{
+ ActivationEndToEndTest<armnn::DataType::QAsymmS8>(defaultBackends, ActivationFunction::Gelu);
+}
+
+TEST_CASE("RefGeluEndToEndTestQAsymmU8")
+{
+ ActivationEndToEndTest<armnn::DataType::QAsymmU8>(defaultBackends, ActivationFunction::Gelu);
+}
+
+TEST_CASE("RefGeluEndToEndTestQSymmS16")
+{
+ ActivationEndToEndTest<armnn::DataType::QSymmS16>(defaultBackends, ActivationFunction::Gelu);
+}
+
// BoundedReLu
TEST_CASE("RefBoundedReLuEndToEndTestFloat32")
{
@@ -1322,6 +1506,27 @@ TEST_CASE("RefMaxPool2DThreeLayerEndtoEndTestFloat32")
MaxPool2dThreeLayerEndToEnd<DataType::Float32>(defaultBackends);
}
+// Pad
+TEST_CASE("RefPadEndToEndFloat32Test")
+{
+ PadEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefPadEndToEndInt8Test")
+{
+ PadEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
+}
+
+TEST_CASE("RefPad4dEndToEndFloat32Test")
+{
+ Pad4dEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+TEST_CASE("RefPad4dEndToEndInt8Test")
+{
+ Pad4dEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
+}
+
// Quantization
TEST_CASE("QuantizationEndToEndFloat32_U8Test")
{
@@ -1920,14 +2125,125 @@ TEST_CASE("RefRankEndToEndTestQSymmS8")
}
// Reduce
-TEST_CASE("RefReduceEndToEndTest")
+// Reduce Sum
+TEST_CASE("RefReduceSum2dEndtoEndTestSigned32")
+{
+ ReduceEndToEnd2d<DataType::Signed32>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum2dEndtoEndTestSigned32WithKeepDims")
+{
+ ReduceEndToEnd2d<DataType::Signed32>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum2dEndtoEndTestFloat16")
+{
+ ReduceEndToEnd2d<DataType::Float16>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum2dEndtoEndTestFloat16WithKeepDims")
+{
+ ReduceEndToEnd2d<DataType::Float16>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum2dEndtoEndTestFloat32")
+{
+ ReduceEndToEnd2d<DataType::Float32>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum2dEndtoEndTestFloat32WithKeepDims")
+{
+ ReduceEndToEnd2d<DataType::Float32>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum2dEndtoEndTestInt8")
+{
+ ReduceEndToEnd2d<DataType::QAsymmS8>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum2dEndtoEndTestInt8WithKeepDims")
+{
+ ReduceEndToEnd2d<DataType::QAsymmS8>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestSigned32")
+{
+ ReduceEndToEnd3d<DataType::Signed32>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestSigned32WithKeepDims")
+{
+ ReduceEndToEnd3d<DataType::Signed32>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestFloat16")
+{
+ ReduceEndToEnd3d<DataType::Float16>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestFloat16WithKeepDims")
+{
+ ReduceEndToEnd3d<DataType::Float16>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestFloat32")
+{
+ ReduceEndToEnd3d<DataType::Float32>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestFloat32WithKeepDims")
+{
+ ReduceEndToEnd3d<DataType::Float32>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestInt8")
+{
+ ReduceEndToEnd3d<DataType::QAsymmS8>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum3dEndtoEndTestInt8WithKeepDims")
+{
+ ReduceEndToEnd3d<DataType::QAsymmS8>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum4dEndtoEndTestSigned32")
+{
+ ReduceEndToEnd4d<DataType::Signed32>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum4dEndtoEndTestSigned32WithKeepDims")
+{
+ ReduceEndToEnd4d<DataType::Signed32>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum4dEndtoEndTestFloat16")
+{
+ ReduceEndToEnd4d<DataType::Float16>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum4dEndtoEndTestFloat16WithKeepDims")
+{
+ ReduceEndToEnd4d<DataType::Float16>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum4dEndtoEndTestFloat32")
+{
+ ReduceEndToEnd4d<DataType::Float32>(defaultBackends, ReduceOperation::Sum);
+}
+
+TEST_CASE("RefReduceSum4dEndtoEndTestFloat32WithKeepDims")
+{
+ ReduceEndToEnd4d<DataType::Float32>(defaultBackends, ReduceOperation::Sum, true);
+}
+
+TEST_CASE("RefReduceSum4dEndtoEndTestInt8")
{
- ReduceEndToEnd<armnn::DataType::Float32>(defaultBackends);
+ ReduceEndToEnd4d<DataType::QAsymmS8>(defaultBackends, ReduceOperation::Sum);
}
-TEST_CASE("RefReduceEndToEndTestFloat16")
+TEST_CASE("RefReduceSum4dEndtoEndTestInt8WithKeepDims")
{
- ReduceEndToEnd<armnn::DataType::Float16>(defaultBackends);
+ ReduceEndToEnd4d<DataType::QAsymmS8>(defaultBackends, ReduceOperation::Sum, true);
}
// Reshape
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 078338163f..eef70a9b10 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1693,6 +1693,7 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(PadFloat322d, PadFloat322dTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(PadFloat322dCustomPadding, PadFloat322dCustomPaddingTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(PadFloat323d, PadFloat323dTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(PadFloat324d, PadFloat324dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(PadFloat325d, PadFloat325dTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(PadUint82d, PadUint82dTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(PadUint82dCustomPadding, PadUint82dCustomPaddingTest)
@@ -2239,6 +2240,30 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dFloat32, StridedSlice3dFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dReverseFloat32, StridedSlice3dReverseFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice2dFloat32, StridedSlice2dFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice2dReverseFloat32, StridedSlice2dReverseFloat32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask1Float32, StridedSlice3dNewAxisMask1Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask2Float32, StridedSlice3dNewAxisMask2Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask4Float32, StridedSlice3dNewAxisMask4Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask1Float32, StridedSlice3dEllipsisMask1Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask2Float32, StridedSlice3dEllipsisMask2Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask4Float32, StridedSlice3dEllipsisMask4Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask1EllipsisMask1Float32, StridedSlice3dNewAxisMask1EllipsisMask1Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask1EllipsisMask2Float32, StridedSlice3dNewAxisMask1EllipsisMask2Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask1EllipsisMask4Float32, StridedSlice3dNewAxisMask1EllipsisMask4Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask2EllipsisMask1Float32, StridedSlice3dNewAxisMask2EllipsisMask1Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask2EllipsisMask2Float32, StridedSlice3dNewAxisMask2EllipsisMask2Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask2EllipsisMask4Float32, StridedSlice3dNewAxisMask2EllipsisMask4Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask4EllipsisMask1Float32, StridedSlice3dNewAxisMask4EllipsisMask1Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask4EllipsisMask2Float32, StridedSlice3dNewAxisMask4EllipsisMask2Float32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(
+ StridedSlice3dNewAxisMask4EllipsisMask4Float32, StridedSlice3dNewAxisMask4EllipsisMask4Float32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice4dUint8, StridedSlice4dUint8Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice4dReverseUint8, StridedSlice4dReverseUint8Test)
@@ -2267,6 +2292,12 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dUint8, StridedSlice3dUint8Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dReverseUint8, StridedSlice3dReverseUint8Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice2dUint8, StridedSlice2dUint8Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice2dReverseUint8, StridedSlice2dReverseUint8Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask1Uint8, StridedSlice3dNewAxisMask1Uint8Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask2Uint8, StridedSlice3dNewAxisMask2Uint8Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask4Uint8, StridedSlice3dNewAxisMask4Uint8Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask1Uint8, StridedSlice3dEllipsisMask1Uint8Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask2Uint8, StridedSlice3dEllipsisMask2Uint8Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask4Uint8, StridedSlice3dEllipsisMask4Uint8Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice4dInt16, StridedSlice4dInt16Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice4dReverseInt16, StridedSlice4dReverseInt16Test)
@@ -2277,6 +2308,12 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dInt16, StridedSlice3dInt16Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dReverseInt16, StridedSlice3dReverseInt16Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice2dInt16, StridedSlice2dInt16Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice2dReverseInt16, StridedSlice2dReverseInt16Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask1Int16, StridedSlice3dNewAxisMask1Int16Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask2Int16, StridedSlice3dNewAxisMask2Int16Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dNewAxisMask4Int16, StridedSlice3dNewAxisMask4Int16Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask1Int16, StridedSlice3dEllipsisMask1Int16Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask2Int16, StridedSlice3dEllipsisMask2Int16Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(StridedSlice3dEllipsisMask4Int16, StridedSlice3dEllipsisMask4Int16Test)
// Debug
ARMNN_AUTO_TEST_CASE(Debug4dFloat32, Debug4dFloat32Test, /*toFile*/ false)
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp
index 24af0fc4b1..f17ec6b311 100644
--- a/src/backends/reference/workloads/Broadcast.cpp
+++ b/src/backends/reference/workloads/Broadcast.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019 Arm Ltd. All rights reserved.
+// Copyright © 2019,2024 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -38,13 +38,31 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outS
unsigned int sIn = 1;
unsigned int sOut = 1;
+ // Get the difference between the output dimension and input dimension
+ const unsigned int dimDifference = numDims - inShape.GetNumDimensions();
+
for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--)
{
+
m_DimData[j].m_DimSize = outShape[j];
- m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ // Pretend there are extra 1-dimensional tensors prepended
+ if (dimDifference > 0 && j < dimDifference)
+ {
+ m_DimData[j].m_Stride1 = 0;
+ sIn *= 1;
+ }
+ else if (dimDifference > 0)
+ {
+ m_DimData[j].m_Stride1 = (inShape[j - dimDifference] > 1) ? sIn : 0;
+ sIn *= inShape[j - dimDifference];
+ }
+ else
+ {
+ m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ sIn *= inShape[j];
+ }
m_DimData[j].m_StrideOut = sOut;
- sIn *= inShape[j];
sOut *= outShape[j];
}
}
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 4044f06ac4..1d1ca5a856 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2021,2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2021,2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -67,6 +67,7 @@ template struct armnn::ElementwiseBinaryFunction<std::plus<float>>;
template struct armnn::ElementwiseBinaryFunction<std::minus<float>>;
template struct armnn::ElementwiseBinaryFunction<std::multiplies<float>>;
template struct armnn::ElementwiseBinaryFunction<std::divides<float>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::floorDiv<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::power<float>>;
@@ -76,6 +77,7 @@ template struct armnn::ElementwiseBinaryFunction<std::plus<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<std::minus<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<std::multiplies<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<std::divides<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::floorDiv<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<armnn::maximum<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<armnn::minimum<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<armnn::power<int32_t>>;
diff --git a/src/backends/reference/workloads/Maximum.hpp b/src/backends/reference/workloads/Maximum.hpp
index ca4b480b51..1e1f02d68a 100644
--- a/src/backends/reference/workloads/Maximum.hpp
+++ b/src/backends/reference/workloads/Maximum.hpp
@@ -1,24 +1,36 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
-#include <iostream>
+
namespace armnn
{
template<typename T>
struct maximum
+{
+ typedef T result_type;
+ typedef T first_argument_type;
+
+ T operator () (const T& inputData0, const T& inputData1) const
+ {
+ return std::max(inputData0, inputData1);
+ }
+};
+
+template<typename T>
+struct floorDiv
{
typedef T result_type;
typedef T first_argument_type;
- T
- operator () (const T& inputData0, const T& inputData1) const
+ T operator () (const T& inputData0, const T& inputData1) const
{
- return std::max(inputData0, inputData1);
+ double result = static_cast<double>(inputData0)/static_cast<double>(inputData1);
+ return static_cast<T>(std::floor(result));
}
};
diff --git a/src/backends/reference/workloads/Pad.cpp b/src/backends/reference/workloads/Pad.cpp
index f58dbaea61..8273d34365 100644
--- a/src/backends/reference/workloads/Pad.cpp
+++ b/src/backends/reference/workloads/Pad.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2024 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -63,7 +63,9 @@ void Pad(const TensorInfo& inputInfo,
unsigned int inputChannels = 0;
unsigned int inputHeight = 0;
unsigned int inputWidth = 0;
+ unsigned int inputDim5 = 0;
+ unsigned int outputBatches = 0;
unsigned int outputChannels = 0;
unsigned int outputHeight = 0;
unsigned int outputWidth = 0;
@@ -76,6 +78,7 @@ void Pad(const TensorInfo& inputInfo,
{
// For Quantized types Pad Value should not be quantized with scale and offset of the tensor info
auto temporaryInfo = TensorInfo(outputInfo.GetShape(), outputInfo.GetDataType(), 1.0f, 0);
+
auto outputData = MakeEncoder<float>(temporaryInfo, outputHandle->Map());
FillOutputWithPadValue(*outputData, padValue, numOutputElements);
}
@@ -95,13 +98,13 @@ void Pad(const TensorInfo& inputInfo,
{
input[w];
auto inputValue = input.Get();
- auto outputIndex = w + std::get<0>(padList[0]);
+ auto outputIndex = w + padList[0].first;
output[outputIndex];
output.Set(inputValue);
}
break;
- case 2 :
+ case 2:
inputHeight = inputShape[0];
inputWidth = inputShape[1];
outputWidth = outputShape[1];
@@ -112,14 +115,14 @@ void Pad(const TensorInfo& inputInfo,
{
input[h * inputWidth + w];
auto inputValue = input.Get();
- auto outputIndex = (h + std::get<0>(padList[0])) * outputWidth + (w + std::get<0>(padList[1]));
+ auto outputIndex = (h + padList[0].first) * outputWidth + (w + padList[1].first);
output[outputIndex];
output.Set(inputValue);
}
}
break;
- case 3 :
+ case 3:
inputChannels = inputShape[0];
inputHeight = inputShape[1];
inputWidth = inputShape[2];
@@ -134,9 +137,9 @@ void Pad(const TensorInfo& inputInfo,
{
input[c * inputHeight * inputWidth + h * inputWidth + w];
auto inputValue = input.Get();
- auto outputIndex = (c + std::get<0>(padList[0])) * outputHeight * outputWidth
- + (h + std::get<0>(padList[1])) * outputWidth
- + (w + std::get<0>(padList[2]));
+ auto outputIndex = (c + padList[0].first) * outputHeight * outputWidth
+ + (h + padList[1].first) * outputWidth
+ + (w + padList[2].first);
output[outputIndex];
output.Set(inputValue);
}
@@ -144,7 +147,7 @@ void Pad(const TensorInfo& inputInfo,
}
break;
- case 4 :
+ case 4:
inputBatches = inputShape[0];
inputChannels = inputShape[1];
inputHeight = inputShape[2];
@@ -162,24 +165,69 @@ void Pad(const TensorInfo& inputInfo,
for (unsigned int w = 0; w < inputWidth ; w++)
{
input[b * inputChannels * inputHeight * inputWidth
- + c * inputHeight * inputWidth
- + h * inputWidth
- + w];
+ + c * inputHeight * inputWidth
+ + h * inputWidth
+ + w];
auto inputValue = input.Get();
- auto outputIndex = (b + std::get<0>(padList[0]))
+ auto outputIndex = (b + padList[0].first)
* outputChannels * outputHeight * outputWidth
- + (c + std::get<0>(padList[1])) * outputHeight * outputWidth
- + (h + std::get<0>(padList[2])) * outputWidth
- + (w + std::get<0>(padList[3]));
+ + (c + padList[1].first) * outputHeight * outputWidth
+ + (h + padList[2].first) * outputWidth
+ + (w + padList[3].first);
output[outputIndex];
output.Set(inputValue);
}
}
}
}
+ break;
+ case 5:
+ inputBatches = inputShape[0];
+ inputChannels = inputShape[1];
+ inputHeight = inputShape[2];
+ inputWidth = inputShape[3];
+ inputDim5 = inputShape[4];
+
+ outputBatches = outputShape[1];
+ outputChannels = outputShape[2];
+ outputHeight = outputShape[3];
+ outputWidth = outputShape[4];
+
+ for (unsigned int b = 0; b < inputBatches; ++b)
+ {
+ for (unsigned int c = 0; c < inputChannels; ++c)
+ {
+ for (unsigned int h = 0; h < inputHeight; ++h)
+ {
+ for (unsigned int w = 0; w < inputWidth ; ++w)
+ {
+ for (unsigned int d = 0; d < inputDim5 ; ++d)
+ {
+ input[b * inputChannels * inputHeight * inputWidth * inputDim5
+ + c * inputHeight * inputWidth * inputDim5
+ + h * inputWidth * inputDim5
+ + d];
+
+ auto inputValue = input.Get();
+
+ auto outputIndex = (b + padList[0].first)
+ * outputBatches * outputChannels * outputHeight * outputWidth
+ + (c + padList[1].first) * outputChannels * outputHeight*outputWidth
+ + (h + padList[2].first) * outputHeight * outputWidth
+ + (w + padList[3].first) * outputWidth
+ + (d + padList[4].first);
+
+ output[outputIndex];
+ output.Set(inputValue);
+ }
+ }
+ }
+ }
+ }
break;
- default :
+
+ default:
break;
}
}
diff --git a/src/backends/reference/workloads/RefCastWorkload.cpp b/src/backends/reference/workloads/RefCastWorkload.cpp
index 40fbce6f4e..c8484d9672 100644
--- a/src/backends/reference/workloads/RefCastWorkload.cpp
+++ b/src/backends/reference/workloads/RefCastWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -12,17 +12,54 @@
namespace
{
- void Cast(armnn::Decoder<float>& in, armnn::Encoder<float>& out, const uint32_t numElements )
+ void Cast(armnn::Decoder<float>& in, armnn::Encoder<float>& out,
+ const uint32_t numElements, const armnn::DataType OutputDataType)
{
- for (unsigned int i = 0; i < numElements; i++)
+ for (unsigned int i = 0; i < numElements; ++i)
+ {
+ switch (OutputDataType)
+ {
+ case armnn::DataType::Float32:
+ case armnn::DataType::Float16:
+ case armnn::DataType::BFloat16:
+ out.Set(in.Get());
+ break;
+ default:
+ out.Set(std::floor(in.Get()));
+ break;
+ }
+ ++in;
+ ++out;
+ }
+ }
+
+
+ // Cast Float to Int64
+ void Cast(armnn::Decoder<float>& in, armnn::Encoder<double_t>& out,
+ const uint32_t numElements, const armnn::DataType)
+ {
+ for (unsigned int i = 0; i < numElements; ++i)
{
out.Set(in.Get());
++in;
++out;
}
}
+
+ // Cast Int64 To Float
+ void Cast(armnn::Decoder<double_t>& in, armnn::Encoder<float>& out,
+ const uint32_t numElements, const armnn::DataType)
+ {
+ for (unsigned int i = 0; i < numElements; ++i)
+ {
+ out.Set(static_cast<float>(in.Get()));
+ ++in;
+ ++out;
+ }
+ }
}
+
namespace armnn
{
@@ -56,9 +93,27 @@ void RefCastWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<IT
outputTensorInfo.SetQuantizationOffset(0);
}
- Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()),
- *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()),
- inputTensorInfo.GetNumElements());
+ if(inputTensorInfo.GetDataType() == DataType::Signed64)
+ {
+ Cast(*MakeDecoder<double_t>(inputTensorInfo, inputs[0]->Map()),
+ *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()),
+ inputTensorInfo.GetNumElements(),
+ outputTensorInfo.GetDataType());
+ }
+ else if(outputTensorInfo.GetDataType() == DataType::Signed64)
+ {
+ Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()),
+ *MakeEncoder<double_t>(outputTensorInfo, outputs[0]->Map()),
+ inputTensorInfo.GetNumElements(),
+ outputTensorInfo.GetDataType());
+ }
+ else
+ {
+ Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()),
+ *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()),
+ inputTensorInfo.GetNumElements(),
+ outputTensorInfo.GetDataType());
+ }
}
} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
index 2f30dff211..0cefe0f20d 100644
--- a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -26,7 +26,8 @@ namespace armnn
template<typename DataType>
void ExecuteFunction(std::vector<ITensorHandle*> inputs,
std::vector<ITensorHandle*> outputs,
- BinaryOperation operation)
+ BinaryOperation operation,
+ const std::string& layerName = "")
{
const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
@@ -42,6 +43,7 @@ void ExecuteFunction(std::vector<ITensorHandle*> inputs,
using AddFunction = ElementwiseBinaryFunction<std::plus<DataType>>;
using DivFunction = ElementwiseBinaryFunction<std::divides<DataType>>;
+ using FloorDivFunction = ElementwiseBinaryFunction<armnn::floorDiv<DataType>>;
using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>;
using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>;
@@ -49,6 +51,7 @@ void ExecuteFunction(std::vector<ITensorHandle*> inputs,
using SqDiffFunction = ElementwiseBinaryFunction<armnn::squaredDifference<DataType>>;
using PowerFunction = ElementwiseBinaryFunction<armnn::power<DataType>>;
+
switch (operation)
{
case BinaryOperation::Add:
@@ -58,7 +61,14 @@ void ExecuteFunction(std::vector<ITensorHandle*> inputs,
}
case BinaryOperation::Div:
{
- DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ if(!layerName.empty() && layerName.find("FloorDiv") != std::string::npos)
+ {
+ FloorDivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ }
+ else
+ {
+ DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ }
break;
}
case BinaryOperation::Maximum:
@@ -123,11 +133,11 @@ void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs,
if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32)
{
- ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation);
+ ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation, m_Name);
}
else
{
- ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation);
+ ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation, m_Name);
}
}
diff --git a/src/backends/reference/workloads/Slice.cpp b/src/backends/reference/workloads/Slice.cpp
index 534a063ed5..1232e9f373 100644
--- a/src/backends/reference/workloads/Slice.cpp
+++ b/src/backends/reference/workloads/Slice.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019 Arm Ltd. All rights reserved.
+// Copyright © 2019,2024 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -20,7 +20,7 @@ void Slice(const TensorInfo& inputInfo,
const TensorShape& inputShape = inputInfo.GetShape();
const unsigned int numDims = inputShape.GetNumDimensions();
- constexpr unsigned int maxNumDims = 4;
+ constexpr unsigned int maxNumDims = 5;
if (descriptor.m_Begin.size() != numDims)
{
std::stringstream msg;
@@ -43,9 +43,9 @@ void Slice(const TensorInfo& inputInfo,
throw InvalidArgumentException(msg.str());
}
- std::vector<unsigned int> paddedInput(4);
- std::vector<unsigned int> paddedBegin(4);
- std::vector<unsigned int> paddedSize (4);
+ std::vector<unsigned int> paddedInput(5);
+ std::vector<unsigned int> paddedBegin(5);
+ std::vector<unsigned int> paddedSize (5);
const unsigned int numPaddingDims = maxNumDims - numDims;
for (unsigned int i = 0u; i < maxNumDims; ++i)
@@ -69,16 +69,19 @@ void Slice(const TensorInfo& inputInfo,
unsigned int dim1 = paddedInput[1];
unsigned int dim2 = paddedInput[2];
unsigned int dim3 = paddedInput[3];
+ unsigned int dim4 = paddedInput[4];
unsigned int begin0 = paddedBegin[0];
unsigned int begin1 = paddedBegin[1];
unsigned int begin2 = paddedBegin[2];
unsigned int begin3 = paddedBegin[3];
+ unsigned int begin4 = paddedBegin[4];
unsigned int size0 = paddedSize[0];
unsigned int size1 = paddedSize[1];
unsigned int size2 = paddedSize[2];
unsigned int size3 = paddedSize[3];
+ unsigned int size4 = paddedSize[4];
if (begin0 + size0 > dim0)
{
@@ -129,11 +132,14 @@ void Slice(const TensorInfo& inputInfo,
{
for (unsigned int idx3 = begin3; idx3 < begin3 + size3; ++idx3)
{
- const unsigned int inputOffset =
- (((idx0 * dim1 + idx1) * dim2 + idx2) * dim3 + idx3) * dataTypeSize;
-
- ::memcpy(output, input + inputOffset, dataTypeSize);
- output += dataTypeSize;
+ for (unsigned int idx4 = begin4; idx4 < begin4 + size4; ++idx4)
+ {
+ const unsigned int inputOffset =
+ ((((idx0 * dim1 + idx1) * dim2 + idx2) * dim3 + idx3) * dim4 + idx4) * dataTypeSize;
+
+ ::memcpy(output, input + inputOffset, dataTypeSize);
+ output += dataTypeSize;
+ }
}
}
}
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp
index fcd1c357f8..a8828fdfbe 100644
--- a/src/backends/reference/workloads/StridedSlice.cpp
+++ b/src/backends/reference/workloads/StridedSlice.cpp
@@ -108,34 +108,105 @@ void StridedSlice(const TensorInfo& inputInfo,
// Pad parameters to 4 dimensions
PadParams(paddedParams, 4);
- const int start0 = paddedParams.GetStartForAxis(inputShape, 0);
- const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0);
+ // Arrays containing the start and stop index for each axis (adjusted by set params/flags)
+ int startArray [4] = {0};
+ int stopArray [4] = {0};
- const int start1 = paddedParams.GetStartForAxis(inputShape, 1);
- const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1);
+ // Getting paddedParams stop and start values for each axis
+ for(unsigned int i = 0; i < 4; ++i)
+ {
+ startArray[i] = paddedParams.GetStartForAxis(inputShape, i);
+ stopArray[i] = paddedParams.GetStopForAxis(inputShape, i, startArray[i]);
+ }
- const int start2 = paddedParams.GetStartForAxis(inputShape, 2);
- const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2);
+ // Adjusting the EllipsisMask based on the NewAxisMask
+ // (if NewAxisMask extends an axis, the ellipsis flag is extended as well)
+ if(paddedParams.m_NewAxisMask > 0 && paddedParams.m_EllipsisMask > 0)
+ {
+ // Iterate until the current EllipsisMask 1-bit found
+ for(unsigned int i = 0; i < 4; ++i)
+ {
+ // If EllipsisMask bit found, adjust based on NewAxisMask and exit loop
+ if(paddedParams.m_EllipsisMask & (1 << i) && !(paddedParams.m_NewAxisMask & (1 << i)))
+ {
+ // If the previous bit is the NewAxisMask, set the EllipsisMask there
+ // (this condition was determined based on the unit tests expected data)
+ if(paddedParams.m_NewAxisMask & (1 << (i-1)))
+ {
+ paddedParams.m_EllipsisMask |= (1 << (i-1));
+ }
+ // Otherwise, extend the EllipsisMask by one bit
+ else
+ {
+ paddedParams.m_EllipsisMask |= (1 << (i+1));
+ }
+ break;
+ }
+ }
+ }
- const int start3 = paddedParams.GetStartForAxis(inputShape, 3);
- const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3);
+ // Processing start and stop values based on the EllipsisMask and NewAxisMask
+ for(unsigned int i = 0, dimIdx = 0; i < 4; ++i)
+ {
+ // If the EllipsisMask is set, extend the start/stop to the input dimension size
+ if(paddedParams.m_EllipsisMask & (1 << dimIdx))
+ {
+ startArray[i] = 0;
+ stopArray[i] = armnn::numeric_cast<int>(inputShape[i]);
+ }
+ // Otherwise, if the NewAxisMask is set, shift all following start/stop values to the left
+ else if(paddedParams.m_NewAxisMask & (1 << dimIdx))
+ {
+ // Increment dimIdx - skip the current dimension for which NewAxisMask is set
+ ++dimIdx;
+ }
+
+ // If the index of the currently processed dimension is higher than
+ // the index of the current start/stop array position, shift start/stop values
+ if(dimIdx > i && !(paddedParams.m_EllipsisMask & (1 << dimIdx)))
+ {
+ if(dimIdx < 4)
+ {
+ startArray[i] = startArray[dimIdx];
+ stopArray[i] = stopArray[dimIdx];
+ }
+ else
+ {
+ // If dimIdx is greater than the amount of available dimensions,
+ // instead of shifting the next ones, create new start/stop values
+ if(paddedParams.m_EllipsisMask > 0)
+ {
+ // The new values are 0,1 if there is an EllipsisMask bit present
+ startArray[i] = 0;
+ stopArray[i] = 1;
+ }
+ else
+ {
+ // Otherwise, select the entire inputTensor dimension size
+ startArray[i] = 0;
+ stopArray[i] = armnn::numeric_cast<int>(inputShape[i]);
+ }
+ }
+ }
+ ++dimIdx;
+ }
const int step = armnn::numeric_cast<int>(dataTypeSize);
- for (int in0 = start0;
- !LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
+ for (int in0 = startArray[0];
+ !LoopCondition(in0, stopArray[0], paddedParams.m_Stride[0]);
in0 += paddedParams.m_Stride[0])
{
- for (int in1 = start1;
- !LoopCondition(in1, stop1, paddedParams.m_Stride[1]);
+ for (int in1 = startArray[1];
+ !LoopCondition(in1, stopArray[1], paddedParams.m_Stride[1]);
in1 += paddedParams.m_Stride[1])
{
- for (int in2 = start2;
- !LoopCondition(in2, stop2, paddedParams.m_Stride[2]);
+ for (int in2 = startArray[2];
+ !LoopCondition(in2, stopArray[2], paddedParams.m_Stride[2]);
in2 += paddedParams.m_Stride[2])
{
- for (int in3 = start3;
- !LoopCondition(in3, stop3, paddedParams.m_Stride[3]);
+ for (int in3 = startArray[3];
+ !LoopCondition(in3, stopArray[3], paddedParams.m_Stride[3]);
in3 += paddedParams.m_Stride[3])
{
int dim1 = armnn::numeric_cast<int>(inputShape[1]);