aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-01-08 11:33:44 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-01-14 10:55:58 +0000
commit9c700378f2227cb9d51455ed4a5086daaac5532a (patch)
tree53eb4acf5a9226e941e332b93db8ef260fb2d42b
parentab709a0ea0f7f5c8e02c315afffc300e09c783a8 (diff)
downloadComputeLibrary-9c700378f2227cb9d51455ed4a5086daaac5532a.tar.gz
COMPMID-2769: Add support for QASYMM8_SIGNED in NEFullyConnectedLayer
Change-Id: I4c35c522375ae5a5de78716e079ebb9ffad15956 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/2581 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h12
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h36
-rw-r--r--src/runtime/NEON/functions/NEFullyConnectedLayer.cpp21
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp93
-rw-r--r--tests/validation/NEON/FullyConnectedLayer.cpp14
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h11
-rw-r--r--tests/validation/reference/FullyConnectedLayer.cpp11
7 files changed, 178 insertions, 20 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
index 784637a796..78f12daf9c 100644
--- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -126,12 +126,12 @@ public:
NEFullyConnectedLayer &operator=(NEFullyConnectedLayer &&) = default;
/** Set the input and output tensors.
*
- * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32.
+ * @param[in] input Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor. The weights must be 2 dimensional.
* If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
* If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
* Data type supported: Same as @p input.
- * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8.
+ * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
* @param[out] output Destination tensor. Its shape should be equal to the output of a matrix multiplication between:
* - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
* - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
@@ -142,12 +142,12 @@ public:
FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
/** Static function to check if given info will lead to a valid configuration of @ref NEFullyConnectedLayer
*
- * @param[in] input Source tensor info. Data type supported: QASYMM8/F16/F32.
+ * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. The weights must be 2 dimensional.
* If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
* If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
* Data type supported: Same as @p input.
- * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8.
+ * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
* @param[in] output Destination tensor info. Its shape should be equal to the output of a matrix multiplication between:
* - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
* - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
@@ -177,7 +177,7 @@ private:
weights_transformations::NEFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function;
NEGEMM _mm_gemm;
NEGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
- NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
+ NEGEMMLowpOutputStage _gemmlowp_output_stage;
NEGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel;
Tensor _flatten_output;
Tensor _gemmlowp_output;
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h
index b483d03c85..ca2cbbc268 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -264,5 +264,39 @@ public:
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min = 0, int max = 0);
};
+
+/** Basic function to execute GEMMLowpQuantizeDown kernels on NEON.
+ *
+ * This function calls the following NEON kernels:
+ *
+ * -# @ref NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel
+ * -# @ref NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel
+ * -# @ref NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel
+ * -# @ref NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel
+*/
+class NEGEMMLowpOutputStage : public INESimpleFunctionNoBorder
+{
+public:
+ /** Initialise the kernel's inputs, output
+ *
+ * @param[in] input Input tensor. Data type supported: S32
+ * @param[in] bias Biases tensor. Only shared biases supported and it can be a nullptr if the biases addition is not required.
+ * Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p input.
+ * @param[out] output Output tensor. Data type supported: Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM16
+ * @param[in] info GEMMLowp output stage metadata.
+ */
+ void configure(const ITensor *input, const ITensor *bias, ITensor *output, const GEMMLowpOutputStageInfo &info);
+ /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMLowpOutputStage
+ *
+ * @param[in] input Input tensor info. It is the output of @ref NEGEMMLowpMatrixMultiplyCore function. Data type supported: S32
+ * @param[in] bias Biases tensor info. Only shared biases supported and it can be a nullptr if the addition of biases is not required.
+ * Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p input.
+ * @param[in] output Output tensor info. Data type supported: Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM16
+ * @param[in] info GEMMLowp output stage metadata.
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info);
+};
} // namespace arm_compute
#endif /*ARM_COMPUTE_NEGEMMLOWPOUTPUTSTAGE_H */
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 4c264e4832..92ccd5d1cc 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -141,9 +141,8 @@ void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor
void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output,
FullyConnectedLayerInfo fc_info)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
-
// Perform validate step
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_ERROR_THROW_ON(NEFullyConnectedLayer::validate(input->info(),
weights->info(),
biases != nullptr ? biases->info() : nullptr,
@@ -260,7 +259,13 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
int32_t output_multiplier;
int32_t output_shift;
quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
- _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, output_multiplier, output_shift, oq_info.offset);
+
+ GEMMLowpOutputStageInfo gemmlowp_output_stage_info;
+ gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
+ gemmlowp_output_stage_info.gemmlowp_shift = output_shift;
+ gemmlowp_output_stage_info.gemmlowp_offset = oq_info.offset;
+ gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+ _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, gemmlowp_output_stage_info);
_gemmlowp_output.allocator()->allocate();
}
@@ -272,7 +277,7 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
{
ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
@@ -361,7 +366,13 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
int32_t output_multiplier;
int32_t output_shift;
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(&gemmlowp_output, biases, output));
+
+ GEMMLowpOutputStageInfo gemmlowp_output_stage_info;
+ gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
+ gemmlowp_output_stage_info.gemmlowp_shift = output_shift;
+ gemmlowp_output_stage_info.gemmlowp_offset = oq_info.offset;
+ gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&gemmlowp_output, biases, output, gemmlowp_output_stage_info));
}
return Status{};
diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
index 3ef9351b78..465dddaac2 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,6 +28,7 @@
#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h"
+#include "arm_compute/core/Validate.h"
#include "support/ToolchainSupport.h"
namespace arm_compute
@@ -81,4 +82,94 @@ Status NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(const ITens
{
return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, min, max);
}
+
+void NEGEMMLowpOutputStage::configure(const ITensor *input, const ITensor *bias, ITensor *output, const GEMMLowpOutputStageInfo &info)
+{
+ // Perform validate step
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpOutputStage::validate(input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), info));
+
+ if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN)
+ {
+ switch(output->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>();
+ k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ _kernel = std::move(k);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Unsupported output data type.");
+ }
+ }
+ else if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ {
+ switch(output->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
+ k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ _kernel = std::move(k);
+ break;
+ }
+ case DataType::QASYMM8_SIGNED:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>();
+ k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ _kernel = std::move(k);
+ break;
+ }
+ case DataType::QSYMM16:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>();
+ k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ _kernel = std::move(k);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Unsupported output data type.");
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Unsupported output stage quantization type.");
+ }
+}
+
+Status NEGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::UNKNOWN, "NEGEMMLowpQuantizeDownScaleByFixedPoint cannot be used with UNKNOWN output data type.");
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16);
+
+ ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) && (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT));
+
+ if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN)
+ {
+ switch(output->data_type())
+ {
+ case DataType::QASYMM8:
+ return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ default:
+ return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+ }
+ }
+ else
+ {
+ switch(output->data_type())
+ {
+ case DataType::QASYMM8:
+ return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ case DataType::QASYMM8_SIGNED:
+ return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ case DataType::QSYMM16:
+ return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ default:
+ return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+ }
+ }
+}
} // namespace arm_compute
diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp
index a7b837fedf..fae116aa9f 100644
--- a/tests/validation/NEON/FullyConnectedLayer.cpp
+++ b/tests/validation/NEON/FullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,7 @@ constexpr float tolerance_num_f16 = 0.07f;
/** Tolerance for quantized asymmetric operations */
constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);
+constexpr AbsoluteTolerance<int8_t> tolerance_qasymm8_signed(1);
/** CNN data types */
const auto CNNDataTypes = framework::dataset::make("DataType",
@@ -235,6 +236,17 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerQuantizedFixture<uint8_t>,
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
TEST_SUITE_END()
+TEST_SUITE(QASYMM8_SIGNED)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(
+ combine(datasets::SmallFullyConnectedLayerDataset(),
+ FullyConnectedParameters),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ QuantizationData))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+TEST_SUITE_END()
TEST_SUITE_END()
TEST_SUITE_END()
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 0449d80de8..ff6ac17744 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -49,7 +49,7 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class FullyConnectedLayerValidationGenericFixture : public framework::Fixture
{
public:
- using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, int32_t, T>::type;
+ using TBias = typename std::conditional < std::is_same<typename std::decay<T>::type, uint8_t>::value || std::is_same<typename std::decay<T>::type, int8_t>::value, int32_t, T >::type;
public:
template <typename...>
@@ -71,11 +71,16 @@ protected:
template <typename U>
void fill(U &&tensor, int i)
{
- if(is_data_type_quantized_asymmetric(_data_type))
+ if(_data_type == DataType::QASYMM8)
{
std::uniform_int_distribution<uint8_t> distribution(0, 30);
library->fill(tensor, distribution, i);
}
+ else if(_data_type == DataType::QASYMM8_SIGNED)
+ {
+ std::uniform_int_distribution<int8_t> distribution(-15, 15);
+ library->fill(tensor, distribution, i);
+ }
else if(_data_type == DataType::S32)
{
std::uniform_int_distribution<int32_t> distribution(-50, 50);
diff --git a/tests/validation/reference/FullyConnectedLayer.cpp b/tests/validation/reference/FullyConnectedLayer.cpp
index 261c6453b9..9aecd6cf14 100644
--- a/tests/validation/reference/FullyConnectedLayer.cpp
+++ b/tests/validation/reference/FullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,7 +58,7 @@ void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &w
}
// Vector matrix multiply for quantized type
-template < typename T, typename TB, typename std::enable_if < std::is_same<T, uint8_t>::value &&std::is_same<TB, int32_t>::value, int >::type = 0 >
+template < typename T, typename TB, typename std::enable_if < (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) &&std::is_same<TB, int32_t>::value, int >::type = 0 >
void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, int offset_src, int offset_dst,
int cols_weights, int rows_weights)
{
@@ -83,6 +83,9 @@ void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &w
const float multiplier = input_scale * weights_scale / output_scale;
arm_compute::quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
+ const int min = std::numeric_limits<T>::lowest();
+ const int max = std::numeric_limits<T>::max();
+
for(int y = 0; y < rows_weights; ++y)
{
// Reset accumulator
@@ -97,7 +100,7 @@ void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &w
acc += bias_ptr[y];
// Quantize down
- acc = quantize_down_scale_by_fixedpoint(acc, output_multiplier, output_shift, output_offset, 0, 255);
+ acc = quantize_down_scale_by_fixedpoint(acc, output_multiplier, output_shift, output_offset, min, max);
// Store the result
dst_ptr[y] = static_cast<T>(acc);
@@ -160,6 +163,8 @@ template SimpleTensor<half> fully_connected_layer(const SimpleTensor<half> &src,
QuantizationInfo out_quant_info);
template SimpleTensor<uint8_t> fully_connected_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &dst_shape,
QuantizationInfo out_quant_info);
+template SimpleTensor<int8_t> fully_connected_layer(const SimpleTensor<int8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &dst_shape,
+ QuantizationInfo out_quant_info);
} // namespace reference
} // namespace validation
} // namespace test