diff options
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h index 1317fb740e..ac065533e5 100644 --- a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h @@ -90,15 +90,16 @@ private: NEPermute _permute_weights; NEPermute _permute_output; Tensor _accumulator; - Tensor _input_nhwc; - Tensor _weights_hwio; - Tensor _output_nhwc; + Tensor _permuted_input; + Tensor _permuted_weights; + Tensor _permuted_output; bool _has_bias; bool _is_quantized; bool _is_optimized; bool _are_weights_reshaped; bool _is_nchw; bool _is_first_run; + bool _permute; }; /** Basic function to execute a generic depthwise convolution. This function calls the following NEON kernels: @@ -146,7 +147,7 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier = 1); // Inherited methods overriden: void run() override; @@ -160,12 +161,19 @@ private: NEDirectConvolutionLayerOutputStageKernel _output_stage_kernel; NEFillBorderKernel _v2mm_input_fill_border; NEFillBorderKernel _v2mm_weights_fill_border; + NEPermute _permute_input; + NEPermute _permute_weights; + NEPermute _permute_output; Tensor _input_reshaped; Tensor _weights_reshaped; Tensor _v2mm_output; Tensor _output_reshaped; + Tensor _permuted_input; + Tensor _permuted_weights; + Tensor _permuted_output; bool _is_prepared; bool _is_quantized; + bool _is_nhwc; const ITensor *_original_weights; }; } |