aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2019-05-28 16:48:20 +0100
committerRuomei Yan <ruomei.yan@arm.com>2019-05-30 15:45:40 +0100
commit25339c31a829111fce691311cb84100e8591f5da (patch)
tree9d79436f67f4570e4e4d1c52516de7f3df6eb7e0
parentedb8b2ec42624ec27c37bc1b3d345ef0c97c024c (diff)
downloadarmnn-25339c31a829111fce691311cb84100e8591f5da.tar.gz
IVGCVSW-3159 Support QSymm16 for Splitter workloads
Change-Id: I9af5d2d8ade97b9ecd2e6fbf13db9fa3bb622ed8 Signed-off-by: Ruomei Yan <ruomei.yan@arm.com>
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp21
-rw-r--r--src/backends/backendsCommon/test/LayerTests.cpp14
-rw-r--r--src/backends/backendsCommon/test/LayerTests.hpp8
-rw-r--r--src/backends/backendsCommon/test/SplitterTestImpl.hpp12
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp6
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp6
-rw-r--r--src/backends/reference/backend.mk6
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp12
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp2
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt7
-rw-r--r--src/backends/reference/workloads/RefSplitterFloat32Workload.cpp21
-rw-r--r--src/backends/reference/workloads/RefSplitterUint8Workload.cpp21
-rw-r--r--src/backends/reference/workloads/RefSplitterUint8Workload.hpp21
-rw-r--r--src/backends/reference/workloads/RefSplitterWorkload.cpp20
-rw-r--r--src/backends/reference/workloads/RefSplitterWorkload.hpp (renamed from src/backends/reference/workloads/RefSplitterFloat32Workload.hpp)6
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp3
-rw-r--r--src/backends/reference/workloads/Splitter.cpp94
-rw-r--r--src/backends/reference/workloads/Splitter.hpp3
18 files changed, 191 insertions, 92 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index e8e10d972a..c94fa25ac2 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -350,6 +350,27 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateNumInputs(workloadInfo, "SplitterQueueDescriptor", 1);
+ // Check the supported data types
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::Boolean,
+ DataType::Signed32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ for (unsigned long i = 0; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
+ {
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[i],
+ supportedTypes,
+ "SplitterQueueDescriptor");
+ }
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ {workloadInfo.m_InputTensorInfos[0].GetDataType()},
+ "SplitterQueueDescriptor");
+
if (workloadInfo.m_OutputTensorInfos.size() <= 0)
{
throw InvalidArgumentException("SplitterQueueDescriptor: At least one output needs to be provided.");
diff --git a/src/backends/backendsCommon/test/LayerTests.cpp b/src/backends/backendsCommon/test/LayerTests.cpp
index de3c857399..34adf90379 100644
--- a/src/backends/backendsCommon/test/LayerTests.cpp
+++ b/src/backends/backendsCommon/test/LayerTests.cpp
@@ -1138,6 +1138,13 @@ std::vector<LayerTestResult<uint8_t,3>> SplitterUint8Test(
return SplitterTestCommon<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, 1.0f, 0);
}
+std::vector<LayerTestResult<int16_t,3>> SplitterInt16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return SplitterTestCommon<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, 1.0f, 0);
+}
+
LayerTestResult<float, 3> CopyViaSplitterTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
@@ -1152,6 +1159,13 @@ LayerTestResult<uint8_t, 3> CopyViaSplitterUint8Test(
return CopyViaSplitterTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, 1.0f, 0);
}
+LayerTestResult<int16_t, 3> CopyViaSplitterInt16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return CopyViaSplitterTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, 1.0f, 0);
+}
+
LayerTestResult<float, 2> LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index 893252b054..7607bf0720 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -821,10 +821,18 @@ std::vector<LayerTestResult<uint8_t, 3>> SplitterUint8Test(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+std::vector<LayerTestResult<int16_t, 3>> SplitterInt16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
LayerTestResult<uint8_t, 3> CopyViaSplitterUint8Test(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+LayerTestResult<int16_t, 3> CopyViaSplitterInt16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
LayerTestResult<uint8_t, 3> ConcatUint8Test(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
diff --git a/src/backends/backendsCommon/test/SplitterTestImpl.hpp b/src/backends/backendsCommon/test/SplitterTestImpl.hpp
index 004060f0b8..de677ef45d 100644
--- a/src/backends/backendsCommon/test/SplitterTestImpl.hpp
+++ b/src/backends/backendsCommon/test/SplitterTestImpl.hpp
@@ -46,15 +46,15 @@ std::vector<LayerTestResult<T,3>> SplitterTestCommon(
// Define the tensor descriptors.
- armnn::TensorInfo inputTensorInfo({ inputChannels, inputHeight, inputWidth }, ArmnnType);
+ armnn::TensorInfo inputTensorInfo({ inputChannels, inputHeight, inputWidth }, ArmnnType, qScale, qOffset);
// Outputs of the original split.
- armnn::TensorInfo outputTensorInfo1({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType);
- armnn::TensorInfo outputTensorInfo2({ outputChannels2, outputHeight2, outputWidth2 }, ArmnnType);
+ armnn::TensorInfo outputTensorInfo1({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputTensorInfo2({ outputChannels2, outputHeight2, outputWidth2 }, ArmnnType, qScale, qOffset);
// Outputs of the subsequent subtensor split.
- armnn::TensorInfo outputTensorInfo3({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType);
- armnn::TensorInfo outputTensorInfo4({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType);
+ armnn::TensorInfo outputTensorInfo3({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputTensorInfo4({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType, qScale, qOffset);
// Set quantization parameters if the requested type is a quantized type.
// The quantization doesn't really matter as the splitter operator doesn't dequantize/quantize.
@@ -251,7 +251,7 @@ LayerTestResult<T, 3> CopyViaSplitterTestImpl(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
float qScale, int32_t qOffset)
{
- const armnn::TensorInfo tensorInfo({ 3, 6, 5 }, ArmnnType);
+ const armnn::TensorInfo tensorInfo({ 3, 6, 5 }, ArmnnType, qScale, qOffset);
auto input = MakeTensor<T, 3>(tensorInfo, QuantizedVector<T>(qScale, qOffset,
{
1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 795791fc5e..7d04e3220f 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -211,14 +211,14 @@ BOOST_AUTO_TEST_CASE(SplitterQueueDescriptor_Validate_WrongWindow)
BOOST_TEST_INFO("Invalid argument exception is expected, because split window dimensionality does not "
"match input.");
- BOOST_CHECK_THROW(RefSplitterFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefSplitterWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
// Invalid, since window extends past the boundary of input tensor.
std::vector<unsigned int> wOrigin3 = {0, 0, 15, 0};
armnn::SplitterQueueDescriptor::ViewOrigin window3(wOrigin3);
invalidData.m_ViewOrigins[0] = window3;
BOOST_TEST_INFO("Invalid argument exception is expected (wOrigin3[2]+ outputHeight > inputHeight");
- BOOST_CHECK_THROW(RefSplitterFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefSplitterWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
std::vector<unsigned int> wOrigin4 = {0, 0, 0, 0};
@@ -230,7 +230,7 @@ BOOST_AUTO_TEST_CASE(SplitterQueueDescriptor_Validate_WrongWindow)
invalidData.m_ViewOrigins.push_back(window5);
BOOST_TEST_INFO("Invalid exception due to number of split windows not matching number of outputs.");
- BOOST_CHECK_THROW(RefSplitterFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefSplitterWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 50e3c0006c..71b1cb034d 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -133,7 +133,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueD
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefSplitterFloat32Workload, RefSplitterUint8Workload>(descriptor, info);
+ if (IsFloat16(info))
+ {
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ }
+ return std::make_unique<RefSplitterWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 57204a05ac..46aac056bf 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -61,14 +61,14 @@ BACKEND_SOURCES := \
workloads/RefSoftmaxWorkload.cpp \
workloads/RefSpaceToBatchNdWorkload.cpp \
workloads/RefStridedSliceWorkload.cpp \
- workloads/RefSplitterFloat32Workload.cpp \
- workloads/RefSplitterUint8Workload.cpp \
+ workloads/RefSplitterWorkload.cpp \
workloads/ResizeBilinear.cpp \
workloads/Rsqrt.cpp \
workloads/SpaceToBatchNd.cpp \
workloads/StridedSlice.cpp \
workloads/StringMapping.cpp \
- workloads/Softmax.cpp
+ workloads/Softmax.cpp \
+ workloads/Splitter.cpp
# BACKEND_TEST_SOURCES contains the list of files to be included
# in the Android unit test build (armnn-tests) and it is picked
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 78083fa07b..0311276f10 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -470,12 +470,12 @@ static void RefCreateSplitterWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateSplitterFloat32Workload)
{
- RefCreateSplitterWorkloadTest<RefSplitterFloat32Workload, armnn::DataType::Float32>();
+ RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateSplitterUint8Workload)
{
- RefCreateSplitterWorkloadTest<RefSplitterUint8Workload, armnn::DataType::QuantisedAsymm8>();
+ RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::QuantisedAsymm8>();
}
template <typename SplitterWorkloadType, typename ConcatWorkloadType, armnn::DataType DataType>
@@ -513,12 +513,12 @@ static void RefCreateSplitterConcatWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateSplitterConcatFloat32)
{
- RefCreateSplitterConcatWorkloadTest<RefSplitterFloat32Workload, RefConcatWorkload, DataType::Float32>();
+ RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateSplitterConcatUint8)
{
- RefCreateSplitterConcatWorkloadTest<RefSplitterUint8Workload, RefConcatWorkload, DataType::QuantisedAsymm8>();
+ RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::QuantisedAsymm8>();
}
template <typename SplitterWorkloadType, typename ActivationWorkloadType, armnn::DataType DataType>
@@ -561,13 +561,13 @@ static void RefCreateSingleOutputMultipleInputsTest()
BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputsFloat32)
{
- RefCreateSingleOutputMultipleInputsTest<RefSplitterFloat32Workload, RefActivationWorkload,
+ RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
armnn::DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateSingleOutputMultipleInputsUint8)
{
- RefCreateSingleOutputMultipleInputsTest<RefSplitterUint8Workload, RefActivationWorkload,
+ RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
armnn::DataType::QuantisedAsymm8>();
}
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 522d673f07..690a78c21c 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -253,9 +253,11 @@ ARMNN_AUTO_TEST_CASE(FullyConnectedLargeTransposed, FullyConnectedLargeTest, tru
// Splitter
ARMNN_AUTO_TEST_CASE(SimpleSplitter, SplitterTest)
ARMNN_AUTO_TEST_CASE(SimpleSplitterUint8, SplitterUint8Test)
+ARMNN_AUTO_TEST_CASE(SimpleSplitterInt16, SplitterInt16Test)
ARMNN_AUTO_TEST_CASE(CopyViaSplitter, CopyViaSplitterTest)
ARMNN_AUTO_TEST_CASE(CopyViaSplitterUint8, CopyViaSplitterUint8Test)
+ARMNN_AUTO_TEST_CASE(CopyViaSplitterInt16, CopyViaSplitterInt16Test)
// Concat
ARMNN_AUTO_TEST_CASE(SimpleConcat, ConcatTest)
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index e2f93d72a9..e57a40a9ba 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -103,10 +103,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefSoftmaxWorkload.hpp
RefSpaceToBatchNdWorkload.cpp
RefSpaceToBatchNdWorkload.hpp
- RefSplitterFloat32Workload.cpp
- RefSplitterFloat32Workload.hpp
- RefSplitterUint8Workload.cpp
- RefSplitterUint8Workload.hpp
+ RefSplitterWorkload.cpp
+ RefSplitterWorkload.hpp
RefStridedSliceWorkload.cpp
RefStridedSliceWorkload.hpp
RefWorkloads.hpp
@@ -120,6 +118,7 @@ list(APPEND armnnRefBackendWorkloads_sources
SpaceToBatchNd.hpp
SpaceToBatchNd.cpp
Splitter.hpp
+ Splitter.cpp
StridedSlice.hpp
StridedSlice.cpp
StringMapping.cpp
diff --git a/src/backends/reference/workloads/RefSplitterFloat32Workload.cpp b/src/backends/reference/workloads/RefSplitterFloat32Workload.cpp
deleted file mode 100644
index 75611dacf3..0000000000
--- a/src/backends/reference/workloads/RefSplitterFloat32Workload.cpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefSplitterFloat32Workload.hpp"
-
-#include "Splitter.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-
-void RefSplitterFloat32Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSplitterFloat32Workload_Execute");
- Splitter<float>(m_Data);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefSplitterUint8Workload.cpp b/src/backends/reference/workloads/RefSplitterUint8Workload.cpp
deleted file mode 100644
index ca9f5db850..0000000000
--- a/src/backends/reference/workloads/RefSplitterUint8Workload.cpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefSplitterUint8Workload.hpp"
-
-#include "Splitter.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-
-void RefSplitterUint8Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSplitterUint8Workload_Execute");
- Splitter<uint8_t>(m_Data);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefSplitterUint8Workload.hpp b/src/backends/reference/workloads/RefSplitterUint8Workload.hpp
deleted file mode 100644
index d9b6aaf639..0000000000
--- a/src/backends/reference/workloads/RefSplitterUint8Workload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefSplitterUint8Workload : public Uint8Workload<SplitterQueueDescriptor>
-{
-public:
- using Uint8Workload<SplitterQueueDescriptor>::Uint8Workload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefSplitterWorkload.cpp b/src/backends/reference/workloads/RefSplitterWorkload.cpp
new file mode 100644
index 0000000000..ffe4eb880b
--- /dev/null
+++ b/src/backends/reference/workloads/RefSplitterWorkload.cpp
@@ -0,0 +1,20 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefSplitterWorkload.hpp"
+#include "Splitter.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+void RefSplitterWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSplitterWorkload_Execute");
+ Split(m_Data);
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefSplitterFloat32Workload.hpp b/src/backends/reference/workloads/RefSplitterWorkload.hpp
index 502eb3555f..95cc4a5db7 100644
--- a/src/backends/reference/workloads/RefSplitterFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefSplitterWorkload.hpp
@@ -7,14 +7,16 @@
#include <backendsCommon/Workload.hpp>
#include <backendsCommon/WorkloadData.hpp>
+#include "Decoders.hpp"
+#include "Encoders.hpp"
namespace armnn
{
-class RefSplitterFloat32Workload : public Float32Workload<SplitterQueueDescriptor>
+class RefSplitterWorkload : public BaseWorkload<SplitterQueueDescriptor>
{
public:
- using Float32Workload<SplitterQueueDescriptor>::Float32Workload;
+ using BaseWorkload<SplitterQueueDescriptor>::BaseWorkload;
virtual void Execute() const override;
};
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index ab3da88437..c20e2f6191 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -10,7 +10,7 @@
#include "ConvImpl.hpp"
#include "RefConstantWorkload.hpp"
#include "RefConvolution2dWorkload.hpp"
-#include "RefSplitterUint8Workload.hpp"
+#include "RefSplitterWorkload.hpp"
#include "RefResizeBilinearUint8Workload.hpp"
#include "RefL2NormalizationFloat32Workload.hpp"
#include "RefActivationWorkload.hpp"
@@ -39,7 +39,6 @@
#include "Activation.hpp"
#include "Concatenate.hpp"
#include "RefSpaceToBatchNdWorkload.hpp"
-#include "RefSplitterFloat32Workload.hpp"
#include "RefStridedSliceWorkload.hpp"
#include "Pooling2d.hpp"
#include "RefFakeQuantizationFloat32Workload.hpp"
diff --git a/src/backends/reference/workloads/Splitter.cpp b/src/backends/reference/workloads/Splitter.cpp
new file mode 100644
index 0000000000..3bddfb0cab
--- /dev/null
+++ b/src/backends/reference/workloads/Splitter.cpp
@@ -0,0 +1,94 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefWorkloadUtils.hpp"
+#include <backendsCommon/WorkloadData.hpp>
+#include <armnn/Tensor.hpp>
+
+#include <boost/assert.hpp>
+#include "Splitter.hpp"
+
+#include <cmath>
+#include <limits>
+
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+
+namespace armnn
+{
+
+void Split(const SplitterQueueDescriptor& data)
+{
+ const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
+
+ std::unique_ptr<Decoder<float>> decoderPtr =
+ MakeDecoder<float>(inputInfo, data.m_Inputs[0]->Map());
+ Decoder<float>& decoder = *decoderPtr;
+
+ for (unsigned int index = 0; index < inputInfo.GetNumElements(); ++index)
+ {
+ unsigned int indices[MaxNumOfTensorDimensions] = { 0 };
+
+ unsigned int indexRemainder = index;
+ unsigned int dimensionStride = inputInfo.GetNumElements();
+
+ for (unsigned int i = 0; i<inputInfo.GetNumDimensions(); i++)
+ {
+ dimensionStride /= inputInfo.GetShape()[i];
+ indices[i] = indexRemainder / dimensionStride; // Use integer division to round down.
+ indexRemainder -= indices[i] * dimensionStride;
+ }
+
+ for (unsigned int viewIdx = 0; viewIdx < data.m_ViewOrigins.size(); ++viewIdx)
+ {
+ SplitterQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx];
+
+ //Split view extents are defined by the size of (the corresponding) input tensor.
+ const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[viewIdx]);
+ BOOST_ASSERT(outputInfo.GetNumDimensions() == inputInfo.GetNumDimensions());
+
+ // Check all dimensions to see if this element is inside the given input view.
+ bool insideView = true;
+ for (unsigned int i = 0; i<outputInfo.GetNumDimensions(); i++)
+ {
+ if (indices[i] < view.m_Origin[i])
+ {
+ insideView = false;
+ }
+ if (indices[i] >= view.m_Origin[i] + outputInfo.GetShape()[i])
+ {
+ insideView = false;
+ }
+ }
+
+ if (insideView)
+ {
+ std::unique_ptr<Encoder<float>> encoderPtr =
+ MakeEncoder<float>(outputInfo, data.m_Outputs[viewIdx]->Map());
+ Encoder<float>& encoder = *encoderPtr;
+
+ unsigned int outIndex = 0;
+ unsigned int dimensionStride = 1;
+ float inputValue = 0.f;
+
+ for (unsigned int i = outputInfo.GetNumDimensions(); i-- > 0;)
+ {
+ outIndex += dimensionStride * (indices[i] - view.m_Origin[i]);
+ dimensionStride *= outputInfo.GetShape()[i];
+ }
+
+ decoder += index;
+ inputValue = decoder.Get();
+ decoder -= index;
+
+ encoder += outIndex;
+ encoder.Set(inputValue);
+ break;
+ }
+ }
+ }
+}
+
+} \ No newline at end of file
diff --git a/src/backends/reference/workloads/Splitter.hpp b/src/backends/reference/workloads/Splitter.hpp
index 0e522d5ad5..271c6fdeb8 100644
--- a/src/backends/reference/workloads/Splitter.hpp
+++ b/src/backends/reference/workloads/Splitter.hpp
@@ -6,10 +6,8 @@
#pragma once
#include "RefWorkloadUtils.hpp"
-
#include <backendsCommon/WorkloadData.hpp>
#include <armnn/Tensor.hpp>
-
#include <boost/assert.hpp>
namespace armnn
@@ -80,4 +78,5 @@ void Splitter(const SplitterQueueDescriptor& data)
}
}
+void Split(const SplitterQueueDescriptor& data);
} //namespace armnn