diff options
Diffstat (limited to 'tests/validation/reference/Winograd.h')
-rw-r--r-- | tests/validation/reference/Winograd.h | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tests/validation/reference/Winograd.h b/tests/validation/reference/Winograd.h index 62e136b09d..29181f1142 100644 --- a/tests/validation/reference/Winograd.h +++ b/tests/validation/reference/Winograd.h @@ -36,14 +36,22 @@ namespace validation { namespace reference { +/** Winograd transform type */ +enum class WinogradTransformType +{ + INPUT, /**< Winograd input transform */ + FILTER, /**< Winograd filter transform */ + OUTPUT /**< Winograd output transform */ +}; + template <typename T> -SimpleTensor<T> winograd_input_transform(const SimpleTensor<T> &src, const TensorShape &dst_shape, const PadStrideInfo &conv_info, const Size2D &kernel_dims); +SimpleTensor<T> winograd_input_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); template <typename T> -SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const Size2D &output_tile); +SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); template <typename T> -SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const Size2D &kernel_dims, const Size2D &num_tiles); +SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); } // namespace reference } // namespace validation } // namespace test |