aboutsummaryrefslogtreecommitdiff
path: root/src/cpu
diff options
context:
space:
mode:
authorMilos Puzovic <milos.puzovic@arm.com>2022-07-27 17:53:21 +0000
committerGunes Bayir <gunes.bayir@arm.com>2022-08-03 16:47:58 +0000
commit13b623e575ed2f1096c70560a2db4a9e03cf22f9 (patch)
treea517d94c55cfc803c7e13dc89090bf2b3be4dc41 /src/cpu
parent3c4d085da54c3d9727cb31718c5b407c18ff646a (diff)
downloadComputeLibrary-13b623e575ed2f1096c70560a2db4a9e03cf22f9.tar.gz
[ONCPUML-968] Fixed format kernel support in additional APIs
Implements required plumbing in order to be able to ask and execute fixed format kernels from NEFullyConnected, NEGEMM and NEGEMMConv2d. These APIs are used to accelerate oneDNN primitives (inner product, matrix multiplication and indirect GEMM respectively) and without changes it would not be possible to call fixed format kernels from those oneDNN primitives. Change-Id: I27534f0491ce28d0ccb98c19f318bd33dcdf2ff5 Signed-off-by: Milos Puzovic <milos.puzovic@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7999 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu')
-rw-r--r--src/cpu/operators/CpuFullyConnected.cpp43
-rw-r--r--src/cpu/operators/CpuFullyConnected.h55
-rw-r--r--src/cpu/operators/CpuGemmDirectConv2d.cpp23
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp55
4 files changed, 115 insertions, 61 deletions
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 6d77c614f7..3172644488 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,7 +53,7 @@ std::pair<PixelValue, PixelValue> get_quantized_asymmetric_output_min_max(const
{
PixelValue type_min{};
PixelValue type_max{};
- std::tie(type_min, type_max) = get_min_max(data_type);
+ std::tie(type_min, type_max) = get_min_max(data_type);
const UniformQuantizationInfo q_unif = q_info.uniform();
if(act_info.enabled())
@@ -162,8 +162,9 @@ CpuFullyConnected::CpuFullyConnected()
_is_fc_after_conv(false),
_is_quantized_asymmetric(false),
_is_prepared(false),
- _enable_fast_math(false)
-
+ _enable_fast_math(false),
+ _fixed_format(false),
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
{
}
@@ -199,6 +200,8 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo *
GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
gemm_info.set_activation_info(act);
gemm_info.set_fast_math(_enable_fast_math);
+ gemm_info.set_fixed_format(_fixed_format);
+ gemm_info.set_weight_format(_weight_format);
_mm_gemm = std::make_unique<CpuGemm>();
_mm_gemm->configure(src, weights, biases, dst, 1.f, 1.0f, gemm_info);
}
@@ -229,7 +232,7 @@ void CpuFullyConnected::configure_fc_fc(const ITensorInfo *src, const ITensorInf
}
void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst,
- FullyConnectedLayerInfo fc_info)
+ FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info)
{
// Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
@@ -248,6 +251,8 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
_is_prepared = false;
_trans_weights_idx = AuxTensorIdx::Count;
_enable_fast_math = fc_info.enable_fast_math;
+ _fixed_format = weights_info.weight_format() != WeightFormat::UNSPECIFIED;
+ _weight_format = weights_info.weight_format();
// With the Fully Connected layer we can have 4 different cases:
// 1) Convolution layer -> Fully Connected layer without batches
@@ -261,9 +266,7 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
const bool is_batched_fc_layer = dst->dimension(1) > 1;
if(is_batched_fc_layer)
{
- _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
- src->tensor_shape().cend(),
- dst->tensor_shape().cbegin() + 1));
+ _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1));
}
else
{
@@ -323,12 +326,10 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
{
// Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch
// Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation
- _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric
- && biases && !(biases->are_values_constant())) ?
- MemoryLifetime::Persistent :
- MemoryLifetime::Prepare,
+ _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric && biases
+ && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare,
_reshaped_weights.total_size());
- _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
+ _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
}
else
{
@@ -338,6 +339,18 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
_aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size());
}
+Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights,
+ const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info, WeightsInfo weights_info)
+{
+ GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+ gemm_info.set_activation_info(fc_info.activation_info);
+ gemm_info.set_fast_math(fc_info.enable_fast_math);
+ gemm_info.set_fixed_format(weights_info.weight_format() != WeightFormat::UNSPECIFIED);
+ gemm_info.set_weight_format(weights_info.weight_format());
+
+ return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
+}
+
Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
FullyConnectedLayerInfo fc_info)
{
@@ -384,9 +397,7 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we
if(is_batched_fc_layer)
{
- is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
- src->tensor_shape().cend(),
- dst->tensor_shape().cbegin() + 1));
+ is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1));
}
else
{
diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h
index 44fa21f9f8..36511e9d32 100644
--- a/src/cpu/operators/CpuFullyConnected.h
+++ b/src/cpu/operators/CpuFullyConnected.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -72,20 +72,21 @@ public:
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
*
- * @param[in] src Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
- * @param[in] weights Weights tensor info. The weights must be 2 dimensional.
- * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
- * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
- * Data type supported: Same as @p src.
- * @param[in] biases Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
- * @param[out] dst Destination tensor info. Its shape should be equal to the output of a matrix multiplication between:
- * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
- * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
- * Data type supported: Same as @p src.
- * @param[in] fc_info (Optional) Fully connected layer additional info
+ * @param[in] src Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
+ * @param[in] weights Weights tensor info. The weights must be 2 dimensional.
+ * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
+ * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
+ * Data type supported: Same as @p src.
+ * @param[in] biases Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
+ * @param[out] dst Destination tensor info. Its shape should be equal to the output of a matrix multiplication between:
+ * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
+ * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
+ * Data type supported: Same as @p src.
+ * @param[in] fc_info (Optional) Fully connected layer additional info
+ * @param[in] weights_info (Optional) Stores neccessary compute information when weights are already reshaped
*/
void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst,
- FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected
*
* Similar to @ref CpuFullyConnected
@@ -95,9 +96,19 @@ public:
static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
+ /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format
+ * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same
+ * as in @ref CpuFullyConnectedLayer::validate() except that all arguments are required.
+ *
+ * @return a status
+ */
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights,
+ const ITensorInfo *biases, const ITensorInfo *dst,
+ FullyConnectedLayerInfo fc_info, WeightsInfo weights_info);
+
//Inherited methods override
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
@@ -136,12 +147,14 @@ private:
experimental::MemoryRequirements _aux_mem;
- bool _needs_weights_conversion;
- bool _needs_weights_reshape;
- bool _is_fc_after_conv;
- bool _is_quantized_asymmetric;
- bool _is_prepared;
- bool _enable_fast_math;
+ bool _needs_weights_conversion;
+ bool _needs_weights_reshape;
+ bool _is_fc_after_conv;
+ bool _is_quantized_asymmetric;
+ bool _is_prepared;
+ bool _enable_fast_math;
+ bool _fixed_format;
+ arm_compute::WeightFormat _weight_format;
};
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp
index fd1a042bb4..ee47a17d64 100644
--- a/src/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -57,11 +57,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
};
- PixelValue type_min{};
- PixelValue type_max{};
+ PixelValue type_min{};
+ PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
if(supported_acts.count(act.activation()) != 0)
{
std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo);
@@ -88,6 +88,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
asm_info.padding_value = 0.f;
asm_info.negated_offsets = false;
asm_info.fast_mode = info.enable_fast_math;
+ asm_info.fixed_format = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED;
+ asm_info.weight_format = info.weights_info.weight_format();
return asm_info;
}
} // namespace
@@ -146,7 +148,9 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w
}
else
{
- _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
+ // We must permute weights if they are WeightFormat::UNSPECIFIED
+ if(info.weights_info.weight_format() == WeightFormat::UNSPECIFIED)
+ _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
}
}
Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -203,6 +207,13 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
+ // If we are using fixed-format kernel the weights are already reshaped
+ if(_gemm_asm_func && _gemm_asm_func->isVarWeightsKernel())
+ {
+ _gemm_asm_func->prepare(tensors);
+ _is_prepared = true;
+ return;
+ }
const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1);
ITensor *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
@@ -224,4 +235,4 @@ experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
return _aux_mem;
}
} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 45b3232423..df02d649f8 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -156,8 +156,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
bool isVarWeightsKernel() const override
@@ -210,12 +210,12 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -493,6 +493,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
if(!_gemm_kernel_asm->B_is_pretransposed())
{
ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
if(is_fixed_format(wf))
{
@@ -501,17 +502,35 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// as a 2D tensor at arm_gemm level, where the rows are
// O'/<interleave_by> and the columns are <interleave_by> *
// H * W * I'.
- ITensorInfo *tensor_info = b->info();
- const DataLayout data_layout = tensor_info->data_layout();
- const TensorShape tensor_shape = tensor_info->tensor_shape();
- const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
- const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
- const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_compute::interleave_by(wf);
- ldb = (interleave_by * H * W * Ip);
+ ITensorInfo *tensor_info = b->info();
+ const DataLayout data_layout = tensor_info->data_layout();
+ const TensorShape tensor_shape = tensor_info->tensor_shape();
+ const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
+ const int interleave_by = arm_compute::interleave_by(wf);
+ // We need to find a new stride that is distance from the data for one
+ // set of output channels to the next
+ if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width)
+ {
+ // In this case dimensions that are packed are height, width and channel
+ // so we need to stride it by interleave_by
+ ldb = interleave_by * tensor_height * tensor_width * tensor_channels;
+ }
+ else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width))
+ {
+ // In this case dimension that is packed is only height
+ // so we need to stride only height by interleave_by
+ ldb = interleave_by * tensor_height;
+ }
+ else
+ {
+ // If dimensions are not packed as above error is thrown
+ // as at the moment other forms of packing are not supported
+ ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel");
+ }
}
- multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant