aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/ConvolutionLayer.cpp')
-rw-r--r--tests/validation/reference/ConvolutionLayer.cpp20
1 files changed, 10 insertions, 10 deletions
diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp
index 69090117fe..4d2c1acb6f 100644
--- a/tests/validation/reference/ConvolutionLayer.cpp
+++ b/tests/validation/reference/ConvolutionLayer.cpp
@@ -45,8 +45,8 @@ namespace
{
} // namespace
-template <typename T, typename TB>
-SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info,
+template <typename T, typename TW, typename TB>
+SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<TW> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info,
const Size2D &dilation, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON((src.shape()[2] / num_groups) != weights.shape()[2]);
@@ -73,7 +73,6 @@ SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleT
const int end_xi = output_wh.first * stride_xi;
const int end_yi = output_wh.second * stride_yi;
const int num_batches = src.shape().total_size() / (width_in * height_in * depth_in);
-
for(int r = 0; r < num_batches; ++r)
{
for(int yi = start_yi; yi < start_yi + end_yi; yi += stride_yi)
@@ -100,17 +99,16 @@ SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleT
offset_in, offset_w, offset_b, offset_out,
xi, yi,
width_in, height_in, (depth_in / num_groups),
- width_weights, height_weights, dilation.x(), dilation.y());
+ width_weights, height_weights, dilation.x(), dilation.y(), ofm);
}
}
}
}
}
-
return dst;
}
-template <typename T, typename TB>
-SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info,
+template <typename T, typename TW, typename TB>
+SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<TW> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info,
const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info)
{
// if no explicit quantization has been set you the same as src
@@ -123,9 +121,9 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor
if(src.data_layout() == DataLayout::NHWC)
{
- SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U));
- SimpleTensor<T> weights_nchw = reference::permute<T>(weights, PermutationVector(1U, 2U, 0U));
- SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U));
+ SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U));
+ SimpleTensor<TW> weights_nchw = reference::permute<TW>(weights, PermutationVector(1U, 2U, 0U));
+ SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U));
return reference::permute<T>(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation, num_groups), PermutationVector(2U, 0U, 1U));
}
@@ -141,6 +139,8 @@ template SimpleTensor<half> convolution_layer(const SimpleTensor<half> &src, con
const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info);
template SimpleTensor<uint8_t> convolution_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &output_shape,
const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info);
+template SimpleTensor<uint8_t> convolution_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &output_shape,
+ const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info);
} // namespace reference
} // namespace validation
} // namespace test