diff options
Diffstat (limited to 'tests/validation/fixtures/ReverseFixture.h')
-rw-r--r-- | tests/validation/fixtures/ReverseFixture.h | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tests/validation/fixtures/ReverseFixture.h b/tests/validation/fixtures/ReverseFixture.h index 3bf4c3e327..d53945191e 100644 --- a/tests/validation/fixtures/ReverseFixture.h +++ b/tests/validation/fixtures/ReverseFixture.h @@ -86,9 +86,9 @@ protected: { library->fill_tensor_uniform(tensor, 0); } - std::vector<int> generate_random_axis(bool use_negative = false) + std::vector<int32_t> generate_random_axis(bool use_negative = false) { - std::vector<int> axis_v; + std::vector<int32_t> axis_v; if(use_negative) { axis_v = { -1, -2, -3, -4 }; @@ -97,7 +97,7 @@ protected: { axis_v = { 0, 1, 2, 3 }; } - axis_v = std::vector<int>(axis_v.begin(), axis_v.begin() + _num_dims); + axis_v = std::vector<int32_t>(axis_v.begin(), axis_v.begin() + _num_dims); std::mt19937 g(library->seed()); std::shuffle(axis_v.begin(), axis_v.end(), g); @@ -134,7 +134,7 @@ protected: { auto axis_data = AccessorType(axis); auto axis_v = generate_random_axis(use_negative_axis); - std::copy(axis_v.begin(), axis_v.begin() + _num_dims, static_cast<int32_t *>(axis_data.data())); + std::copy(axis_v.begin(), axis_v.begin() + axis_shape.total_size(), static_cast<int32_t *>(axis_data.data())); } // Compute function @@ -152,7 +152,7 @@ protected: // Fill reference fill(src); auto axis_v = generate_random_axis(use_negative_axis); - std::copy(axis_v.begin(), axis_v.begin() + _num_dims, axis.data()); + std::copy(axis_v.begin(), axis_v.begin() + axis_shape.total_size(), axis.data()); return reference::reverse<T>(src, axis, use_inverted_axis); } |