From 245d64c60d0ea30f5080ff53225b5169927e24d6 Mon Sep 17 00:00:00 2001 From: Matthew Bentham Date: Mon, 2 Dec 2019 12:59:43 +0000 Subject: Work in progress of python bindings for Arm NN Not built or tested in any way Signed-off-by: Matthew Bentham Change-Id: Ie7f92b529aa5087130f0c5cc8c17db1581373236 --- .../src/pyarmnn/swig/modules/armnn_tensor.i | 313 +++++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i (limited to 'python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i') diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i new file mode 100644 index 0000000000..efa9a16352 --- /dev/null +++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_tensor.i @@ -0,0 +1,313 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +%{ +#include "armnn/Tensor.hpp" +%} + +%include +%include + +namespace armnn +{ + +%feature("docstring", +" +Class for holding the shape information of an Arm NN tensor. + +") TensorShape; +class TensorShape +{ +public: + %tensor_shape_typemap(unsigned int numDimensions, const unsigned int* dimensionSizes); + TensorShape(unsigned int numDimensions, const unsigned int* dimensionSizes); + %clear_tensor_shape_typemap(unsigned int numDimensions, const unsigned int* dimensionSizes); + + %feature("docstring", + " + Returns the number of dimensions in this TensorShape. + + Returns: + int: The number of dimensions in this TensorShape. + + ") GetNumDimensions; + unsigned int GetNumDimensions() const; + + %feature("docstring", + " + Returns the total number of elements for a tensor with this TensorShape. + + Returns: + int: The total number of elements for a tensor with this TensorShape. + + ") GetNumElements; + unsigned int GetNumElements() const; + +}; + +%extend TensorShape { + + unsigned int __getitem__(unsigned int i) const { + return $self->operator[](i); + } + void __setitem__(unsigned int i, unsigned int val) { + $self->operator[](i) = val; + } + + std::string __str__() { + std::string dim = "NumDimensions: " + std::to_string($self->GetNumDimensions()); + std::string elm = "NumElements: " + std::to_string($self->GetNumElements()); + + std::string shapeStr = "TensorShape{Shape("; + + auto numDimensions = $self->GetNumDimensions(); + auto sizeDims = $self->GetNumDimensions(); + for (unsigned int i = 0; i < numDimensions; i++) { + shapeStr += std::to_string($self->operator[](i)); + + if (sizeDims - 1 > 0) { + shapeStr += ", "; + } + sizeDims--; + } + shapeStr = shapeStr + "), " + dim + ", " + elm + "}"; + return shapeStr; + } + +} + + +%feature("docstring", +" +Class for holding the tensor information of an Arm NN tensor such as quantization, datatype, shape etc. + +") TensorInfo; +class TensorInfo +{ +public: + TensorInfo(); + + TensorInfo(const TensorInfo& other); + + TensorInfo(const TensorShape& shape, DataType dataType, + float quantizationScale = 0.0f, int32_t quantizationOffset = 0); + + %feature("docstring", + " + Get the tensor shape. + + Return: + TensorShape: Current shape of the tensor. + + ") GetShape; + TensorShape& GetShape(); + + %feature("docstring", + " + Set the tensor shape. Must have the same number of elements as current tensor. + + Args: + newShape (TensorShape): New tensor shape to reshape to. + + ") SetShape; + void SetShape(const TensorShape& newShape); + + %feature("docstring", + " + Returns the number of dimensions in this Tensor. + + Returns: + int: The number of dimensions in this Tensor. + + ") GetNumDimensions; + unsigned int GetNumDimensions() const; + + %feature("docstring", + " + Returns the total number of elements for this Tensor. + + Returns: + int: The total number of elements for this Tensor. + + ") GetNumElements; + unsigned int GetNumElements() const; + + %feature("docstring", + " + Get the tensor datatype. + + Returns: + DataType: Current tensor DataType. + + ") GetDataType; + DataType GetDataType() const; + + %feature("docstring", + " + Set the tensor datatype. + + Args: + type (DataType): DataType to set the tensor to. + + ") SetDataType; + void SetDataType(DataType type); + + %feature("docstring", + " + Get the value of the tensors quantization scale. + + Returns: + float: Tensor quantization scale value. + + ") GetQuantizationScale; + float GetQuantizationScale() const; + + %feature("docstring", + " + Get the value of the tensors quantization offset. + + Returns: + int: Tensor quantization offset value. + + ") GetQuantizationOffset; + int32_t GetQuantizationOffset() const; + + %feature("docstring", + " + Set the value of the tensors quantization scale. + + Args: + scale (float): Scale value to set. + + ") SetQuantizationScale; + void SetQuantizationScale(float scale); + + %feature("docstring", + " + Set the value of the tensors quantization offset. + + Args: + offset (int): Offset value to set. + + ") SetQuantizationOffset; + void SetQuantizationOffset(int32_t offset); + + %feature("docstring", + " + Returns true if the tensor is a quantized data type. + + Returns: + bool: True if the tensor is a quantized data type. + + ") IsQuantized; + bool IsQuantized() const; + + + + %feature("docstring", + " + Check that the types are the same and, if quantize, that the quantization parameters are the same. + + Returns: + bool: True if matched, else False. + + ") IsTypeSpaceMatch; + bool IsTypeSpaceMatch(const TensorInfo& other) const; + + %feature("docstring", + " + Get the number of bytes needed for this tensor. + + Returns: + int: Number of bytes consumed by this tensor. + + ") GetNumBytes; + unsigned int GetNumBytes() const; + +}; + +%extend TensorInfo { + + std::string __str__() { + const std::string tmp = "TensorInfo{DataType: " + std::to_string(static_cast($self->GetDataType())) + + ", IsQuantized: " + std::to_string($self->IsQuantized()) + + ", QuantizationScale: " + std::to_string( $self->GetQuantizationScale()) + + ", QuantizationOffset: " + std::to_string($self->GetQuantizationOffset()) + + ", NumDimensions: " + std::to_string($self->GetNumDimensions()) + + ", NumElements: " + std::to_string($self->GetNumElements()) + "}"; + return tmp; + } + +} + +class Tensor +{ +public: + ~Tensor(); + Tensor(); + Tensor(const Tensor& other); + + %mutable_memory(void* memory); + Tensor(const TensorInfo& info, void* memory); + %clear_mutable_memory(void* memory); + + const TensorInfo& GetInfo() const; + const TensorShape& GetShape() const; + + DataType GetDataType() const; + unsigned int GetNumDimensions() const; + unsigned int GetNumBytes() const; + unsigned int GetNumElements() const; + + /* we want to disable getting the memory area from here - forcing use of get_memory_area() in public api. + void* GetMemoryArea() const;*/ +}; + +%extend Tensor { + + std::string __str__() { + const std::string tmp = "Tensor{DataType: " + std::to_string(static_cast($self->GetDataType())) + + ", NumBytes: " + std::to_string($self->GetNumBytes()) + + ", NumDimensions: " + std::to_string( $self->GetNumDimensions()) + + ", NumElements: " + std::to_string($self->GetNumElements()) + "}"; + return tmp; + } +} + +class ConstTensor +{ +public: + ~ConstTensor(); + ConstTensor(); + ConstTensor(const Tensor& other); + ConstTensor(const ConstTensor& other); + + %const_memory(const void* memory); + ConstTensor(const TensorInfo& info, const void* memory); + %clear_const_memory(const void* memory); + + const TensorInfo& GetInfo() const; + const TensorShape& GetShape() const; + + DataType GetDataType() const; + unsigned int GetNumDimensions() const; + unsigned int GetNumBytes() const; + unsigned int GetNumElements() const; + + /* we want to disable getting the memory area from here - forcing use of get_memory_area() in public api. + void* GetMemoryArea() const;*/ +}; + +%extend ConstTensor { + + std::string __str__() { + const std::string tmp = "ConstTensor{DataType: " + std::to_string(static_cast($self->GetDataType())) + + ", NumBytes: " + std::to_string($self->GetNumBytes()) + + ", NumDimensions: " + std::to_string( $self->GetNumDimensions()) + + ", NumElements: " + std::to_string($self->GetNumElements()) + "}"; + return tmp; + } +} + +} -- cgit v1.2.1