aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp43
1 files changed, 28 insertions, 15 deletions
diff --git a/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp b/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp
index 86858d0c03..69ab590540 100644
--- a/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp
+++ b/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp
@@ -41,27 +41,32 @@ void CLConvertFullyConnectedWeightsKernel::configure(const ICLTensor *input, ICL
DataLayout data_layout)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ // Output tensor auto initialisation if not yet initialized
+ auto_init_if_empty(*output->info(), *input->info()->clone());
+
ARM_COMPUTE_ERROR_THROW_ON(CLConvertFullyConnectedWeightsKernel::validate(input->info(), output->info(), original_input_shape, data_layout));
_input = input;
_output = output;
- const unsigned int num_elems_per_input_plane = original_input_shape.x() * original_input_shape.y();
- const unsigned int num_channels = original_input_shape.z();
+ const DataLayout input_data_layout = (data_layout == DataLayout::NCHW) ? DataLayout::NHWC : DataLayout::NCHW;
+
+ const int width_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::CHANNEL);
+
+ const unsigned int num_elems_per_input_plane = original_input_shape[width_idx] * original_input_shape[height_idx];
+ const unsigned int num_channels = original_input_shape[channel_idx];
+
+ const unsigned int factor_1 = (data_layout == DataLayout::NCHW) ? num_elems_per_input_plane : num_channels;
+ const unsigned int factor_2 = (data_layout == DataLayout::NCHW) ? num_channels : num_elems_per_input_plane;
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
- if(data_layout == DataLayout::NCHW)
- {
- build_opts.add_option("-DFACTOR_1=" + support::cpp11::to_string(num_elems_per_input_plane));
- build_opts.add_option("-DFACTOR_2=" + support::cpp11::to_string(num_channels));
- }
- else
- {
- build_opts.add_option("-DFACTOR_1=" + support::cpp11::to_string(num_channels));
- build_opts.add_option("-DFACTOR_2=" + support::cpp11::to_string(num_elems_per_input_plane));
- }
+ build_opts.add_option("-DFACTOR_1=" + support::cpp11::to_string(factor_1));
+ build_opts.add_option("-DFACTOR_2=" + support::cpp11::to_string(factor_2));
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("convert_fc_weights", build_opts.options()));
@@ -75,14 +80,22 @@ Status CLConvertFullyConnectedWeightsKernel::validate(const ITensorInfo *input,
DataLayout data_layout)
{
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::U32, DataType::S32,
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1,
+ DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
+ DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != original_input_shape.total_size_lower(3));
ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+ // Checks performed when output is configured
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+
return Status{};
}