From e851b3da2ba51edc69c7b3dbfad06c4e22a63595 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Tue, 28 May 2019 14:31:20 +0100 Subject: IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and QAsymm8 types * RefStridedSliceWorkload is no longer a template class * Refactoring of the ref StridedSlice implementation * Added ValidateTensorQuantizationSpace function Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152 Signed-off-by: Matteo Martincigh --- .../workloads/RefStridedSliceWorkload.cpp | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) (limited to 'src/backends/reference/workloads/RefStridedSliceWorkload.cpp') diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp index bcc3520f45..8bb1670a48 100644 --- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp @@ -4,31 +4,37 @@ // #include "RefStridedSliceWorkload.hpp" +#include "RefWorkloadUtils.hpp" #include "StridedSlice.hpp" -#include "RefWorkloadUtils.hpp" -#include +#include namespace armnn { -template -void RefStridedSliceWorkload::Execute() const -{ - using T = ResolveType; +RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{} - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); +void RefStridedSliceWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute"); - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - const T* inputData = GetInputTensorData(0, m_Data); - T* outputData = GetOutputTensorData(0, m_Data); + DataType inputDataType = inputInfo.GetDataType(); + DataType outputDataType = outputInfo.GetDataType(); - StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData); -} + BOOST_ASSERT(inputDataType == outputDataType); + boost::ignore_unused(outputDataType); -template class RefStridedSliceWorkload; -template class RefStridedSliceWorkload; + StridedSlice(inputInfo, + m_Data.m_Parameters, + m_Data.m_Inputs[0]->Map(), + m_Data.m_Outputs[0]->Map(), + GetDataTypeSize(inputDataType)); +} -} //namespace armnn +} // namespace armnn -- cgit v1.2.1