From 464ed2087c2ce2d2e741cc1e1dc4bd49d06e7d26 Mon Sep 17 00:00:00 2001 From: Jonathan Deakin Date: Thu, 12 Jan 2023 11:41:14 +0000 Subject: Remove fixed format strides hack - Remove hack in CpuGemmAssemblyDispatch.cpp which tried to guess strides for fixed format kernels. Instead, expect that strides will have been correctly set on weights externally - Update fixed format test fixtures to set the strides - If the fixed format uses fast math mode, then weights should be of type BFLOAT16. Change the validation logic to accept this. Resolves: [ONCPUML-1131] Co-authored-by: Milos Puzovic Change-Id: I0f18d8b86b0f639be25fd122fa06a591e90645f2 Signed-off-by: Jonathan Deakin Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8985 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Benchmark: Arm Jenkins --- .../runtime/NEON/functions/NEFullyConnectedLayer.h | 6 +- src/cpu/operators/CpuFullyConnected.cpp | 26 ++++++-- src/cpu/operators/CpuFullyConnected.h | 6 +- src/cpu/operators/CpuGemm.cpp | 14 ++++- src/cpu/operators/CpuGemmConv2d.cpp | 11 +++- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 73 ++++++---------------- .../NEON/functions/NEFullyConnectedLayer.cpp | 9 +-- tests/validation/NEON/ConvolutionLayer.cpp | 7 +-- .../validation/fixtures/ConvolutionLayerFixture.h | 17 ++++- 9 files changed, 87 insertions(+), 82 deletions(-) diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h index 2b4f848b22..6a4de2e311 100644 --- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h +++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -129,12 +129,12 @@ public: FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will lead to a valid configuration of @ref NEFullyConnectedLayer * - * Similar to @ref NEFullyConnectedLayer + * Similar to @ref NEFullyConnectedLayer::configure() * * @return a status */ static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function that queries whether fixed-format kernel exists for a given problem description * diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp index 3172644488..1e1598a8ee 100644 --- a/src/cpu/operators/CpuFullyConnected.cpp +++ b/src/cpu/operators/CpuFullyConnected.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -109,7 +109,7 @@ Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo return Status{}; } -Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math) +Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math, WeightFormat weight_format) { if(is_data_type_quantized_asymmetric(src->data_type())) { @@ -137,6 +137,8 @@ Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITe else { GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); + gemm_info.set_weight_format(weight_format); + gemm_info.set_fixed_format(weight_format != WeightFormat::UNSPECIFIED); gemm_info.set_fast_math(enable_fast_math); ARM_COMPUTE_RETURN_ON_ERROR(CpuGemm::validate(src, weights, biases, dst, 1.f, 1.0f, gemm_info)); } @@ -240,7 +242,8 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei weights, biases != nullptr ? biases : nullptr, dst, - fc_info)); + fc_info, + weights_info)); ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, fc_info); _needs_weights_conversion = false; @@ -352,12 +355,23 @@ Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weigh } Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - FullyConnectedLayerInfo fc_info) + FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info) { ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst); + + if (is_fixed_format_fast_math(weights_info.weight_format())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(src, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(weights, DataType::BFLOAT16); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst); + } + ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); @@ -436,7 +450,7 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension(1)); } // Validate matrix multiply kernel - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math, weights_info.weight_format())); return Status{}; } diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h index 36511e9d32..9cd67f2ca6 100644 --- a/src/cpu/operators/CpuFullyConnected.h +++ b/src/cpu/operators/CpuFullyConnected.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -89,12 +89,12 @@ public: FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected * - * Similar to @ref CpuFullyConnected + * Similar to @ref CpuFullyConnected::configure() * * @return a status */ static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index a17e4f31d5..545d59f410 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -158,7 +158,17 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); + + if (is_fixed_format_fast_math(gemm_info.weight_format())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); + } + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported"); diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp index f3a16f104f..9bf6ed1e85 100644 --- a/src/cpu/operators/CpuGemmConv2d.cpp +++ b/src/cpu/operators/CpuGemmConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -427,7 +427,12 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!"); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights); + + if (!is_fixed_format(weights_info.weight_format())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights); + } + ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported"); const DataLayout data_layout = src->data_layout(); @@ -493,7 +498,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight unsigned int mat_weights_cols = weights->dimension(idx_kernels); unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel); - weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, data_type); + weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, weights->data_type()); weights_reshaped_info.set_quantization_info(weights->quantization_info()); weights_to_use = &weights_reshaped_info; diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 8ff81afe54..bf3ec5a1ac 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -430,9 +430,9 @@ void Fallback::prepare(ITensorPack &tensors) { // Fixed format kernels need no pretranspose. ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); - const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size(); const auto in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); - const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); @@ -470,21 +470,21 @@ void Fallback::run(ITensorPack &tensors) auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); auto d = tensors.get_tensor(TensorType::ACL_DST); - int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput); + int lda = a->info()->strides_in_bytes().y() / a->info()->element_size(); int ldb = 0; - const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput); + const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size(); const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2; const size_t a_multi_idx = a_batch_idx + 1; const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2; const size_t d_multi_idx = d_batch_idx + 1; - int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); - const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); + int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size(); + const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size(); - int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); + int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size(); int multi_stride_b = 0; - const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); + const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size(); auto in0_ptr = reinterpret_cast(a->buffer() + a->info()->offset_first_element_in_bytes()); const TypeInput *in1_ptr = nullptr; @@ -493,50 +493,8 @@ void Fallback::run(ITensorPack &tensors) // Check if B is pre-tranposed and de-reference if not if(!_gemm_kernel_asm->B_is_pretransposed()) { - ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); - multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); - const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); - if(is_fixed_format(wf)) - { - // The 4D tensor of dimension O'HWI' created for the - // OHWIoi format is in reality seen - // as a 2D tensor at arm_gemm level, where the rows are - // O'/ and the columns are * - // H * W * I'. - ITensorInfo *tensor_info = b->info(); - const DataLayout data_layout = tensor_info->data_layout(); - const TensorShape tensor_shape = tensor_info->tensor_shape(); - const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; - const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; - int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; - const int interleave_by = arm_compute::interleave_by(wf); - const int blocked_by = arm_compute::block_by(wf); - // We need to find a new stride that is distance from the data for one - // set of output channels to the next - if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) - { - // In this case dimensions that are packed are height, width and channel - // so we need to stride it by interleave_by - if(tensor_channels % blocked_by != 0) - { - // We need to pad - tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by; - } - ldb = interleave_by * tensor_height * tensor_width * tensor_channels; - } - else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) - { - // In this case dimension that is packed is only height - // so we need to stride only height by interleave_by - ldb = interleave_by * tensor_height; - } - else - { - // If dimensions are not packed as above error is thrown - // as at the moment other forms of packing are not supported - ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel"); - } - } + ldb = b->info()->strides_in_bytes().y() / b->info()->element_size(); + multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size(); in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); } @@ -551,9 +509,9 @@ void Fallback::run(ITensorPack &tensors) // Pretranspose B if required if(_B_pretranspose_required) { - const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size(); const auto b_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); - const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); @@ -780,6 +738,11 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8); } + else if(is_fixed_format_fast_math(info.weight_format)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); + } else { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 4f858fb54b..919e5ed84f 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -69,7 +69,8 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), - fc_info)); + fc_info, + weights_info)); ARM_COMPUTE_LOG_PARAMS(input, weights, biases, output, fc_info); _impl->op = std::make_unique(); @@ -96,9 +97,9 @@ Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_w } Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, - FullyConnectedLayerInfo fc_info) + FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info) { - return cpu::CpuFullyConnected::validate(input, weights, biases, output, fc_info); + return cpu::CpuFullyConnected::validate(input, weights, biases, output, fc_info, weights_info); } void NEFullyConnectedLayer::run() diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index f03ac45bfc..08b6a02375 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -756,9 +756,8 @@ DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL, const DataType DT = DataType::F32; const DataLayout DL = DataLayout::NHWC; const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL); - const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf); - const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL); - ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS); + const TensorInfo computed_info = ::arm_compute::test::validation::prepare_weights(TI, wf); + ARM_COMPUTE_EXPECT_EQUAL(computed_info.tensor_shape(), expected_shape, framework::LogLevel::ERRORS); } TEST_SUITE_END() // VariableWeightUtils diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 63e6dc9377..5b8963ebfe 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -416,8 +416,21 @@ inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_comput const int Ip = arm_gemm::roundup(C, block_by); // C'=I' const int Op = arm_gemm::roundup(N, interleave_by); // O'=N' + arm_compute::Strides strides_in_bytes = tensor_info.strides_in_bytes(); + strides_in_bytes.set(1, Ip * interleave_by * H * W * tensor_info.element_size()); + strides_in_bytes.set(2, Ip * Op * tensor_info.element_size()); + + const size_t offset_first_element_in_bytes = tensor_info.offset_first_element_in_bytes(); + + // Total size needs to include padded dimensions + const size_t total_size_in_bytes = Op * H * W * Ip * tensor_info.element_size(); + const TensorShape TS(Ip, W, H, Op); - return TensorInfo(TS, 1 /*num_channels*/, data_type, data_layout); + + TensorInfo new_tensor_info = tensor_info; + new_tensor_info.init(TS, 1 /*num_channels, deprecated*/, data_type, strides_in_bytes, + offset_first_element_in_bytes, total_size_in_bytes); + return new_tensor_info; } template -- cgit v1.2.1