aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmDirectConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r--src/cpu/operators/CpuGemmDirectConv2d.cpp23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp
index fd1a042bb4..ee47a17d64 100644
--- a/src/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -57,11 +57,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
};
- PixelValue type_min{};
- PixelValue type_max{};
+ PixelValue type_min{};
+ PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
if(supported_acts.count(act.activation()) != 0)
{
std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo);
@@ -88,6 +88,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
asm_info.padding_value = 0.f;
asm_info.negated_offsets = false;
asm_info.fast_mode = info.enable_fast_math;
+ asm_info.fixed_format = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED;
+ asm_info.weight_format = info.weights_info.weight_format();
return asm_info;
}
} // namespace
@@ -146,7 +148,9 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w
}
else
{
- _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
+ // We must permute weights if they are WeightFormat::UNSPECIFIED
+ if(info.weights_info.weight_format() == WeightFormat::UNSPECIFIED)
+ _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
}
}
Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -203,6 +207,13 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
+ // If we are using fixed-format kernel the weights are already reshaped
+ if(_gemm_asm_func && _gemm_asm_func->isVarWeightsKernel())
+ {
+ _gemm_asm_func->prepare(tensors);
+ _is_prepared = true;
+ return;
+ }
const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1);
ITensor *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
@@ -224,4 +235,4 @@ experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
return _aux_mem;
}
} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute