diff options
author | Jonathan Deakin <jonathan.deakin@arm.com> | 2023-01-12 11:41:14 +0000 |
---|---|---|
committer | Jonathan Deakin <jonathan.deakin@arm.com> | 2023-02-01 08:05:35 +0000 |
commit | 464ed2087c2ce2d2e741cc1e1dc4bd49d06e7d26 (patch) | |
tree | da07a18be246742773a729e264080d9a9b314d59 /src/cpu/operators/CpuFullyConnected.cpp | |
parent | 7594f989963724e127c3e28210d60fed590b0524 (diff) | |
download | ComputeLibrary-464ed2087c2ce2d2e741cc1e1dc4bd49d06e7d26.tar.gz |
Remove fixed format strides hack
- Remove hack in CpuGemmAssemblyDispatch.cpp which tried to guess
strides for fixed format kernels. Instead, expect that strides will
have been correctly set on weights externally
- Update fixed format test fixtures to set the strides
- If the fixed format uses fast math mode, then weights should be of
type BFLOAT16. Change the validation logic to accept this.
Resolves: [ONCPUML-1131]
Co-authored-by: Milos Puzovic <Milos.Puzovic@arm.com>
Change-Id: I0f18d8b86b0f639be25fd122fa06a591e90645f2
Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8985
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuFullyConnected.cpp')
-rw-r--r-- | src/cpu/operators/CpuFullyConnected.cpp | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp index 3172644488..1e1598a8ee 100644 --- a/src/cpu/operators/CpuFullyConnected.cpp +++ b/src/cpu/operators/CpuFullyConnected.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -109,7 +109,7 @@ Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo return Status{}; } -Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math) +Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math, WeightFormat weight_format) { if(is_data_type_quantized_asymmetric(src->data_type())) { @@ -137,6 +137,8 @@ Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITe else { GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); + gemm_info.set_weight_format(weight_format); + gemm_info.set_fixed_format(weight_format != WeightFormat::UNSPECIFIED); gemm_info.set_fast_math(enable_fast_math); ARM_COMPUTE_RETURN_ON_ERROR(CpuGemm::validate(src, weights, biases, dst, 1.f, 1.0f, gemm_info)); } @@ -240,7 +242,8 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei weights, biases != nullptr ? biases : nullptr, dst, - fc_info)); + fc_info, + weights_info)); ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, fc_info); _needs_weights_conversion = false; @@ -352,12 +355,23 @@ Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weigh } Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - FullyConnectedLayerInfo fc_info) + FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info) { ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst); + + if (is_fixed_format_fast_math(weights_info.weight_format())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(src, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(weights, DataType::BFLOAT16); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst); + } + ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); @@ -436,7 +450,7 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension(1)); } // Validate matrix multiply kernel - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math, weights_info.weight_format())); return Status{}; } |