aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/Winograd.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/Winograd.h')
-rw-r--r--tests/validation/reference/Winograd.h14
1 files changed, 11 insertions, 3 deletions
diff --git a/tests/validation/reference/Winograd.h b/tests/validation/reference/Winograd.h
index 62e136b09d..29181f1142 100644
--- a/tests/validation/reference/Winograd.h
+++ b/tests/validation/reference/Winograd.h
@@ -36,14 +36,22 @@ namespace validation
{
namespace reference
{
+/** Winograd transform type */
+enum class WinogradTransformType
+{
+ INPUT, /**< Winograd input transform */
+ FILTER, /**< Winograd filter transform */
+ OUTPUT /**< Winograd output transform */
+};
+
template <typename T>
-SimpleTensor<T> winograd_input_transform(const SimpleTensor<T> &src, const TensorShape &dst_shape, const PadStrideInfo &conv_info, const Size2D &kernel_dims);
+SimpleTensor<T> winograd_input_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
template <typename T>
-SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const Size2D &output_tile);
+SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
template <typename T>
-SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const Size2D &kernel_dims, const Size2D &num_tiles);
+SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info);
} // namespace reference
} // namespace validation
} // namespace test