aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/Tensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/Tensor.h')
-rw-r--r--arm_compute/graph/Tensor.h124
1 files changed, 58 insertions, 66 deletions
diff --git a/arm_compute/graph/Tensor.h b/arm_compute/graph/Tensor.h
index e5821dc812..5199ac2328 100644
--- a/arm_compute/graph/Tensor.h
+++ b/arm_compute/graph/Tensor.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,99 +24,91 @@
#ifndef __ARM_COMPUTE_GRAPH_TENSOR_H__
#define __ARM_COMPUTE_GRAPH_TENSOR_H__
-#include "arm_compute/graph/ITensorAccessor.h"
-#include "arm_compute/graph/ITensorObject.h"
#include "arm_compute/graph/Types.h"
-#include "support/ToolchainSupport.h"
+
+#include "arm_compute/graph/ITensorAccessor.h"
+#include "arm_compute/graph/ITensorHandle.h"
+#include "arm_compute/graph/TensorDescriptor.h"
#include <memory>
+#include <set>
namespace arm_compute
{
namespace graph
{
-/** Tensor class */
-class Tensor final : public ITensorObject
+/** Tensor object **/
+class Tensor final
{
public:
- /** Constructor
+ /** Default constructor
*
- * @param[in] info Tensor info to use
+ * @param[in] id Tensor ID
+ * @param[in] desc Tensor information
*/
- Tensor(TensorInfo &&info);
- /** Constructor
+ Tensor(TensorID id, TensorDescriptor desc);
+ /** Tensor ID accessor
*
- * @param[in] accessor Tensor accessor
+ * @return Tensor ID
*/
- template <typename AccessorType>
- Tensor(std::unique_ptr<AccessorType> accessor)
- : _target(TargetHint::DONT_CARE), _info(), _accessor(std::move(accessor)), _tensor(nullptr)
- {
- }
- /** Constructor
+ TensorID id() const;
+ /** TensorInfo metadata accessor
*
- * @param[in] accessor Tensor accessor
+ * @return Tensor descriptor metadata
*/
- template <typename AccessorType>
- Tensor(AccessorType &&accessor)
- : _target(TargetHint::DONT_CARE), _info(), _accessor(arm_compute::support::cpp14::make_unique<AccessorType>(std::forward<AccessorType>(accessor))), _tensor(nullptr)
- {
- }
- /** Constructor
+ TensorDescriptor &desc();
+ /** TensorInfo metadata accessor
*
- * @param[in] info Tensor info to use
- * @param[in] accessor Tensor accessor
+ * @return Tensor descriptor metadata
*/
- template <typename AccessorType>
- Tensor(TensorInfo &&info, std::unique_ptr<AccessorType> &&accessor)
- : _target(TargetHint::DONT_CARE), _info(info), _accessor(std::move(accessor)), _tensor(nullptr)
- {
- }
- /** Constructor
+ const TensorDescriptor &desc() const;
+ /** Sets the backend tensor
*
- * @param[in] info Tensor info to use
- * @param[in] accessor Tensor accessor
+ * @param[in] backend_tensor Backend tensor to set
*/
- template <typename AccessorType>
- Tensor(TensorInfo &&info, AccessorType &&accessor)
- : _target(TargetHint::DONT_CARE), _info(info), _accessor(arm_compute::support::cpp14::make_unique<AccessorType>(std::forward<AccessorType>(accessor))), _tensor(nullptr)
- {
- }
- /** Default Destructor */
- ~Tensor() = default;
- /** Move Constructor
+ void set_handle(std::unique_ptr<ITensorHandle> backend_tensor);
+ /** Backend tensor handle accessor
*
- * @param[in] src Tensor to move
+ * @return Backend tensor handle
*/
- Tensor(Tensor &&src) noexcept;
-
- /** Sets the given TensorInfo to the tensor
+ ITensorHandle *handle();
+ /** Sets the backend tensor accessor
*
- * @param[in] info TensorInfo to set
+ * @param[in] accessor Accessor to set
*/
- void set_info(TensorInfo &&info);
- /** Returns tensor's TensorInfo
+ void set_accessor(std::unique_ptr<ITensorAccessor> accessor);
+ /** Backend tensor accessor
*
- * @return TensorInfo of the tensor
+ * @return Backend tensor accessor
*/
- const TensorInfo &info() const;
- /** Allocates and fills the tensor if needed */
- void allocate_and_fill_if_needed();
-
- // Inherited methods overriden:
- bool call_accessor() override;
- bool has_accessor() const override;
- arm_compute::ITensor *set_target(TargetHint target) override;
- arm_compute::ITensor *tensor() override;
- const arm_compute::ITensor *tensor() const override;
- TargetHint target() const override;
- void allocate() override;
+ ITensorAccessor *accessor();
+ /** Calls accessor on tensor
+ *
+ * @return True if the accessor was called else false
+ */
+ bool call_accessor();
+ /** Binds the tensor with an edge
+ *
+ * @param[in] eid Edge ID that is bound to the tensor
+ */
+ void bind_edge(EdgeID eid);
+ /** Unbinds an edge from a tensor
+ *
+ * @param[in] eid Edge to unbind
+ */
+ void unbind_edge(EdgeID eid);
+ /** Accessor the edges that are bound with the tensor
+ *
+ * @return Bound edges
+ */
+ const std::set<EdgeID> bound_edges() const;
private:
- TargetHint _target; /**< Target that this tensor is pinned on */
- TensorInfo _info; /**< Tensor metadata */
- std::unique_ptr<ITensorAccessor> _accessor; /**< Tensor Accessor */
- std::unique_ptr<arm_compute::ITensor> _tensor; /**< Tensor */
+ TensorID _id; /**< Tensor id */
+ TensorDescriptor _desc; /**< Tensor metadata */
+ std::unique_ptr<ITensorHandle> _handle; /**< Tensor Handle */
+ std::unique_ptr<ITensorAccessor> _accessor; /**< Tensor Accessor */
+ std::set<EdgeID> _bound_edges; /**< Edges bound to this tensor */
};
} // namespace graph
} // namespace arm_compute