aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/FullyConnectedLayer.cpp
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-07-16 17:20:38 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commita855af10a486c53c2271361cb87f349eca64b749 (patch)
treeb326b63bdcaf76c9620b1bbf22942d4683503a65 /tests/validation/NEON/FullyConnectedLayer.cpp
parent5a3ee4f708a9e1642b0211955ff905e7b67e831d (diff)
downloadComputeLibrary-a855af10a486c53c2271361cb87f349eca64b749.tar.gz
COMPMID-1401 Implement NEFullyConnectedLayer for QASYMM8
Change-Id: I0404df6d369855e2f458f2db8f26e81c80a1ee87 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140148 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/NEON/FullyConnectedLayer.cpp')
-rw-r--r--tests/validation/NEON/FullyConnectedLayer.cpp59
1 files changed, 44 insertions, 15 deletions
diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp
index 80fdf1784e..3aeba7a969 100644
--- a/tests/validation/NEON/FullyConnectedLayer.cpp
+++ b/tests/validation/NEON/FullyConnectedLayer.cpp
@@ -48,6 +48,9 @@ constexpr RelativeTolerance<float> tolerance_f32(0.01f);
constexpr RelativeTolerance<float> tolerance_f16(0.01f);
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/
+/** Tolerance for quantized asymmetric operations */
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);
+
/** CNN data types */
const auto CNNDataTypes = framework::dataset::make("DataType",
{
@@ -68,6 +71,9 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame
CNNDataTypes),
src_shape, weights_shape, bias_shape, dst_shape, transpose_weights, reshape_weights, data_type)
{
+ const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
+ const QuantizationInfo quantization_info = is_data_type_quantized_asymmetric(data_type) ? QuantizationInfo(2.f / 255.f, 127) : QuantizationInfo();
+
TensorShape ws(weights_shape);
// Transpose weights if not done in the function
@@ -76,23 +82,13 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame
const size_t shape_x = ws.x();
ws.set(0, ws.y());
ws.set(1, shape_x);
-
- // Weights have to be passed reshaped
- // Transpose 1xW for batched version
- if(!reshape_weights && dst_shape.y() > 1)
- {
- const float transpose_width = 16.0f / data_size_from_type(data_type);
- const size_t shape_x = ws.x();
- ws.set(0, ws.y() * static_cast<unsigned int>(transpose_width));
- ws.set(1, static_cast<unsigned int>(std::ceil(shape_x / transpose_width)));
- }
}
// Create tensors
- Tensor src = create_tensor<Tensor>(src_shape, data_type, 1);
- Tensor weights = create_tensor<Tensor>(ws, data_type, 1);
- Tensor bias = create_tensor<Tensor>(bias_shape, data_type, 1);
- Tensor dst = create_tensor<Tensor>(dst_shape, data_type, 1);
+ Tensor src = create_tensor<Tensor>(src_shape, data_type, 1, quantization_info);
+ Tensor weights = create_tensor<Tensor>(ws, data_type, 1, quantization_info);
+ Tensor bias = create_tensor<Tensor>(bias_shape, bias_data_type, 1, quantization_info);
+ Tensor dst = create_tensor<Tensor>(dst_shape, data_type, 1, quantization_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -104,6 +100,9 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame
fc_info.transpose_weights = transpose_weights;
fc_info.are_weights_reshaped = !reshape_weights;
+ const QuantizationInfo src_quantization_info = src.info()->quantization_info();
+ const QuantizationInfo weights_quantization_info = weights.info()->quantization_info();
+
// Create and configure function.
NEFullyConnectedLayer fc;
fc.configure(&src, &weights, &bias, &dst, fc_info);
@@ -111,6 +110,10 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame
// Validate valid region
const ValidRegion dst_valid_region = shape_to_valid_region(dst_shape);
validate(dst.info()->valid_region(), dst_valid_region);
+
+ // Validate QuantizationInfo
+ ARM_COMPUTE_EXPECT(src.info()->quantization_info() == src_quantization_info, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(weights.info()->quantization_info() == weights_quantization_info, framework::LogLevel::ERRORS);
}
// *INDENT-OFF*
@@ -161,7 +164,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(
// *INDENT-ON*
template <typename T>
-using NEFullyConnectedLayerFixture = FullyConnectedLayerValidationFixture<Tensor, Accessor, NEFullyConnectedLayer, T, true>;
+using NEFullyConnectedLayerFixture = FullyConnectedLayerValidationFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -199,6 +202,32 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture<float>, framework:
TEST_SUITE_END()
TEST_SUITE_END()
+template <typename T>
+using NEFullyConnectedLayerQuantizedFixture = FullyConnectedLayerValidationQuantizedFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QASYMM8)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(
+ combine(datasets::SmallFullyConnectedLayerDataset(),
+ FullyConnectedParameters),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255.f, 10) })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(
+ combine(datasets::LargeFullyConnectedLayerDataset(),
+ FullyConnectedParameters),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 256.f, 10) })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
TEST_SUITE_END()
TEST_SUITE_END()
} // namespace validation