aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ScatterLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/ScatterLayer.h')
-rw-r--r--tests/validation/reference/ScatterLayer.h4
1 files changed, 2 insertions, 2 deletions
diff --git a/tests/validation/reference/ScatterLayer.h b/tests/validation/reference/ScatterLayer.h
index dc441a8894..97d5e70b0d 100644
--- a/tests/validation/reference/ScatterLayer.h
+++ b/tests/validation/reference/ScatterLayer.h
@@ -37,10 +37,10 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
template <typename T>
-SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
} // namespace test