aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ArithmeticOperationsFixture.h
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-06-04 15:05:38 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-06-15 13:59:04 +0000
commit4a61653202afb018f4f259d3c144a735d73f0a20 (patch)
tree082fd42e91cc0914dcacc0746bbe3e117d74210c /tests/validation/fixtures/ArithmeticOperationsFixture.h
parentccd94966cc58ef5148577e71ba1a4ff5aae1f3bb (diff)
downloadComputeLibrary-4a61653202afb018f4f259d3c144a735d73f0a20.tar.gz
COMPMID-3480: Perform in-place computations in NEArithmeticAdditionKernel
Change-Id: I0089657dd95d7c7b8592984def8e8de1d7e6d085 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3308 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ArithmeticOperationsFixture.h')
-rw-r--r--tests/validation/fixtures/ArithmeticOperationsFixture.h50
1 files changed, 33 insertions, 17 deletions
diff --git a/tests/validation/fixtures/ArithmeticOperationsFixture.h b/tests/validation/fixtures/ArithmeticOperationsFixture.h
index 4a6b0bd3f3..1019e60233 100644
--- a/tests/validation/fixtures/ArithmeticOperationsFixture.h
+++ b/tests/validation/fixtures/ArithmeticOperationsFixture.h
@@ -48,8 +48,10 @@ public:
template <typename...>
void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1,
DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
- QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info)
+ QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info,
+ bool in_place)
{
+ _in_place = in_place;
_op = op;
_act_info = act_info;
_target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
@@ -71,9 +73,11 @@ protected:
TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1);
TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out);
+ TensorType *dst_ptr = _in_place ? nullptr : &dst;
+
// Create and configure function
FunctionType arith_op;
- arith_op.configure(&ref_src1, &ref_src2, &dst, convert_policy, _act_info);
+ arith_op.configure(&ref_src1, &ref_src2, dst_ptr, convert_policy, _act_info);
ARM_COMPUTE_EXPECT(ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -82,11 +86,15 @@ protected:
// Allocate tensors
ref_src1.allocator()->allocate();
ref_src2.allocator()->allocate();
- dst.allocator()->allocate();
ARM_COMPUTE_EXPECT(!ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ if(!_in_place)
+ {
+ dst.allocator()->allocate();
+ ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ }
// Fill tensors
fill(AccessorType(ref_src1), 0);
@@ -95,7 +103,14 @@ protected:
// Compute function
arith_op.run();
- return dst;
+ if(_in_place)
+ {
+ return ref_src1;
+ }
+ else
+ {
+ return dst;
+ }
}
SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1,
@@ -119,6 +134,7 @@ protected:
SimpleTensor<T> _reference{};
reference::ArithmeticOperation _op{ reference::ArithmeticOperation::ADD };
ActivationLayerInfo _act_info{};
+ bool _in_place{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -126,10 +142,10 @@ class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationG
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
}
};
@@ -138,10 +154,10 @@ class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFix
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
+ void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
}
};
@@ -153,7 +169,7 @@ public:
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -165,7 +181,7 @@ public:
void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -179,7 +195,7 @@ public:
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo());
+ output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
}
};
@@ -192,7 +208,7 @@ public:
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
}
};
@@ -205,7 +221,7 @@ public:
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -218,7 +234,7 @@ public:
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
}
};
@@ -231,7 +247,7 @@ public:
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -246,7 +262,7 @@ public:
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
data_type0, data_type1, output_data_type,
- convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo());
+ convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
}
};
} // namespace validation