diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-28 14:31:20 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-31 11:57:50 +0100 |
commit | e851b3da2ba51edc69c7b3dbfad06c4e22a63595 (patch) | |
tree | 5ead856b8c4de5198170f8ff3fdb2541eb6676d9 /src/backends/reference/workloads/RefStridedSliceWorkload.cpp | |
parent | 01961a7df1c4357981a33b9c1eb80fb51888a8fa (diff) | |
download | armnn-e851b3da2ba51edc69c7b3dbfad06c4e22a63595.tar.gz |
IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and
QAsymm8 types
* RefStridedSliceWorkload is no longer a template class
* Refactoring of the ref StridedSlice implementation
* Added ValidateTensorQuantizationSpace function
Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/RefStridedSliceWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefStridedSliceWorkload.cpp | 36 |
1 files changed, 21 insertions, 15 deletions
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp index bcc3520f45..8bb1670a48 100644 --- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp @@ -4,31 +4,37 @@ // #include "RefStridedSliceWorkload.hpp" +#include "RefWorkloadUtils.hpp" #include "StridedSlice.hpp" -#include "RefWorkloadUtils.hpp" -#include <ResolveType.hpp> +#include <boost/format.hpp> namespace armnn { -template<armnn::DataType DataType> -void RefStridedSliceWorkload<DataType>::Execute() const -{ - using T = ResolveType<DataType>; +RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{} - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); +void RefStridedSliceWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute"); - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - const T* inputData = GetInputTensorData<T>(0, m_Data); - T* outputData = GetOutputTensorData<T>(0, m_Data); + DataType inputDataType = inputInfo.GetDataType(); + DataType outputDataType = outputInfo.GetDataType(); - StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData); -} + BOOST_ASSERT(inputDataType == outputDataType); + boost::ignore_unused(outputDataType); -template class RefStridedSliceWorkload<DataType::Float32>; -template class RefStridedSliceWorkload<DataType::QuantisedAsymm8>; + StridedSlice(inputInfo, + m_Data.m_Parameters, + m_Data.m_Inputs[0]->Map(), + m_Data.m_Outputs[0]->Map(), + GetDataTypeSize(inputDataType)); +} -} //namespace armnn +} // namespace armnn |