diff options
author | Tai Ly <tai.ly@arm.com> | 2023-03-28 22:06:56 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-05-05 19:23:15 +0000 |
commit | a4d748b08accce06fab93e2d2b96e499b35ae89b (patch) | |
tree | 20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/scatter_gather.cc | |
parent | 0c71686875618b2e11290273b7a05b88ef8a8aae (diff) | |
download | reference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz |
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model,
which will cause reference model to convert all floating point tensors
to FP64 tensors and compute all operators accordingly.
Also adds optional -p arguments to test runners tosa_verif_run_tests.py
and tosa_verif_framework_compiler_runner.py to run tests in precise mode
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/scatter_gather.cc')
-rw-r--r-- | reference_model/src/ops/scatter_gather.cc | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index bcd8ce5..80b6c58 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -29,11 +29,11 @@ OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, setRequiredOperands(2, 1); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpGather<Dtype>::~OpGather() {} -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpGather<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -96,7 +96,7 @@ int OpGather<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpGather<Dtype>::eval() { for (int32_t n = 0; n < N; n++) @@ -116,7 +116,7 @@ int OpGather<Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -125,11 +125,11 @@ OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, setRequiredOperands(3, 1); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpScatter<Dtype>::~OpScatter() {} -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpScatter<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -199,7 +199,7 @@ int OpScatter<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpScatter<Dtype>::eval() { // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. @@ -229,6 +229,7 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); @@ -236,3 +237,4 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64); |