aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmDirectConv2d.cpp
diff options
context:
space:
mode:
authorMilos Puzovic <milos.puzovic@arm.com>2022-07-27 17:53:21 +0000
committerGunes Bayir <gunes.bayir@arm.com>2022-08-03 16:47:58 +0000
commit13b623e575ed2f1096c70560a2db4a9e03cf22f9 (patch)
treea517d94c55cfc803c7e13dc89090bf2b3be4dc41 /src/cpu/operators/CpuGemmDirectConv2d.cpp
parent3c4d085da54c3d9727cb31718c5b407c18ff646a (diff)
downloadComputeLibrary-13b623e575ed2f1096c70560a2db4a9e03cf22f9.tar.gz
[ONCPUML-968] Fixed format kernel support in additional APIs
Implements required plumbing in order to be able to ask and execute fixed format kernels from NEFullyConnected, NEGEMM and NEGEMMConv2d. These APIs are used to accelerate oneDNN primitives (inner product, matrix multiplication and indirect GEMM respectively) and without changes it would not be possible to call fixed format kernels from those oneDNN primitives. Change-Id: I27534f0491ce28d0ccb98c19f318bd33dcdf2ff5 Signed-off-by: Milos Puzovic <milos.puzovic@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7999 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r--src/cpu/operators/CpuGemmDirectConv2d.cpp23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp
index fd1a042bb4..ee47a17d64 100644
--- a/src/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -57,11 +57,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
};
- PixelValue type_min{};
- PixelValue type_max{};
+ PixelValue type_min{};
+ PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
if(supported_acts.count(act.activation()) != 0)
{
std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo);
@@ -88,6 +88,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
asm_info.padding_value = 0.f;
asm_info.negated_offsets = false;
asm_info.fast_mode = info.enable_fast_math;
+ asm_info.fixed_format = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED;
+ asm_info.weight_format = info.weights_info.weight_format();
return asm_info;
}
} // namespace
@@ -146,7 +148,9 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w
}
else
{
- _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
+ // We must permute weights if they are WeightFormat::UNSPECIFIED
+ if(info.weights_info.weight_format() == WeightFormat::UNSPECIFIED)
+ _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
}
}
Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -203,6 +207,13 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
+ // If we are using fixed-format kernel the weights are already reshaped
+ if(_gemm_asm_func && _gemm_asm_func->isVarWeightsKernel())
+ {
+ _gemm_asm_func->prepare(tensors);
+ _is_prepared = true;
+ return;
+ }
const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1);
ITensor *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
@@ -224,4 +235,4 @@ experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
return _aux_mem;
}
} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute