aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h')
-rw-r--r--tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h41
1 files changed, 20 insertions, 21 deletions
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
index faed610874..b0680c0e4a 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
@@ -31,12 +31,9 @@
#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
-#include "tests/CL/CLAccessor.h"
#include "tests/framework/Fixture.h"
#include "tests/framework/Macros.h"
-#include "tests/validation/Validation.h"
#include "tests/validation/reference/ElementwiseOperations.h"
-#include "tests/validation/reference/Permute.h"
using namespace arm_compute::experimental::dynamic_fusion;
@@ -51,12 +48,13 @@ class DynamicFusionGpuElementwiseBinaryValidationGenericFixture : public framewo
{
public:
template <typename...>
- void setup(ArithmeticOperation op, TensorShape shape0, TensorShape shape1, TensorShape shape2, const DataType data_type, const bool is_inplace)
+ void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops = false)
{
- _op = op;
+ _ref_op = ref_op;
_is_inplace = is_inplace;
_data_type = data_type;
- _fuse = shape2.total_size() != 0;
+ _fuse = fuse_two_ops;
+ ARM_COMPUTE_ERROR_ON_MSG(_fuse && shape2.total_size() == 0, "No shape2 provided for fusion of two ops.");
ARM_COMPUTE_ERROR_ON_MSG(_fuse && _is_inplace, "In place for fusing case not supported yet.");
_target = compute_target(shape0, shape1, shape2);
_reference = compute_reference(shape0, shape1, shape2);
@@ -68,7 +66,7 @@ protected:
{
if(is_data_type_float(tensor.data_type()))
{
- switch(_op)
+ switch(_ref_op)
{
case ArithmeticOperation::DIV:
library->fill_tensor_uniform_ranged(tensor, i, { std::pair<float, float>(-0.001f, 0.001f) });
@@ -82,7 +80,7 @@ protected:
}
else if(tensor.data_type() == DataType::S32)
{
- switch(_op)
+ switch(_ref_op)
{
case ArithmeticOperation::DIV:
library->fill_tensor_uniform_ranged(tensor, i, { std::pair<int32_t, int32_t>(-1U, 1U) });
@@ -97,7 +95,7 @@ protected:
}
}
- TensorType compute_target(TensorShape shape0, TensorShape shape1, TensorShape shape2)
+ TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2)
{
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
@@ -105,7 +103,7 @@ protected:
GpuWorkloadSketch sketch{ &gpu_ctx };
// Fuse first element wise binary Op
- TensorInfo lhs_info = sketch.create_tensor_info(shape0, 1, _data_type);
+ TensorInfo lhs_info = sketch.create_tensor_info(TensorInfo(shape0, 1, _data_type));
TensorInfo rhs_info = sketch.create_tensor_info(TensorInfo(shape1, 1, _data_type));
TensorInfo dst_info = sketch.create_tensor_info();
@@ -115,7 +113,7 @@ protected:
if(_fuse)
{
- rhs_info_fuse = sketch.create_tensor_info(shape2, 1, _data_type);
+ rhs_info_fuse = sketch.create_tensor_info(TensorInfo(shape2, 1, _data_type));
ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, &rhs_info_fuse);
GpuOutput::create_op(sketch, ans2_info, &dst_info);
}
@@ -183,7 +181,7 @@ protected:
return t_dst;
}
- SimpleTensor<T> compute_reference(TensorShape shape0, TensorShape shape1, TensorShape shape2)
+ SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2)
{
const TensorShape out_shape = TensorShape::broadcast_shape(shape0, shape1);
const TensorShape out_shape_fuse = TensorShape::broadcast_shape(out_shape, shape1);
@@ -194,21 +192,22 @@ protected:
SimpleTensor<T> ref_rhs_fuse{ shape2, _data_type, 1, QuantizationInfo() };
SimpleTensor<T> ref_dst{ out_shape, _data_type, 1, QuantizationInfo() };
SimpleTensor<T> ref_dst_fuse{ out_shape_fuse, _data_type, 1, QuantizationInfo() };
+
// Fill reference
fill(ref_lhs, 0);
fill(ref_rhs, 1);
- reference::arithmetic_operation<T>(_op, ref_lhs, ref_rhs, ref_dst, ConvertPolicy::WRAP);
+ reference::arithmetic_operation<T>(_ref_op, ref_lhs, ref_rhs, ref_dst, ConvertPolicy::WRAP);
if(_fuse)
{
fill(ref_rhs_fuse, 2);
- reference::arithmetic_operation<T>(_op, ref_dst, ref_rhs_fuse, ref_dst_fuse, ConvertPolicy::WRAP);
+ reference::arithmetic_operation<T>(_ref_op, ref_dst, ref_rhs_fuse, ref_dst_fuse, ConvertPolicy::WRAP);
}
SimpleTensor<T> *ret = _fuse ? &ref_dst_fuse : &ref_dst;
return *ret;
}
- ArithmeticOperation _op{ ArithmeticOperation::ADD };
+ ArithmeticOperation _ref_op{ ArithmeticOperation::ADD };
TensorType _target{};
SimpleTensor<T> _reference{};
DataType _data_type{};
@@ -222,9 +221,9 @@ class DynamicFusionGpuElementwiseBinaryOneOpValidationFixture : public DynamicFu
{
public:
template <typename...>
- void setup(ArithmeticOperation op, TensorShape shape, const DataType data_type, const bool is_inplace)
+ void setup(ArithmeticOperation ref_op, const TensorShape &shape0, DataType data_type, bool is_inplace)
{
- DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(op, shape, shape, TensorShape(), data_type, is_inplace);
+ DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape0, TensorShape(), data_type, is_inplace);
}
};
@@ -233,9 +232,9 @@ class DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture : public
{
public:
template <typename...>
- void setup(ArithmeticOperation op, TensorShape shape0, TensorShape shape1, const DataType data_type, const bool is_inplace)
+ void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, bool is_inplace)
{
- DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(op, shape0, shape1, TensorShape(), data_type, is_inplace);
+ DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape1, TensorShape(), data_type, is_inplace);
}
};
@@ -244,9 +243,9 @@ class DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture : public DynamicF
{
public:
template <typename...>
- void setup(ArithmeticOperation op, TensorShape shape0, TensorShape shape1, TensorShape shape2, const DataType data_type, const bool is_inplace)
+ void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops)
{
- DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(op, shape0, shape1, shape2, data_type, is_inplace);
+ DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape1, shape2, data_type, is_inplace, fuse_two_ops);
}
};