diff options
Diffstat (limited to 'tests/validation/NEON/PixelWiseMultiplication.cpp')
-rw-r--r-- | tests/validation/NEON/PixelWiseMultiplication.cpp | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tests/validation/NEON/PixelWiseMultiplication.cpp b/tests/validation/NEON/PixelWiseMultiplication.cpp index a66f6f192f..1bb0588919 100644 --- a/tests/validation/NEON/PixelWiseMultiplication.cpp +++ b/tests/validation/NEON/PixelWiseMultiplication.cpp @@ -118,7 +118,7 @@ template <typename T> using NEPixelWiseMultiplicationToF32Fixture = PixelWiseMultiplicationValidationFixture<Tensor, Accessor, NEPixelWiseMultiplication, T, float>; using NEPixelWiseMultiplicationU8U8ToS16Fixture = PixelWiseMultiplicationValidationFixture<Tensor, Accessor, NEPixelWiseMultiplication, uint8_t, uint8_t, int16_t>; template <typename T> -using NEPixelWiseMultiplicationBroadcastFixture = PixelWiseMultiplicationBroadcastValidationFixture<Tensor, Accessor, NEPixelWiseMultiplication, T, float>; +using NEPixelWiseMultiplicationBroadcastFixture = PixelWiseMultiplicationBroadcastValidationFixture<Tensor, Accessor, NEPixelWiseMultiplication, T, T>; using NEPixelWiseMultiplicationBroadcastQASYMM8Fixture = PixelWiseMultiplicationBroadcastValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, uint8_t, uint8_t>; using NEPixelWiseMultiplicationBroadcastQASYMM8SignedFixture = PixelWiseMultiplicationBroadcastValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int8_t, int8_t>; @@ -493,6 +493,11 @@ TEST_SUITE(ScaleOther) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS32Fixture<int32_t>, ALL, SmallShapes(), S32, S32, S32, scale_other, TO_ZERO, InPlaceDataSet, WRAP_VALIDATE(int32_t, 1)) TEST_SUITE_END() // ScaleOther +TEST_SUITE(Broadcast) +PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, BroadcastFixture<int32_t>, ALL, SmallShapesBroadcast(), S32, S32, S32, scale_unity, TO_ZERO, framework::dataset::make("InPlace", { false }), + WRAP_VALIDATE(int32_t, 1)) +TEST_SUITE_END() // Broadcast + TEST_SUITE_END() // S32toS32 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |