diff options
Diffstat (limited to 'tests/validation/reference/DFT.cpp')
-rw-r--r-- | tests/validation/reference/DFT.cpp | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/tests/validation/reference/DFT.cpp b/tests/validation/reference/DFT.cpp index 6ad1b9e150..b3c2c6b0b9 100644 --- a/tests/validation/reference/DFT.cpp +++ b/tests/validation/reference/DFT.cpp @@ -237,11 +237,11 @@ void scale(SimpleTensor<T> &tensor, T scaling_factor) template <typename T> SimpleTensor<T> complex_mul_and_reduce(const SimpleTensor<T> &input, const SimpleTensor<T> &weights) { - const int W = input.shape().x(); - const int H = input.shape().y(); - const int Ci = input.shape().z(); - const int Co = weights.shape()[3]; - const int N = input.shape().total_size() / (W * H * Ci); + const uint32_t W = input.shape().x(); + const uint32_t H = input.shape().y(); + const uint32_t Ci = input.shape().z(); + const uint32_t Co = weights.shape()[3]; + const uint32_t N = input.shape().total_size() / (W * H * Ci); TensorShape output_shape = input.shape(); output_shape.set(2, Co); @@ -250,19 +250,19 @@ SimpleTensor<T> complex_mul_and_reduce(const SimpleTensor<T> &input, const Simpl // MemSet dst memory to zero std::memset(dst.data(), 0, dst.size()); - for(int b = 0; b < N; ++b) + for(uint32_t b = 0; b < N; ++b) { - for(int co = 0; co < Co; ++co) + for(uint32_t co = 0; co < Co; ++co) { - for(int ci = 0; ci < Ci; ++ci) + for(uint32_t ci = 0; ci < Ci; ++ci) { - for(int h = 0; h < H; ++h) + for(uint32_t h = 0; h < H; ++h) { - for(int w = 0; w < W; ++w) + for(uint32_t w = 0; w < W; ++w) { - size_t i_index = w + h * W + ci * H * W + b * H * W * Ci; - size_t w_index = w + h * W + ci * H * W + co * H * W * Ci; - size_t o_index = w + h * W + co * H * W + b * H * W * Co; + const uint32_t i_index = w + h * W + ci * H * W + b * H * W * Ci; + const uint32_t w_index = w + h * W + ci * H * W + co * H * W * Ci; + const uint32_t o_index = w + h * W + co * H * W + b * H * W * Co; const Coordinates i_coords = index2coords(input.shape(), i_index); const Coordinates w_coords = index2coords(weights.shape(), w_index); const Coordinates o_coords = index2coords(dst.shape(), o_index); |