aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CPP
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CPP')
-rw-r--r--tests/validation/CPP/ActivationLayer.cpp4
-rw-r--r--tests/validation/CPP/ArithmeticAddition.cpp6
-rw-r--r--tests/validation/CPP/ArithmeticSubtraction.cpp4
-rw-r--r--tests/validation/CPP/ConvolutionLayer.cpp5
-rw-r--r--tests/validation/CPP/DepthConcatenateLayer.cpp3
-rw-r--r--tests/validation/CPP/DepthConvert.cpp1
-rw-r--r--tests/validation/CPP/DepthwiseConvolution.cpp1
-rw-r--r--tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp1
-rw-r--r--tests/validation/CPP/FullyConnectedLayer.cpp5
-rw-r--r--tests/validation/CPP/GEMM.cpp4
-rw-r--r--tests/validation/CPP/NormalizationLayer.cpp4
-rw-r--r--tests/validation/CPP/PoolingLayer.cpp4
-rw-r--r--tests/validation/CPP/ReshapeLayer.cpp4
-rw-r--r--tests/validation/CPP/Scale.cpp5
-rw-r--r--tests/validation/CPP/Scale.h2
-rw-r--r--tests/validation/CPP/SoftmaxLayer.cpp4
-rw-r--r--tests/validation/CPP/Utils.cpp26
-rw-r--r--tests/validation/CPP/Utils.h2
18 files changed, 44 insertions, 41 deletions
diff --git a/tests/validation/CPP/ActivationLayer.cpp b/tests/validation/CPP/ActivationLayer.cpp
index 8fcacca1e2..2243e6ff59 100644
--- a/tests/validation/CPP/ActivationLayer.cpp
+++ b/tests/validation/CPP/ActivationLayer.cpp
@@ -23,9 +23,9 @@
*/
#include "ActivationLayer.h"
+#include "arm_compute/core/Types.h"
#include "tests/validation/FixedPoint.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -155,7 +155,7 @@ SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo
}
template SimpleTensor<float> activation_layer(const SimpleTensor<float> &src, ActivationLayerInfo info);
-template SimpleTensor<half_float::half> activation_layer(const SimpleTensor<half_float::half> &src, ActivationLayerInfo info);
+template SimpleTensor<half> activation_layer(const SimpleTensor<half> &src, ActivationLayerInfo info);
template SimpleTensor<qint8_t> activation_layer(const SimpleTensor<qint8_t> &src, ActivationLayerInfo info);
template SimpleTensor<qint16_t> activation_layer(const SimpleTensor<qint16_t> &src, ActivationLayerInfo info);
} // namespace reference
diff --git a/tests/validation/CPP/ArithmeticAddition.cpp b/tests/validation/CPP/ArithmeticAddition.cpp
index 41052c469b..82dd1437cd 100644
--- a/tests/validation/CPP/ArithmeticAddition.cpp
+++ b/tests/validation/CPP/ArithmeticAddition.cpp
@@ -23,9 +23,9 @@
*/
#include "ArithmeticAddition.h"
+#include "arm_compute/core/Types.h"
#include "tests/validation/FixedPoint.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -54,8 +54,8 @@ SimpleTensor<T> arithmetic_addition(const SimpleTensor<T> &src1, const SimpleTen
template SimpleTensor<uint8_t> arithmetic_addition(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int16_t> arithmetic_addition(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int8_t> arithmetic_addition(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
-template SimpleTensor<half_float::half> arithmetic_addition(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, DataType dst_data_type,
- ConvertPolicy convert_policy);
+template SimpleTensor<half> arithmetic_addition(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type,
+ ConvertPolicy convert_policy);
template SimpleTensor<float> arithmetic_addition(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp
index fa7fec9d1b..80bdb15a49 100644
--- a/tests/validation/CPP/ArithmeticSubtraction.cpp
+++ b/tests/validation/CPP/ArithmeticSubtraction.cpp
@@ -25,7 +25,6 @@
#include "tests/validation/FixedPoint.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -54,8 +53,7 @@ SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const Simple
template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
-template SimpleTensor<half_float::half> arithmetic_subtraction(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, DataType dst_data_type,
- ConvertPolicy convert_policy);
+template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<float> arithmetic_subtraction(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/CPP/ConvolutionLayer.cpp b/tests/validation/CPP/ConvolutionLayer.cpp
index 1824ada791..656cd2ee26 100644
--- a/tests/validation/CPP/ConvolutionLayer.cpp
+++ b/tests/validation/CPP/ConvolutionLayer.cpp
@@ -25,7 +25,6 @@
#include "tests/validation/FixedPoint.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -193,8 +192,8 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor
template SimpleTensor<float> convolution_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &output_shape,
const PadStrideInfo &info);
-template SimpleTensor<half_float::half> convolution_layer(const SimpleTensor<half_float::half> &src, const SimpleTensor<half_float::half> &weights, const SimpleTensor<half_float::half> &bias,
- const TensorShape &output_shape, const PadStrideInfo &info);
+template SimpleTensor<half> convolution_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, const TensorShape &output_shape,
+ const PadStrideInfo &info);
template SimpleTensor<qint8_t> convolution_layer(const SimpleTensor<qint8_t> &src, const SimpleTensor<qint8_t> &weights, const SimpleTensor<qint8_t> &bias, const TensorShape &output_shape,
const PadStrideInfo &info);
template SimpleTensor<qint16_t> convolution_layer(const SimpleTensor<qint16_t> &src, const SimpleTensor<qint16_t> &weights, const SimpleTensor<qint16_t> &bias, const TensorShape &output_shape,
diff --git a/tests/validation/CPP/DepthConcatenateLayer.cpp b/tests/validation/CPP/DepthConcatenateLayer.cpp
index 139d26f2b6..9a7248493d 100644
--- a/tests/validation/CPP/DepthConcatenateLayer.cpp
+++ b/tests/validation/CPP/DepthConcatenateLayer.cpp
@@ -25,7 +25,6 @@
#include "tests/validation/FixedPoint.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -95,7 +94,7 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
}
template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs);
-template SimpleTensor<half_float::half> depthconcatenate_layer(const std::vector<SimpleTensor<half_float::half>> &srcs);
+template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs);
template SimpleTensor<qint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<qint8_t>> &srcs);
template SimpleTensor<qint16_t> depthconcatenate_layer(const std::vector<SimpleTensor<qint16_t>> &srcs);
} // namespace reference
diff --git a/tests/validation/CPP/DepthConvert.cpp b/tests/validation/CPP/DepthConvert.cpp
index bb34f67086..110174a73f 100644
--- a/tests/validation/CPP/DepthConvert.cpp
+++ b/tests/validation/CPP/DepthConvert.cpp
@@ -25,7 +25,6 @@
#include "tests/validation/FixedPoint.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
#include "tests/Types.h"
diff --git a/tests/validation/CPP/DepthwiseConvolution.cpp b/tests/validation/CPP/DepthwiseConvolution.cpp
index ebca333715..ce30bed640 100644
--- a/tests/validation/CPP/DepthwiseConvolution.cpp
+++ b/tests/validation/CPP/DepthwiseConvolution.cpp
@@ -27,7 +27,6 @@
#include "Utils.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
diff --git a/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp b/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp
index 7020a854cf..3942ecf02a 100644
--- a/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp
+++ b/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp
@@ -29,7 +29,6 @@
#include "Utils.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
diff --git a/tests/validation/CPP/FullyConnectedLayer.cpp b/tests/validation/CPP/FullyConnectedLayer.cpp
index a146535bd9..2b32c4b161 100644
--- a/tests/validation/CPP/FullyConnectedLayer.cpp
+++ b/tests/validation/CPP/FullyConnectedLayer.cpp
@@ -23,8 +23,8 @@
*/
#include "FullyConnectedLayer.h"
+#include "arm_compute/core/Types.h"
#include "tests/validation/FixedPoint.h"
-#include "tests/validation/half.h"
#include <numeric>
@@ -123,8 +123,7 @@ SimpleTensor<T> fully_connected_layer(const SimpleTensor<T> &src, const SimpleTe
}
template SimpleTensor<float> fully_connected_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &dst_shape);
-template SimpleTensor<half_float::half> fully_connected_layer(const SimpleTensor<half_float::half> &src, const SimpleTensor<half_float::half> &weights, const SimpleTensor<half_float::half> &bias,
- const TensorShape &dst_shape);
+template SimpleTensor<half> fully_connected_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, const TensorShape &dst_shape);
template SimpleTensor<qint8_t> fully_connected_layer(const SimpleTensor<qint8_t> &src, const SimpleTensor<qint8_t> &weights, const SimpleTensor<qint8_t> &bias, const TensorShape &dst_shape);
template SimpleTensor<qint16_t> fully_connected_layer(const SimpleTensor<qint16_t> &src, const SimpleTensor<qint16_t> &weights, const SimpleTensor<qint16_t> &bias, const TensorShape &dst_shape);
} // namespace reference
diff --git a/tests/validation/CPP/GEMM.cpp b/tests/validation/CPP/GEMM.cpp
index 9b66597eb8..77d025ec8e 100644
--- a/tests/validation/CPP/GEMM.cpp
+++ b/tests/validation/CPP/GEMM.cpp
@@ -23,8 +23,8 @@
*/
#include "GEMM.h"
+#include "arm_compute/core/Types.h"
#include "tests/validation/FixedPoint.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -113,7 +113,7 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
}
template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
-template SimpleTensor<half_float::half> gemm(const SimpleTensor<half_float::half> &a, const SimpleTensor<half_float::half> &b, const SimpleTensor<half_float::half> &c, float alpha, float beta);
+template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
template SimpleTensor<qint8_t> gemm(const SimpleTensor<qint8_t> &a, const SimpleTensor<qint8_t> &b, const SimpleTensor<qint8_t> &c, float alpha, float beta);
template SimpleTensor<qint16_t> gemm(const SimpleTensor<qint16_t> &a, const SimpleTensor<qint16_t> &b, const SimpleTensor<qint16_t> &c, float alpha, float beta);
} // namespace reference
diff --git a/tests/validation/CPP/NormalizationLayer.cpp b/tests/validation/CPP/NormalizationLayer.cpp
index 3c6f5e1a54..226af96fe3 100644
--- a/tests/validation/CPP/NormalizationLayer.cpp
+++ b/tests/validation/CPP/NormalizationLayer.cpp
@@ -23,8 +23,8 @@
*/
#include "NormalizationLayer.h"
+#include "arm_compute/core/Types.h"
#include "tests/validation/FixedPoint.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -266,7 +266,7 @@ SimpleTensor<T> normalization_layer(const SimpleTensor<T> &src, NormalizationLay
}
template SimpleTensor<float> normalization_layer(const SimpleTensor<float> &src, NormalizationLayerInfo info);
-template SimpleTensor<half_float::half> normalization_layer(const SimpleTensor<half_float::half> &src, NormalizationLayerInfo info);
+template SimpleTensor<half> normalization_layer(const SimpleTensor<half> &src, NormalizationLayerInfo info);
template SimpleTensor<qint8_t> normalization_layer(const SimpleTensor<qint8_t> &src, NormalizationLayerInfo info);
template SimpleTensor<qint16_t> normalization_layer(const SimpleTensor<qint16_t> &src, NormalizationLayerInfo info);
} // namespace reference
diff --git a/tests/validation/CPP/PoolingLayer.cpp b/tests/validation/CPP/PoolingLayer.cpp
index c4425ca9a1..f7273f073f 100644
--- a/tests/validation/CPP/PoolingLayer.cpp
+++ b/tests/validation/CPP/PoolingLayer.cpp
@@ -23,8 +23,8 @@
*/
#include "PoolingLayer.h"
+#include "arm_compute/core/Types.h"
#include "tests/validation/FixedPoint.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -234,7 +234,7 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info)
}
template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, PoolingLayerInfo info);
-template SimpleTensor<half_float::half> pooling_layer(const SimpleTensor<half_float::half> &src, PoolingLayerInfo info);
+template SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, PoolingLayerInfo info);
template SimpleTensor<qint8_t> pooling_layer(const SimpleTensor<qint8_t> &src, PoolingLayerInfo info);
template SimpleTensor<qint16_t> pooling_layer(const SimpleTensor<qint16_t> &src, PoolingLayerInfo info);
} // namespace reference
diff --git a/tests/validation/CPP/ReshapeLayer.cpp b/tests/validation/CPP/ReshapeLayer.cpp
index cc7f15e4b1..42f06e4f5a 100644
--- a/tests/validation/CPP/ReshapeLayer.cpp
+++ b/tests/validation/CPP/ReshapeLayer.cpp
@@ -23,7 +23,7 @@
*/
#include "ReshapeLayer.h"
-#include "tests/validation/half.h"
+#include "arm_compute/core/Types.h"
namespace arm_compute
{
@@ -49,7 +49,7 @@ template SimpleTensor<uint16_t> reshape_layer(const SimpleTensor<uint16_t> &src,
template SimpleTensor<int16_t> reshape_layer(const SimpleTensor<int16_t> &src, const TensorShape &output_shape);
template SimpleTensor<uint32_t> reshape_layer(const SimpleTensor<uint32_t> &src, const TensorShape &output_shape);
template SimpleTensor<int32_t> reshape_layer(const SimpleTensor<int32_t> &src, const TensorShape &output_shape);
-template SimpleTensor<half_float::half> reshape_layer(const SimpleTensor<half_float::half> &src, const TensorShape &output_shape);
+template SimpleTensor<half> reshape_layer(const SimpleTensor<half> &src, const TensorShape &output_shape);
template SimpleTensor<float> reshape_layer(const SimpleTensor<float> &src, const TensorShape &output_shape);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/CPP/Scale.cpp b/tests/validation/CPP/Scale.cpp
index a1119f33b9..ba34553a99 100644
--- a/tests/validation/CPP/Scale.cpp
+++ b/tests/validation/CPP/Scale.cpp
@@ -36,7 +36,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
+SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, T constant_border_value)
{
TensorShape shape_scaled(in.shape());
shape_scaled.set(0, in.shape()[0] * scale_x);
@@ -160,6 +160,9 @@ SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, I
}
template SimpleTensor<uint8_t> scale(const SimpleTensor<uint8_t> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value);
+template SimpleTensor<int16_t> scale(const SimpleTensor<int16_t> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, int16_t constant_border_value);
+template SimpleTensor<half> scale(const SimpleTensor<half> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, half constant_border_value);
+template SimpleTensor<float> scale(const SimpleTensor<float> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, float constant_border_value);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/CPP/Scale.h b/tests/validation/CPP/Scale.h
index b882915946..53183ae742 100644
--- a/tests/validation/CPP/Scale.h
+++ b/tests/validation/CPP/Scale.h
@@ -35,7 +35,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value = 0);
+SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, T constant_border_value = 0);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/CPP/SoftmaxLayer.cpp b/tests/validation/CPP/SoftmaxLayer.cpp
index 4fe87d07dc..eb7655078c 100644
--- a/tests/validation/CPP/SoftmaxLayer.cpp
+++ b/tests/validation/CPP/SoftmaxLayer.cpp
@@ -23,8 +23,8 @@
*/
#include "SoftmaxLayer.h"
+#include "arm_compute/core/Types.h"
#include "tests/validation/FixedPoint.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -113,7 +113,7 @@ SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src)
}
template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src);
-template SimpleTensor<half_float::half> softmax_layer(const SimpleTensor<half_float::half> &src);
+template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src);
template SimpleTensor<qint8_t> softmax_layer(const SimpleTensor<qint8_t> &src);
template SimpleTensor<qint16_t> softmax_layer(const SimpleTensor<qint16_t> &src);
} // namespace reference
diff --git a/tests/validation/CPP/Utils.cpp b/tests/validation/CPP/Utils.cpp
index 15e9fc3138..2f54879818 100644
--- a/tests/validation/CPP/Utils.cpp
+++ b/tests/validation/CPP/Utils.cpp
@@ -24,7 +24,6 @@
#include "Utils.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/half.h"
namespace arm_compute
{
@@ -51,17 +50,20 @@ T tensor_elem_at(const SimpleTensor<T> &in, Coordinates coord, BorderMode border
}
else
{
- return constant_border_value;
+ return static_cast<T>(constant_border_value);
}
}
return in[coord2index(in.shape(), coord)];
}
-template float tensor_elem_at(const SimpleTensor<float> &in, Coordinates coord, BorderMode border_mode, float constant_border_value);
+
template uint8_t tensor_elem_at(const SimpleTensor<uint8_t> &in, Coordinates coord, BorderMode border_mode, uint8_t constant_border_value);
+template int16_t tensor_elem_at(const SimpleTensor<int16_t> &in, Coordinates coord, BorderMode border_mode, int16_t constant_border_value);
+template half tensor_elem_at(const SimpleTensor<half> &in, Coordinates coord, BorderMode border_mode, half constant_border_value);
+template float tensor_elem_at(const SimpleTensor<float> &in, Coordinates coord, BorderMode border_mode, float constant_border_value);
// Return the bilinear value at a specified coordinate with different border modes
template <typename T>
-T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, uint8_t constant_border_value)
+T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, T constant_border_value)
{
int idx = std::floor(xn);
int idy = std::floor(yn);
@@ -71,22 +73,28 @@ T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn,
const float dx_1 = 1.0f - dx;
const float dy_1 = 1.0f - dy;
+ const T border_value = constant_border_value;
+
id.set(0, idx);
id.set(1, idy);
- const T tl = tensor_elem_at(in, id, border_mode, constant_border_value);
+ const float tl = tensor_elem_at(in, id, border_mode, border_value);
id.set(0, idx + 1);
id.set(1, idy);
- const T tr = tensor_elem_at(in, id, border_mode, constant_border_value);
+ const float tr = tensor_elem_at(in, id, border_mode, border_value);
id.set(0, idx);
id.set(1, idy + 1);
- const T bl = tensor_elem_at(in, id, border_mode, constant_border_value);
+ const float bl = tensor_elem_at(in, id, border_mode, border_value);
id.set(0, idx + 1);
id.set(1, idy + 1);
- const T br = tensor_elem_at(in, id, border_mode, constant_border_value);
+ const float br = tensor_elem_at(in, id, border_mode, border_value);
- return tl * (dx_1 * dy_1) + tr * (dx * dy_1) + bl * (dx_1 * dy) + br * (dx * dy);
+ return static_cast<T>(tl * (dx_1 * dy_1) + tr * (dx * dy_1) + bl * (dx_1 * dy) + br * (dx * dy));
}
+
template uint8_t bilinear_policy(const SimpleTensor<uint8_t> &in, Coordinates id, float xn, float yn, BorderMode border_mode, uint8_t constant_border_value);
+template int16_t bilinear_policy(const SimpleTensor<int16_t> &in, Coordinates id, float xn, float yn, BorderMode border_mode, int16_t constant_border_value);
+template half bilinear_policy(const SimpleTensor<half> &in, Coordinates id, float xn, float yn, BorderMode border_mode, half constant_border_value);
+template float bilinear_policy(const SimpleTensor<float> &in, Coordinates id, float xn, float yn, BorderMode border_mode, float constant_border_value);
/* Apply 2D spatial filter on a single element of @p in at coordinates @p coord
*
diff --git a/tests/validation/CPP/Utils.h b/tests/validation/CPP/Utils.h
index 34ba60bed6..557d85f204 100644
--- a/tests/validation/CPP/Utils.h
+++ b/tests/validation/CPP/Utils.h
@@ -45,7 +45,7 @@ template <typename T>
T tensor_elem_at(const SimpleTensor<T> &in, Coordinates coord, BorderMode border_mode, T constant_border_value);
template <typename T>
-T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, uint8_t constant_border_value);
+T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, T constant_border_value);
template <typename T1, typename T2, typename T3>
void apply_2d_spatial_filter(Coordinates coord, const SimpleTensor<T1> &in, SimpleTensor<T3> &out, const TensorShape &filter_shape, const T2 *filter_itr, float scale, BorderMode border_mode,