aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/Acl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/Acl.hpp')
-rw-r--r--arm_compute/Acl.hpp61
1 files changed, 61 insertions, 0 deletions
diff --git a/arm_compute/Acl.hpp b/arm_compute/Acl.hpp
index a009894438..01f7179c2f 100644
--- a/arm_compute/Acl.hpp
+++ b/arm_compute/Acl.hpp
@@ -428,6 +428,20 @@ public:
_cdesc.strides = nullptr;
_cdesc.boffset = 0;
}
+ /** Constructor
+ *
+ * @param[in] desc C-type descriptor
+ */
+ explicit TensorDescriptor(const AclTensorDescriptor &desc)
+ {
+ _cdesc = desc;
+ _data_type = detail::as_enum<DataType>(desc.data_type);
+ _shape.reserve(desc.ndims);
+ for(int32_t d = 0; d < desc.ndims; ++d)
+ {
+ _shape.emplace_back(desc.shape[d]);
+ }
+ }
/** Get underlying C tensor descriptor
*
* @return Underlying structure
@@ -436,6 +450,29 @@ public:
{
return &_cdesc;
}
+ /** Operator to compare two TensorDescriptor
+ *
+ * @param[in] other The instance to compare against
+ *
+ * @return True if two instances have the same shape and data type
+ */
+ bool operator==(const TensorDescriptor &other)
+ {
+ bool is_same = true;
+
+ is_same &= _data_type == other._data_type;
+ is_same &= _shape.size() == other._shape.size();
+
+ if(is_same)
+ {
+ for(uint32_t d = 0; d < _shape.size(); ++d)
+ {
+ is_same &= _shape[d] == other._shape[d];
+ }
+ }
+
+ return is_same;
+ }
private:
std::vector<int32_t> _shape{};
@@ -524,6 +561,30 @@ public:
report_status(st, "[Arm Compute Library] Failed to import external memory to tensor!");
return st;
}
+ /** Get the size of the tensor in byte
+ *
+ * @note The size isn't based on allocated memory, but based on information in its descriptor (dimensions, data type, etc.).
+ *
+ * @return The size of the tensor in byte
+ */
+ uint64_t get_size()
+ {
+ uint64_t size{ 0 };
+ const auto st = detail::as_enum<StatusCode>(AclGetTensorSize(_object.get(), &size));
+ report_status(st, "[Arm Compute Library] Failed to get the size of the tensor");
+ return size;
+ }
+ /** Get the descriptor of this tensor
+ *
+ * @return The descriptor describing the characteristics of this tensor
+ */
+ TensorDescriptor get_descriptor()
+ {
+ AclTensorDescriptor desc;
+ const auto st = detail::as_enum<StatusCode>(AclGetTensorDescriptor(_object.get(), &desc));
+ report_status(st, "[Arm Compute Library] Failed to get the descriptor of the tensor");
+ return TensorDescriptor(desc);
+ }
};
/** Tensor pack class