aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/UpsampleLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/UpsampleLayer.cpp')
-rw-r--r--tests/validation/reference/UpsampleLayer.cpp33
1 files changed, 4 insertions, 29 deletions
diff --git a/tests/validation/reference/UpsampleLayer.cpp b/tests/validation/reference/UpsampleLayer.cpp
index d77f9ae348..a81a601057 100644
--- a/tests/validation/reference/UpsampleLayer.cpp
+++ b/tests/validation/reference/UpsampleLayer.cpp
@@ -23,6 +23,7 @@
*/
#include "UpsampleLayer.h"
+#include "arm_compute/core/utils/misc/Requires.h"
#include "tests/validation/Helpers.h"
namespace arm_compute
@@ -33,10 +34,8 @@ namespace validation
{
namespace reference
{
-namespace
-{
template <typename T>
-SimpleTensor<T> upsample_function(const SimpleTensor<T> &src, const Size2D &info, const InterpolationPolicy policy)
+SimpleTensor<T> upsample_layer(const SimpleTensor<T> &src, const Size2D &info, const InterpolationPolicy policy)
{
ARM_COMPUTE_ERROR_ON(policy != InterpolationPolicy::NEAREST_NEIGHBOR);
ARM_COMPUTE_UNUSED(policy);
@@ -76,36 +75,12 @@ SimpleTensor<T> upsample_function(const SimpleTensor<T> &src, const Size2D &info
return out;
}
-} // namespace
-
-template <typename T>
-SimpleTensor<T> upsample_layer(const SimpleTensor<T> &src, const Size2D &info, const InterpolationPolicy policy)
-{
- return upsample_function<T>(src, info, policy);
-}
-
-template <>
-SimpleTensor<uint8_t> upsample_layer(const SimpleTensor<uint8_t> &src, const Size2D &info, const InterpolationPolicy policy)
-{
- SimpleTensor<uint8_t> dst(src.shape(), src.data_type(), 1, src.quantization_info());
-
- if(is_data_type_quantized_asymmetric(src.data_type()))
- {
- SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
- SimpleTensor<float> dst_tmp = upsample_function<float>(src_tmp, info, policy);
- dst = convert_to_asymmetric<uint8_t>(dst_tmp, src.quantization_info());
- }
- else
- {
- dst = upsample_function<uint8_t>(src, info, policy);
- }
- return dst;
-}
-
template SimpleTensor<float> upsample_layer(const SimpleTensor<float> &src,
const Size2D &info, const InterpolationPolicy policy);
template SimpleTensor<half> upsample_layer(const SimpleTensor<half> &src,
const Size2D &info, const InterpolationPolicy policy);
+template SimpleTensor<uint8_t> upsample_layer(const SimpleTensor<uint8_t> &src,
+ const Size2D &info, const InterpolationPolicy policy);
template SimpleTensor<int8_t> upsample_layer(const SimpleTensor<int8_t> &src,
const Size2D &info, const InterpolationPolicy policy);
} // namespace reference