aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefElementwiseWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefElementwiseWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp20
1 files changed, 7 insertions, 13 deletions
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 13d6e70a96..c9b93c8524 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -26,7 +26,7 @@ void BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(cons
const float* inData1 = GetInputTensorDataFloat(1, data);
float* outData = GetOutputTensorDataFloat(0, data);
- ElementwiseFunction<Functor>(inShape0, inShape1, outShape, inData0, inData1, outData);
+ ElementwiseFunction<Functor, float, float>(inShape0, inShape1, outShape, inData0, inData1, outData);
}
template <typename ParentDescriptor, typename Functor>
@@ -44,12 +44,12 @@ void BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const
std::vector<float> results(outputInfo.GetNumElements());
- ElementwiseFunction<Functor>(inputInfo0.GetShape(),
- inputInfo1.GetShape(),
- outputInfo.GetShape(),
- dequant0.data(),
- dequant1.data(),
- results.data());
+ ElementwiseFunction<Functor, float, float>(inputInfo0.GetShape(),
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo);
}
@@ -73,9 +73,3 @@ template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor
template class armnn::BaseFloat32ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
-
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
-
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>;