aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/DepthConvertLayer.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-19 11:56:51 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-22 12:41:32 +0000
commit303f0dbebf631b3db00d9d64e71018abbbe9d4fe (patch)
tree631e70c9a8141f1262752829a64b3e33c7f1ee93 /tests/validation/reference/DepthConvertLayer.cpp
parent9d3a831d4131f8a8b37f127f11d36848d33e8496 (diff)
downloadComputeLibrary-303f0dbebf631b3db00d9d64e71018abbbe9d4fe.tar.gz
COMPMID-1718: Extend DepthConvert to support Cast
Change-Id: I6ee2c0b670727fc808fa636c53ddfaec3a0036c9
Diffstat (limited to 'tests/validation/reference/DepthConvertLayer.cpp')
-rw-r--r--tests/validation/reference/DepthConvertLayer.cpp99
1 files changed, 77 insertions, 22 deletions
diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp
index fd2e0ae378..4d5b97b478 100644
--- a/tests/validation/reference/DepthConvertLayer.cpp
+++ b/tests/validation/reference/DepthConvertLayer.cpp
@@ -25,6 +25,9 @@
#include "tests/validation/Helpers.h"
+#include "arm_compute/core/utils/misc/Rounding.h"
+#include "arm_compute/core/utils/misc/SaturateCast.h"
+
#include "tests/Types.h"
namespace arm_compute
@@ -35,13 +38,13 @@ namespace validation
{
namespace reference
{
-template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_integral<T2>::value &&!std::is_same<T1, T2>::value, int >::type >
+template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&!std::is_same<T1, T2>::value, int >::type >
SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift)
{
SimpleTensor<T2> result(src.shape(), dt_out);
// Up-casting
- if(src.data_type() <= dt_out)
+ if(element_size_from_data_type(src.data_type()) < element_size_from_data_type(dt_out))
{
for(int i = 0; i < src.num_elements(); ++i)
{
@@ -54,48 +57,100 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con
for(int i = 0; i < src.num_elements(); ++i)
{
T1 val = src[i] >> shift;
- result[i] = (policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(val) : static_cast<T2>(val);
+ result[i] = (policy == ConvertPolicy::SATURATE) ? utils::cast::saturate_cast<T2>(val) : static_cast<T2>(val);
}
}
return result;
}
-template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&is_floating_point<T2>::value &&!std::is_same<T1, T2>::value, int >::type >
+template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&!std::is_same<T1, T2>::value, int >::type >
SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift)
{
SimpleTensor<T2> result(src.shape(), dt_out);
+ ARM_COMPUTE_ERROR_ON(shift != 0);
+ ARM_COMPUTE_UNUSED(policy, shift);
- const uint32_t scale = 1 << shift;
-
- // Up-casting
- if(src.data_type() <= dt_out)
+ // Always saturate on floats
+ for(int i = 0; i < src.num_elements(); ++i)
{
- for(int i = 0; i < src.num_elements(); ++i)
- {
- result[i] = src[i] * static_cast<T2>(scale);
- }
- }
- // Down-casting
- else
- {
- for(int i = 0; i < src.num_elements(); ++i)
- {
- T1 val = src[i] / static_cast<T1>(scale);
- result[i] = (policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(val) : static_cast<T2>(val);
- }
+ T1 val = utils::rounding::round_half_away_from_zero(src[i]);
+ result[i] = utils::cast::saturate_cast<T2>(val);
}
return result;
}
+// U8
+template SimpleTensor<int8_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<int16_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<int32_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<half> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<float> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+// S8
+template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<int8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<int8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int16_t> depth_convert(const SimpleTensor<int8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<int8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int32_t> depth_convert(const SimpleTensor<int8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<half> depth_convert(const SimpleTensor<int8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<float> depth_convert(const SimpleTensor<int8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+// U16
template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<uint16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int8_t> depth_convert(const SimpleTensor<uint16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int16_t> depth_convert(const SimpleTensor<uint16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<uint16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int32_t> depth_convert(const SimpleTensor<uint16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<half> depth_convert(const SimpleTensor<uint16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<float> depth_convert(const SimpleTensor<uint16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+// S16
template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<int16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int8_t> depth_convert(const SimpleTensor<int16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<int16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<int16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<int32_t> depth_convert(const SimpleTensor<int16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
-template SimpleTensor<half> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<half> depth_convert(const SimpleTensor<int16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<float> depth_convert(const SimpleTensor<int16_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+// U32
+template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<uint32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int8_t> depth_convert(const SimpleTensor<uint32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<uint32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int16_t> depth_convert(const SimpleTensor<uint32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int32_t> depth_convert(const SimpleTensor<uint32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<half> depth_convert(const SimpleTensor<uint32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<float> depth_convert(const SimpleTensor<uint32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+// S32
+template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int8_t> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int16_t> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<half> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<float> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+// F16
+template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int16_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int32_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<float> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+// F32
+template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int8_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int16_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<int32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<half> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
} // namespace reference
} // namespace validation
} // namespace test