aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp30
1 files changed, 29 insertions, 1 deletions
diff --git a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
index 0deff79dac..1894048788 100644
--- a/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonNormalizationFloatWorkload.cpp
@@ -13,6 +13,34 @@ using namespace armnn::armcomputetensorutils;
namespace armnn
{
+namespace
+{
+
+bool IsNeonNormalizationDescriptorSupported(const NormalizationDescriptor& parameters,
+ Optional<std::string&> reasonIfUnsupported)
+{
+ if (parameters.m_NormMethodType != NormalizationAlgorithmMethod::LocalBrightness)
+ {
+ if (reasonIfUnsupported)
+ {
+ reasonIfUnsupported.value() = "Unsupported normalisation method type, only LocalBrightness is supported";
+ }
+ return false;
+ }
+ if (parameters.m_NormSize % 2 == 0)
+ {
+ if (reasonIfUnsupported)
+ {
+ reasonIfUnsupported.value() = "Normalization size must be an odd number.";
+ }
+ return false;
+ }
+
+ return true;
+}
+
+} // anonymous namespace
+
arm_compute::Status NeonNormalizationWorkloadValidate(const TensorInfo& input,
const TensorInfo& output,
const NormalizationDescriptor& descriptor)
@@ -33,7 +61,7 @@ NeonNormalizationFloatWorkload::NeonNormalizationFloatWorkload(const Normalizati
{
m_Data.ValidateInputsOutputs("NeonNormalizationFloatWorkload", 1, 1);
std::string reasonIfUnsupported;
- if (!IsNeonNormalizationDescParamsSupported(Optional<std::string&>(reasonIfUnsupported), m_Data.m_Parameters))
+ if (!IsNeonNormalizationDescriptorSupported(m_Data.m_Parameters, Optional<std::string&>(reasonIfUnsupported)))
{
throw UnimplementedException(reasonIfUnsupported);
}