diff options
Diffstat (limited to 'src/common/IOperator.h')
-rw-r--r-- | src/common/IOperator.h | 72 |
1 files changed, 29 insertions, 43 deletions
diff --git a/src/common/IOperator.h b/src/common/IOperator.h index 7fdb443acb..1b65a09e0d 100644 --- a/src/common/IOperator.h +++ b/src/common/IOperator.h @@ -27,6 +27,11 @@ #include "src/common/IContext.h" #include "src/common/IQueue.h" +// TODO: Remove when all functions have been ported +#include "arm_compute/core/experimental/Types.h" +#include "arm_compute/runtime/IOperator.h" +#include "src/common/utils/Validate.h" + #include <vector> struct AclOperator_ @@ -42,21 +47,12 @@ namespace arm_compute { // Forward declarations class ITensorPack; - -/** Structure to capture internally memory requirements */ -struct MemoryInfo +namespace experimental { - MemoryInfo(AclTensorSlot slot_id, size_t size, size_t alignment) noexcept - : slot_id(slot_id), - size(size), - alignment(alignment) - { - } - AclTensorSlot slot_id; - size_t size; - size_t alignment; -}; -using MemoryRequirements = std::vector<MemoryInfo>; +class IOperator; +} // namespace experimental + +using MemoryRequirements = experimental::MemoryRequirements; /** Base class specifying the operator interface */ class IOperator : public AclOperator_ @@ -66,55 +62,45 @@ public: * * @param[in] ctx Context to be used by the operator */ - explicit IOperator(IContext *ctx) - { - this->header.ctx = ctx; - this->header.ctx->inc_ref(); - } - + explicit IOperator(IContext *ctx); /** Destructor */ - virtual ~IOperator() - { - this->header.ctx->dec_ref(); - this->header.type = detail::ObjectType::Invalid; - }; - + virtual ~IOperator(); /** Checks if an operator is valid * * @return True if successful otherwise false */ - bool is_valid() const - { - return this->header.type == detail::ObjectType::Operator; - }; - + bool is_valid() const; /** Run the kernels contained in the function * - * @param[in] queue Queue to run a kernel on + * @param[in] queue Queue to use * @param[in] tensors Vector that contains the tensors to operate on */ - virtual StatusCode run(IQueue &queue, ITensorPack &tensors) = 0; - + virtual StatusCode run(IQueue &queue, ITensorPack &tensors); + /** Run the kernels contained in the function + * + * @param[in] tensors Vector that contains the tensors to operate on + */ + virtual StatusCode run(ITensorPack &tensors); /** Prepare the operator for execution * * Any one off pre-processing step required by the function is handled here * - * @param[in] constants Vector that contains the constants tensors. + * @param[in] tensors Vector that contains the preparation tensors. * * @note Prepare stage might not need all the function's buffers' backing memory to be available in order to execute */ - virtual StatusCode prepare(ITensorPack &constants) - { - ARM_COMPUTE_UNUSED(constants); - return StatusCode::Success; - } - + virtual StatusCode prepare(ITensorPack &tensors); /** Return the memory requirements required by the workspace */ - virtual MemoryRequirements workspace() const + virtual MemoryRequirements workspace() const; + + void set_internal_operator(std::unique_ptr<experimental::IOperator> op) { - return {}; + _op = std::move(op); } + +private: + std::unique_ptr<experimental::IOperator> _op{ nullptr }; }; /** Extract internal representation of an Operator |