aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ConvertFullyConnectedWeights.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/ConvertFullyConnectedWeights.cpp')
-rw-r--r--tests/validation/reference/ConvertFullyConnectedWeights.cpp10
1 files changed, 8 insertions, 2 deletions
diff --git a/tests/validation/reference/ConvertFullyConnectedWeights.cpp b/tests/validation/reference/ConvertFullyConnectedWeights.cpp
index b0f537fa0c..e27846c726 100644
--- a/tests/validation/reference/ConvertFullyConnectedWeights.cpp
+++ b/tests/validation/reference/ConvertFullyConnectedWeights.cpp
@@ -36,9 +36,15 @@ SimpleTensor<T> convert_fully_connected_weights(const SimpleTensor<T> &src, cons
{
SimpleTensor<T> dst(src.shape(), src.data_type());
+ const DataLayout original_input_data_layout = (training_data_layout == DataLayout::NCHW) ? DataLayout::NHWC : DataLayout::NCHW;
+
+ const int width_idx = get_data_layout_dimension_index(original_input_data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(original_input_data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(original_input_data_layout, DataLayoutDimension::CHANNEL);
+
const bool is_nchw_to_nhwc = training_data_layout == DataLayout::NCHW;
- const unsigned int num_elems_per_input_plane = original_input_shape.x() * original_input_shape.y();
- const unsigned int num_channels = original_input_shape.z();
+ const unsigned int num_elems_per_input_plane = original_input_shape[width_idx] * original_input_shape[height_idx];
+ const unsigned int num_channels = original_input_shape[channel_idx];
const unsigned int factor_1 = is_nchw_to_nhwc ? num_elems_per_input_plane : num_channels;
const unsigned int factor_2 = is_nchw_to_nhwc ? num_channels : num_elems_per_input_plane;