diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2017-11-10 18:14:06 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | 283c1790da45ab562ecfb2aa7741297191886d85 (patch) | |
tree | 45956bb79167e17aa634fd5f4d05c68ba059274c /src/core | |
parent | 624b77859dc9d0618056dad66833b9c37033337b (diff) | |
download | ComputeLibrary-283c1790da45ab562ecfb2aa7741297191886d85.tar.gz |
COMPMID-676: Rework TensorInfo building
Change-Id: Ic98f64ffe30739437a1fe31ef98d83ee900741e3
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95512
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/CL/kernels/CLSoftmaxLayerKernel.cpp | 6 | ||||
-rw-r--r-- | src/core/SubTensorInfo.cpp | 11 | ||||
-rw-r--r-- | src/core/TensorInfo.cpp | 27 |
3 files changed, 32 insertions, 12 deletions
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp index af4fd88593..5331f40838 100644 --- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp +++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp @@ -386,11 +386,7 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su // Output auto initialization if not yet initialized auto_init_if_empty(*output->info(), - input->info()->tensor_shape(), - 1, - output_data_type, - input->info()->fixed_point_position(), - allowed_quantization_info); + input->info()->clone()->set_data_type(output_data_type).set_quantization_info(allowed_quantization_info)); ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output); if(!is_quantized_asymmetric) diff --git a/src/core/SubTensorInfo.cpp b/src/core/SubTensorInfo.cpp index 878283bd8e..4c558bfae9 100644 --- a/src/core/SubTensorInfo.cpp +++ b/src/core/SubTensorInfo.cpp @@ -26,6 +26,7 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Validate.h" +#include "support/ToolchainSupport.h" using namespace arm_compute; @@ -34,7 +35,7 @@ SubTensorInfo::SubTensorInfo() { } -SubTensorInfo::SubTensorInfo(ITensorInfo *parent, const TensorShape &tensor_shape, const Coordinates &coords) +SubTensorInfo::SubTensorInfo(ITensorInfo *parent, TensorShape tensor_shape, Coordinates coords) : _parent(parent), _tensor_shape(tensor_shape), _coords(coords), _valid_region{ Coordinates(), _tensor_shape } { ARM_COMPUTE_ERROR_ON(parent == nullptr); @@ -50,7 +51,12 @@ SubTensorInfo::SubTensorInfo(ITensorInfo *parent, const TensorShape &tensor_shap _valid_region = ValidRegion{ coordinates, _tensor_shape }; } -void SubTensorInfo::set_tensor_shape(TensorShape shape) +std::unique_ptr<ITensorInfo> SubTensorInfo::clone() const +{ + return support::cpp14::make_unique<SubTensorInfo>(*this); +} + +ITensorInfo &SubTensorInfo::set_tensor_shape(TensorShape shape) { ARM_COMPUTE_ERROR_ON(_parent == nullptr); // Check if subtensor is valid if parent is configured @@ -59,6 +65,7 @@ void SubTensorInfo::set_tensor_shape(TensorShape shape) ARM_COMPUTE_ERROR_ON_INVALID_SUBTENSOR(_parent->tensor_shape(), _coords, shape); } _tensor_shape = shape; + return *this; } bool SubTensorInfo::extend_padding(const PaddingSize &padding) diff --git a/src/core/TensorInfo.cpp b/src/core/TensorInfo.cpp index f3cd776497..a49b7b7e02 100644 --- a/src/core/TensorInfo.cpp +++ b/src/core/TensorInfo.cpp @@ -28,6 +28,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Validate.h" +#include "support/ToolchainSupport.h" using namespace arm_compute; @@ -314,19 +315,26 @@ bool TensorInfo::extend_padding(const PaddingSize &padding) return updated; } -void TensorInfo::set_data_type(DataType data_type) +std::unique_ptr<ITensorInfo> TensorInfo::clone() const +{ + return support::cpp14::make_unique<TensorInfo>(*this); +} + +ITensorInfo &TensorInfo::set_data_type(DataType data_type) { _data_type = data_type; _format = Format::UNKNOWN; + return *this; } -void TensorInfo::set_num_channels(int num_channels) +ITensorInfo &TensorInfo::set_num_channels(int num_channels) { _num_channels = num_channels; _format = Format::UNKNOWN; + return *this; } -void TensorInfo::set_format(Format format) +ITensorInfo &TensorInfo::set_format(Format format) { _format = format; @@ -340,9 +348,10 @@ void TensorInfo::set_format(Format format) ARM_COMPUTE_ERROR_ON(num_channels_from_format(format) != _num_channels); ARM_COMPUTE_ERROR_ON(data_type_from_format(format) != _data_type); } + return *this; } -void TensorInfo::set_tensor_shape(TensorShape shape) +ITensorInfo &TensorInfo::set_tensor_shape(TensorShape shape) { _tensor_shape = shape; _offset_first_element_in_bytes = 0; @@ -361,13 +370,21 @@ void TensorInfo::set_tensor_shape(TensorShape shape) Coordinates coordinates; coordinates.set_num_dimensions(_tensor_shape.num_dimensions()); _valid_region = ValidRegion{ coordinates, _tensor_shape }; + return *this; } -void TensorInfo::set_fixed_point_position(int fixed_point_position) +ITensorInfo &TensorInfo::set_fixed_point_position(int fixed_point_position) { ARM_COMPUTE_ERROR_ON(_data_type == DataType::QS8 && (fixed_point_position < 1 || fixed_point_position > 6)); ARM_COMPUTE_ERROR_ON(_data_type == DataType::QS16 && (fixed_point_position < 1 || fixed_point_position > 14)); _fixed_point_position = fixed_point_position; + return *this; +} + +ITensorInfo &TensorInfo::set_quantization_info(QuantizationInfo quantization_info) +{ + _quantization_info = quantization_info; + return *this; } size_t TensorInfo::offset_element_in_bytes(const Coordinates &pos) const |