diff options
Diffstat (limited to 'src/graph/Tensor.cpp')
-rw-r--r-- | src/graph/Tensor.cpp | 141 |
1 files changed, 46 insertions, 95 deletions
diff --git a/src/graph/Tensor.cpp b/src/graph/Tensor.cpp index 4db79e93ad..47fb5c65bc 100644 --- a/src/graph/Tensor.cpp +++ b/src/graph/Tensor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,138 +23,89 @@ */ #include "arm_compute/graph/Tensor.h" -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/runtime/CL/CLTensor.h" -#include "arm_compute/runtime/Tensor.h" -#include "utils/TypePrinter.h" - -using namespace arm_compute::graph; - -namespace +namespace arm_compute { -template <typename TensorType> -std::unique_ptr<arm_compute::ITensor> initialise_tensor(TensorInfo &info) +namespace graph { - auto tensor = arm_compute::support::cpp14::make_unique<TensorType>(); - tensor->allocator()->init(info); - return std::move(tensor); -} - -template <typename TensorType> -void tensor_allocate(arm_compute::ITensor &tensor) +Tensor::Tensor(TensorID id, TensorDescriptor desc) + : _id(id), _desc(desc), _handle(nullptr), _accessor(nullptr), _bound_edges() { - auto itensor = dynamic_cast<TensorType *>(&tensor); - ARM_COMPUTE_ERROR_ON_NULLPTR(itensor); - itensor->allocator()->allocate(); } -} // namespace -Tensor::Tensor(TensorInfo &&info) - : _target(TargetHint::DONT_CARE), _info(info), _accessor(nullptr), _tensor(nullptr) +TensorID Tensor::id() const { + return _id; } -Tensor::Tensor(Tensor &&src) noexcept - : _target(src._target), - _info(std::move(src._info)), - _accessor(std::move(src._accessor)), - _tensor(std::move(src._tensor)) +TensorDescriptor &Tensor::desc() { + return _desc; } -void Tensor::set_info(TensorInfo &&info) +const TensorDescriptor &Tensor::desc() const { - _info = info; -} - -bool Tensor::call_accessor() -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(_accessor.get()); - auto cl_tensor = dynamic_cast<arm_compute::CLTensor *>(_tensor.get()); - if(cl_tensor != nullptr && cl_tensor->buffer() == nullptr) - { - cl_tensor->map(); - } - bool retval = _accessor->access_tensor(*_tensor); - if(cl_tensor != nullptr) - { - cl_tensor->unmap(); - } - return retval; + return _desc; } -bool Tensor::has_accessor() const +void Tensor::set_handle(std::unique_ptr<ITensorHandle> backend_tensor) { - return (_accessor != nullptr); + _handle = std::move(backend_tensor); } -arm_compute::ITensor *Tensor::tensor() +ITensorHandle *Tensor::handle() { - return _tensor.get(); + return _handle.get(); } -const arm_compute::ITensor *Tensor::tensor() const +void Tensor::set_accessor(std::unique_ptr<ITensorAccessor> accessor) { - return _tensor.get(); + _accessor = std::move(accessor); } -const TensorInfo &Tensor::info() const +ITensorAccessor *Tensor::accessor() { - return _info; + return _accessor.get(); } -arm_compute::ITensor *Tensor::set_target(TargetHint target) +bool Tensor::call_accessor() { - if(_tensor != nullptr) + // Early exit guard + if(!_accessor || !_handle) { - ARM_COMPUTE_ERROR_ON(target != _target); + return false; } - else + + // Map tensor + _handle->map(true); + + // Return in case of null backend buffer + if(_handle->tensor().buffer() == nullptr) { - switch(target) - { - case TargetHint::OPENCL: - _tensor = initialise_tensor<arm_compute::CLTensor>(_info); - break; - case TargetHint::NEON: - _tensor = initialise_tensor<arm_compute::Tensor>(_info); - break; - default: - ARM_COMPUTE_ERROR("Invalid TargetHint"); - } - _target = target; + return false; } - return _tensor.get(); + + // Call accessor + _accessor->access_tensor(_handle->tensor()); + + // Unmap tensor + _handle->unmap(); + + return true; } -void Tensor::allocate() +void Tensor::bind_edge(EdgeID eid) { - ARM_COMPUTE_ERROR_ON_NULLPTR(_tensor.get()); - switch(_target) - { - case TargetHint::OPENCL: - tensor_allocate<arm_compute::CLTensor>(*_tensor); - break; - case TargetHint::NEON: - tensor_allocate<arm_compute::Tensor>(*_tensor); - break; - default: - ARM_COMPUTE_ERROR("Invalid TargetHint"); - } + _bound_edges.insert(eid); } -void Tensor::allocate_and_fill_if_needed() +void Tensor::unbind_edge(EdgeID eid) { - allocate(); - if(_accessor != nullptr) - { - call_accessor(); - } + _bound_edges.erase(eid); } -TargetHint Tensor::target() const +const std::set<EdgeID> Tensor::bound_edges() const { - return _target; + return _bound_edges; } +} // namespace graph +} // namespace arm_compute
\ No newline at end of file |