aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/SubTensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/SubTensor.h')
-rw-r--r--arm_compute/graph/SubTensor.h42
1 files changed, 15 insertions, 27 deletions
diff --git a/arm_compute/graph/SubTensor.h b/arm_compute/graph/SubTensor.h
index ace93d20a3..22a0a9e27f 100644
--- a/arm_compute/graph/SubTensor.h
+++ b/arm_compute/graph/SubTensor.h
@@ -25,6 +25,7 @@
#define __ARM_COMPUTE_GRAPH_SUBTENSOR_H__
#include "arm_compute/graph/ITensorAccessor.h"
+#include "arm_compute/graph/ITensorObject.h"
#include "arm_compute/graph/Tensor.h"
#include "arm_compute/graph/Types.h"
#include "support/ToolchainSupport.h"
@@ -36,7 +37,7 @@ namespace arm_compute
namespace graph
{
/** SubTensor class */
-class SubTensor final
+class SubTensor final : public ITensorObject
{
public:
/** Default Constructor */
@@ -55,7 +56,7 @@ public:
* @param[in] coords Starting coordinates of the sub-tensor in the parent tensor
* @param[in] target Execution target
*/
- SubTensor(ITensor *parent, TensorShape tensor_shape, Coordinates coords, TargetHint target);
+ SubTensor(arm_compute::ITensor *parent, TensorShape tensor_shape, Coordinates coords, TargetHint target);
/** Prevent instances of this class from being copied (As this class contains pointers) */
SubTensor(const SubTensor &) = delete;
/** Prevent instances of this class from being copied (As this class contains pointers) */
@@ -67,37 +68,24 @@ public:
/** Default Destructor */
~SubTensor() = default;
- /** Sets the given TensorInfo to the tensor
- *
- * @param[in] info TensorInfo to set
- */
- void set_info(SubTensorInfo &&info);
- /** Returns tensor's TensorInfo
- *
- * @return TensorInfo of the tensor
- */
- const SubTensorInfo &info() const;
- /** Returns a pointer to the internal tensor
- *
- * @return Tensor
- */
- ITensor *tensor();
- /** Return the target that this tensor is pinned on
- *
- * @return Target of the tensor
- */
- TargetHint target() const;
+ // Inherited methods overriden:
+ bool call_accessor() override;
+ bool has_accessor() const override;
+ arm_compute::ITensor *set_target(TargetHint target) override;
+ arm_compute::ITensor *tensor() override;
+ TargetHint target() const override;
+ void allocate() override;
private:
/** Instantiates a sub-tensor */
void instantiate_subtensor();
private:
- TargetHint _target; /**< Target that this tensor is pinned on */
- Coordinates _coords; /**< SubTensor Coordinates */
- SubTensorInfo _info; /**< SubTensor metadata */
- ITensor *_parent; /**< Parent tensor */
- std::unique_ptr<ITensor> _subtensor; /**< SubTensor */
+ TargetHint _target; /**< Target that this tensor is pinned on */
+ TensorShape _tensor_shape; /**< SubTensor shape */
+ Coordinates _coords; /**< SubTensor Coordinates */
+ arm_compute::ITensor *_parent; /**< Parent tensor */
+ std::unique_ptr<arm_compute::ITensor> _subtensor; /**< SubTensor */
};
} // namespace graph
} // namespace arm_compute