aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp46
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp1
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp2
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp8
-rw-r--r--src/backends/reference/test/ArgMinMaxTests.cpp12
-rw-r--r--src/backends/reference/workloads/ArgMinMax.cpp12
-rw-r--r--src/backends/reference/workloads/ArgMinMax.hpp3
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/RefArgMinMaxWorkload.cpp13
9 files changed, 80 insertions, 19 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 6143f4af6a..0aad048970 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -345,7 +345,9 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr,
case tflite::TensorType_INT32:
type = armnn::DataType::Signed32;
break;
-
+ case tflite::TensorType_INT64:
+ type = armnn::DataType::Signed64;
+ break;
default:
{
CheckLocation location = CHECK_LOCATION();
@@ -598,6 +600,7 @@ TfLiteParser::TfLiteParser(const Optional<ITfLiteParser::TfLiteParserOptions>& o
m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV] = &TfLiteParser::ParseTransposeConv;
m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParser::ParseUnpack;
m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParser::ParseDiv;
+ m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX] = &TfLiteParser::ParseArgMax;
// register supported custom operators
m_CustomParserFunctions["TFLite_Detection_PostProcess"] = &TfLiteParser::ParseDetectionPostProcess;
}
@@ -2847,6 +2850,47 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
}
+void TfLiteParser::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
+{
+ const auto &operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto *options = operatorPtr->builtin_options.AsArgMaxOptions();
+
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = boost::str(boost::format("ArgMax:%1%:%2%") % subgraphIndex % operatorIndex);
+
+ armnn::TensorInfo sizeTensorInfo0 = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo sizeTensorInfo1 = ToTensorInfo(inputs[1]);
+
+ // Get const axis value from model and set it to descriptor.
+ BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+
+ ArgMinMaxDescriptor desc;
+ desc.m_Axis = axisBufferPtr->data.data()[0];
+ // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64.
+ desc.m_Output_Type = options->output_type == 3 ? armnn::DataType::Signed32 : armnn::DataType::Signed64;
+ desc.m_Function = ArgMinMaxFunction::Max;
+
+ // Register a ArgMax layer.
+ IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ // Register input tensor to the layer.
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+ // Register output tensor to the layer.
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
+}
+
armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 6a611509f4..9b081a5db9 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -137,6 +137,7 @@ private:
void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
+ void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index f9335058c2..98b5adafbc 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -31,6 +31,8 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multi
return arm_compute::DataType::QASYMM8;
case armnn::DataType::QSymmS16:
return arm_compute::DataType::QSYMM16;
+ case armnn::DataType::Signed64:
+ return arm_compute::DataType::S64;
case armnn::DataType::QSymmS8:
{
return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 07ce14b763..ff97fc7f41 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -623,9 +623,10 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
- if (outputTensorInfo.GetDataType() != DataType::Signed32)
+ if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
+ outputTensorInfo.GetDataType() != DataType::Signed64)
{
- throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
+ throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
}
std::vector<DataType> supportedInputTypes =
@@ -636,7 +637,8 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Signed32
+ DataType::Signed32,
+ DataType::Signed64
};
ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
diff --git a/src/backends/reference/test/ArgMinMaxTests.cpp b/src/backends/reference/test/ArgMinMaxTests.cpp
index 201a2c0c2e..dce15b29ef 100644
--- a/src/backends/reference/test/ArgMinMaxTests.cpp
+++ b/src/backends/reference/test/ArgMinMaxTests.cpp
@@ -12,11 +12,11 @@ BOOST_AUTO_TEST_SUITE(RefArgMinMax)
BOOST_AUTO_TEST_CASE(ArgMinTest)
{
const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32);
- const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Float32);
+ const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64);
std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f});
- std::vector<int32_t> outputValues(outputInfo.GetNumElements());
- std::vector<int32_t> expectedValues({ 0, 1, 0 });
+ std::vector<int64_t> outputValues(outputInfo.GetNumElements());
+ std::vector<int64_t> expectedValues({ 0, 1, 0 });
ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()),
outputValues.data(),
@@ -35,11 +35,11 @@ BOOST_AUTO_TEST_CASE(ArgMinTest)
BOOST_AUTO_TEST_CASE(ArgMaxTest)
{
const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32);
- const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Float32);
+ const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64);
std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f });
- std::vector<int32_t> outputValues(outputInfo.GetNumElements());
- std::vector<int32_t> expectedValues({ 1, 0, 1 });
+ std::vector<int64_t> outputValues(outputInfo.GetNumElements());
+ std::vector<int64_t> expectedValues({ 1, 0, 1 });
ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()),
outputValues.data(),
diff --git a/src/backends/reference/workloads/ArgMinMax.cpp b/src/backends/reference/workloads/ArgMinMax.cpp
index c455c52e5a..3bf2853a20 100644
--- a/src/backends/reference/workloads/ArgMinMax.cpp
+++ b/src/backends/reference/workloads/ArgMinMax.cpp
@@ -12,7 +12,8 @@
namespace armnn
{
-void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+template <typename OUT>
+void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo,
const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis)
{
IgnoreUnused(outputTensorInfo);
@@ -39,9 +40,16 @@ void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorIn
tmpIndex = i;
}
}
- out[outer * innerElements + inner] = armnn::numeric_cast<int32_t>(tmpIndex);
+
+ out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex);
}
}
}
+template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+ const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
+template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo,
+ const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/ArgMinMax.hpp b/src/backends/reference/workloads/ArgMinMax.hpp
index 5a9c6a8a2a..3958ed7afd 100644
--- a/src/backends/reference/workloads/ArgMinMax.hpp
+++ b/src/backends/reference/workloads/ArgMinMax.hpp
@@ -13,7 +13,8 @@
namespace armnn
{
-void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+template <typename OUT>
+void ArgMinMax(Decoder<float>& in, OUT *out, const TensorInfo& inputTensorInfo,
const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
} //namespace armnn
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 937a32029e..cd9efc96af 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -5,8 +5,6 @@
list(APPEND armnnRefBackendWorkloads_sources
Abs.hpp
- ArgMinMax.cpp
- ArgMinMax.hpp
Activation.cpp
Activation.hpp
ArgMinMax.cpp
diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
index 5f1eb73b61..b7246d5b93 100644
--- a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
+++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
@@ -29,10 +29,15 @@ void RefArgMinMaxWorkload::Execute() const
const TensorInfo &outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- int32_t* output = GetOutputTensorData<int32_t>(0, m_Data);
-
- ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
- m_Data.m_Parameters.m_Axis);
+ if (m_Data.m_Parameters.m_Output_Type == armnn::DataType::Signed32) {
+ int32_t *output = GetOutputTensorData<int32_t>(0, m_Data);
+ ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
+ m_Data.m_Parameters.m_Axis);
+ } else {
+ int64_t *output = GetOutputTensorData<int64_t>(0, m_Data);
+ ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
+ m_Data.m_Parameters.m_Axis);
+ }
}
} //namespace armnn \ No newline at end of file