diff options
Diffstat (limited to 'reference_model/src/ops/scatter_gather.cc')
-rw-r--r-- | reference_model/src/ops/scatter_gather.cc | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index bcd8ce5..80b6c58 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -29,11 +29,11 @@ OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, setRequiredOperands(2, 1); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpGather<Dtype>::~OpGather() {} -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpGather<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -96,7 +96,7 @@ int OpGather<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpGather<Dtype>::eval() { for (int32_t n = 0; n < N; n++) @@ -116,7 +116,7 @@ int OpGather<Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -125,11 +125,11 @@ OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, setRequiredOperands(3, 1); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpScatter<Dtype>::~OpScatter() {} -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpScatter<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -199,7 +199,7 @@ int OpScatter<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpScatter<Dtype>::eval() { // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. @@ -229,6 +229,7 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); @@ -236,3 +237,4 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64); |