diff options
Diffstat (limited to 'arm_compute/graph/SubTensor.h')
-rw-r--r-- | arm_compute/graph/SubTensor.h | 42 |
1 files changed, 15 insertions, 27 deletions
diff --git a/arm_compute/graph/SubTensor.h b/arm_compute/graph/SubTensor.h index ace93d20a3..22a0a9e27f 100644 --- a/arm_compute/graph/SubTensor.h +++ b/arm_compute/graph/SubTensor.h @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_GRAPH_SUBTENSOR_H__ #include "arm_compute/graph/ITensorAccessor.h" +#include "arm_compute/graph/ITensorObject.h" #include "arm_compute/graph/Tensor.h" #include "arm_compute/graph/Types.h" #include "support/ToolchainSupport.h" @@ -36,7 +37,7 @@ namespace arm_compute namespace graph { /** SubTensor class */ -class SubTensor final +class SubTensor final : public ITensorObject { public: /** Default Constructor */ @@ -55,7 +56,7 @@ public: * @param[in] coords Starting coordinates of the sub-tensor in the parent tensor * @param[in] target Execution target */ - SubTensor(ITensor *parent, TensorShape tensor_shape, Coordinates coords, TargetHint target); + SubTensor(arm_compute::ITensor *parent, TensorShape tensor_shape, Coordinates coords, TargetHint target); /** Prevent instances of this class from being copied (As this class contains pointers) */ SubTensor(const SubTensor &) = delete; /** Prevent instances of this class from being copied (As this class contains pointers) */ @@ -67,37 +68,24 @@ public: /** Default Destructor */ ~SubTensor() = default; - /** Sets the given TensorInfo to the tensor - * - * @param[in] info TensorInfo to set - */ - void set_info(SubTensorInfo &&info); - /** Returns tensor's TensorInfo - * - * @return TensorInfo of the tensor - */ - const SubTensorInfo &info() const; - /** Returns a pointer to the internal tensor - * - * @return Tensor - */ - ITensor *tensor(); - /** Return the target that this tensor is pinned on - * - * @return Target of the tensor - */ - TargetHint target() const; + // Inherited methods overriden: + bool call_accessor() override; + bool has_accessor() const override; + arm_compute::ITensor *set_target(TargetHint target) override; + arm_compute::ITensor *tensor() override; + TargetHint target() const override; + void allocate() override; private: /** Instantiates a sub-tensor */ void instantiate_subtensor(); private: - TargetHint _target; /**< Target that this tensor is pinned on */ - Coordinates _coords; /**< SubTensor Coordinates */ - SubTensorInfo _info; /**< SubTensor metadata */ - ITensor *_parent; /**< Parent tensor */ - std::unique_ptr<ITensor> _subtensor; /**< SubTensor */ + TargetHint _target; /**< Target that this tensor is pinned on */ + TensorShape _tensor_shape; /**< SubTensor shape */ + Coordinates _coords; /**< SubTensor Coordinates */ + arm_compute::ITensor *_parent; /**< Parent tensor */ + std::unique_ptr<arm_compute::ITensor> _subtensor; /**< SubTensor */ }; } // namespace graph } // namespace arm_compute |