diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-07-17 12:28:42 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | 7d66a8e3f603f2cd363f04a750847e3f9eabdfd4 (patch) | |
tree | 0d7e1ad5bf0ecd32cd919074f756d27c351d7638 /tests/validation/reference/ConvertFullyConnectedWeights.cpp | |
parent | ae54e026c86aec7d6819ee3ef76372c1a3c92467 (diff) | |
download | ComputeLibrary-7d66a8e3f603f2cd363f04a750847e3f9eabdfd4.tar.gz |
COMPMID-1386: Add support for converting weights for CL.
Change-Id: I62e3ead903366baeeb1488f233a9b8b0c388c9de
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140403
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/reference/ConvertFullyConnectedWeights.cpp')
-rw-r--r-- | tests/validation/reference/ConvertFullyConnectedWeights.cpp | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/tests/validation/reference/ConvertFullyConnectedWeights.cpp b/tests/validation/reference/ConvertFullyConnectedWeights.cpp index b0f537fa0c..e27846c726 100644 --- a/tests/validation/reference/ConvertFullyConnectedWeights.cpp +++ b/tests/validation/reference/ConvertFullyConnectedWeights.cpp @@ -36,9 +36,15 @@ SimpleTensor<T> convert_fully_connected_weights(const SimpleTensor<T> &src, cons { SimpleTensor<T> dst(src.shape(), src.data_type()); + const DataLayout original_input_data_layout = (training_data_layout == DataLayout::NCHW) ? DataLayout::NHWC : DataLayout::NCHW; + + const int width_idx = get_data_layout_dimension_index(original_input_data_layout, DataLayoutDimension::WIDTH); + const int height_idx = get_data_layout_dimension_index(original_input_data_layout, DataLayoutDimension::HEIGHT); + const int channel_idx = get_data_layout_dimension_index(original_input_data_layout, DataLayoutDimension::CHANNEL); + const bool is_nchw_to_nhwc = training_data_layout == DataLayout::NCHW; - const unsigned int num_elems_per_input_plane = original_input_shape.x() * original_input_shape.y(); - const unsigned int num_channels = original_input_shape.z(); + const unsigned int num_elems_per_input_plane = original_input_shape[width_idx] * original_input_shape[height_idx]; + const unsigned int num_channels = original_input_shape[channel_idx]; const unsigned int factor_1 = is_nchw_to_nhwc ? num_elems_per_input_plane : num_channels; const unsigned int factor_2 = is_nchw_to_nhwc ? num_channels : num_elems_per_input_plane; |