aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/NeonLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/NeonLayerSupport.cpp')
-rw-r--r--src/armnn/backends/NeonLayerSupport.cpp26
1 files changed, 22 insertions, 4 deletions
diff --git a/src/armnn/backends/NeonLayerSupport.cpp b/src/armnn/backends/NeonLayerSupport.cpp
index 382b15e277..d8a3366775 100644
--- a/src/armnn/backends/NeonLayerSupport.cpp
+++ b/src/armnn/backends/NeonLayerSupport.cpp
@@ -71,6 +71,22 @@ bool IsNeonDirectConvolutionPreferred(const TensorInfo& weightInfo, const Convol
return preferDirectConvolution;
}
+bool IsNeonMultiplicationParamsSupported(std::string* reasonIfUnsupported,
+ const TensorInfo& info0,
+ const TensorInfo& info1)
+{
+ if (info0.GetShape() == info1.GetShape())
+ {
+ return true;
+ }
+
+ if (reasonIfUnsupported)
+ {
+ *reasonIfUnsupported = "Multiplication on Neon does not support implicit broadcast.";
+ }
+ return false;
+}
+
bool IsNeonNormalizationDescParamsSupported(std::string* reasonIfUnsupported, const NormalizationDescriptor& parameters)
{
if (parameters.m_NormMethodType != NormalizationAlgorithmMethod::LocalBrightness)
@@ -233,7 +249,7 @@ bool IsConvolution2dSupportedNeon(const TensorInfo& input,
return IsSupportedForDataTypeNeon(reasonIfUnsupported,
input.GetDataType(),
&TrueFunc<>,
- &FalseFuncU8<>);
+ &TrueFunc<>);
}
bool IsDepthwiseConvolutionSupportedNeon(const TensorInfo& input,
@@ -293,11 +309,13 @@ bool IsMultiplicationSupportedNeon(const TensorInfo& input0,
const TensorInfo& input1,
std::string* reasonIfUnsupported)
{
- ignore_unused(input1);
return IsSupportedForDataTypeNeon(reasonIfUnsupported,
input0.GetDataType(),
- &TrueFunc<>,
- &FalseFuncU8<>);
+ &IsNeonMultiplicationParamsSupported,
+ &FalseFuncU8<const TensorInfo&, const TensorInfo&>,
+ input0,
+ input1
+ );
}
bool IsNormalizationSupportedNeon(const TensorInfo& input,