From d368df381a63feaaa13d94cab5dae47846b67489 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 4 Jul 2017 11:06:15 +0100 Subject: COMPMID-417: Auto initialize for SoftmaxLayer NEON/CL. Change-Id: I6f35ac7a15fecab93deec4c6266e5c9632e599f0 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79628 Reviewed-by: Moritz Pflanzer Tested-by: Kaizen Reviewed-by: Gian Marco Iodice --- src/core/NEON/kernels/NESoftmaxLayerKernel.cpp | 38 ++++++++++++++++++++------ 1 file changed, 29 insertions(+), 9 deletions(-) (limited to 'src/core/NEON/kernels/NESoftmaxLayerKernel.cpp') diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp index 942662e84b..854fd84845 100644 --- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp +++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp @@ -120,9 +120,19 @@ BorderSize NELogits1DMaxKernel::border_size() const void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32, DataType::QS8); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32); + ARM_COMPUTE_ERROR_ON_NULLPTR(output); + + // Softmax across the x dimension + TensorShape output_shape{ input->info()->tensor_shape() }; + output_shape.set(0, 1); + + // Output auto initialization if not yet initialized + auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position()); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); const int input_width = input->info()->valid_region().shape.x(); unsigned int num_elems_processed_per_iteration = 0; @@ -302,11 +312,16 @@ NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel() void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(max, 1, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, max, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, max, output); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32); + ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output); + + // Output auto initialization if not yet initialized + auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position()); + auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position()); + + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum); + ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum); unsigned int num_elems_processed_per_iteration = input->info()->valid_region().shape.x(); @@ -426,8 +441,13 @@ NELogits1DNormKernel::NELogits1DNormKernel() void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, sum); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output, sum); + ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output); + + // Output auto initialization if not yet initialized + auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position()); + + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output); _input = input; -- cgit v1.2.1