aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNina Drozd <nina.drozd@arm.com>2019-05-27 10:37:05 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-05-28 11:43:10 +0000
commit2f2778f36e59537bbd47fb8b21e73c6c5a949584 (patch)
treee91943ef038ef02f12eccf69fb94f2b25ab865e5
parent0be43386864ad3f86ad6c569520c9374883776f7 (diff)
downloadarmnn-2f2778f36e59537bbd47fb8b21e73c6c5a949584.tar.gz
IVGCVSW-3145 Refactor Reference Reshape workloads
* Removed reference reshape workloads for float32 and uint8 * Added RefReshapeWorkload * Added check for supported datatypes for reshape in WorkloadData * Added check for supported datatypes for reshape in RefLayerSupport * Updated CMakeLists.txt * Updated references to reshape workloads Signed-off-by: Nina Drozd <nina.drozd@arm.com> Change-Id: I9941659067b022f8f7686ab0ff14776944dca3e5
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp25
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp2
-rw-r--r--src/backends/reference/RefLayerSupport.cpp13
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/backend.mk3
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp4
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt6
-rw-r--r--src/backends/reference/workloads/RefReshapeFloat32Workload.cpp27
-rw-r--r--src/backends/reference/workloads/RefReshapeFloat32Workload.hpp21
-rw-r--r--src/backends/reference/workloads/RefReshapeWorkload.cpp (renamed from src/backends/reference/workloads/RefReshapeUint8Workload.cpp)10
-rw-r--r--src/backends/reference/workloads/RefReshapeWorkload.hpp (renamed from src/backends/reference/workloads/RefReshapeUint8Workload.hpp)6
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp3
12 files changed, 42 insertions, 80 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d9779e4e37..ea84c0b9f2 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -850,13 +850,13 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
// Check the supported data types
std::vector<DataType> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::Signed32,
- DataType::QuantisedAsymm8,
- DataType::QuantisedSymm16
- };
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::Signed32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ConstantQueueDescriptor");
}
@@ -872,6 +872,17 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " +
to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements.");
}
+
+ // Check the supported data types
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor");
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor");
}
void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 119eb7df90..067cca8319 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -447,7 +447,7 @@ BOOST_AUTO_TEST_CASE(ReshapeQueueDescriptor_Validate_MismatchingNumElements)
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
// InvalidArgumentException is expected, because the number of elements don't match.
- BOOST_CHECK_THROW(RefReshapeFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefReshapeWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 9be1ed6d74..2adcb1099d 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1021,10 +1021,15 @@ bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ // Define supported output types.
+ std::array<DataType,3> supportedOutputTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QuantisedAsymm8
+ };
+ return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
+ "Reference reshape: input type not supported.");
}
bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 6abcf9cd08..1243328852 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -264,7 +264,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueu
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefReshapeFloat32Workload, RefReshapeUint8Workload>(descriptor, info);
+ return std::make_unique<RefReshapeWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 50cfbf68cc..1c7f8dc22c 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -54,8 +54,7 @@ BACKEND_SOURCES := \
workloads/RefPooling2dFloat32Workload.cpp \
workloads/RefPooling2dUint8Workload.cpp \
workloads/RefQuantizeWorkload.cpp \
- workloads/RefReshapeFloat32Workload.cpp \
- workloads/RefReshapeUint8Workload.cpp \
+ workloads/RefReshapeWorkload.cpp \
workloads/RefResizeBilinearFloat32Workload.cpp \
workloads/RefResizeBilinearUint8Workload.cpp \
workloads/RefRsqrtFloat32Workload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 95da7abad1..3f4cc75fea 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -663,12 +663,12 @@ static void RefCreateReshapeWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateReshapeFloat32Workload)
{
- RefCreateReshapeWorkloadTest<RefReshapeFloat32Workload, armnn::DataType::Float32>();
+ RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateReshapeUint8Workload)
{
- RefCreateReshapeWorkloadTest<RefReshapeUint8Workload, armnn::DataType::QuantisedAsymm8>();
+ RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QuantisedAsymm8>();
}
template <typename MergerWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 7f26d78c7e..508dfdc293 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -91,10 +91,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefPooling2dUint8Workload.hpp
RefQuantizeWorkload.cpp
RefQuantizeWorkload.hpp
- RefReshapeFloat32Workload.cpp
- RefReshapeFloat32Workload.hpp
- RefReshapeUint8Workload.cpp
- RefReshapeUint8Workload.hpp
+ RefReshapeWorkload.cpp
+ RefReshapeWorkload.hpp
RefResizeBilinearFloat32Workload.cpp
RefResizeBilinearFloat32Workload.hpp
RefResizeBilinearUint8Workload.cpp
diff --git a/src/backends/reference/workloads/RefReshapeFloat32Workload.cpp b/src/backends/reference/workloads/RefReshapeFloat32Workload.cpp
deleted file mode 100644
index 99c94a49a1..0000000000
--- a/src/backends/reference/workloads/RefReshapeFloat32Workload.cpp
+++ /dev/null
@@ -1,27 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefReshapeFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-
-#include <cstring>
-
-namespace armnn
-{
-
-void RefReshapeFloat32Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReshapeFloat32Workload_Execute");
-
- void* output = GetOutputTensorData<void>(0, m_Data);
- const void* input = GetInputTensorData<void>(0, m_Data);
- unsigned int numBytes = GetTensorInfo(m_Data.m_Inputs[0]).GetNumBytes();
- memcpy(output, input, numBytes);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefReshapeFloat32Workload.hpp b/src/backends/reference/workloads/RefReshapeFloat32Workload.hpp
deleted file mode 100644
index 75024b3db6..0000000000
--- a/src/backends/reference/workloads/RefReshapeFloat32Workload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefReshapeFloat32Workload : public Float32Workload<ReshapeQueueDescriptor>
-{
-public:
- using Float32Workload<ReshapeQueueDescriptor>::Float32Workload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefReshapeUint8Workload.cpp b/src/backends/reference/workloads/RefReshapeWorkload.cpp
index 8f475f3db3..6d29781937 100644
--- a/src/backends/reference/workloads/RefReshapeUint8Workload.cpp
+++ b/src/backends/reference/workloads/RefReshapeWorkload.cpp
@@ -1,12 +1,10 @@
-//
+//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include "RefReshapeUint8Workload.hpp"
-
+#include "RefReshapeWorkload.hpp"
#include "RefWorkloadUtils.hpp"
-
#include "Profiling.hpp"
#include <cstring>
@@ -14,9 +12,9 @@
namespace armnn
{
-void RefReshapeUint8Workload::Execute() const
+void RefReshapeWorkload::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReshapeUint8Workload_Execute");
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReshapeWorkload_Execute");
void* output = GetOutputTensorData<void>(0, m_Data);
const void* input = GetInputTensorData<void>(0, m_Data);
diff --git a/src/backends/reference/workloads/RefReshapeUint8Workload.hpp b/src/backends/reference/workloads/RefReshapeWorkload.hpp
index c3d31f8a73..7359ff9cde 100644
--- a/src/backends/reference/workloads/RefReshapeUint8Workload.hpp
+++ b/src/backends/reference/workloads/RefReshapeWorkload.hpp
@@ -1,4 +1,4 @@
-//
+//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -11,10 +11,10 @@
namespace armnn
{
-class RefReshapeUint8Workload : public Uint8Workload<ReshapeQueueDescriptor>
+class RefReshapeWorkload : public BaseWorkload<ReshapeQueueDescriptor>
{
public:
- using Uint8Workload<ReshapeQueueDescriptor>::Uint8Workload;
+ using BaseWorkload<ReshapeQueueDescriptor>::BaseWorkload;
virtual void Execute() const override;
};
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 54bc5c7f01..20649d93ce 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -23,14 +23,12 @@
#include "TensorBufferArrayView.hpp"
#include "RefBatchNormalizationFloat32Workload.hpp"
#include "Splitter.hpp"
-#include "RefReshapeFloat32Workload.hpp"
#include "RefDepthwiseConvolution2dWorkload.hpp"
#include "FullyConnected.hpp"
#include "Gather.hpp"
#include "RefFloorFloat32Workload.hpp"
#include "RefSoftmaxFloat32Workload.hpp"
#include "RefSoftmaxUint8Workload.hpp"
-#include "RefReshapeUint8Workload.hpp"
#include "RefResizeBilinearFloat32Workload.hpp"
#include "RefBatchNormalizationUint8Workload.hpp"
#include "ResizeBilinear.hpp"
@@ -59,3 +57,4 @@
#include "RefRsqrtFloat32Workload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefQuantizeWorkload.hpp"
+#include "RefReshapeWorkload.hpp"