aboutsummaryrefslogtreecommitdiff
path: root/src/graph/Tensor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/Tensor.cpp')
-rw-r--r--src/graph/Tensor.cpp141
1 files changed, 46 insertions, 95 deletions
diff --git a/src/graph/Tensor.cpp b/src/graph/Tensor.cpp
index 4db79e93ad..47fb5c65bc 100644
--- a/src/graph/Tensor.cpp
+++ b/src/graph/Tensor.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,138 +23,89 @@
*/
#include "arm_compute/graph/Tensor.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/runtime/CL/CLTensor.h"
-#include "arm_compute/runtime/Tensor.h"
-#include "utils/TypePrinter.h"
-
-using namespace arm_compute::graph;
-
-namespace
+namespace arm_compute
{
-template <typename TensorType>
-std::unique_ptr<arm_compute::ITensor> initialise_tensor(TensorInfo &info)
+namespace graph
{
- auto tensor = arm_compute::support::cpp14::make_unique<TensorType>();
- tensor->allocator()->init(info);
- return std::move(tensor);
-}
-
-template <typename TensorType>
-void tensor_allocate(arm_compute::ITensor &tensor)
+Tensor::Tensor(TensorID id, TensorDescriptor desc)
+ : _id(id), _desc(desc), _handle(nullptr), _accessor(nullptr), _bound_edges()
{
- auto itensor = dynamic_cast<TensorType *>(&tensor);
- ARM_COMPUTE_ERROR_ON_NULLPTR(itensor);
- itensor->allocator()->allocate();
}
-} // namespace
-Tensor::Tensor(TensorInfo &&info)
- : _target(TargetHint::DONT_CARE), _info(info), _accessor(nullptr), _tensor(nullptr)
+TensorID Tensor::id() const
{
+ return _id;
}
-Tensor::Tensor(Tensor &&src) noexcept
- : _target(src._target),
- _info(std::move(src._info)),
- _accessor(std::move(src._accessor)),
- _tensor(std::move(src._tensor))
+TensorDescriptor &Tensor::desc()
{
+ return _desc;
}
-void Tensor::set_info(TensorInfo &&info)
+const TensorDescriptor &Tensor::desc() const
{
- _info = info;
-}
-
-bool Tensor::call_accessor()
-{
- ARM_COMPUTE_ERROR_ON_NULLPTR(_accessor.get());
- auto cl_tensor = dynamic_cast<arm_compute::CLTensor *>(_tensor.get());
- if(cl_tensor != nullptr && cl_tensor->buffer() == nullptr)
- {
- cl_tensor->map();
- }
- bool retval = _accessor->access_tensor(*_tensor);
- if(cl_tensor != nullptr)
- {
- cl_tensor->unmap();
- }
- return retval;
+ return _desc;
}
-bool Tensor::has_accessor() const
+void Tensor::set_handle(std::unique_ptr<ITensorHandle> backend_tensor)
{
- return (_accessor != nullptr);
+ _handle = std::move(backend_tensor);
}
-arm_compute::ITensor *Tensor::tensor()
+ITensorHandle *Tensor::handle()
{
- return _tensor.get();
+ return _handle.get();
}
-const arm_compute::ITensor *Tensor::tensor() const
+void Tensor::set_accessor(std::unique_ptr<ITensorAccessor> accessor)
{
- return _tensor.get();
+ _accessor = std::move(accessor);
}
-const TensorInfo &Tensor::info() const
+ITensorAccessor *Tensor::accessor()
{
- return _info;
+ return _accessor.get();
}
-arm_compute::ITensor *Tensor::set_target(TargetHint target)
+bool Tensor::call_accessor()
{
- if(_tensor != nullptr)
+ // Early exit guard
+ if(!_accessor || !_handle)
{
- ARM_COMPUTE_ERROR_ON(target != _target);
+ return false;
}
- else
+
+ // Map tensor
+ _handle->map(true);
+
+ // Return in case of null backend buffer
+ if(_handle->tensor().buffer() == nullptr)
{
- switch(target)
- {
- case TargetHint::OPENCL:
- _tensor = initialise_tensor<arm_compute::CLTensor>(_info);
- break;
- case TargetHint::NEON:
- _tensor = initialise_tensor<arm_compute::Tensor>(_info);
- break;
- default:
- ARM_COMPUTE_ERROR("Invalid TargetHint");
- }
- _target = target;
+ return false;
}
- return _tensor.get();
+
+ // Call accessor
+ _accessor->access_tensor(_handle->tensor());
+
+ // Unmap tensor
+ _handle->unmap();
+
+ return true;
}
-void Tensor::allocate()
+void Tensor::bind_edge(EdgeID eid)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(_tensor.get());
- switch(_target)
- {
- case TargetHint::OPENCL:
- tensor_allocate<arm_compute::CLTensor>(*_tensor);
- break;
- case TargetHint::NEON:
- tensor_allocate<arm_compute::Tensor>(*_tensor);
- break;
- default:
- ARM_COMPUTE_ERROR("Invalid TargetHint");
- }
+ _bound_edges.insert(eid);
}
-void Tensor::allocate_and_fill_if_needed()
+void Tensor::unbind_edge(EdgeID eid)
{
- allocate();
- if(_accessor != nullptr)
- {
- call_accessor();
- }
+ _bound_edges.erase(eid);
}
-TargetHint Tensor::target() const
+const std::set<EdgeID> Tensor::bound_edges() const
{
- return _target;
+ return _bound_edges;
}
+} // namespace graph
+} // namespace arm_compute \ No newline at end of file