aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorOperations.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r--tests/validation/TensorOperations.h90
1 files changed, 27 insertions, 63 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 0502f53186..e90635f0d4 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -518,94 +518,58 @@ void box3x3(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T const
}
// Depth conversion
-template <typename T1, typename T2>
+template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_floating_point<T2>::value, int >::type = 0 >
void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift)
{
- ARM_COMPUTE_ERROR("The conversion is not supported");
-}
-
-template <>
-void depth_convert<int8_t, float>(const Tensor<int8_t> &in, Tensor<float> &out, ConvertPolicy policy, uint32_t shift)
-{
- const int8_t fixed_point_position = static_cast<int8_t>(in.fixed_point_position());
- for(int i = 0; i < in.num_elements(); ++i)
- {
- out[i] = static_cast<float>(in[i]) * (1.0f / (1 << fixed_point_position));
- }
-}
-
-template <>
-void depth_convert<float, int8_t>(const Tensor<float> &in, Tensor<int8_t> &out, ConvertPolicy policy, uint32_t shift)
-{
- const int8_t fixed_point_position = static_cast<int8_t>(in.fixed_point_position());
- for(int i = 0; i < in.num_elements(); ++i)
- {
- float val = in[i] * (1 << fixed_point_position) + 0.5f;
- out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<int8_t>(val) : static_cast<int8_t>(val));
- }
-}
+ using namespace fixed_point_arithmetic;
-template <>
-void depth_convert<uint8_t, uint16_t>(const Tensor<uint8_t> &in, Tensor<uint16_t> &out, ConvertPolicy policy, uint32_t shift)
-{
+ const int fixed_point_position = in.fixed_point_position();
for(int i = 0; i < in.num_elements(); ++i)
{
- out[i] = static_cast<uint16_t>(in[i]) << shift;
+ out[i] = static_cast<float>(fixed_point<T1>(in[i], fixed_point_position, true));
}
}
-template <>
-void depth_convert<uint8_t, int16_t>(const Tensor<uint8_t> &in, Tensor<int16_t> &out, ConvertPolicy policy, uint32_t shift)
+template < typename T1, typename T2, typename std::enable_if < std::is_floating_point<T1>::value &&std::is_integral<T2>::value, int >::type = 0 >
+void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift)
{
- for(int i = 0; i < in.num_elements(); ++i)
- {
- out[i] = static_cast<int16_t>(in[i]) << shift;
- }
-}
+ using namespace fixed_point_arithmetic;
-template <>
-void depth_convert<uint8_t, int32_t>(const Tensor<uint8_t> &in, Tensor<int32_t> &out, ConvertPolicy policy, uint32_t shift)
-{
+ const int fixed_point_position = out.fixed_point_position();
for(int i = 0; i < in.num_elements(); ++i)
{
- out[i] = static_cast<int32_t>(in[i]) << shift;
+ out[i] = fixed_point<T2>(in[i], fixed_point_position).raw();
}
}
-template <>
-void depth_convert<uint16_t, uint8_t>(const Tensor<uint16_t> &in, Tensor<uint8_t> &out, ConvertPolicy policy, uint32_t shift)
+template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_integral<T2>::value, int >::type = 0 >
+void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift)
{
- for(int i = 0; i < in.num_elements(); ++i)
+ // Up-casting
+ if(std::numeric_limits<T1>::digits <= std::numeric_limits<T2>::digits)
{
- uint16_t val = in[i] >> shift;
- out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<uint8_t>(val) : static_cast<uint8_t>(val));
+ for(int i = 0; i < in.num_elements(); ++i)
+ {
+ out[i] = static_cast<T2>(in[i]) << shift;
+ }
}
-}
-
-template <>
-void depth_convert<uint16_t, uint32_t>(const Tensor<uint16_t> &in, Tensor<uint32_t> &out, ConvertPolicy policy, uint32_t shift)
-{
- for(int i = 0; i < in.num_elements(); ++i)
+ // Down-casting
+ else
{
- out[i] = static_cast<uint32_t>(in[i]) << shift;
+ for(int i = 0; i < in.num_elements(); ++i)
+ {
+ T1 val = in[i] >> shift;
+ out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(val) : static_cast<T2>(val));
+ }
}
}
-template <>
-void depth_convert<int16_t, uint8_t>(const Tensor<int16_t> &in, Tensor<uint8_t> &out, ConvertPolicy policy, uint32_t shift)
-{
- for(int i = 0; i < in.num_elements(); ++i)
- {
- int16_t val = in[i] >> shift;
- out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<uint8_t>(val) : static_cast<uint8_t>(val));
- }
-}
-template <>
-void depth_convert<int16_t, int32_t>(const Tensor<int16_t> &in, Tensor<int32_t> &out, ConvertPolicy policy, uint32_t shift)
+template < typename T1, typename T2, typename std::enable_if < std::is_floating_point<T1>::value &&std::is_floating_point<T2>::value, int >::type = 0 >
+void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift)
{
for(int i = 0; i < in.num_elements(); ++i)
{
- out[i] = static_cast<int32_t>(in[i]) << shift;
+ out[i] = static_cast<T2>(in[i]);
}
}