// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include "RefComparisonWorkload.hpp" #include "ElementwiseFunction.hpp" #include "RefWorkloadUtils.hpp" #include "Profiling.hpp" #include namespace armnn { template void RefFloat32ComparisonWorkload::ExecuteImpl(const char* debugString) const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); auto data = BaseFloat32ComparisonWorkload::GetData(); const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); const float* inData0 = GetInputTensorDataFloat(0, data); const float* inData1 = GetInputTensorDataFloat(1, data); uint8_t* outData = GetOutputTensorData(0, data); ElementwiseFunction(inShape0, inShape1, outputShape, inData0, inData1, outData); } template void RefUint8ComparisonWorkload::ExecuteImpl(const char* debugString) const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); auto data = BaseUint8ComparisonWorkload::GetData(); const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); const uint8_t* inData0 = GetInputTensorData(0, data); const uint8_t* inData1 = GetInputTensorData(1, data); uint8_t* outData = GetOutputTensorData(0, data); ElementwiseFunction(inputInfo0, inputInfo1, outputShape, inData0, inData1, outData); } } template class armnn::RefFloat32ComparisonWorkload>; template class armnn::RefUint8ComparisonWorkload>; template class armnn::RefFloat32ComparisonWorkload>; template class armnn::RefUint8ComparisonWorkload>;