diff options
Diffstat (limited to 'src/backends/reference/workloads/RefStridedSliceWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefStridedSliceWorkload.cpp | 36 |
1 files changed, 21 insertions, 15 deletions
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp index bcc3520f45..8bb1670a48 100644 --- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp @@ -4,31 +4,37 @@ // #include "RefStridedSliceWorkload.hpp" +#include "RefWorkloadUtils.hpp" #include "StridedSlice.hpp" -#include "RefWorkloadUtils.hpp" -#include <ResolveType.hpp> +#include <boost/format.hpp> namespace armnn { -template<armnn::DataType DataType> -void RefStridedSliceWorkload<DataType>::Execute() const -{ - using T = ResolveType<DataType>; +RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{} - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); +void RefStridedSliceWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute"); - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - const T* inputData = GetInputTensorData<T>(0, m_Data); - T* outputData = GetOutputTensorData<T>(0, m_Data); + DataType inputDataType = inputInfo.GetDataType(); + DataType outputDataType = outputInfo.GetDataType(); - StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData); -} + BOOST_ASSERT(inputDataType == outputDataType); + boost::ignore_unused(outputDataType); -template class RefStridedSliceWorkload<DataType::Float32>; -template class RefStridedSliceWorkload<DataType::QuantisedAsymm8>; + StridedSlice(inputInfo, + m_Data.m_Parameters, + m_Data.m_Inputs[0]->Map(), + m_Data.m_Outputs[0]->Map(), + GetDataTypeSize(inputDataType)); +} -} //namespace armnn +} // namespace armnn |