aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/CpuTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/CpuTensorHandle.hpp')
-rw-r--r--src/backends/backendsCommon/CpuTensorHandle.hpp21
1 files changed, 7 insertions, 14 deletions
diff --git a/src/backends/backendsCommon/CpuTensorHandle.hpp b/src/backends/backendsCommon/CpuTensorHandle.hpp
index dd6413f2e7..5fefc125c1 100644
--- a/src/backends/backendsCommon/CpuTensorHandle.hpp
+++ b/src/backends/backendsCommon/CpuTensorHandle.hpp
@@ -16,6 +16,10 @@
namespace armnn
{
+// Get a TensorShape representing the strides (in bytes) for each dimension
+// of a tensor, assuming fully packed data with no padding
+TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
+
// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
class ConstCpuTensorHandle : public ITensorHandle
{
@@ -41,18 +45,7 @@ public:
TensorShape GetStrides() const override
{
- TensorShape shape(m_TensorInfo.GetShape());
- auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
- auto runningSize = size;
- std::vector<unsigned int> strides(shape.GetNumDimensions());
- auto lastIdx = shape.GetNumDimensions()-1;
- for (unsigned int i=0; i < lastIdx ; i++)
- {
- strides[lastIdx-i] = runningSize;
- runningSize *= shape[lastIdx-i];
- }
- strides[0] = runningSize;
- return TensorShape(shape.GetNumDimensions(), strides.data());
+ return GetUnpaddedTensorStrides(m_TensorInfo);
}
TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
@@ -63,8 +56,8 @@ protected:
private:
// Only used for testing
- void CopyOutTo(void *) const override {}
- void CopyInFrom(const void*) override {}
+ void CopyOutTo(void *) const override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
+ void CopyInFrom(const void*) override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;