diff options
Diffstat (limited to 'src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp')
-rw-r--r-- | src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp index 0deff79dac..1894048788 100644 --- a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp @@ -13,6 +13,34 @@ using namespace armnn::armcomputetensorutils; namespace armnn { +namespace +{ + +bool IsNeonNormalizationDescriptorSupported(const NormalizationDescriptor& parameters, + Optional<std::string&> reasonIfUnsupported) +{ + if (parameters.m_NormMethodType != NormalizationAlgorithmMethod::LocalBrightness) + { + if (reasonIfUnsupported) + { + reasonIfUnsupported.value() = "Unsupported normalisation method type, only LocalBrightness is supported"; + } + return false; + } + if (parameters.m_NormSize % 2 == 0) + { + if (reasonIfUnsupported) + { + reasonIfUnsupported.value() = "Normalization size must be an odd number."; + } + return false; + } + + return true; +} + +} // anonymous namespace + arm_compute::Status NeonNormalizationWorkloadValidate(const TensorInfo& input, const TensorInfo& output, const NormalizationDescriptor& descriptor) @@ -33,7 +61,7 @@ NeonNormalizationFloatWorkload::NeonNormalizationFloatWorkload(const Normalizati { m_Data.ValidateInputsOutputs("NeonNormalizationFloatWorkload", 1, 1); std::string reasonIfUnsupported; - if (!IsNeonNormalizationDescParamsSupported(Optional<std::string&>(reasonIfUnsupported), m_Data.m_Parameters)) + if (!IsNeonNormalizationDescriptorSupported(m_Data.m_Parameters, Optional<std::string&>(reasonIfUnsupported))) { throw UnimplementedException(reasonIfUnsupported); } |