diff options
Diffstat (limited to 'src/gpu/cl/operators/ClMatMul.h')
-rw-r--r-- | src/gpu/cl/operators/ClMatMul.h | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/src/gpu/cl/operators/ClMatMul.h b/src/gpu/cl/operators/ClMatMul.h index 6aba801301..9dce5288e6 100644 --- a/src/gpu/cl/operators/ClMatMul.h +++ b/src/gpu/cl/operators/ClMatMul.h @@ -21,14 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ACL_ARM_COMPUTE_SRC_GPU_CL_OPERATORS_CLMATMUL -#define ACL_ARM_COMPUTE_SRC_GPU_CL_OPERATORS_CLMATMUL +#ifndef ACL_SRC_GPU_CL_OPERATORS_CLMATMUL +#define ACL_SRC_GPU_CL_OPERATORS_CLMATMUL #include "arm_compute/core/ActivationLayerInfo.h" #include "arm_compute/core/MatMulInfo.h" #include "src/gpu/cl/IClOperator.h" -#include "src/gpu/cl/kernels/ClMatMulNativeKernel.h" #include "src/gpu/cl/kernels/ClMatMulLowpNativeKernel.h" +#include "src/gpu/cl/kernels/ClMatMulNativeKernel.h" #include <memory> @@ -71,24 +71,26 @@ public: * @param[in] rhs Right-hand side tensor info. Data types supported: same as @p lhs. * @param[out] dst Output tensor to store the result of the batched matrix multiplication. Data types supported: same as @p lhs. * @param[in] matmul_info Contains MatMul operation information described in @ref MatMulInfo. + * @param[in] act_info Class containing information about fused activation function. */ - void configure(const CLCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &matmul_info); + void configure(const CLCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &matmul_info, + const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration * * Similar to @ref ClMatMul::configure() * * @return a status */ - static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &matmul_info); + static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &matmul_info, const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overridden: void run(ITensorPack &tensors) override; private: - std::unique_ptr<kernels::ClMatMulNativeKernel> _matmul_native_kernel{nullptr}; - std::unique_ptr<kernels::ClMatMulLowpNativeKernel> _matmul_lowp_native_kernel{nullptr}; + std::unique_ptr<kernels::ClMatMulNativeKernel> _matmul_native_kernel{ nullptr }; + std::unique_ptr<kernels::ClMatMulLowpNativeKernel> _matmul_lowp_native_kernel{ nullptr }; bool _is_quantized{ false }; }; } // namespace opencl } // namespace arm_compute -#endif /* ACL_ARM_COMPUTE_SRC_GPU_CL_OPERATORS_CLMATMUL */ +#endif /* ACL_SRC_GPU_CL_OPERATORS_CLMATMUL */ |