From 82e70a12dc3bf8309c43620c08a2c4ff05a6d13e Mon Sep 17 00:00:00 2001 From: Moritz Pflanzer Date: Tue, 8 Aug 2017 16:20:45 +0100 Subject: COMPMID-415: Improve SimpleTensor and RawTensor Change-Id: I7a5f970b3c04b925682fd9f0ece3254478dc96f7 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83343 Reviewed-by: Anthony Barbier Tested-by: Kaizen --- tests/RawTensor.cpp | 140 ++++------------------------------------------------ 1 file changed, 10 insertions(+), 130 deletions(-) (limited to 'tests/RawTensor.cpp') diff --git a/tests/RawTensor.cpp b/tests/RawTensor.cpp index e6b320fcb2..bc2747d2a1 100644 --- a/tests/RawTensor.cpp +++ b/tests/RawTensor.cpp @@ -23,48 +23,28 @@ */ #include "RawTensor.h" -#include "Utils.h" - -#include "arm_compute/core/Utils.h" -#include "support/ToolchainSupport.h" - -#include -#include -#include -#include -#include - namespace arm_compute { namespace test { RawTensor::RawTensor(TensorShape shape, Format format, int fixed_point_position) - : _buffer(nullptr), - _shape(shape), - _format(format), - _fixed_point_position(fixed_point_position) + : SimpleTensor(shape, format, fixed_point_position) { - _buffer = support::cpp14::make_unique(size()); + _buffer = support::cpp14::make_unique(SimpleTensor::num_elements() * SimpleTensor::num_channels() * SimpleTensor::element_size()); } RawTensor::RawTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position) - : _buffer(nullptr), - _shape(shape), - _data_type(data_type), - _num_channels(num_channels), - _fixed_point_position(fixed_point_position) + : SimpleTensor(shape, data_type, num_channels, fixed_point_position) { - _buffer = support::cpp14::make_unique(size()); + _buffer = support::cpp14::make_unique(SimpleTensor::num_elements() * SimpleTensor::num_channels() * SimpleTensor::element_size()); } RawTensor::RawTensor(const RawTensor &tensor) - : _buffer(nullptr), - _shape(tensor.shape()), - _format(tensor.format()), - _fixed_point_position(tensor.fixed_point_position()) + : SimpleTensor(tensor.shape(), tensor.data_type(), tensor.num_channels(), tensor.fixed_point_position()) { - _buffer = support::cpp14::make_unique(tensor.size()); - std::copy(tensor.data(), tensor.data() + size(), _buffer.get()); + _format = tensor.format(); + _buffer = support::cpp14::make_unique(num_elements() * num_channels() * element_size()); + std::copy_n(tensor.data(), num_elements() * num_channels() * element_size(), _buffer.get()); } RawTensor &RawTensor::operator=(RawTensor tensor) @@ -74,114 +54,14 @@ RawTensor &RawTensor::operator=(RawTensor tensor) return *this; } -RawTensor::BufferType &RawTensor::operator[](size_t offset) -{ - return _buffer[offset]; -} - -const RawTensor::BufferType &RawTensor::operator[](size_t offset) const -{ - return _buffer[offset]; -} - -TensorShape RawTensor::shape() const -{ - return _shape; -} - -size_t RawTensor::element_size() const -{ - return num_channels() * element_size_from_data_type(data_type()); -} - -int RawTensor::fixed_point_position() const -{ - return _fixed_point_position; -} - -size_t RawTensor::size() const -{ - const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies()); - return size * element_size(); -} - -Format RawTensor::format() const -{ - return _format; -} - -DataType RawTensor::data_type() const -{ - if(_format != Format::UNKNOWN) - { - return data_type_from_format(_format); - } - else - { - return _data_type; - } -} - -int RawTensor::num_channels() const -{ - switch(_format) - { - case Format::U8: - case Format::S16: - case Format::U16: - case Format::S32: - case Format::U32: - case Format::F32: - return 1; - case Format::RGB888: - return 3; - case Format::UNKNOWN: - return _num_channels; - default: - ARM_COMPUTE_ERROR("NOT SUPPORTED!"); - } -} - -int RawTensor::num_elements() const -{ - return _shape.total_size(); -} - -PaddingSize RawTensor::padding() const -{ - return PaddingSize(0); -} - -const RawTensor::BufferType *RawTensor::data() const -{ - return _buffer.get(); -} - -RawTensor::BufferType *RawTensor::data() -{ - return _buffer.get(); -} - -const RawTensor::BufferType *RawTensor::operator()(const Coordinates &coord) const +const void *RawTensor::operator()(const Coordinates &coord) const { return _buffer.get() + coord2index(_shape, coord) * element_size(); } -RawTensor::BufferType *RawTensor::operator()(const Coordinates &coord) +void *RawTensor::operator()(const Coordinates &coord) { return _buffer.get() + coord2index(_shape, coord) * element_size(); } - -void swap(RawTensor &tensor1, RawTensor &tensor2) -{ - // Use unqualified call to swap to enable ADL. But make std::swap available - // as backup. - using std::swap; - swap(tensor1._shape, tensor2._shape); - swap(tensor1._format, tensor2._format); - swap(tensor1._data_type, tensor2._data_type); - swap(tensor1._num_channels, tensor2._num_channels); - swap(tensor1._buffer, tensor2._buffer); -} } // namespace test } // namespace arm_compute -- cgit v1.2.1