aboutsummaryrefslogtreecommitdiff
path: root/tests/SimpleTensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r--tests/SimpleTensor.h8
1 files changed, 4 insertions, 4 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index dd4a8bee2c..f0e9b15021 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -280,7 +280,7 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_cha
_quantization_info(quantization_info),
_data_layout(data_layout)
{
- _buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels());
+ _buffer = support::cpp14::make_unique<T[]>(this->_shape.total_size() * _num_channels);
}
template <typename T>
@@ -293,8 +293,8 @@ SimpleTensor<T>::SimpleTensor(const SimpleTensor &tensor)
_quantization_info(tensor.quantization_info()),
_data_layout(tensor.data_layout())
{
- _buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * num_channels());
- std::copy_n(tensor.data(), num_elements() * num_channels(), _buffer.get());
+ _buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * _num_channels);
+ std::copy_n(tensor.data(), this->_shape.total_size() * _num_channels, _buffer.get());
}
template <typename T>