From 91780021e25575086c6c31d014d34b6513649a9d Mon Sep 17 00:00:00 2001 From: Ramy Elgammal Date: Wed, 20 Jul 2022 14:57:37 +0100 Subject: Fix for inclusion of "arm_gemm" from src into "Types.h" from core - Added arm_compute::WeightFormat and converted to/from arm_gemm::WeightFormat when needed through two map function. - Moved to_string(WeightFormat) to TypePrinter.h Resolves: COMPMID-5415 Signed-off-by: Ramy Elgammal Change-Id: I65f7942100bcd4dbf2c5cf6c07f26c8e1e3bf86e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/438511 Tested-by: bsgcomp Reviewed-by: Pablo Tello Reviewed-by: Sicong Li Comments-Addressed: bsgcomp Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7985 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Michalis Spyrou Benchmark: Arm Jenkins --- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 43 ++++++++++---------- .../operators/internal/CpuGemmAssemblyDispatch.h | 46 +++++++++++----------- 2 files changed, 45 insertions(+), 44 deletions(-) (limited to 'src/cpu/operators/internal') diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 558ff41a5c..c969c9f4f6 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -164,8 +164,8 @@ public: { if(!_gemm_kernel_asm) return false; - const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format; - return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY; + const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); + return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; } private: @@ -428,7 +428,7 @@ void Fallback::prepare(ITensorPack &tensors) if(_gemm_kernel_asm->B_pretranspose_required()) { // Fixed format kernels need no pretranspose. - ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format)); + ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); const auto in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); @@ -492,8 +492,8 @@ void Fallback::run(ITensorPack &tensors) // Check if B is pre-tranposed and de-reference if not if(!_gemm_kernel_asm->B_is_pretransposed()) { - ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); - const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format; + ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); if(is_fixed_format(wf)) { // The 4D tensor of dimension O'HWI' created for the @@ -507,7 +507,7 @@ void Fallback::run(ITensorPack &tensors) const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; - const int interleave_by = arm_gemm::interleave_by(wf); + const int interleave_by = arm_compute::interleave_by(wf); ldb = (interleave_by * H * W * Ip); } multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); @@ -603,7 +603,7 @@ void create_arm_gemm(std::unique_ptr &arm_ge unsigned int num_threads = NEScheduler::get().num_threads(); arm_gemm::GemmConfig cfg; - cfg.weight_format = info.weight_format; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); // Create arm_gemm fallback @@ -623,7 +623,7 @@ void create_arm_gemm_quant(std::unique_ptr & const unsigned int num_threads = NEScheduler::get().num_threads(); arm_gemm::GemmConfig cfg; - cfg.weight_format = info.weight_format; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); // Create arm_gemm fallback @@ -665,7 +665,7 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() { } -Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, +Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); @@ -675,13 +675,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); arm_gemm::GemmConfig cfg; - cfg.weight_format = info.weight_format; - - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg); + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); + arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg); switch(a->data_type()) { case DataType::F32: - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(expected_weight_format, args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F32 input"); break; #ifdef __aarch64__ @@ -689,12 +689,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we case DataType::QASYMM8: if(d->data_type() == DataType::S32) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(expected_weight_format, args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8/QASYMM8 input and S32 output"); } else { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(expected_weight_format, args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8 input and U8 output"); } break; @@ -702,12 +702,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we case DataType::QASYMM8_SIGNED: if(d->data_type() == DataType::S32) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(expected_weight_format, args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); } else { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(expected_weight_format, args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8 input and S32 output"); } break; @@ -722,7 +722,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(expected_weight_format, args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ @@ -730,6 +730,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); break; } + expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf); return Status{}; } @@ -762,9 +763,9 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - arm_gemm::WeightFormat expected_weight_format; - const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); - if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY) + arm_compute::WeightFormat expected_weight_format; + const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); + if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) { // Correctness check: if the format expected by the kernel is // not "any", make sure that the one found matches the format diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 4ef108d430..691eeff8d2 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -41,19 +41,19 @@ enum class AsmConvMethod struct AsmGemmInfo { - AsmConvMethod method{ AsmConvMethod::Im2Col }; - PadStrideInfo ps_info{}; - ActivationLayerInfo activation_info{}; - GEMMLowpOutputStageInfo output_stage{}; - bool negated_offsets{ true }; - bool reinterpret_input_as_3d{ false }; - bool depth_output_gemm3d{ false }; - int64_t padding_top{ 0 }; - int64_t padding_left{ 0 }; - float padding_value{ 0.f }; - bool fast_mode{ false }; - bool fixed_format{ false }; - arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED }; + AsmConvMethod method{ AsmConvMethod::Im2Col }; + PadStrideInfo ps_info{}; + ActivationLayerInfo activation_info{}; + GEMMLowpOutputStageInfo output_stage{}; + bool negated_offsets{ true }; + bool reinterpret_input_as_3d{ false }; + bool depth_output_gemm3d{ false }; + int64_t padding_top{ 0 }; + int64_t padding_left{ 0 }; + float padding_value{ 0.f }; + bool fast_mode{ false }; + bool fixed_format{ false }; + arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; }; /** Assembly kernel glue */ @@ -70,12 +70,12 @@ public: class IFallback { public: - virtual void run(ITensorPack &tensors) = 0; - virtual void prepare(ITensorPack &tensors) = 0; - virtual experimental::MemoryRequirements workspace() const = 0; - virtual bool is_configured() const = 0; - virtual bool isVarWeightsKernel() const = 0; - virtual ~IFallback() = default; + virtual void run(ITensorPack &tensors) = 0; + virtual void prepare(ITensorPack &tensors) = 0; + virtual experimental::MemoryRequirements workspace() const = 0; + virtual bool is_configured() const = 0; + virtual bool isVarWeightsKernel() const = 0; + virtual ~IFallback() = default; }; public: @@ -105,12 +105,12 @@ public: * * This method has the same use of @ref * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that - * the value of arm_gemm::WeightFormat need to be passed via the + * the value of arm_compute::WeightFormat need to be passed via the * parameter info. * * @return a status. */ - static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); + static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); /** Checks if activation is supported by the gemm assembly dispatcher * * @param[in] activation Activation to check @@ -133,8 +133,8 @@ public: } // Inherited methods overridden: - void prepare(ITensorPack &tensors) override; - void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: -- cgit v1.2.1