From 2f2778f36e59537bbd47fb8b21e73c6c5a949584 Mon Sep 17 00:00:00 2001 From: Nina Drozd Date: Mon, 27 May 2019 10:37:05 +0100 Subject: IVGCVSW-3145 Refactor Reference Reshape workloads * Removed reference reshape workloads for float32 and uint8 * Added RefReshapeWorkload * Added check for supported datatypes for reshape in WorkloadData * Added check for supported datatypes for reshape in RefLayerSupport * Updated CMakeLists.txt * Updated references to reshape workloads Signed-off-by: Nina Drozd Change-Id: I9941659067b022f8f7686ab0ff14776944dca3e5 --- src/backends/backendsCommon/WorkloadData.cpp | 25 ++++++++++++++++------ .../backendsCommon/test/WorkloadDataValidation.cpp | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) (limited to 'src/backends/backendsCommon') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index d9779e4e37..ea84c0b9f2 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -850,13 +850,13 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::Signed32, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; + { + DataType::Float32, + DataType::Float16, + DataType::Signed32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ConstantQueueDescriptor"); } @@ -872,6 +872,17 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " + to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements."); } + + // Check the supported data types + std::vector supportedTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor"); + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor"); } void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp index 119eb7df90..067cca8319 100644 --- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp +++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp @@ -447,7 +447,7 @@ BOOST_AUTO_TEST_CASE(ReshapeQueueDescriptor_Validate_MismatchingNumElements) AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); // InvalidArgumentException is expected, because the number of elements don't match. - BOOST_CHECK_THROW(RefReshapeFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefReshapeWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } -- cgit v1.2.1