aboutsummaryrefslogtreecommitdiff
path: root/tests/SimpleTensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r--tests/SimpleTensor.h9
1 files changed, 5 insertions, 4 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index 6091991e66..902f5b51b5 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017, 2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -191,7 +191,8 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format, int fixed_point_
_fixed_point_position(fixed_point_position),
_quantization_info()
{
- _buffer = support::cpp14::make_unique<T[]>(num_elements() * num_channels());
+ _num_channels = num_channels();
+ _buffer = support::cpp14::make_unique<T[]>(num_elements() * _num_channels);
}
template <typename T>
@@ -338,13 +339,13 @@ T *SimpleTensor<T>::data()
template <typename T>
const void *SimpleTensor<T>::operator()(const Coordinates &coord) const
{
- return _buffer.get() + coord2index(_shape, coord);
+ return _buffer.get() + coord2index(_shape, coord) * _num_channels;
}
template <typename T>
void *SimpleTensor<T>::operator()(const Coordinates &coord)
{
- return _buffer.get() + coord2index(_shape, coord);
+ return _buffer.get() + coord2index(_shape, coord) * _num_channels;
}
template <typename U>