diff options
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEWinogradLayer.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEWinogradLayer.h | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h index a939f82854..61a4caae3a 100644 --- a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h +++ b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h @@ -30,6 +30,7 @@ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/CPP/functions/CPPPermute.h" #include "arm_compute/runtime/MemoryGroup.h" +#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" #include "arm_compute/runtime/Tensor.h" #include <memory> @@ -61,8 +62,9 @@ public: * @param[out] output Destination tensor. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. * Data types supported: Same as @p input. * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. Currently only unit strides are supported. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. */ - void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info); + void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overridden: void run() override; @@ -94,6 +96,7 @@ private: std::unique_ptr<INEKernel> _transform_input_kernel; std::unique_ptr<INEKernel> _transform_output_kernel; std::unique_ptr<INEKernel> _transform_weights_kernel; + NEActivationLayer _activationlayer_function; CPPPermute _permute_input; CPPPermute _permute_weights; @@ -108,6 +111,7 @@ private: const ITensor *_weights; ITensor *_output; bool _reshaped_kernel; + bool _is_activationlayer_enabled; }; } #endif /* __ARM_COMPUTE_NEWINOGRADLAYER_H__ */ |