aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUsama Arif <usama.arif@arm.com>2019-05-20 13:44:34 +0100
committerUsama Arif <usama.arif@arm.com>2019-05-22 15:13:21 +0000
commit28f0dd99fba11ed9b7165eca17d801bdfb421576 (patch)
treecb8bb464cdcea9946179ccf8add3158e50eefa48
parenta4a08ad5e33867f9938a3fbaf9b6dcc56ad8f7b5 (diff)
downloadComputeLibrary-28f0dd99fba11ed9b7165eca17d801bdfb421576.tar.gz
COMPMID-2279: Implement REDUCE_MAX operator for NEON
Change-Id: Iccd25b8aab1dd871c0d86ec3816b1cbf48370066 Signed-off-by: Usama Arif <usama.arif@arm.com> Reviewed-on: https://review.mlplatform.org/c/1193 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
-rw-r--r--arm_compute/core/Types.h1
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp34
-rw-r--r--src/runtime/NEON/functions/NEReductionOperation.cpp26
-rw-r--r--tests/validation/NEON/ReductionOperation.cpp1
-rw-r--r--tests/validation/reference/ReductionOperation.cpp13
-rw-r--r--utils/TypePrinter.h3
6 files changed, 77 insertions, 1 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 241c1fe1f4..65db06b878 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -561,6 +561,7 @@ enum class ReductionOperation
SUM_SQUARE, /**< Sum of squares */
SUM, /**< Sum */
MIN, /**< Min */
+ MAX, /**< Max */
};
/** Available element-wise operations */
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index b51d4b311f..e6edf22083 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -417,6 +417,7 @@ struct RedOpX
case ReductionOperation::ARG_IDX_MAX:
case ReductionOperation::ARG_IDX_MIN:
case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
{
init_res_value = *reinterpret_cast<T *>(input.ptr());
break;
@@ -468,6 +469,11 @@ struct RedOpX
vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
break;
}
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -518,6 +524,11 @@ struct RedOpX
*(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
break;
}
+ case ReductionOperation::MAX:
+ {
+ *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -541,7 +552,7 @@ struct RedOpX_qasymm8
uint8x16_t vec_res_value = { 0 };
- if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN)
+ if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
{
vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
}
@@ -618,6 +629,11 @@ struct RedOpX_qasymm8
vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
break;
}
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -638,6 +654,11 @@ struct RedOpX_qasymm8
*(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
break;
}
+ case ReductionOperation::MAX:
+ {
+ *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
+ break;
+ }
case ReductionOperation::PROD:
{
auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
@@ -694,6 +715,7 @@ struct RedOpYZW
case ReductionOperation::ARG_IDX_MAX:
case ReductionOperation::ARG_IDX_MIN:
case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
{
vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
break;
@@ -761,6 +783,11 @@ struct RedOpYZW
vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
break;
}
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -950,6 +977,11 @@ struct RedOpYZW_qasymm8
vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
break;
}
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp
index 81bb32f5dc..dc6cf59019 100644
--- a/src/runtime/NEON/functions/NEReductionOperation.cpp
+++ b/src/runtime/NEON/functions/NEReductionOperation.cpp
@@ -112,6 +112,32 @@ void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned i
}
break;
}
+ case ReductionOperation::MAX:
+ {
+ switch(input->info()->data_type())
+ {
+ case DataType::F32:
+ {
+ pixelValue = PixelValue(-std::numeric_limits<float>::max());
+ break;
+ }
+ case DataType::F16:
+ {
+ pixelValue = PixelValue(static_cast<half>(-65504.0f));
+ break;
+ }
+ case DataType::QASYMM8:
+ {
+ pixelValue = PixelValue(0, input->info()->data_type(), input->info()->quantization_info());
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported DataType");
+ }
+ }
+ break;
+ }
case ReductionOperation::ARG_IDX_MAX:
case ReductionOperation::ARG_IDX_MIN:
case ReductionOperation::MEAN_SUM:
diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp
index 074689d678..5b697a5efa 100644
--- a/tests/validation/NEON/ReductionOperation.cpp
+++ b/tests/validation/NEON/ReductionOperation.cpp
@@ -53,6 +53,7 @@ const auto ReductionOperations = framework::dataset::make("ReductionOperation",
ReductionOperation::SUM,
ReductionOperation::PROD,
ReductionOperation::MIN,
+ ReductionOperation::MAX,
});
const auto QuantizationInfos = framework::dataset::make("QuantizationInfo",
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 1f825f0e0f..571b991b92 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -51,6 +51,7 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
}
break;
case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
{
res = *ptr;
}
@@ -88,6 +89,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
int_res = elem;
}
break;
+ case ReductionOperation::MAX:
+ if(static_cast<T>(int_res) < elem)
+ {
+ int_res = elem;
+ }
+ break;
case ReductionOperation::SUM_SQUARE:
int_res += elem * elem;
break;
@@ -133,6 +140,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
res = elem;
}
break;
+ case ReductionOperation::MAX:
+ if(res < elem)
+ {
+ res = elem;
+ }
+ break;
case ReductionOperation::SUM_SQUARE:
res += elem * elem;
break;
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 9b8efe5a23..74dd0bbc35 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -1449,6 +1449,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const ReductionOperation &
case ReductionOperation::MIN:
os << "MIN";
break;
+ case ReductionOperation::MAX:
+ os << "MAX";
+ break;
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}