aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CPP/DepthwiseConvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CPP/DepthwiseConvolution.cpp')
-rw-r--r--tests/validation/CPP/DepthwiseConvolution.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/tests/validation/CPP/DepthwiseConvolution.cpp b/tests/validation/CPP/DepthwiseConvolution.cpp
index b57c2686f6..e29d014f77 100644
--- a/tests/validation/CPP/DepthwiseConvolution.cpp
+++ b/tests/validation/CPP/DepthwiseConvolution.cpp
@@ -45,7 +45,7 @@ namespace reference
*
*/
template <typename T>
-SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const TensorShape &dst_shape, const PadStrideInfo &conv_info)
+SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<T> &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info)
{
// Create reference
SimpleTensor<T> dst{ dst_shape, src.data_type(), 1, src.fixed_point_position() };
@@ -97,7 +97,7 @@ SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTe
}
coords.set(0, x);
coords.set(1, y);
- dst[out_pos++] = saturate_cast<T>(val);
+ dst[out_pos++] = saturate_cast<T>(val + *static_cast<const T *>(biases(Coordinates(z))));
}
}
}
@@ -106,7 +106,8 @@ SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTe
return dst;
}
-template SimpleTensor<float> depthwise_convolution(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const TensorShape &dst_shape, const PadStrideInfo &conv_info);
+template SimpleTensor<float> depthwise_convolution(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &biases, const TensorShape &dst_shape,
+ const PadStrideInfo &conv_info);
} // namespace reference
} // namespace validation
} // namespace test