diff options
Diffstat (limited to 'src/graph/backends/NEON/NEFunctionFactory.cpp')
-rw-r--r-- | src/graph/backends/NEON/NEFunctionFactory.cpp | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/src/graph/backends/NEON/NEFunctionFactory.cpp b/src/graph/backends/NEON/NEFunctionFactory.cpp index 454215e7ec..0b3036cb4e 100644 --- a/src/graph/backends/NEON/NEFunctionFactory.cpp +++ b/src/graph/backends/NEON/NEFunctionFactory.cpp @@ -53,7 +53,7 @@ struct NETargetInfo Target NETargetInfo::TargetType = Target::NEON; -/** Collection of CL convolution functions */ +/** Collection of NEON convolution functions */ struct NEConvolutionLayerFunctions { using GenericConvolutionLayer = NEConvolutionLayer; @@ -62,7 +62,7 @@ struct NEConvolutionLayerFunctions using WinogradConvolutionLayer = NEWinogradConvolutionLayer; }; -/** Collection of CL element-wise functions */ +/** Collection of NEON element-wise functions */ struct NEEltwiseFunctions { using Addition = NEArithmeticAddition; @@ -70,6 +70,12 @@ struct NEEltwiseFunctions using Multiplication = NEPixelWiseMultiplication; }; +/** Collection of NEON unary element-wise functions */ +struct NEUnaryEltwiseFunctions +{ + using Exp = NEExpLayer; +}; + /** Function and tensor types to be used inside a NEON fused convolution/batch normalization layer */ struct NEFusedLayerTypes { @@ -143,6 +149,8 @@ std::unique_ptr<IFunction> NEFunctionFactory::create(INode *node, GraphContext & return detail::create_detection_post_process_layer<NEDetectionPostProcessLayer, NETargetInfo>(*polymorphic_downcast<DetectionPostProcessLayerNode *>(node)); case NodeType::EltwiseLayer: return detail::create_eltwise_layer<NEEltwiseFunctions, NETargetInfo>(*polymorphic_downcast<EltwiseLayerNode *>(node)); + case NodeType::UnaryEltwiseLayer: + return detail::create_unary_eltwise_layer<NEUnaryEltwiseFunctions, NETargetInfo>(*polymorphic_downcast<UnaryEltwiseLayerNode *>(node)); case NodeType::FlattenLayer: return detail::create_flatten_layer<NEFlattenLayer, NETargetInfo>(*polymorphic_downcast<FlattenLayerNode *>(node)); case NodeType::FullyConnectedLayer: |