aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/PixelWiseMultiplication.cpp76
-rw-r--r--tests/validation/NEON/PixelWiseMultiplication.cpp147
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h22
-rw-r--r--tests/validation/fixtures/PixelWiseMultiplicationFixture.h46
-rw-r--r--tests/validation/reference/PixelWiseMultiplication.cpp57
-rw-r--r--tests/validation/reference/PixelWiseMultiplication.h9
-rw-r--r--tests/validation/reference/QLSTMLayerNormalization.cpp2
7 files changed, 174 insertions, 185 deletions
diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp
index 3b55e25f37..ff9101a997 100644
--- a/tests/validation/CL/PixelWiseMultiplication.cpp
+++ b/tests/validation/CL/PixelWiseMultiplication.cpp
@@ -127,14 +127,17 @@ using CLPixelWiseMultiplicationQuantizedFixture = PixelWiseMultiplicationValidat
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
@@ -142,14 +145,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8
TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
@@ -157,26 +163,32 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE(QSYMM16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
- framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
- framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qsymm16);
diff --git a/tests/validation/NEON/PixelWiseMultiplication.cpp b/tests/validation/NEON/PixelWiseMultiplication.cpp
index fd54e42083..6a75b00b9b 100644
--- a/tests/validation/NEON/PixelWiseMultiplication.cpp
+++ b/tests/validation/NEON/PixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -70,20 +70,6 @@ const auto PixelWiseMultiplicationPolicySTZDataset = combine(
// *INDENT-OFF*
// clang-format off
-#define PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(DT1, DT2, SCALE, RP) \
- DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, \
- combine(combine(combine(combine(combine( \
- concat(datasets::SmallShapes(), datasets::LargeShapes()), \
- framework::dataset::make("DataType1", DataType::DT1)), \
- framework::dataset::make("DataType2", DataType::DT2)), \
- framework::dataset::make("Scale", std::move(SCALE))), \
- datasets::ConvertPolicies()), \
- framework::dataset::make("RoundingPolicy", RoundingPolicy::RP)), \
- shape, dt1, dt2, scale, convert_policy, rounding_policy) \
- { \
- validate_configuration(shape, dt1, dt2, scale, convert_policy, rounding_policy); \
- }
-
#define PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, SHAPES, DT1, DT2, SCALE, RP, VALIDATE) \
FIXTURE_DATA_TEST_CASE(TEST_NAME, NEPixelWiseMultiplication##FIXTURE, framework::DatasetMode::MODE, \
combine(combine(combine(combine(combine( \
@@ -99,38 +85,12 @@ const auto PixelWiseMultiplicationPolicySTZDataset = combine(
// *INDENT-ON*
// clang-format on
-
-void validate_configuration(TensorShape shape, DataType dt1, DataType dt2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
-{
- Tensor src1 = create_tensor<Tensor>(shape, dt1);
- Tensor src2 = create_tensor<Tensor>(shape, dt2);
- Tensor dst = create_tensor<Tensor>(shape, dt2);
-
- ARM_COMPUTE_EXPECT(src1.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(src2.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
-
- // Create and configure function
- NEPixelWiseMultiplication multiply;
- multiply.configure(&src1, &src2, &dst, scale, convert_policy, rounding_policy);
-
- // Validate valid region
- const ValidRegion valid_region = shape_to_valid_region(shape);
- validate(src1.info()->valid_region(), valid_region);
- validate(src2.info()->valid_region(), valid_region);
- validate(dst.info()->valid_region(), valid_region);
-
- // Validate padding
- const PaddingSize padding = PaddingCalculator(shape.x(), 16).required_padding();
- validate(src1.info()->padding(), padding);
- validate(src2.info()->padding(), padding);
- validate(dst.info()->padding(), padding);
-}
} // namespace
using NEPixelWiseMultiplicationQASYMM8Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, uint8_t, uint8_t>;
using NEPixelWiseMultiplicationQASYMM8SignedFixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int8_t, int8_t>;
using NEPixelWiseMultiplicationQSYMM16Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int16_t, int16_t>;
+using NEPixelWiseMultiplicationQSYMM16ToS32Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int16_t, int16_t, int32_t>;
template <typename T>
using NEPixelWiseMultiplicationToU8Fixture = PixelWiseMultiplicationValidationFixture<Tensor, Accessor, NEPixelWiseMultiplication, T, uint8_t>;
template <typename T>
@@ -231,8 +191,10 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8_SIGNED)
TEST_SUITE(Scale255)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8SignedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8SignedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8_SIGNED)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -245,8 +207,10 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8)
TEST_SUITE(Scale255)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -254,8 +218,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -265,8 +231,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framew
}
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -274,8 +242,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -285,8 +255,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framew
}
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -294,8 +266,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -307,8 +281,10 @@ TEST_SUITE_END() // ScaleOther
TEST_SUITE_END() // QASYMM8
TEST_SUITE(QSYMM16)
TEST_SUITE(Scale255)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -316,8 +292,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -327,8 +305,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew
}
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -336,8 +316,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -347,8 +329,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew
}
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -356,8 +340,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -367,24 +353,34 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew
}
TEST_SUITE_END() // ScaleOther
TEST_SUITE_END() // QSYMM16
+TEST_SUITE(QSYMM16toS32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16ToS32Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::S32)),
+ framework::dataset::make("Scale", { scale_unity })),
+ PixelWiseMultiplicationPolicySTZDataset),
+ PixelWiseMultiplicationQSYMM16QuantDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // QSYMM16toS32
TEST_SUITE_END() // Quantized
TEST_SUITE(U8toU8)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(uint8_t, 1))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(uint8_t, 1))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
@@ -394,19 +390,16 @@ TEST_SUITE_END() // U8toU8
TEST_SUITE(U8toS16)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
@@ -416,19 +409,16 @@ TEST_SUITE_END() // U8toS16
TEST_SUITE(S16toS16)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
@@ -448,19 +438,16 @@ TEST_SUITE_END() // F16toF16
TEST_SUITE(F32toF32)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index 9260686d56..858ee07d3e 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -407,7 +407,7 @@ protected:
if(peephole_opt)
{
- SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
+ SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, data_type);
forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
}
@@ -416,7 +416,7 @@ protected:
SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type };
fill(forget_layer_norm_w, 23);
forget_gate = reference::mean_std_normalization_layer(forget_gate);
- forget_gate = reference::pixel_wise_multiplication(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ forget_gate = reference::pixel_wise_multiplication<T, T, T>(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(forget_gate_bias, 7);
forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE);
}
@@ -438,7 +438,7 @@ protected:
input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
}
if(use_layer_norm)
@@ -446,7 +446,7 @@ protected:
SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type };
fill(input_layer_norm_w, 22);
input_gate = reference::mean_std_normalization_layer(input_gate);
- input_gate = reference::pixel_wise_multiplication(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ input_gate = reference::pixel_wise_multiplication<T, T, T>(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(input_gate_bias, 17);
input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE);
}
@@ -457,19 +457,19 @@ protected:
SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_cell_w);
gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
- SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
if(use_layer_norm)
{
SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type };
fill(cell_layer_norm_w, 24);
cell_state_out = reference::mean_std_normalization_layer(cell_state_out);
- cell_state_out = reference::pixel_wise_multiplication(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(cell_bias, 8);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
}
cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
if(cell_threshold != 0.f)
{
@@ -483,7 +483,7 @@ protected:
output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
}
if(use_layer_norm)
@@ -491,7 +491,7 @@ protected:
SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type };
fill(output_layer_norm_w, 25);
output = reference::mean_std_normalization_layer(output);
- output = reference::pixel_wise_multiplication(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ output = reference::pixel_wise_multiplication<T, T, T>(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(output_gate_bias, 9);
output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE);
}
@@ -499,7 +499,7 @@ protected:
// Compute output state
SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info);
- output_state_out = reference::pixel_wise_multiplication(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ output_state_out = reference::pixel_wise_multiplication<T, T, T>(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
if(projection_opt)
{
diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
index efdf5d078e..37359f421b 100644
--- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
+++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,7 +39,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2>
class PixelWiseMultiplicationGenericValidationFixture : public framework::Fixture
{
public:
@@ -48,6 +48,7 @@ public:
const TensorShape &shape1,
DataType dt_in1,
DataType dt_in2,
+ DataType dt_out,
float scale,
ConvertPolicy convert_policy,
RoundingPolicy rounding_policy,
@@ -55,8 +56,8 @@ public:
QuantizationInfo qinfo1,
QuantizationInfo qinfo_out)
{
- _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
- _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
+ _target = compute_target(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
+ _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
}
protected:
@@ -66,14 +67,14 @@ protected:
library->fill_tensor_uniform(tensor, seed_offset);
}
- TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
+ TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
// Create tensors
TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0);
TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1);
- TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2, 1, qinfo_out);
+ TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_out, 1, qinfo_out);
// Create and configure function
FunctionType multiply;
@@ -102,7 +103,7 @@ protected:
return dst;
}
- SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
+ SimpleTensor<T3> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
@@ -114,24 +115,11 @@ protected:
fill(src1, 0);
fill(src2, 1);
- return reference::pixel_wise_multiplication<T1, T2>(src1, src2, scale, convert_policy, rounding_policy, qinfo_out);
+ return reference::pixel_wise_multiplication<T1, T2, T3>(src1, src2, scale, convert_policy, rounding_policy, dt_out, qinfo_out);
}
TensorType _target{};
- SimpleTensor<T2> _reference{};
-};
-
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
-class PixelWiseMultiplicationQuatizedValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
-{
-public:
- template <typename...>
- void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- QuantizationInfo in1_qua_info, QuantizationInfo in2_qua_info, QuantizationInfo out_qua_info)
- {
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
- in1_qua_info, in2_qua_info, out_qua_info);
- }
+ SimpleTensor<T3> _reference{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
@@ -141,7 +129,7 @@ public:
template <typename...>
void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
}
};
@@ -153,21 +141,21 @@ public:
template <typename...>
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
}
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
-class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2>
+class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType dt, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt, dt, scale, convert_policy, rounding_policy,
- qinfo0, qinfo1, qinfo_out);
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, shape, dt_in1, dt_in2, dt_out, scale, convert_policy,
+ rounding_policy, qinfo0, qinfo1, qinfo_out);
}
};
} // namespace validation
diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp
index 2b4c849c39..3e21fca72a 100644
--- a/tests/validation/reference/PixelWiseMultiplication.cpp
+++ b/tests/validation/reference/PixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,16 +52,16 @@ namespace
* @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
* @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
*/
-template <typename T1, typename T2>
-T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+template <typename T1, typename T2, typename T3>
+T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
+ using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type;
const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
- if(is_floating_point<T2>::value)
+ if(is_floating_point<T3>::value)
{
- const auto result = static_cast<T2>(val);
+ const auto result = static_cast<T3>(val);
return result;
}
@@ -83,7 +83,7 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy,
ARM_COMPUTE_ERROR("Unsupported rounding policy");
}
- const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val);
+ const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val);
return result;
}
@@ -92,8 +92,8 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy,
template <size_t dim>
struct BroadcastUnroll
{
- template <typename T1, typename T2>
- static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ template <typename T1, typename T2, typename T3>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
{
@@ -117,23 +117,23 @@ struct BroadcastUnroll
template <>
struct BroadcastUnroll<0>
{
- template <typename T1, typename T2>
- static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ template <typename T1, typename T2, typename T3>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
{
- dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
+ dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
}
};
} // namespace
-template <typename T1, typename T2>
-SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ DataType dt_out, const QuantizationInfo &qout)
{
ARM_COMPUTE_UNUSED(qout);
- SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type());
+ SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out);
if(scale < 0)
{
@@ -151,15 +151,15 @@ SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const S
template <>
SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
{
SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout);
}
else
@@ -179,15 +179,15 @@ SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src
template <>
SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED)
{
SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_asymmetric<int8_t>(dst_tmp, qout);
}
else
@@ -207,15 +207,15 @@ SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1,
template <>
SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16)
{
SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
}
else
@@ -234,9 +234,10 @@ SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src
}
// *INDENT-OFF*
// clang-format off
-template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
-template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
-template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
+template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
// clang-format on
// *INDENT-ON*
} // namespace reference
diff --git a/tests/validation/reference/PixelWiseMultiplication.h b/tests/validation/reference/PixelWiseMultiplication.h
index f5b8e777fb..f8afa0384b 100644
--- a/tests/validation/reference/PixelWiseMultiplication.h
+++ b/tests/validation/reference/PixelWiseMultiplication.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,9 +34,10 @@ namespace validation
{
namespace reference
{
-template <typename T1, typename T2>
-SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale,
- ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout = QuantizationInfo());
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale,
+ ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out,
+ const QuantizationInfo &qout = QuantizationInfo());
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/QLSTMLayerNormalization.cpp b/tests/validation/reference/QLSTMLayerNormalization.cpp
index 90d59b93ad..0e24de6584 100644
--- a/tests/validation/reference/QLSTMLayerNormalization.cpp
+++ b/tests/validation/reference/QLSTMLayerNormalization.cpp
@@ -41,7 +41,7 @@ namespace reference
SimpleTensor<float> qlstm_layer_normalization_float_compute(SimpleTensor<float> src, SimpleTensor<float> weight, SimpleTensor<float> bias)
{
SimpleTensor<float> output = mean_std_normalization_layer(src);
- output = pixel_wise_multiplication(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
+ output = pixel_wise_multiplication<float, float, float>(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, DataType::F32);
return arithmetic_operation(ArithmeticOperation::ADD, output, bias, DataType::F32, ConvertPolicy::SATURATE);
}