diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2020-11-24 10:26:32 -0800 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2020-11-24 14:06:04 -0800 |
commit | 77d0f76b42c6353459d3271ccba9e41728d066cd (patch) | |
tree | 70f95376de80e249a15b15dba6faa27fb4d48f45 /reference_model/src/ops/scatter_gather.cc | |
parent | aee1facbde25caf27cc34e5ec08eb8bba6af8e18 (diff) | |
download | reference_model-77d0f76b42c6353459d3271ccba9e41728d066cd.tar.gz |
Update reference model/serialization library to 0.21.0 with unit tests added/updated
- update tosa.GATHER
- update tosa.RESIZE
- add tosa.SCATTER
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25
Diffstat (limited to 'reference_model/src/ops/scatter_gather.cc')
-rw-r--r-- | reference_model/src/ops/scatter_gather.cc | 210 |
1 files changed, 160 insertions, 50 deletions
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index c54204a..2d1026f 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -20,31 +20,61 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int InRank, int IndexRank, DType Dtype> -OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) +template <DType Dtype> +OpGather<Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_GATHER, id_) { setRequiredOperands(2, 1); - setRequiredRank(1, 6); - - INIT_ATTRIBUTE(Axis); } -template <int InRank, int IndexRank, DType Dtype> -OpGather<InRank, IndexRank, Dtype>::~OpGather() -{ - if (attribute) - delete attribute; -} +template <DType Dtype> +OpGather<Dtype>::~OpGather() +{} -template <int InRank, int IndexRank, DType Dtype> -int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes() +template <DType Dtype> +int OpGather<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + if (inputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: values needs to be rank 3 tensor"); + return 1; + } + + if (inputs[1]->getRank() != 2) { + printNodeValidationError("OpGather: indices needs to be rank 2 tensor"); + return 1; + } + + if (outputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: output needs to be rank 3 tensor"); + return 1; + } + + K = inputs[0]->getShape()[1]; + N = outputs[0]->getShape()[0]; + W = outputs[0]->getShape()[1]; + C = outputs[0]->getShape()[2]; + + if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0]) + { + printNodeValidationError("OpGather: dimension N mismatch"); + return 1; + } + + if (W != inputs[1]->getShape()[1]) + { + printNodeValidationError("OpGather: dimension W mismatch"); + return 1; + } + + if (C != inputs[0]->getShape()[2]) + { + printNodeValidationError("OpGather: dimension C mismatch"); return 1; } @@ -55,59 +85,133 @@ int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes() return 1; } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]); - out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + values = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]); + indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]); + output = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]); - ASSERT_MEM(in && index && out); + ASSERT_MEM(values && indices && output); return 0; } -template <int InRank, int IndexRank, DType Dtype> -int OpGather<InRank, IndexRank, Dtype>::eval() +template <DType Dtype> +int OpGather<Dtype>::eval() +{ + for (int32_t n = 0; n < N; n++) + { + for (int32_t w = 0; w < W; w++) + { + int32_t k = this->indices->getTensor()(n, w); + ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + for (int32_t c = 0; c < C; c++) + { + EigenType value = this->values->getTensor()(n, k, c); + this->output->getTensor()(n, w, c) = value; + } + } + } + + return GraphNode::eval(); +} + +template <DType Dtype> +OpScatter<Dtype>::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SCATTER, id_) { - int axis = attribute->axis(); + setRequiredOperands(3, 1); +} + +template <DType Dtype> +OpScatter<Dtype>::~OpScatter() +{} - // calculate size left and right to axis - int left_size = 1; - for (int i = 0; i < axis; ++i) +template <DType Dtype> +int OpScatter<Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (inputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: values_in needs to be rank 3 tensor"); + return 1; + } + + if (inputs[1]->getRank() != 2) + { + printNodeValidationError("OpGather: indices needs to be rank 2 tensor"); + return 1; + } + + if (inputs[2]->getRank() != 3) { - left_size *= in->getShape()[i]; + printNodeValidationError("OpGather: input needs to be rank 3 tensor"); + return 1; } - int right_size = 1; - for (int i = axis + 1; i < in->getRank(); ++i) + if (outputs[0]->getRank() != 3) { - right_size *= in->getShape()[i]; + printNodeValidationError("OpGather: values_out needs to be rank 3 tensor"); + return 1; } - InEigenType* input_data = in->getTensor().data(); - int32_t* index_data = index->getTensor().data(); - OutEigenType* output_data = out->getTensor().data(); + W = inputs[2]->getShape()[1]; + N = outputs[0]->getShape()[0]; + K = outputs[0]->getShape()[1]; + C = outputs[0]->getShape()[2]; - int32_t axis_size = in->getShape()[axis]; - int32_t index_count = index->getElementCount(); + if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0]) + { + printNodeValidationError("OpScatter: dimension N mismatch"); + return 1; + } - // sanity check if index is valid - // need to check until this point since index is not known until runtime - for (size_t i = 0; i < index->getElementCount(); i++) + if (W != inputs[1]->getShape()[1]) { - if (index_data[i] >= axis_size) - { - FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size); - } + printNodeValidationError("OpGather: dimension W mismatch"); + return 1; } - // Eigen stores tensor in column-major - // so we iterate through dimension right to axis and the index array - // do memory copy with size of left size each time - for (int right = 0; right < right_size; ++right) + if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2]) + { + printNodeValidationError("OpGather: dimension C mismatch"); + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type"); + return 1; + } + + values_in = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]); + indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]); + input = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[2]); + values_out = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]); + + ASSERT_MEM(values_in && indices && input && values_out); + + return 0; +} + +template <DType Dtype> +int OpScatter<Dtype>::eval() +{ + // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. + this->values_out->getTensor() = this->values_in->getTensor(); + + for (int n = 0; n < N; n++) { - for (int i = 0; i < index_count; ++i) + for (int w = 0; w < W; w++) { - std::memcpy(output_data + (right * index_count + i) * left_size, - input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size); + int32_t k = this->indices->getTensor()(n, w); + ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + for (int c = 0; c < C; c++) + { + EigenType value = this->input->getTensor()(n, w, c); + this->values_out->getTensor()(n, k, c) = value; + } } } @@ -115,6 +219,12 @@ int OpGather<InRank, IndexRank, Dtype>::eval() } // template explicit instantiation -DEF_INSTANTIATE_GATHER(OpGather, AINT8); -DEF_INSTANTIATE_GATHER(OpGather, INT16); -DEF_INSTANTIATE_GATHER(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT); + +DEF_INSTANTIATE_ONE_TYPE(OpScatter, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT); |