diff options
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r-- | tests/SimpleTensor.h | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h index dd4a8bee2c..f0e9b15021 100644 --- a/tests/SimpleTensor.h +++ b/tests/SimpleTensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -280,7 +280,7 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_cha _quantization_info(quantization_info), _data_layout(data_layout) { - _buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels()); + _buffer = support::cpp14::make_unique<T[]>(this->_shape.total_size() * _num_channels); } template <typename T> @@ -293,8 +293,8 @@ SimpleTensor<T>::SimpleTensor(const SimpleTensor &tensor) _quantization_info(tensor.quantization_info()), _data_layout(tensor.data_layout()) { - _buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * num_channels()); - std::copy_n(tensor.data(), num_elements() * num_channels(), _buffer.get()); + _buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * _num_channels); + std::copy_n(tensor.data(), this->_shape.total_size() * _num_channels, _buffer.get()); } template <typename T> |