aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ElementwiseUnaryFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ElementwiseUnaryFixture.h')
-rw-r--r--tests/validation/fixtures/ElementwiseUnaryFixture.h44
1 files changed, 31 insertions, 13 deletions
diff --git a/tests/validation/fixtures/ElementwiseUnaryFixture.h b/tests/validation/fixtures/ElementwiseUnaryFixture.h
index 7221226fd1..1dc4f03e99 100644
--- a/tests/validation/fixtures/ElementwiseUnaryFixture.h
+++ b/tests/validation/fixtures/ElementwiseUnaryFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,11 +44,12 @@ class ElementWiseUnaryValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, bool use_dynamic_shape = false)
+ void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op,
+ bool use_dynamic_shape = false, QuantizationInfo qinfo = QuantizationInfo(), QuantizationInfo qinfo_out = QuantizationInfo())
{
_op = op;
- _target = compute_target(input_shape, input_data_type, in_place);
- _reference = compute_reference(input_shape, input_data_type);
+ _target = compute_target(input_shape, input_data_type, in_place, qinfo, qinfo_out);
+ _reference = compute_reference(input_shape, input_data_type, qinfo, qinfo_out);
_use_dynamic_shape = use_dynamic_shape;
}
@@ -69,8 +70,15 @@ protected:
}
case ElementWiseUnary::RSQRT:
{
- FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
- library->fill(tensor, distribution, i);
+ if(data_type == DataType::F32 || data_type == DataType::F16)
+ {
+ FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
+ library->fill(tensor, distribution, i);
+ }
+ else
+ {
+ library->fill_tensor_uniform(tensor, i);
+ }
break;
}
case ElementWiseUnary::ABS:
@@ -124,12 +132,11 @@ protected:
}
}
- TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place)
+ TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type);
- TensorType dst = create_tensor<TensorType>(shape, data_type);
-
+ TensorType src = create_tensor<TensorType>(shape, data_type, 1, qinfo);
+ TensorType dst = create_tensor<TensorType>(shape, data_type, 1, qinfo_out);
TensorType *actual_dst = in_place ? &src : &dst;
// if _use_dynamic_shape is true, this fixture will test scenario for dynamic shapes.
@@ -176,15 +183,16 @@ protected:
}
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type };
+ SimpleTensor<T> src{ shape, data_type, 1, qinfo };
+ SimpleTensor<T> dst{ shape, data_type, 1, qinfo_out };
// Fill reference
fill(src, 0, data_type);
- return reference::elementwise_unary<T>(src, _op);
+ return reference::elementwise_unary<T>(src, dst, _op);
}
TensorType _target{};
@@ -192,6 +200,16 @@ protected:
ElementWiseUnary _op{};
bool _use_dynamic_shape{ false };
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class RsqrtQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT, false, qinfo, qinfo_out);
+ }
+};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class RsqrtValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>