aboutsummaryrefslogtreecommitdiff
path: root/src/graph/SubTensor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/SubTensor.cpp')
-rw-r--r--src/graph/SubTensor.cpp40
1 files changed, 25 insertions, 15 deletions
diff --git a/src/graph/SubTensor.cpp b/src/graph/SubTensor.cpp
index abf8506c33..da8de956d7 100644
--- a/src/graph/SubTensor.cpp
+++ b/src/graph/SubTensor.cpp
@@ -27,7 +27,9 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/runtime/CL/CLSubTensor.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/SubTensor.h"
+#include "arm_compute/runtime/Tensor.h"
#include "utils/TypePrinter.h"
using namespace arm_compute::graph;
@@ -35,7 +37,7 @@ using namespace arm_compute::graph;
namespace
{
template <typename SubTensorType, typename ParentTensorType>
-std::unique_ptr<ITensor> initialise_subtensor(ITensor *parent, TensorShape shape, Coordinates coords)
+std::unique_ptr<arm_compute::ITensor> initialise_subtensor(arm_compute::ITensor *parent, TensorShape shape, Coordinates coords)
{
auto ptensor = dynamic_cast<ParentTensorType *>(parent);
auto subtensor = arm_compute::support::cpp14::make_unique<SubTensorType>(ptensor, shape, coords);
@@ -44,41 +46,44 @@ std::unique_ptr<ITensor> initialise_subtensor(ITensor *parent, TensorShape shape
} // namespace
SubTensor::SubTensor()
- : _target(TargetHint::DONT_CARE), _coords(), _info(), _parent(nullptr), _subtensor(nullptr)
+ : _target(TargetHint::DONT_CARE), _tensor_shape(), _coords(), _parent(nullptr), _subtensor(nullptr)
{
}
SubTensor::SubTensor(Tensor &parent, TensorShape tensor_shape, Coordinates coords)
- : _target(TargetHint::DONT_CARE), _coords(coords), _info(), _parent(nullptr), _subtensor(nullptr)
+ : _target(TargetHint::DONT_CARE), _tensor_shape(tensor_shape), _coords(coords), _parent(nullptr), _subtensor(nullptr)
{
ARM_COMPUTE_ERROR_ON(parent.tensor() == nullptr);
_parent = parent.tensor();
- _info = SubTensorInfo(parent.tensor()->info(), tensor_shape, coords);
_target = parent.target();
instantiate_subtensor();
}
-SubTensor::SubTensor(ITensor *parent, TensorShape tensor_shape, Coordinates coords, TargetHint target)
- : _target(target), _coords(coords), _info(), _parent(parent), _subtensor(nullptr)
+SubTensor::SubTensor(arm_compute::ITensor *parent, TensorShape tensor_shape, Coordinates coords, TargetHint target)
+ : _target(target), _tensor_shape(tensor_shape), _coords(coords), _parent(parent), _subtensor(nullptr)
{
ARM_COMPUTE_ERROR_ON(parent == nullptr);
- _info = SubTensorInfo(parent->info(), tensor_shape, coords);
-
instantiate_subtensor();
}
-void SubTensor::set_info(SubTensorInfo &&info)
+bool SubTensor::call_accessor()
+{
+ return true;
+}
+
+bool SubTensor::has_accessor() const
{
- _info = info;
+ return false;
}
-const SubTensorInfo &SubTensor::info() const
+arm_compute::ITensor *SubTensor::set_target(TargetHint target)
{
- return _info;
+ ARM_COMPUTE_ERROR_ON(target != _target);
+ return (target == _target) ? _subtensor.get() : nullptr;
}
-ITensor *SubTensor::tensor()
+arm_compute::ITensor *SubTensor::tensor()
{
return _subtensor.get();
}
@@ -88,15 +93,20 @@ TargetHint SubTensor::target() const
return _target;
}
+void SubTensor::allocate()
+{
+ // NOP for sub-tensors
+}
+
void SubTensor::instantiate_subtensor()
{
switch(_target)
{
case TargetHint::OPENCL:
- _subtensor = initialise_subtensor<arm_compute::CLSubTensor, arm_compute::ICLTensor>(_parent, _info.tensor_shape(), _coords);
+ _subtensor = initialise_subtensor<arm_compute::CLSubTensor, arm_compute::ICLTensor>(_parent, _tensor_shape, _coords);
break;
case TargetHint::NEON:
- _subtensor = initialise_subtensor<arm_compute::SubTensor, arm_compute::ITensor>(_parent, _info.tensor_shape(), _coords);
+ _subtensor = initialise_subtensor<arm_compute::SubTensor, arm_compute::ITensor>(_parent, _tensor_shape, _coords);
break;
default:
ARM_COMPUTE_ERROR("Invalid TargetHint");