diff options
Diffstat (limited to 'arm_compute/Acl.hpp')
-rw-r--r-- | arm_compute/Acl.hpp | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/arm_compute/Acl.hpp b/arm_compute/Acl.hpp index a009894438..01f7179c2f 100644 --- a/arm_compute/Acl.hpp +++ b/arm_compute/Acl.hpp @@ -428,6 +428,20 @@ public: _cdesc.strides = nullptr; _cdesc.boffset = 0; } + /** Constructor + * + * @param[in] desc C-type descriptor + */ + explicit TensorDescriptor(const AclTensorDescriptor &desc) + { + _cdesc = desc; + _data_type = detail::as_enum<DataType>(desc.data_type); + _shape.reserve(desc.ndims); + for(int32_t d = 0; d < desc.ndims; ++d) + { + _shape.emplace_back(desc.shape[d]); + } + } /** Get underlying C tensor descriptor * * @return Underlying structure @@ -436,6 +450,29 @@ public: { return &_cdesc; } + /** Operator to compare two TensorDescriptor + * + * @param[in] other The instance to compare against + * + * @return True if two instances have the same shape and data type + */ + bool operator==(const TensorDescriptor &other) + { + bool is_same = true; + + is_same &= _data_type == other._data_type; + is_same &= _shape.size() == other._shape.size(); + + if(is_same) + { + for(uint32_t d = 0; d < _shape.size(); ++d) + { + is_same &= _shape[d] == other._shape[d]; + } + } + + return is_same; + } private: std::vector<int32_t> _shape{}; @@ -524,6 +561,30 @@ public: report_status(st, "[Arm Compute Library] Failed to import external memory to tensor!"); return st; } + /** Get the size of the tensor in byte + * + * @note The size isn't based on allocated memory, but based on information in its descriptor (dimensions, data type, etc.). + * + * @return The size of the tensor in byte + */ + uint64_t get_size() + { + uint64_t size{ 0 }; + const auto st = detail::as_enum<StatusCode>(AclGetTensorSize(_object.get(), &size)); + report_status(st, "[Arm Compute Library] Failed to get the size of the tensor"); + return size; + } + /** Get the descriptor of this tensor + * + * @return The descriptor describing the characteristics of this tensor + */ + TensorDescriptor get_descriptor() + { + AclTensorDescriptor desc; + const auto st = detail::as_enum<StatusCode>(AclGetTensorDescriptor(_object.get(), &desc)); + report_status(st, "[Arm Compute Library] Failed to get the descriptor of the tensor"); + return TensorDescriptor(desc); + } }; /** Tensor pack class |