diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index 5701d60208..f6dc3a8f43 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -62,7 +62,7 @@ Status NEConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, co ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(weights); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, - DataType::F16, DataType::F32); + DataType::BFLOAT16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); if(biases != nullptr) @@ -330,6 +330,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig } // Create temporary GEMM output tensor in case we cannot skip col2im + const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type; if(!_skip_col2im) { TensorShape shape_gemm; @@ -340,7 +341,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig shape_gemm.set(1, conv_w * conv_h); // FIXME: input->clone() doesn't work with subtensors for grouped convolutions. - TensorInfo info_gemm(shape_gemm, 1, data_type); + TensorInfo info_gemm(shape_gemm, 1, output_data_type); info_gemm.set_quantization_info(output->info()->quantization_info()).set_data_layout(input->info()->data_layout()); _gemm_output.allocator()->init(info_gemm); _memory_group.manage(&_gemm_output); @@ -392,8 +393,8 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!"); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights); ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported on NEON"); @@ -497,16 +498,17 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI } // Create temporary GEMM output tensor in case we cannot skip col2im + const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type; if(!skip_col2im) { TensorShape shape_gemm = gemm_input_to_use->tensor_shape(); shape_gemm.set(0, mat_weights_cols); shape_gemm.set(1, conv_w * conv_h); - info_gemm = TensorInfo(shape_gemm, 1, data_type); + info_gemm = TensorInfo(shape_gemm, 1, output_data_type); } else { - info_gemm = TensorInfo(output->tensor_shape(), 1, data_type); + info_gemm = TensorInfo(output->tensor_shape(), 1, output_data_type); } info_gemm.set_quantization_info(output->quantization_info()).set_data_layout(input->data_layout()); gemm_output_to_use = &info_gemm; |