diff options
Diffstat (limited to 'src/cpu')
-rw-r--r-- | src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | 15 | ||||
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 17 | ||||
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.h | 1 |
3 files changed, 18 insertions, 15 deletions
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index 8faa3c217a..aec9da193b 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -59,12 +59,13 @@ namespace cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) { cpu::AsmGemmInfo asm_info; - asm_info.method = cpu::AsmConvMethod::Im2Col; - asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d(); - asm_info.depth_output_gemm3d = info.depth_output_gemm3d(); - asm_info.activation_info = info.activation_info(); - asm_info.output_stage = info.gemmlowp_output_stage(); - asm_info.fast_mode = info.fast_math(); + asm_info.method = cpu::AsmConvMethod::Im2Col; + asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d(); + asm_info.depth_output_gemm3d = info.depth_output_gemm3d(); + asm_info.activation_info = info.activation_info(); + asm_info.output_stage = info.gemmlowp_output_stage(); + asm_info.fast_mode = info.fast_math(); + asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run(); return asm_info; } diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index ab668681ad..8ff81afe54 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -157,8 +157,8 @@ public: const std::vector<int32_t> &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; bool isVarWeightsKernel() const override @@ -211,12 +211,12 @@ private: /** Indirect buffer */ std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template <typename TypeInput, typename TypeOutput, class OutputStage> @@ -767,6 +767,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false"); #ifndef __aarch64__ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 691eeff8d2..0c51c92359 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -54,6 +54,7 @@ struct AsmGemmInfo bool fast_mode{ false }; bool fixed_format{ false }; arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; + bool reshape_b_only_on_first_run{ true }; }; /** Assembly kernel glue */ |