diff options
author | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
---|---|---|
committer | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
commit | bceff2fb3fc68bb0aa88b886900c34b77340c826 (patch) | |
tree | d867d3e090d58d3012dfbbac456e9ea8c7f789bc /src/armnn/backends/NeonLayerSupport.cpp | |
parent | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (diff) | |
download | armnn-bceff2fb3fc68bb0aa88b886900c34b77340c826.tar.gz |
Release 18.03
Diffstat (limited to 'src/armnn/backends/NeonLayerSupport.cpp')
-rw-r--r-- | src/armnn/backends/NeonLayerSupport.cpp | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/src/armnn/backends/NeonLayerSupport.cpp b/src/armnn/backends/NeonLayerSupport.cpp index 382b15e277..d8a3366775 100644 --- a/src/armnn/backends/NeonLayerSupport.cpp +++ b/src/armnn/backends/NeonLayerSupport.cpp @@ -71,6 +71,22 @@ bool IsNeonDirectConvolutionPreferred(const TensorInfo& weightInfo, const Convol return preferDirectConvolution; } +bool IsNeonMultiplicationParamsSupported(std::string* reasonIfUnsupported, + const TensorInfo& info0, + const TensorInfo& info1) +{ + if (info0.GetShape() == info1.GetShape()) + { + return true; + } + + if (reasonIfUnsupported) + { + *reasonIfUnsupported = "Multiplication on Neon does not support implicit broadcast."; + } + return false; +} + bool IsNeonNormalizationDescParamsSupported(std::string* reasonIfUnsupported, const NormalizationDescriptor& parameters) { if (parameters.m_NormMethodType != NormalizationAlgorithmMethod::LocalBrightness) @@ -233,7 +249,7 @@ bool IsConvolution2dSupportedNeon(const TensorInfo& input, return IsSupportedForDataTypeNeon(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, - &FalseFuncU8<>); + &TrueFunc<>); } bool IsDepthwiseConvolutionSupportedNeon(const TensorInfo& input, @@ -293,11 +309,13 @@ bool IsMultiplicationSupportedNeon(const TensorInfo& input0, const TensorInfo& input1, std::string* reasonIfUnsupported) { - ignore_unused(input1); return IsSupportedForDataTypeNeon(reasonIfUnsupported, input0.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); + &IsNeonMultiplicationParamsSupported, + &FalseFuncU8<const TensorInfo&, const TensorInfo&>, + input0, + input1 + ); } bool IsNormalizationSupportedNeon(const TensorInfo& input, |