diff options
Diffstat (limited to 'tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h')
-rw-r--r-- | tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h | 41 |
1 files changed, 20 insertions, 21 deletions
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h index faed610874..b0680c0e4a 100644 --- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h +++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h @@ -31,12 +31,9 @@ #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h" #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h" -#include "tests/CL/CLAccessor.h" #include "tests/framework/Fixture.h" #include "tests/framework/Macros.h" -#include "tests/validation/Validation.h" #include "tests/validation/reference/ElementwiseOperations.h" -#include "tests/validation/reference/Permute.h" using namespace arm_compute::experimental::dynamic_fusion; @@ -51,12 +48,13 @@ class DynamicFusionGpuElementwiseBinaryValidationGenericFixture : public framewo { public: template <typename...> - void setup(ArithmeticOperation op, TensorShape shape0, TensorShape shape1, TensorShape shape2, const DataType data_type, const bool is_inplace) + void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops = false) { - _op = op; + _ref_op = ref_op; _is_inplace = is_inplace; _data_type = data_type; - _fuse = shape2.total_size() != 0; + _fuse = fuse_two_ops; + ARM_COMPUTE_ERROR_ON_MSG(_fuse && shape2.total_size() == 0, "No shape2 provided for fusion of two ops."); ARM_COMPUTE_ERROR_ON_MSG(_fuse && _is_inplace, "In place for fusing case not supported yet."); _target = compute_target(shape0, shape1, shape2); _reference = compute_reference(shape0, shape1, shape2); @@ -68,7 +66,7 @@ protected: { if(is_data_type_float(tensor.data_type())) { - switch(_op) + switch(_ref_op) { case ArithmeticOperation::DIV: library->fill_tensor_uniform_ranged(tensor, i, { std::pair<float, float>(-0.001f, 0.001f) }); @@ -82,7 +80,7 @@ protected: } else if(tensor.data_type() == DataType::S32) { - switch(_op) + switch(_ref_op) { case ArithmeticOperation::DIV: library->fill_tensor_uniform_ranged(tensor, i, { std::pair<int32_t, int32_t>(-1U, 1U) }); @@ -97,7 +95,7 @@ protected: } } - TensorType compute_target(TensorShape shape0, TensorShape shape1, TensorShape shape2) + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2) { // Create a new workload sketch auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); @@ -105,7 +103,7 @@ protected: GpuWorkloadSketch sketch{ &gpu_ctx }; // Fuse first element wise binary Op - TensorInfo lhs_info = sketch.create_tensor_info(shape0, 1, _data_type); + TensorInfo lhs_info = sketch.create_tensor_info(TensorInfo(shape0, 1, _data_type)); TensorInfo rhs_info = sketch.create_tensor_info(TensorInfo(shape1, 1, _data_type)); TensorInfo dst_info = sketch.create_tensor_info(); @@ -115,7 +113,7 @@ protected: if(_fuse) { - rhs_info_fuse = sketch.create_tensor_info(shape2, 1, _data_type); + rhs_info_fuse = sketch.create_tensor_info(TensorInfo(shape2, 1, _data_type)); ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, &rhs_info_fuse); GpuOutput::create_op(sketch, ans2_info, &dst_info); } @@ -183,7 +181,7 @@ protected: return t_dst; } - SimpleTensor<T> compute_reference(TensorShape shape0, TensorShape shape1, TensorShape shape2) + SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2) { const TensorShape out_shape = TensorShape::broadcast_shape(shape0, shape1); const TensorShape out_shape_fuse = TensorShape::broadcast_shape(out_shape, shape1); @@ -194,21 +192,22 @@ protected: SimpleTensor<T> ref_rhs_fuse{ shape2, _data_type, 1, QuantizationInfo() }; SimpleTensor<T> ref_dst{ out_shape, _data_type, 1, QuantizationInfo() }; SimpleTensor<T> ref_dst_fuse{ out_shape_fuse, _data_type, 1, QuantizationInfo() }; + // Fill reference fill(ref_lhs, 0); fill(ref_rhs, 1); - reference::arithmetic_operation<T>(_op, ref_lhs, ref_rhs, ref_dst, ConvertPolicy::WRAP); + reference::arithmetic_operation<T>(_ref_op, ref_lhs, ref_rhs, ref_dst, ConvertPolicy::WRAP); if(_fuse) { fill(ref_rhs_fuse, 2); - reference::arithmetic_operation<T>(_op, ref_dst, ref_rhs_fuse, ref_dst_fuse, ConvertPolicy::WRAP); + reference::arithmetic_operation<T>(_ref_op, ref_dst, ref_rhs_fuse, ref_dst_fuse, ConvertPolicy::WRAP); } SimpleTensor<T> *ret = _fuse ? &ref_dst_fuse : &ref_dst; return *ret; } - ArithmeticOperation _op{ ArithmeticOperation::ADD }; + ArithmeticOperation _ref_op{ ArithmeticOperation::ADD }; TensorType _target{}; SimpleTensor<T> _reference{}; DataType _data_type{}; @@ -222,9 +221,9 @@ class DynamicFusionGpuElementwiseBinaryOneOpValidationFixture : public DynamicFu { public: template <typename...> - void setup(ArithmeticOperation op, TensorShape shape, const DataType data_type, const bool is_inplace) + void setup(ArithmeticOperation ref_op, const TensorShape &shape0, DataType data_type, bool is_inplace) { - DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(op, shape, shape, TensorShape(), data_type, is_inplace); + DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape0, TensorShape(), data_type, is_inplace); } }; @@ -233,9 +232,9 @@ class DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture : public { public: template <typename...> - void setup(ArithmeticOperation op, TensorShape shape0, TensorShape shape1, const DataType data_type, const bool is_inplace) + void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, bool is_inplace) { - DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(op, shape0, shape1, TensorShape(), data_type, is_inplace); + DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape1, TensorShape(), data_type, is_inplace); } }; @@ -244,9 +243,9 @@ class DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture : public DynamicF { public: template <typename...> - void setup(ArithmeticOperation op, TensorShape shape0, TensorShape shape1, TensorShape shape2, const DataType data_type, const bool is_inplace) + void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops) { - DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(op, shape0, shape1, shape2, data_type, is_inplace); + DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape1, shape2, data_type, is_inplace, fuse_two_ops); } }; |