aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-09-18 14:34:48 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitbf28a3cde6f77cbac3e3409d5597085ccbc71499 (patch)
tree724a88e6e4dd05e5ffb5d95407e2a292b7b0eb40 /tests/validation/fixtures
parent13a51e11680aa24a9b841a4afe4079419bc8b80c (diff)
downloadComputeLibrary-bf28a3cde6f77cbac3e3409d5597085ccbc71499.tar.gz
COMPMID-1564: Add QASYMM8 on CLPixelwiseMultiplication
Change-Id: I5f719f5b2915c18cd0ca6271db401152112863a6 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148982 Tested-by: bsgcomp <bsgcomp@arm.com> Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/PixelWiseMultiplicationFixture.h58
1 files changed, 44 insertions, 14 deletions
diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
index b9f19f3e77..9927b75032 100644
--- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
+++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
@@ -40,7 +40,7 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
-class PixelWiseMultiplicationBroadcastValidationFixture : public framework::Fixture
+class PixelWiseMultiplicationGenericValidationFixture : public framework::Fixture
{
public:
template <typename...>
@@ -50,10 +50,13 @@ public:
DataType dt_in2,
float scale,
ConvertPolicy convert_policy,
- RoundingPolicy rounding_policy)
+ RoundingPolicy rounding_policy,
+ QuantizationInfo qinfo0,
+ QuantizationInfo qinfo1,
+ QuantizationInfo qinfo_out)
{
- _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
- _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
+ _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
+ _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
}
protected:
@@ -64,12 +67,13 @@ protected:
}
TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
- float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
// Create tensors
- TensorType src1 = create_tensor<TensorType>(shape0, dt_in1);
- TensorType src2 = create_tensor<TensorType>(shape1, dt_in2);
- TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2);
+ TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0);
+ TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1);
+ TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2, 1, qinfo_out);
// Create and configure function
FunctionType multiply;
@@ -99,17 +103,18 @@ protected:
}
SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
- float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
// Create reference
- SimpleTensor<T1> src1{ shape0, dt_in1 };
- SimpleTensor<T2> src2{ shape1, dt_in2 };
+ SimpleTensor<T1> src1{ shape0, dt_in1, 1, qinfo0 };
+ SimpleTensor<T2> src2{ shape1, dt_in2, 1, qinfo1 };
// Fill reference
fill(src1, 0);
fill(src2, 1);
- return reference::pixel_wise_multiplication<T1, T2>(src1, src2, scale, convert_policy, rounding_policy);
+ return reference::pixel_wise_multiplication<T1, T2>(src1, src2, scale, convert_policy, rounding_policy, qinfo_out);
}
TensorType _target{};
@@ -117,16 +122,41 @@ protected:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
-class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationBroadcastValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
+class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
{
public:
template <typename...>
void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- PixelWiseMultiplicationBroadcastValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
}
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+class PixelWiseMultiplicationBroadcastValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+ {
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType dt, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
+ {
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt, dt, scale, convert_policy, rounding_policy,
+ qinfo0, qinfo1, qinfo_out);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute