From 6259e5f9204abf31b811b1d002f68ce6504197bd Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 17 Jan 2018 17:29:33 +0000 Subject: COMPMID-787: Add CL support for broadcast multiply Change-Id: I71f67789648ef05ccdedce77c7427bc0127b3a69 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116741 Tested-by: Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier --- .../fixtures/PixelWiseMultiplicationFixture.h | 49 ++++++++++++++-------- 1 file changed, 32 insertions(+), 17 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h index 7428fb5cb7..b9f19f3e77 100644 --- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h +++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,19 +40,20 @@ namespace test namespace validation { template -class PixelWiseMultiplicationValidationFixture : public framework::Fixture +class PixelWiseMultiplicationBroadcastValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape, - DataType dt_in1, - DataType dt_in2, - float scale, - ConvertPolicy convert_policy, - RoundingPolicy rounding_policy) + void setup(const TensorShape &shape0, + const TensorShape &shape1, + DataType dt_in1, + DataType dt_in2, + float scale, + ConvertPolicy convert_policy, + RoundingPolicy rounding_policy) { - _target = compute_target(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); - _reference = compute_reference(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy); } protected: @@ -62,12 +63,13 @@ protected: library->fill_tensor_uniform(tensor, seed_offset); } - TensorType compute_target(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { // Create tensors - TensorType src1 = create_tensor(shape, dt_in1); - TensorType src2 = create_tensor(shape, dt_in2); - TensorType dst = create_tensor(shape, dt_in2); + TensorType src1 = create_tensor(shape0, dt_in1); + TensorType src2 = create_tensor(shape1, dt_in2); + TensorType dst = create_tensor(TensorShape::broadcast_shape(shape0, shape1), dt_in2); // Create and configure function FunctionType multiply; @@ -96,11 +98,12 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + SimpleTensor compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { // Create reference - SimpleTensor src1{ shape, dt_in1 }; - SimpleTensor src2{ shape, dt_in2 }; + SimpleTensor src1{ shape0, dt_in1 }; + SimpleTensor src2{ shape1, dt_in2 }; // Fill reference fill(src1, 0); @@ -112,6 +115,18 @@ protected: TensorType _target{}; SimpleTensor _reference{}; }; + +template +class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationBroadcastValidationFixture +{ +public: + template + void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + { + PixelWiseMultiplicationBroadcastValidationFixture::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + } +}; + } // namespace validation } // namespace test } // namespace arm_compute -- cgit v1.2.1