diff options
author | Milos Puzovic <milos.puzovic@arm.com> | 2022-07-27 17:53:21 +0000 |
---|---|---|
committer | Gunes Bayir <gunes.bayir@arm.com> | 2022-08-03 16:47:58 +0000 |
commit | 13b623e575ed2f1096c70560a2db4a9e03cf22f9 (patch) | |
tree | a517d94c55cfc803c7e13dc89090bf2b3be4dc41 /src/cpu/operators/CpuGemmDirectConv2d.cpp | |
parent | 3c4d085da54c3d9727cb31718c5b407c18ff646a (diff) | |
download | ComputeLibrary-13b623e575ed2f1096c70560a2db4a9e03cf22f9.tar.gz |
[ONCPUML-968] Fixed format kernel support in additional APIs
Implements required plumbing in order to be able to ask and execute
fixed format kernels from NEFullyConnected, NEGEMM and NEGEMMConv2d.
These APIs are used to accelerate oneDNN primitives (inner product, matrix
multiplication and indirect GEMM respectively) and without changes it
would not be possible to call fixed format kernels from those oneDNN
primitives.
Change-Id: I27534f0491ce28d0ccb98c19f318bd33dcdf2ff5
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7999
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemmDirectConv2d.cpp | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp index fd1a042bb4..ee47a17d64 100644 --- a/src/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp @@ -57,11 +57,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - PixelValue type_min{}; - PixelValue type_max{}; + PixelValue type_min{}; + PixelValue type_max{}; std::tie(type_min, type_max) = get_min_max(data_type); - int32_t min_activation = type_min.get<int32_t>(); - int32_t max_activation = type_max.get<int32_t>(); + int32_t min_activation = type_min.get<int32_t>(); + int32_t max_activation = type_max.get<int32_t>(); if(supported_acts.count(act.activation()) != 0) { std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo); @@ -88,6 +88,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect asm_info.padding_value = 0.f; asm_info.negated_offsets = false; asm_info.fast_mode = info.enable_fast_math; + asm_info.fixed_format = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED; + asm_info.weight_format = info.weights_info.weight_format(); return asm_info; } } // namespace @@ -146,7 +148,9 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w } else { - _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); + // We must permute weights if they are WeightFormat::UNSPECIFIED + if(info.weights_info.weight_format() == WeightFormat::UNSPECIFIED) + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); } } Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info) @@ -203,6 +207,13 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { + // If we are using fixed-format kernel the weights are already reshaped + if(_gemm_asm_func && _gemm_asm_func->isVarWeightsKernel()) + { + _gemm_asm_func->prepare(tensors); + _is_prepared = true; + return; + } const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1); ITensor *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights))); ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux); @@ -224,4 +235,4 @@ experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const return _aux_mem; } } // namespace cpu -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |