aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/backends/ValidateHelpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/backends/ValidateHelpers.h')
-rw-r--r--arm_compute/graph/backends/ValidateHelpers.h18
1 files changed, 12 insertions, 6 deletions
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h
index ca01295d15..68a718ab97 100644
--- a/arm_compute/graph/backends/ValidateHelpers.h
+++ b/arm_compute/graph/backends/ValidateHelpers.h
@@ -70,12 +70,18 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
// Extract IO and info
- arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
- arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
- arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
- arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
- const PadStrideInfo conv_info = node.convolution_info();
- const ConvolutionMethod conv_algorithm = node.convolution_method();
+ arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
+ arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
+ arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
+ arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
+
+ if(is_data_type_quantized_asymmetric(input->data_type()))
+ {
+ biases->set_data_type(DataType::S32);
+ }
+
+ const PadStrideInfo conv_info = node.convolution_info();
+ const ConvolutionMethod conv_algorithm = node.convolution_method();
// Validate function
Status status{};