aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefWorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefWorkloadFactory.cpp')
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp33
1 files changed, 16 insertions, 17 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index cb26f2642b..72762a48e6 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -27,15 +27,16 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescripto
info);
}
-bool IsFloat16(const WorkloadInfo& info)
+template <DataType ArmnnType>
+bool IsDataType(const WorkloadInfo& info)
{
- auto checkFloat16 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::Float16;};
- auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkFloat16);
+ auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
+ auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
if (it != std::end(info.m_InputTensorInfos))
{
return true;
}
- it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkFloat16);
+ it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
if (it != std::end(info.m_OutputTensorInfos))
{
return true;
@@ -43,20 +44,14 @@ bool IsFloat16(const WorkloadInfo& info)
return false;
}
+bool IsFloat16(const WorkloadInfo& info)
+{
+ return IsDataType<DataType::Float16>(info);
+}
+
bool IsUint8(const WorkloadInfo& info)
{
- auto checkUint8 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::QuantisedAsymm8;};
- auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkUint8);
- if (it != std::end(info.m_InputTensorInfos))
- {
- return true;
- }
- it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkUint8);
- if (it != std::end(info.m_OutputTensorInfos))
- {
- return true;
- }
- return false;
+ return IsDataType<DataType::QuantisedAsymm8>(info);
}
RefWorkloadFactory::RefWorkloadFactory()
@@ -260,7 +255,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info);
+ if (IsFloat16(info) || IsUint8(info))
+ {
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ }
+ return std::make_unique<RefL2NormalizationWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,