aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-17 12:28:42 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit7d66a8e3f603f2cd363f04a750847e3f9eabdfd4 (patch)
tree0d7e1ad5bf0ecd32cd919074f756d27c351d7638 /src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
parentae54e026c86aec7d6819ee3ef76372c1a3c92467 (diff)
downloadComputeLibrary-7d66a8e3f603f2cd363f04a750847e3f9eabdfd4.tar.gz
COMPMID-1386: Add support for converting weights for CL.
Change-Id: I62e3ead903366baeeb1488f233a9b8b0c388c9de Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140403 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp41
1 files changed, 25 insertions, 16 deletions
diff --git a/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp b/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
index be5e6436b3..198565b1d5 100644
--- a/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
+++ b/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
@@ -37,25 +37,26 @@ void NEConvertFullyConnectedWeightsKernel::configure(const ITensor *input, ITens
DataLayout data_layout)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ // Output tensor auto initialisation if not yet initialized
+ auto_init_if_empty(*output->info(), *input->info()->clone());
+
ARM_COMPUTE_ERROR_THROW_ON(NEConvertFullyConnectedWeightsKernel::validate(input->info(), output->info(), original_input_shape, data_layout));
_input = input;
_output = output;
- const unsigned int num_elems_per_input_plane = original_input_shape.x() * original_input_shape.y();
- const unsigned int num_channels = original_input_shape.z();
+ const DataLayout input_data_layout = (data_layout == DataLayout::NCHW) ? DataLayout::NHWC : DataLayout::NCHW;
- // Set build options
- if(data_layout == DataLayout::NCHW)
- {
- _factor1 = num_elems_per_input_plane;
- _factor2 = num_channels;
- }
- else
- {
- _factor1 = num_channels;
- _factor2 = num_elems_per_input_plane;
- }
+ const int width_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::CHANNEL);
+
+ const unsigned int num_elems_per_input_plane = original_input_shape[width_idx] * original_input_shape[height_idx];
+ const unsigned int num_channels = original_input_shape[channel_idx];
+
+ _factor1 = (data_layout == DataLayout::NCHW) ? num_elems_per_input_plane : num_channels;
+ _factor2 = (data_layout == DataLayout::NCHW) ? num_channels : num_elems_per_input_plane;
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps());
@@ -65,14 +66,22 @@ void NEConvertFullyConnectedWeightsKernel::configure(const ITensor *input, ITens
Status NEConvertFullyConnectedWeightsKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape,
DataLayout data_layout)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::U32, DataType::S32,
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1,
+ DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
+ DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != original_input_shape.total_size_lower(3));
ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+ // Checks performed when output is configured
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+
return Status{};
}