diff options
Diffstat (limited to 'src/cpu/kernels')
-rw-r--r-- | src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp | 6 | ||||
-rw-r--r-- | src/cpu/kernels/CpuKernelSelectionTypes.h | 2 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h | 16 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/arm_gemm.hpp | 12 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/convolution_parameters.hpp | 10 | ||||
-rw-r--r-- | src/cpu/kernels/assembly/gemm_common.hpp | 18 |
6 files changed, 38 insertions, 26 deletions
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp index a3ed2cd171..87340e566e 100644 --- a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp +++ b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -684,6 +684,10 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons DataType::U8); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_type() == DataType::QASYMM8_SIGNED && + src1->data_type() == DataType::QASYMM8, + "QASYMM8_SIGNED input with QASYMM8 weights not supported"); + TensorShape in0_shape = src0->tensor_shape(); TensorShape in1_shape = src1->tensor_shape(); TensorShape out_shape = dst->tensor_shape(); diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h index 7c1e4772a6..03a474de53 100644 --- a/src/cpu/kernels/CpuKernelSelectionTypes.h +++ b/src/cpu/kernels/CpuKernelSelectionTypes.h @@ -105,7 +105,7 @@ struct SoftmaxKernelDataTypeISASelectorData cpuinfo::CpuIsaInfo isa; bool is_log; int axis; - unsigned long sme2_vector_length; + uint64_t sme2_vector_length; }; // Selector pointer types diff --git a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h index 6e8f32ef47..72fafca1bb 100644 --- a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h +++ b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H -#define ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H +#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_CPUGEMMASSEMBLYWRAPPERKERNEL_H +#define ACL_SRC_CPU_KERNELS_ASSEMBLY_CPUGEMMASSEMBLYWRAPPERKERNEL_H #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" @@ -52,7 +52,7 @@ namespace kernel * * */ -template <typename TypeInput, typename TypeOutput> +template <typename TypeInput, typename TypeWeight, typename TypeOutput> class CpuGemmAssemblyWrapperKernel final : public INEKernel { public: @@ -101,7 +101,7 @@ public: * @param[in] kernel Pointer to an assembly kernel implementation. * @param[in] kernel_name_tag Tag to be attacehd to the kernel's name. */ - void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag) + void configure(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *kernel, std::string kernel_name_tag) { ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel))); _kernel = kernel; @@ -131,10 +131,10 @@ public: } private: - arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel; - std::string _name; + arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *_kernel; + std::string _name; }; } // namespace kernel } // namespace cpu } // namespace arm_compute -#endif /* ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H */ +#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_CPUGEMMASSEMBLYWRAPPERKERNEL_H diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index 941fed0ba8..cbc8be416e 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -277,8 +277,8 @@ struct Nothing { }; -template <typename Top, typename Tret> -using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>; +template <typename Tlop, typename Trop, typename Tret> +using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>; /* Low level API calls. * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */ @@ -288,13 +288,13 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>; template <typename Top, typename Tret, class OutputStage = Nothing> KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {}); -template <typename Top, typename Tret, class OutputStage = Nothing> -UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {}); +template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> +UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage & = {}); -template <typename Top, typename Tret, class OutputStage = Nothing> +template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); -template <typename Top, typename Tret, class OutputStage = Nothing> +template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm diff --git a/src/cpu/kernels/assembly/convolution_parameters.hpp b/src/cpu/kernels/assembly/convolution_parameters.hpp index 0c1ae58902..09b73ca409 100644 --- a/src/cpu/kernels/assembly/convolution_parameters.hpp +++ b/src/cpu/kernels/assembly/convolution_parameters.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,6 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_CONVOLUTION_PARAMETERS_HPP +#define ACL_SRC_CPU_KERNELS_ASSEMBLY_CONVOLUTION_PARAMETERS_HPP + #pragma once #include <cstdint> @@ -57,9 +61,13 @@ struct ConvolutionParameters int64_t output_stride_w; int64_t output_stride_h; // output_channels not included as they do not affect the input. + int64_t dilation_w; + int64_t dilation_h; int64_t padding_top; int64_t padding_left; float padding_value; }; } // namespace arm_gemm + +#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_CONVOLUTION_PARAMETERS_HPP diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp index 45d1e43274..f693021fcb 100644 --- a/src/cpu/kernels/assembly/gemm_common.hpp +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -189,7 +189,7 @@ public: * 'set_arrays' to capture the provided arguments in protected class * members, as essentially any implementation will need these. */ -template <typename To, typename Tr> +template <typename To, typename Tw, typename Tr> class GemmCommon : public IGemmCommon { protected: @@ -197,7 +197,7 @@ protected: int _lda = 0; int _A_batch_stride = 0; int _A_multi_stride = 0; - const To *_Bptr = nullptr; + const Tw *_Bptr = nullptr; int _ldb = 0; int _B_multi_stride = 0; Tr *_Cptr = nullptr; @@ -214,7 +214,7 @@ public: const int lda, const int A_batch_stride, const int A_multi_stride, - const To *B, + const Tw *B, const int ldb, /* batches share B */ const int B_multi_stride, Tr *C, @@ -254,7 +254,7 @@ public: const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override { - set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb, + set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb, B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride, static_cast<const Tr *>(bias), bias_multi_stride); } @@ -262,17 +262,17 @@ public: /*** "Pretransposed" interface ***/ /* Compute col sums over all columns */ - virtual void requantize_bias(void *, const To *, const int, const int){}; + virtual void requantize_bias(void *, const Tw *, const int, const int){}; /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ - virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){}; + virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){}; /* Implementation of the void * overload which casts its arguments to the appropriate type. */ void pretranspose_B_array_generic( void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override { - pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed); + pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed); } /* Threaded versions of the above. @@ -280,7 +280,7 @@ public: * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only * legal values for start and end are 0 and 1 respectively. */ virtual void pretranspose_B_array_part( - void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t) + void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t) { pretranspose_B_array(out, in, row_stride, multi_stride, transposed); }; @@ -293,7 +293,7 @@ public: size_t start, size_t end) override { - pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end); + pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end); } /*** Indirect interface ***/ |