aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/ITensorPack.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/ITensorPack.h')
-rw-r--r--arm_compute/core/ITensorPack.h19
1 files changed, 12 insertions, 7 deletions
diff --git a/arm_compute/core/ITensorPack.h b/arm_compute/core/ITensorPack.h
index 8aea880bb6..2f41d4d51e 100644
--- a/arm_compute/core/ITensorPack.h
+++ b/arm_compute/core/ITensorPack.h
@@ -24,9 +24,11 @@
#ifndef ARM_COMPUTE_ITENSORPACK_H
#define ARM_COMPUTE_ITENSORPACK_H
+#include "arm_compute/core/experimental/Types.h"
+
#include <cstddef>
#include <cstdint>
-#include <map>
+#include <unordered_map>
namespace arm_compute
{
@@ -36,19 +38,20 @@ class ITensor;
/** Tensor packing service */
class ITensorPack
{
-private:
+public:
struct PackElement
{
PackElement() = default;
- PackElement(ITensor *tensor)
- : tensor(tensor), ctensor(nullptr)
+ PackElement(int id, ITensor *tensor)
+ : id(id), tensor(tensor), ctensor(nullptr)
{
}
- PackElement(const ITensor *ctensor)
- : tensor(nullptr), ctensor(ctensor)
+ PackElement(int id, const ITensor *ctensor)
+ : id(id), tensor(nullptr), ctensor(ctensor)
{
}
+ int id{ -1 };
ITensor *tensor{ nullptr };
const ITensor *ctensor{ nullptr };
};
@@ -56,6 +59,8 @@ private:
public:
/** Default Constructor */
ITensorPack() = default;
+ /** Initializer list Constructor */
+ ITensorPack(std::initializer_list<PackElement> l);
/** Add tensor to the pack
*
* @param[in] id ID/type of the tensor to add
@@ -102,7 +107,7 @@ public:
bool empty() const;
private:
- std::map<unsigned int, PackElement> _pack{}; /**< Container with the packed tensors */
+ std::unordered_map<int, PackElement> _pack{}; /**< Container with the packed tensors */
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_ITENSORPACK_H */