diff options
author | Pablo Marquez Tello <pablo.tello@arm.com> | 2023-08-03 14:47:31 +0100 |
---|---|---|
committer | Pablo Marquez Tello <pablo.tello@arm.com> | 2023-08-08 15:49:54 +0000 |
commit | 29e27b0544d99e5d98f044a9e606db8abcfb8900 (patch) | |
tree | 3749d3f3640d55fceda4dcd04a2916c87414b045 /tests/validation | |
parent | 66b4a6a8ca1ee55e5b7f05bae2543cf99fe22d6d (diff) | |
download | ComputeLibrary-29e27b0544d99e5d98f044a9e606db8abcfb8900.tar.gz |
Add support for S64 output in NEArgMinMaxLayer
* NEArgMinMaxLayer uses NEReductionOperation to compute its result in S32
* We need to call NECast to convert from S32 to S64
* Resolves MLCE-1089
Change-Id: I6fded869b6076d7af1b9b3e70eb384f4ee82fd8a
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10054
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/NEON/ArgMinMax.cpp | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tests/validation/NEON/ArgMinMax.cpp b/tests/validation/NEON/ArgMinMax.cpp index 2e21a7db7b..c80c936b6d 100644 --- a/tests/validation/NEON/ArgMinMax.cpp +++ b/tests/validation/NEON/ArgMinMax.cpp @@ -97,6 +97,8 @@ using NEArgMinMaxValidationFixture = ArgMinMaxValidationFixture<Tensor, Accessor using NEArgMinMaxValidationFixture_S32_S32 = NEArgMinMaxValidationFixture<int32_t, int32_t>; using NEArgMinMaxValidationFixture_F16_S32 = NEArgMinMaxValidationFixture<half, int32_t>; using NEArgMinMaxValidationFixture_F32_S32 = NEArgMinMaxValidationFixture<float, int32_t>; +using NEArgMinMaxValidationFixture_F32_S64 = NEArgMinMaxValidationFixture<float, int64_t>; + TEST_SUITE(S32) FIXTURE_DATA_TEST_CASE(RunSmallAxis0, NEArgMinMaxValidationFixture_S32_S32, @@ -182,6 +184,19 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(Accessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunSmall_F32_S64, + NEArgMinMaxValidationFixture_F32_S64, + framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(ArgMinMaxSmallDataset(), + framework::dataset::make("DataTypeIn", DataType::F32)), + framework::dataset::make("DataTypeOut", DataType::S64)), + AxisDataset), + OpsDataset)) +{ + // Validate output + validate(Accessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunLarge, NEArgMinMaxValidationFixture_F32_S32, framework::DatasetMode::NIGHTLY, |