aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/ITensorPack.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/ITensorPack.h')
-rw-r--r--arm_compute/core/ITensorPack.h36
1 files changed, 26 insertions, 10 deletions
diff --git a/arm_compute/core/ITensorPack.h b/arm_compute/core/ITensorPack.h
index 36b6aea490..f456c50769 100644
--- a/arm_compute/core/ITensorPack.h
+++ b/arm_compute/core/ITensorPack.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Arm Limited.
+ * Copyright (c) 2020-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,8 +24,11 @@
#ifndef ARM_COMPUTE_ITENSORPACK_H
#define ARM_COMPUTE_ITENSORPACK_H
+#include "arm_compute/core/experimental/Types.h"
+
+#include <cstddef>
#include <cstdint>
-#include <map>
+#include <unordered_map>
namespace arm_compute
{
@@ -35,26 +38,27 @@ class ITensor;
/** Tensor packing service */
class ITensorPack
{
-private:
+public:
struct PackElement
{
PackElement() = default;
- PackElement(ITensor *tensor)
- : tensor(tensor), ctensor(nullptr)
+ PackElement(int id, ITensor *tensor) : id(id), tensor(tensor), ctensor(nullptr)
{
}
- PackElement(const ITensor *ctensor)
- : tensor(nullptr), ctensor(ctensor)
+ PackElement(int id, const ITensor *ctensor) : id(id), tensor(nullptr), ctensor(ctensor)
{
}
- ITensor *tensor{ nullptr };
- const ITensor *ctensor{ nullptr };
+ int id{-1};
+ ITensor *tensor{nullptr};
+ const ITensor *ctensor{nullptr};
};
public:
/** Default Constructor */
ITensorPack() = default;
+ /** Initializer list Constructor */
+ ITensorPack(std::initializer_list<PackElement> l);
/** Add tensor to the pack
*
* @param[in] id ID/type of the tensor to add
@@ -68,6 +72,13 @@ public:
* @param[in] tensor Tensor to add
*/
void add_tensor(int id, const ITensor *tensor);
+
+ /** Add const tensor to the pack
+ *
+ * @param[in] id ID/type of the tensor to add
+ * @param[in] tensor Tensor to add
+ */
+ void add_const_tensor(int id, const ITensor *tensor);
/** Get tensor of a given id from the pac
*
* @param[in] id ID of tensor to extract
@@ -82,6 +93,11 @@ public:
* @return The pointer to the tensor if exist and is const else nullptr
*/
const ITensor *get_const_tensor(int id) const;
+ /** Remove the tensor stored with the given id
+ *
+ * @param[in] id ID of tensor to remove
+ */
+ void remove_tensor(int id);
/** Pack size accessor
*
* @return Number of tensors registered to the pack
@@ -94,7 +110,7 @@ public:
bool empty() const;
private:
- std::map<unsigned int, PackElement> _pack{}; /**< Container with the packed tensors */
+ std::unordered_map<int, PackElement> _pack{}; /**< Container with the packed tensors */
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_ITENSORPACK_H */