aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/PixelWiseMultiplicationFixture.h')
-rw-r--r--tests/validation/fixtures/PixelWiseMultiplicationFixture.h54
1 files changed, 29 insertions, 25 deletions
diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
index 7c643bd726..4345d8a13f 100644
--- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
+++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,7 +44,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationGenericValidationFixture : public framework::Fixture
{
public:
- template <typename...>
void setup(const TensorShape &shape0,
const TensorShape &shape1,
DataType dt_in1,
@@ -76,9 +75,29 @@ protected:
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info)
{
// Create tensors
- TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0);
- TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1);
- TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_out, 1, qinfo_out);
+ const TensorShape out_shape = TensorShape::broadcast_shape(shape0, shape1);
+ TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0);
+ TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1);
+ TensorType dst = create_tensor<TensorType>(out_shape, dt_out, 1, qinfo_out);
+
+ // Check whether do in-place computation and whether inputs are broadcast compatible
+ TensorType *actual_dst = &dst;
+ if(_is_inplace)
+ {
+ bool src1_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape0, 0) && (qinfo0 == qinfo_out) && (dt_in1 == dt_out);
+ bool src2_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape1, 0) && (qinfo1 == qinfo_out) && (dt_in2 == dt_out);
+ bool do_in_place = out_shape.total_size() != 0 && (src1_is_inplace || src2_is_inplace);
+ ARM_COMPUTE_ASSERT(do_in_place);
+
+ if(src1_is_inplace)
+ {
+ actual_dst = &src1;
+ }
+ else
+ {
+ actual_dst = &src2;
+ }
+ }
auto allocate_tensor = [](TensorType & t)
{
@@ -89,11 +108,12 @@ protected:
// Create and configure function
FunctionType multiply;
- multiply.configure(&src1, &src2, (_is_inplace ? &src1 : &dst), scale, convert_policy, rounding_policy, act_info);
+ multiply.configure(&src1, &src2, actual_dst, scale, convert_policy, rounding_policy, act_info);
allocate_tensor(src1);
allocate_tensor(src2);
+ // If don't do in-place computation, still need to allocate original dst
if(!_is_inplace)
{
allocate_tensor(dst);
@@ -106,12 +126,7 @@ protected:
// Compute function
multiply.run();
- if(_is_inplace)
- {
- return src1;
- }
-
- return dst;
+ return std::move(*actual_dst);
}
SimpleTensor<T3> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out,
@@ -122,16 +137,12 @@ protected:
SimpleTensor<T1> src1{ shape0, dt_in1, 1, qinfo0 };
SimpleTensor<T2> src2{ shape1, dt_in2, 1, qinfo1 };
- // current in-place implementation only supports same metadata of input and output tensors.
- // By ignoring output quantization information here, we can make test cases implementation much simpler.
- QuantizationInfo output_qinfo = _is_inplace ? qinfo0 : qinfo_out;
-
// Fill reference
fill(src1, 0);
fill(src2, 1);
- auto result = reference::pixel_wise_multiplication<T1, T2, T3>(src1, src2, scale, convert_policy, rounding_policy, dt_out, output_qinfo);
- return act_info.enabled() ? reference::activation_layer(result, act_info, output_qinfo) : result;
+ auto result = reference::pixel_wise_multiplication<T1, T2, T3>(src1, src2, scale, convert_policy, rounding_policy, dt_out, qinfo_out);
+ return act_info.enabled() ? reference::activation_layer(result, act_info, qinfo_out) : result;
}
TensorType _target{};
@@ -143,7 +154,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
- template <typename...>
void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, bool is_inplace)
{
PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, shape, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy,
@@ -155,7 +165,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationBroadcastValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
- template <typename...>
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
bool is_inplace)
{
@@ -168,7 +177,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationValidationFloatFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
{
public:
- template <typename...>
void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, ActivationLayerInfo act_info, bool is_inplace)
{
PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy,
@@ -180,7 +188,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationValidationIntegerFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
{
public:
- template <typename...>
void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, ActivationLayerInfo act_info, bool is_inplace)
{
PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy,
@@ -192,7 +199,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationBroadcastValidationFloatFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
{
public:
- template <typename...>
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
ActivationLayerInfo act_info, bool is_inplace)
{
@@ -205,7 +211,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
- template <typename...>
void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
{
@@ -218,7 +223,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class PixelWiseMultiplicationBroadcastValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
- template <typename...>
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
{