From 60e98253f1e3df1723e7b8f4c996b544aa7c7205 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 22 Oct 2018 16:17:20 +0100 Subject: COMPMID-1451: Fuse activation in DepthwiseConvolution. Change-Id: Id964d9068e18aaa13ab8adcbf7a9375b034ea6c3 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/154651 Tested-by: bsgcomp Reviewed-by: Gian Marco Iodice --- arm_compute/graph/backends/FunctionHelpers.h | 11 +++-- .../graph/nodes/DepthwiseConvolutionLayerNode.h | 14 ++++++ .../CL/functions/CLDepthwiseConvolutionLayer.h | 11 ++++- .../functions/GCDepthwiseConvolutionLayer.h | 8 ++- .../NEON/functions/NEDepthwiseConvolutionLayer.h | 21 ++++++-- .../cl_kernels/depthwise_convolution_quantized.cl | 18 +++---- .../CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp | 7 ++- .../CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp | 7 ++- src/graph/backends/GLES/GCFunctionsFactory.cpp | 9 ++-- src/graph/mutators/NodeFusionMutator.cpp | 28 +++++++++-- src/graph/nodes/DepthwiseConvolutionLayerNode.cpp | 14 +++++- .../CL/functions/CLDepthwiseConvolutionLayer.cpp | 27 ++++++++-- .../functions/GCDepthwiseConvolutionLayer.cpp | 19 +++++++- .../NEON/functions/NEDepthwiseConvolutionLayer.cpp | 57 +++++++++++++++++++--- 14 files changed, 200 insertions(+), 51 deletions(-) diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h index a1cadcbf4c..1968ec3923 100644 --- a/arm_compute/graph/backends/FunctionHelpers.h +++ b/arm_compute/graph/backends/FunctionHelpers.h @@ -397,8 +397,10 @@ std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvoluti biases->info()->set_data_type(DataType::S32); } - const PadStrideInfo conv_info = node.convolution_info(); - const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method(); + const PadStrideInfo conv_info = node.convolution_info(); + const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method(); + const unsigned int depth_multiplier = 1; + const ActivationLayerInfo fused_act = node.fused_activation(); // Create and configure function (we assume that functions have been validated before creation) std::unique_ptr func; @@ -407,13 +409,13 @@ std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvoluti { std::tie(func, func_name) = create_named_function( std::string("DepthwiseConvolutionLayer3x3"), - input, weights, biases, output, conv_info); + input, weights, biases, output, conv_info, depth_multiplier, fused_act); } else { std::tie(func, func_name) = create_named_function( std::string("DepthwiseConvolutionLayer"), - input, weights, biases, output, conv_info); + input, weights, biases, output, conv_info, depth_multiplier, fused_act); } // Log info @@ -431,6 +433,7 @@ std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvoluti << " Input shape: " << input->info()->tensor_shape() << " Weights shape: " << weights->info()->tensor_shape() << " Output shape: " << output->info()->tensor_shape() + << (fused_act.enabled() ? " " + to_string(fused_act.activation()) : "") << std::endl); return func; } diff --git a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h index 1a173c5421..7fa44b798f 100644 --- a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h +++ b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h @@ -58,6 +58,16 @@ public: * @return Convolution information */ PadStrideInfo convolution_info() const; + /** Returns fused activation + * + * @return Fused activation + */ + ActivationLayerInfo fused_activation() const; + /** Sets fused activation + * + * @param[in] fused_activation Fused activation to set + */ + void set_fused_activation(ActivationLayerInfo fused_activation); /** Computes depthwise convolution output descriptor * * @param[in] input_descriptor Input descriptor @@ -76,9 +86,13 @@ public: TensorDescriptor configure_output(size_t idx) const override; void accept(INodeVisitor &v) override; +public: + static constexpr NodeType node_type = NodeType::DepthwiseConvolutionLayer; + private: PadStrideInfo _info; DepthwiseConvolutionMethod _method; + ActivationLayerInfo _fused_activation; }; } // namespace graph } // namespace arm_compute diff --git a/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h index 0547a6a6a8..96a0d236f5 100644 --- a/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h @@ -35,6 +35,7 @@ #include "arm_compute/core/CL/kernels/ICLDepthwiseConvolutionLayer3x3Kernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/CL/CLTensor.h" +#include "arm_compute/runtime/CL/functions/CLActivationLayer.h" #include "arm_compute/runtime/IFunction.h" namespace arm_compute @@ -121,8 +122,10 @@ public: * @param[out] output Destination tensor. Data type supported: same as @p input. * @param[in] conv_info Padding and stride information to use for the convolution. * @param[in] depth_multiplier (Optional) Multiplier to apply to the input's depth in order to retrieve the output's depth. Defaults to 1. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. */ - void configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + void configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier = 1, const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CLDepthwiseConvolutionLayer * @@ -133,10 +136,12 @@ public: * @param[in] output Destination tensor. Data type supported: same as @p input. * @param[in] conv_info Padding and stride information to use for the convolution. * @param[in] depth_multiplier (Optional) Multiplier to apply to the input's depth in order to retrieve the output's depth. Defaults to 1. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier = 1, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overriden: void run() override; @@ -148,6 +153,7 @@ private: CLGEMMMatrixVectorMultiplyKernel _v2mm_kernel; CLDepthwiseVectorToTensorKernel _vector_to_tensor_kernel; CLDirectConvolutionLayerOutputStageKernel _output_stage_kernel; + CLActivationLayer _activationlayer_function; CLFillBorderKernel _v2mm_input_fill_border; CLFillBorderKernel _v2mm_weights_fill_border; CLTensor _input_reshaped; @@ -156,6 +162,7 @@ private: CLTensor _output_reshaped; bool _is_prepared; bool _is_quantized; + bool _is_activationlayer_enabled; const ICLTensor *_original_weights; }; } diff --git a/arm_compute/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.h b/arm_compute/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.h index c99485634c..5eccc4d9e8 100644 --- a/arm_compute/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.h +++ b/arm_compute/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.h @@ -28,6 +28,7 @@ #include "arm_compute/core/GLES_COMPUTE/kernels/GCFillBorderKernel.h" #include "arm_compute/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.h" #include "arm_compute/core/Types.h" +#include "arm_compute/runtime/GLES_COMPUTE/functions/GCActivationLayer.h" #include "arm_compute/runtime/IFunction.h" namespace arm_compute @@ -54,8 +55,10 @@ public: * @param[out] output Destination tensor. Data type supported: same as @p input. * @param[in] conv_info Padding and stride information to use for the convolution. * @param[in] depth_multiplier (Optional) Multiplier to apply to the input's depth in order to retrieve the output's depth. Defaults to 1. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. */ - void configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + void configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier = 1, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overridden: void run() override final; @@ -64,6 +67,9 @@ private: std::unique_ptr _kernel; GCFillBorderKernel _border_handler; GCTensorShiftKernel _shift_handler; + GCActivationLayer _activationlayer_function; + + bool _is_activationlayer_enabled; }; } #endif /*__ARM_COMPUTE_GCDEPTHWISECONVOLUTION_H__ */ diff --git a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h index b7398f628a..288d5136d2 100644 --- a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h @@ -35,6 +35,7 @@ #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/MemoryGroup.h" +#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" #include "arm_compute/runtime/NEON/functions/NEPermute.h" #include "arm_compute/runtime/Tensor.h" @@ -62,8 +63,10 @@ public: * @param[out] output Destination tensor. Data type supported: same as @p input. * @param[in] conv_info Padding and stride information to use for the convolution. * @param[in] depth_multiplier (Optional) Multiplier to apply to the input's depth in order to retrieve the output's depth. Defaults to 1. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. */ - void configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + void configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier = 1, const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration of @ref NEDepthwiseConvolutionLayer3x3 * @@ -74,10 +77,12 @@ public: * @param[in] output Destination tensor. Data type supported: same as @p input. * @param[in] conv_info Padding and stride information to use for the convolution. * @param[in] depth_multiplier (Optional) Multiplier to apply to the input's depth in order to retrieve the output's depth. Defaults to 1. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier = 1, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overriden: void run() override; @@ -89,6 +94,7 @@ private: NEPermute _permute_input; NEPermute _permute_weights; NEPermute _permute_output; + NEActivationLayer _activationlayer_function; Tensor _accumulator; Tensor _permuted_input; Tensor _permuted_weights; @@ -100,6 +106,7 @@ private: bool _is_nchw; bool _is_first_run; bool _permute; + bool _is_activationlayer_enabled; }; /** Basic function to execute a generic depthwise convolution. This function calls the following NEON kernels: @@ -132,8 +139,10 @@ public: * Data type supported: Same as @p input, S32 when input is QASYMM8. * @param[in] conv_info Padding and stride information to use for the convolution. * @param[in] depth_multiplier (Optional) Multiplier to apply to the input's depth in order to retrieve the output's depth. Defaults to 1. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. */ - void configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + void configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier = 1, const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration of @ref NEDepthwiseConvolutionLayer * @@ -144,10 +153,12 @@ public: * Data type supported: Same as @p input, S32 when input is QASYMM8. * @param[in] conv_info Padding and stride information to use for the convolution. * @param[in] depth_multiplier (Optional) Multiplier to apply to the input's depth in order to retrieve the output's depth. Defaults to 1. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier = 1, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overriden: void run() override; @@ -164,6 +175,7 @@ private: NEPermute _permute_input; NEPermute _permute_weights; NEPermute _permute_output; + NEActivationLayer _activationlayer_function; Tensor _input_reshaped; Tensor _weights_reshaped; Tensor _v2mm_output; @@ -174,6 +186,7 @@ private: bool _is_prepared; bool _is_quantized; bool _is_nhwc; + bool _is_activationlayer_enabled; const ITensor *_original_weights; }; } diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl index 7cd48790c6..3239885abc 100644 --- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl +++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl @@ -720,7 +720,7 @@ __kernel void depthwise_convolution_3x3_quantized_nhwc( Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); VSTORE(VEC_SIZE) - (res, 0, dst.ptr); + (ACTIVATION_FUNC(res), 0, dst.ptr); } #endif // defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y) @@ -953,18 +953,18 @@ __kernel void depthwise_convolution_3x3_quantized_nhwc_stride1( __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_step_x + y * dst_step_y + (z * NUM_PLANES_PROCESSED) * dst_step_z; VSTORE(VEC_SIZE) - (res0, 0, dst_addr + 0 * dst_stride_y); + (ACTIVATION_FUNC(res0), 0, dst_addr + 0 * dst_stride_y); VSTORE(VEC_SIZE) - (res1, 0, dst_addr + 1 * dst_stride_y); + (ACTIVATION_FUNC(res1), 0, dst_addr + 1 * dst_stride_y); #if((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0) if((z * NUM_PLANES_PROCESSED + 1) < DST_DIM_2) #endif // ((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0) { VSTORE(VEC_SIZE) - (res2, 0, dst_addr + 0 * dst_stride_y + 1 * dst_stride_z); + (ACTIVATION_FUNC(res2), 0, dst_addr + 0 * dst_stride_y + 1 * dst_stride_z); VSTORE(VEC_SIZE) - (res3, 0, dst_addr + 1 * dst_stride_y + 1 * dst_stride_z); + (ACTIVATION_FUNC(res3), 0, dst_addr + 1 * dst_stride_y + 1 * dst_stride_z); } } @@ -1159,18 +1159,18 @@ __kernel void depthwise_convolution_3x3_quantized_dot8_nhwc_stride1( __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_step_x + y * dst_step_y + (z * NUM_PLANES_PROCESSED) * dst_step_z; VSTORE(VEC_SIZE) - (res0, 0, dst_addr + 0 * dst_stride_y); + (ACTIVATION_FUNC(res0), 0, dst_addr + 0 * dst_stride_y); VSTORE(VEC_SIZE) - (res1, 0, dst_addr + 1 * dst_stride_y); + (ACTIVATION_FUNC(res1), 0, dst_addr + 1 * dst_stride_y); #if((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0) if((z * NUM_PLANES_PROCESSED + 1) < DST_DIM_2) #endif // ((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0) { VSTORE(VEC_SIZE) - (res2, 0, dst_addr + 0 * dst_stride_y + 1 * dst_stride_z); + (ACTIVATION_FUNC(res2), 0, dst_addr + 0 * dst_stride_y + 1 * dst_stride_z); VSTORE(VEC_SIZE) - (res3, 0, dst_addr + 1 * dst_stride_y + 1 * dst_stride_z); + (ACTIVATION_FUNC(res3), 0, dst_addr + 1 * dst_stride_y + 1 * dst_stride_z); } } #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp index de7e2b8737..eb561faf77 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp @@ -207,8 +207,7 @@ BorderSize CLDepthwiseConvolutionLayer3x3NCHWKernel::border_size() const } void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, - ActivationLayerInfo act_info) + unsigned int depth_multiplier, ActivationLayerInfo act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, depth_multiplier, act_info)); @@ -272,11 +271,11 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, const float s2 = output->info()->quantization_info().scale; const int o2 = output->info()->quantization_info().offset; + build_opts.add_option("-DS1_VAL=" + float_to_string_with_full_precision(s1)); + build_opts.add_option("-DO1_VAL=" + support::cpp11::to_string(o1)); if(o1 != o2 || s1 != s2) { - build_opts.add_option("-DS1_VAL=" + float_to_string_with_full_precision(s1)); build_opts.add_option("-DS2_VAL=" + float_to_string_with_full_precision(s2)); - build_opts.add_option("-DO1_VAL=" + support::cpp11::to_string(o1)); build_opts.add_option("-DO2_VAL=" + support::cpp11::to_string(o2)); } } diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp index d56ac01a83..d3bed87037 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp @@ -139,8 +139,7 @@ BorderSize CLDepthwiseConvolutionLayer3x3NHWCKernel::border_size() const } void CLDepthwiseConvolutionLayer3x3NHWCKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, - ActivationLayerInfo act_info) + unsigned int depth_multiplier, ActivationLayerInfo act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); @@ -213,11 +212,11 @@ void CLDepthwiseConvolutionLayer3x3NHWCKernel::configure(const ICLTensor *input, const float s2 = output->info()->quantization_info().scale; const int o2 = output->info()->quantization_info().offset; + build_opts.add_option("-DS1_VAL=" + float_to_string_with_full_precision(s1)); + build_opts.add_option("-DO1_VAL=" + support::cpp11::to_string(o1)); if(o1 != o2 || s1 != s2) { - build_opts.add_option("-DS1_VAL=" + float_to_string_with_full_precision(s1)); build_opts.add_option("-DS2_VAL=" + float_to_string_with_full_precision(s2)); - build_opts.add_option("-DO1_VAL=" + support::cpp11::to_string(o1)); build_opts.add_option("-DO2_VAL=" + support::cpp11::to_string(o2)); } } diff --git a/src/graph/backends/GLES/GCFunctionsFactory.cpp b/src/graph/backends/GLES/GCFunctionsFactory.cpp index 02a05679a3..7df659e7b3 100644 --- a/src/graph/backends/GLES/GCFunctionsFactory.cpp +++ b/src/graph/backends/GLES/GCFunctionsFactory.cpp @@ -171,8 +171,10 @@ std::unique_ptr create_depthwise_convolution_layerinfo()->set_data_type(DataType::S32); } - const PadStrideInfo conv_info = node.convolution_info(); - const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method(); + const PadStrideInfo conv_info = node.convolution_info(); + const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method(); + const unsigned int depth_multiplier = 1; + const ActivationLayerInfo fused_act = node.fused_activation(); // Create and configure function (we assume that functions have been validated before creation) std::unique_ptr func; @@ -181,7 +183,7 @@ std::unique_ptr create_depthwise_convolution_layer( std::string("DepthwiseConvolutionLayer3x3"), - input, weights, biases, output, conv_info); + input, weights, biases, output, conv_info, depth_multiplier, fused_act); } else { @@ -197,6 +199,7 @@ std::unique_ptr create_depthwise_convolution_layerinfo()->tensor_shape() << " Weights shape: " << weights->info()->tensor_shape() << " Output shape: " << output->info()->tensor_shape() + << (fused_act.enabled() ? " " + to_string(fused_act.activation()) : "") << std::endl); return func; } diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index 7e66ce0757..98c3a56018 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -39,12 +39,14 @@ namespace graph namespace detail { template -void fuse_node_with_activation(Graph &g, const std::set &supported_fused_activations) +void fuse_node_with_activation(Graph &g, + const std::set &supported_fused_activations, + std::function const &prec) { // Not interested in the order of nodes for(auto &node : g.nodes()) { - // Check if the node is batch norm and not a branching node + // Check if the node is of type N and not a branching node if(node && node->type() == N::node_type && node->output_edges().size() == 1) { auto output_edge_id = *node->output_edges().begin(); @@ -57,6 +59,11 @@ void fuse_node_with_activation(Graph &g, const std::set &supported_f ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr || n_node->output(0) == nullptr); + // Check given precondition + if(!prec(*n_node)) + { + continue; + } // Check if activation is supported for fusion if(supported_fused_activations.count(act_node->activation_info().activation()) == 0) { @@ -110,8 +117,21 @@ void NodeFusionMutator::mutate(Graph &g) // Supported activations when fusing const std::set supported_fused_activations = { Activation::RELU, Activation::BOUNDED_RELU, Activation::LU_BOUNDED_RELU }; - detail::fuse_node_with_activation(g, supported_fused_activations); - detail::fuse_node_with_activation(g, supported_fused_activations); + // Preconditions + auto empty_prec = [](INode & n) + { + return true; + }; + auto qs8_prec = [](INode & n) + { + ARM_COMPUTE_ERROR_ON(n.output(0) == nullptr); + return n.output(0)->desc().data_type == DataType::QASYMM8; + }; + + // Fusion mutations + detail::fuse_node_with_activation(g, supported_fused_activations, empty_prec); + detail::fuse_node_with_activation(g, supported_fused_activations, empty_prec); + detail::fuse_node_with_activation(g, supported_fused_activations, qs8_prec); } } // namespace graph } // namespace arm_compute diff --git a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp index 1a6f8d398d..02d16328b1 100644 --- a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp +++ b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp @@ -33,7 +33,7 @@ namespace arm_compute namespace graph { DepthwiseConvolutionLayerNode::DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method) - : _info(std::move(info)), _method(method) + : _info(std::move(info)), _method(method), _fused_activation() { _input_edges.resize(3, EmptyEdgeID); _outputs.resize(1, NullTensorID); @@ -54,6 +54,16 @@ PadStrideInfo DepthwiseConvolutionLayerNode::convolution_info() const return _info; } +ActivationLayerInfo DepthwiseConvolutionLayerNode::fused_activation() const +{ + return _fused_activation; +} + +void DepthwiseConvolutionLayerNode::set_fused_activation(ActivationLayerInfo fused_activation) +{ + _fused_activation = fused_activation; +} + TensorDescriptor DepthwiseConvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, const TensorDescriptor &weights_descriptor, const PadStrideInfo &info) @@ -100,7 +110,7 @@ TensorDescriptor DepthwiseConvolutionLayerNode::configure_output(size_t idx) con NodeType DepthwiseConvolutionLayerNode::type() const { - return NodeType::DepthwiseConvolutionLayer; + return DepthwiseConvolutionLayerNode::node_type; } void DepthwiseConvolutionLayerNode::accept(INodeVisitor &v) diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index 76451af9b1..497cdae85c 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -90,12 +90,13 @@ void CLDepthwiseConvolutionLayer3x3::run() } CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayer() - : _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _output_stage_kernel(), _v2mm_input_fill_border(), _v2mm_weights_fill_border(), _input_reshaped(), - _weights_reshaped(), _v2mm_output(), _output_reshaped(), _is_prepared(false), _is_quantized(false), _original_weights(nullptr) + : _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _output_stage_kernel(), _activationlayer_function(), _v2mm_input_fill_border(), _v2mm_weights_fill_border(), + _input_reshaped(), _weights_reshaped(), _v2mm_output(), _output_reshaped(), _is_prepared(false), _is_quantized(false), _is_activationlayer_enabled(false), _original_weights(nullptr) { } -void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) +void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); @@ -188,10 +189,18 @@ void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *w // Allocate intermediate tensors _input_reshaped.allocator()->allocate(); _v2mm_output.allocator()->allocate(); + + //Configure Activation Layer + _is_activationlayer_enabled = act_info.enabled(); + + if(_is_activationlayer_enabled) + { + _activationlayer_function.configure(output, nullptr, act_info); + } } Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier) + unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH); const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); @@ -238,6 +247,12 @@ Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITe ARM_COMPUTE_RETURN_ON_ERROR(CLDirectConvolutionLayerOutputStageKernel::validate(&output_reshaped, biases, output)); } + // Validate Activation Layer + if(act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output, nullptr, act_info)); + } + return Status{}; } @@ -253,6 +268,10 @@ void CLDepthwiseConvolutionLayer::run() { CLScheduler::get().enqueue(_output_stage_kernel); } + if(_is_activationlayer_enabled) + { + _activationlayer_function.run(); + } } void CLDepthwiseConvolutionLayer::prepare() diff --git a/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp index 7121654a75..d9aa50d9e1 100644 --- a/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp +++ b/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp @@ -31,11 +31,12 @@ using namespace arm_compute; GCDepthwiseConvolutionLayer3x3::GCDepthwiseConvolutionLayer3x3() - : _kernel(nullptr), _border_handler(), _shift_handler() + : _kernel(nullptr), _border_handler(), _shift_handler(), _activationlayer_function(), _is_activationlayer_enabled(false) { } -void GCDepthwiseConvolutionLayer3x3::configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) +void GCDepthwiseConvolutionLayer3x3::configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { auto k = arm_compute::support::cpp14::make_unique(); k->configure(input, weights, biases, output, conv_info, depth_multiplier); @@ -45,6 +46,14 @@ void GCDepthwiseConvolutionLayer3x3::configure(IGCTensor *input, const IGCTensor _border_handler.configure(input, _kernel->border_size(), BorderMode::CONSTANT, PixelValue(0)); _shift_handler.configure(input); + + //Configure Activation Layer + _is_activationlayer_enabled = act_info.enabled(); + + if(_is_activationlayer_enabled) + { + _activationlayer_function.configure(output, nullptr, act_info); + } } void GCDepthwiseConvolutionLayer3x3::run() @@ -54,4 +63,10 @@ void GCDepthwiseConvolutionLayer3x3::run() GCScheduler::get().dispatch(_border_handler, false); GCScheduler::get().memory_barrier(); GCScheduler::get().dispatch(*_kernel); + + // Run Activation Layer + if(_is_activationlayer_enabled) + { + _activationlayer_function.run(); + } } diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp index 9dcbc99332..a2f0094f9d 100644 --- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp @@ -36,12 +36,14 @@ using namespace arm_compute::misc; using namespace arm_compute::misc::shape_calculator; NEDepthwiseConvolutionLayer3x3::NEDepthwiseConvolutionLayer3x3() - : _dwc_kernel(), _output_stage_kernel(), _border_handler(), _permute_input(), _permute_weights(), _permute_output(), _accumulator(), _permuted_input(), _permuted_weights(), _permuted_output(), - _has_bias(false), _is_quantized(false), _is_optimized(false), _are_weights_reshaped(false), _is_nchw(true), _is_first_run(true), _permute(false) + : _dwc_kernel(), _output_stage_kernel(), _border_handler(), _permute_input(), _permute_weights(), _permute_output(), _activationlayer_function(), _accumulator(), _permuted_input(), + _permuted_weights(), _permuted_output(), _has_bias(false), _is_quantized(false), _is_optimized(false), _are_weights_reshaped(false), _is_nchw(true), _is_first_run(true), _permute(false), + _is_activationlayer_enabled(false) { } -void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) +void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); @@ -159,10 +161,18 @@ void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *we _permute_output.configure(&_permuted_output, output, PermutationVector(2U, 0U, 1U)); _permuted_output.allocator()->allocate(); } + + //Configure Activation Layer + _is_activationlayer_enabled = act_info.enabled(); + + if(_is_activationlayer_enabled) + { + _activationlayer_function.configure(output, nullptr, act_info); + } } Status NEDepthwiseConvolutionLayer3x3::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier) + unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); @@ -184,6 +194,12 @@ Status NEDepthwiseConvolutionLayer3x3::validate(const ITensorInfo *input, const ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&accumulator, biases, output)); } + //Validate Activation Layer + if(act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info)); + } + return Status{}; } @@ -235,16 +251,22 @@ void NEDepthwiseConvolutionLayer3x3::run() { _permute_output.run(); } + + if(_is_activationlayer_enabled) + { + _activationlayer_function.run(); + } } NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer() : _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _output_stage_kernel(), _v2mm_input_fill_border(), _v2mm_weights_fill_border(), _permute_input(), - _permute_weights(), _permute_output(), _input_reshaped(), _weights_reshaped(), _v2mm_output(), _output_reshaped(), _permuted_input(), _permuted_weights(), _permuted_output(), _is_prepared(false), - _is_quantized(false), _is_nhwc(false), _original_weights(nullptr) + _permute_weights(), _permute_output(), _activationlayer_function(), _input_reshaped(), _weights_reshaped(), _v2mm_output(), _output_reshaped(), _permuted_input(), _permuted_weights(), + _permuted_output(), _is_prepared(false), _is_quantized(false), _is_nhwc(false), _is_activationlayer_enabled(false), _original_weights(nullptr) { } -void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) +void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { const unsigned int channel_idx = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::CHANNEL); ARM_COMPUTE_UNUSED(channel_idx); @@ -366,10 +388,18 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh // Allocate intermediate tensors _input_reshaped.allocator()->allocate(); _v2mm_output.allocator()->allocate(); + + //Configure Activation Layer + _is_activationlayer_enabled = act_info.enabled(); + + if(_is_activationlayer_enabled) + { + _activationlayer_function.configure(output, nullptr, act_info); + } } Status NEDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier) + unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); @@ -454,6 +484,12 @@ Status NEDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITe ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&output_reshaped, biases, output_to_use)); } + // Validate Activation Layer + if(act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info)); + } + return Status{}; } @@ -479,6 +515,11 @@ void NEDepthwiseConvolutionLayer::run() { _permute_output.run(); } + + if(_is_activationlayer_enabled) + { + _activationlayer_function.run(); + } } void NEDepthwiseConvolutionLayer::prepare() -- cgit v1.2.1