aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Deakin <jonathan.deakin@arm.com>2023-01-12 11:41:14 +0000
committerJonathan Deakin <jonathan.deakin@arm.com>2023-02-01 08:05:35 +0000
commit464ed2087c2ce2d2e741cc1e1dc4bd49d06e7d26 (patch)
treeda07a18be246742773a729e264080d9a9b314d59
parent7594f989963724e127c3e28210d60fed590b0524 (diff)
downloadComputeLibrary-464ed2087c2ce2d2e741cc1e1dc4bd49d06e7d26.tar.gz
Remove fixed format strides hack
- Remove hack in CpuGemmAssemblyDispatch.cpp which tried to guess strides for fixed format kernels. Instead, expect that strides will have been correctly set on weights externally - Update fixed format test fixtures to set the strides - If the fixed format uses fast math mode, then weights should be of type BFLOAT16. Change the validation logic to accept this. Resolves: [ONCPUML-1131] Co-authored-by: Milos Puzovic <Milos.Puzovic@arm.com> Change-Id: I0f18d8b86b0f639be25fd122fa06a591e90645f2 Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8985 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h6
-rw-r--r--src/cpu/operators/CpuFullyConnected.cpp26
-rw-r--r--src/cpu/operators/CpuFullyConnected.h6
-rw-r--r--src/cpu/operators/CpuGemm.cpp14
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp11
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp73
-rw-r--r--src/runtime/NEON/functions/NEFullyConnectedLayer.cpp9
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp7
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h17
9 files changed, 87 insertions, 82 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
index 2b4f848b22..6a4de2e311 100644
--- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -129,12 +129,12 @@ public:
FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will lead to a valid configuration of @ref NEFullyConnectedLayer
*
- * Similar to @ref NEFullyConnectedLayer
+ * Similar to @ref NEFullyConnectedLayer::configure()
*
* @return a status
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
- FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
/** Static function that queries whether fixed-format kernel exists for a given problem description
*
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 3172644488..1e1598a8ee 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -109,7 +109,7 @@ Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo
return Status{};
}
-Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math)
+Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math, WeightFormat weight_format)
{
if(is_data_type_quantized_asymmetric(src->data_type()))
{
@@ -137,6 +137,8 @@ Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITe
else
{
GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+ gemm_info.set_weight_format(weight_format);
+ gemm_info.set_fixed_format(weight_format != WeightFormat::UNSPECIFIED);
gemm_info.set_fast_math(enable_fast_math);
ARM_COMPUTE_RETURN_ON_ERROR(CpuGemm::validate(src, weights, biases, dst, 1.f, 1.0f, gemm_info));
}
@@ -240,7 +242,8 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
weights,
biases != nullptr ? biases : nullptr,
dst,
- fc_info));
+ fc_info,
+ weights_info));
ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, fc_info);
_needs_weights_conversion = false;
@@ -352,12 +355,23 @@ Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weigh
}
Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- FullyConnectedLayerInfo fc_info)
+ FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info)
{
ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst);
+
+ if (is_fixed_format_fast_math(weights_info.weight_format()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(src, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(weights, DataType::BFLOAT16);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst);
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
&& fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
@@ -436,7 +450,7 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we
ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension(1));
}
// Validate matrix multiply kernel
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math, weights_info.weight_format()));
return Status{};
}
diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h
index 36511e9d32..9cd67f2ca6 100644
--- a/src/cpu/operators/CpuFullyConnected.h
+++ b/src/cpu/operators/CpuFullyConnected.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -89,12 +89,12 @@ public:
FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected
*
- * Similar to @ref CpuFullyConnected
+ * Similar to @ref CpuFullyConnected::configure()
*
* @return a status
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
/** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format
* weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index a17e4f31d5..545d59f410 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -158,7 +158,17 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+
+ if (is_fixed_format_fast_math(gemm_info.weight_format()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index f3a16f104f..9bf6ed1e85 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -427,7 +427,12 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+
+ if (!is_fixed_format(weights_info.weight_format()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported");
const DataLayout data_layout = src->data_layout();
@@ -493,7 +498,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
unsigned int mat_weights_cols = weights->dimension(idx_kernels);
unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel);
- weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, data_type);
+ weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, weights->data_type());
weights_reshaped_info.set_quantization_info(weights->quantization_info());
weights_to_use = &weights_reshaped_info;
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 8ff81afe54..bf3ec5a1ac 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -430,9 +430,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -470,21 +470,21 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
auto d = tensors.get_tensor(TensorType::ACL_DST);
- int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ int lda = a->info()->strides_in_bytes().y() / a->info()->element_size();
int ldb = 0;
- const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
+ const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size();
const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
const size_t a_multi_idx = a_batch_idx + 1;
const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
const size_t d_multi_idx = d_batch_idx + 1;
- int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
- const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
+ int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size();
+ const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size();
- int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
+ int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size();
int multi_stride_b = 0;
- const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
+ const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
const TypeInput *in1_ptr = nullptr;
@@ -493,50 +493,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
- if(is_fixed_format(wf))
- {
- // The 4D tensor of dimension O'HWI' created for the
- // OHWIo<interleave_by>i<block_by> format is in reality seen
- // as a 2D tensor at arm_gemm level, where the rows are
- // O'/<interleave_by> and the columns are <interleave_by> *
- // H * W * I'.
- ITensorInfo *tensor_info = b->info();
- const DataLayout data_layout = tensor_info->data_layout();
- const TensorShape tensor_shape = tensor_info->tensor_shape();
- const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
- const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
- int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_compute::interleave_by(wf);
- const int blocked_by = arm_compute::block_by(wf);
- // We need to find a new stride that is distance from the data for one
- // set of output channels to the next
- if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width)
- {
- // In this case dimensions that are packed are height, width and channel
- // so we need to stride it by interleave_by
- if(tensor_channels % blocked_by != 0)
- {
- // We need to pad
- tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by;
- }
- ldb = interleave_by * tensor_height * tensor_width * tensor_channels;
- }
- else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width))
- {
- // In this case dimension that is packed is only height
- // so we need to stride only height by interleave_by
- ldb = interleave_by * tensor_height;
- }
- else
- {
- // If dimensions are not packed as above error is thrown
- // as at the moment other forms of packing are not supported
- ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel");
- }
- }
+ ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
+ multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
@@ -551,9 +509,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Pretranspose B if required
if(_B_pretranspose_required)
{
- const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
const auto b_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -780,6 +738,11 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8);
}
+ else if(is_fixed_format_fast_math(info.weight_format))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+ }
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 4f858fb54b..919e5ed84f 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -69,7 +69,8 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
weights->info(),
biases != nullptr ? biases->info() : nullptr,
output->info(),
- fc_info));
+ fc_info,
+ weights_info));
ARM_COMPUTE_LOG_PARAMS(input, weights, biases, output, fc_info);
_impl->op = std::make_unique<cpu::CpuFullyConnected>();
@@ -96,9 +97,9 @@ Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_w
}
Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
- FullyConnectedLayerInfo fc_info)
+ FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info)
{
- return cpu::CpuFullyConnected::validate(input, weights, biases, output, fc_info);
+ return cpu::CpuFullyConnected::validate(input, weights, biases, output, fc_info, weights_info);
}
void NEFullyConnectedLayer::run()
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index f03ac45bfc..08b6a02375 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -756,9 +756,8 @@ DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL,
const DataType DT = DataType::F32;
const DataLayout DL = DataLayout::NHWC;
const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL);
- const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf);
- const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL);
- ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS);
+ const TensorInfo computed_info = ::arm_compute::test::validation::prepare_weights(TI, wf);
+ ARM_COMPUTE_EXPECT_EQUAL(computed_info.tensor_shape(), expected_shape, framework::LogLevel::ERRORS);
}
TEST_SUITE_END() // VariableWeightUtils
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 63e6dc9377..5b8963ebfe 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -416,8 +416,21 @@ inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_comput
const int Ip = arm_gemm::roundup<unsigned int>(C, block_by); // C'=I'
const int Op = arm_gemm::roundup<unsigned int>(N, interleave_by); // O'=N'
+ arm_compute::Strides strides_in_bytes = tensor_info.strides_in_bytes();
+ strides_in_bytes.set(1, Ip * interleave_by * H * W * tensor_info.element_size());
+ strides_in_bytes.set(2, Ip * Op * tensor_info.element_size());
+
+ const size_t offset_first_element_in_bytes = tensor_info.offset_first_element_in_bytes();
+
+ // Total size needs to include padded dimensions
+ const size_t total_size_in_bytes = Op * H * W * Ip * tensor_info.element_size();
+
const TensorShape TS(Ip, W, H, Op);
- return TensorInfo(TS, 1 /*num_channels*/, data_type, data_layout);
+
+ TensorInfo new_tensor_info = tensor_info;
+ new_tensor_info.init(TS, 1 /*num_channels, deprecated*/, data_type, strides_in_bytes,
+ offset_first_element_in_bytes, total_size_in_bytes);
+ return new_tensor_info;
}
template <typename ScalarType, typename AccessorType>