From 82fbe7c0b82f7adadd5120ac4b4f779d0da7c0d5 Mon Sep 17 00:00:00 2001 From: Jim Flynn Date: Tue, 2 Apr 2019 15:19:08 +0100 Subject: IVGCVSW-2912 Make get_compute_library.sh sole source for clframework pin * Also incorporated fix for break in master build 32-bit NEDepthwiseConvolution errors in clframework * Fixed a failure in the Float16 workloads for ElementwiseOperations !android-nn-driver:963 Change-Id: Ic2cdb0e6c9399fa42b56001c6f4b46b7f150f143 Signed-off-by: Jim Flynn --- scripts/get_compute_library.sh | 47 +++++++++++++++------------ src/backends/backendsCommon/WorkloadData.cpp | 12 ++++--- src/backends/reference/RefWorkloadFactory.cpp | 32 ++++++++++++++++++ 3 files changed, 67 insertions(+), 24 deletions(-) diff --git a/scripts/get_compute_library.sh b/scripts/get_compute_library.sh index cf117d68b9..6ad98ca1b6 100755 --- a/scripts/get_compute_library.sh +++ b/scripts/get_compute_library.sh @@ -6,11 +6,24 @@ CMD=$( basename $0 ) +# For pinnning to a ref use this: +# DEFAULT_CLFRAMEWORKREVISION="branches/arm_compute_19_02" # Release 19.02 +# +# For pinning to a revision use this: +DEFAULT_CLFRAMEWORKREVISION="a4bba9c594c4022c9f85192bb8fd3593ad1a8d3c" # COMPMID-1995: Fix 32-bit NEDepthwiseConvolution errors. + usage() { - echo "Usage: $CMD -g " + echo "Usage: $CMD (Use the default clframework SHA)" + echo "Usage: $CMD -s " + echo "Usage: $CMD -p (Print current default clframework SHA)" exit 1 } +PrintDefaultClframeworkSha() { + echo $DEFAULT_CLFRAMEWORKREVISION + exit 2; +} + function AssertZeroExitCode { EXITCODE=$? if [ $EXITCODE -ne 0 ]; then @@ -21,9 +34,11 @@ function AssertZeroExitCode { } # process the options given -while getopts "g:h" opt; do +while getopts "s:phg:" opt; do case "$opt" in - g) GITHUB_USERNAME="$OPTARG";; + s) CLFRAMEWORK_SHA="$OPTARG";; + p) PrintDefaultClframeworkSha;; + g) DUMMY="$OPTARG";; # continue to accept -g for backward compatibility h|\?) usage;; esac done @@ -46,36 +61,28 @@ done DIR="$( cd -P "$( dirname "$SRC" )" >/dev/null && pwd )" pushd ${DIR} > /dev/null cd ../.. -if [ -z "$USERNAME" ]; then - USERNAME=$USER -fi -if [ -z "$GITHUB_USERNAME" ]; then - GITHUB_USERNAME=$USERNAME - echo "setting GITHUB_USERNAME: ${GITHUB_USERNAME} use -g command line option to change" -fi if [ ! -d clframework ]; then -echo "+++ Cloning clframework" git clone https://review.mlplatform.org/ml/ComputeLibrary clframework AssertZeroExitCode "Cloning CL Framework failed" fi pushd clframework > /dev/null -# Use the latest pinned version of the CL framework - -# For pinnning to a ref use this: -# CLFRAMEWORKREVISION="branches/arm_compute_19_02" # Release 19.02 -# git fetch https://review.mlplatform.org/ml/ComputeLibrary $CLFRAMEWORKREVISION && git checkout FETCH_HEAD +CLFRAMEWORKREVISION=$DEFAULT_CLFRAMEWORKREVISION +if [ ! -z "$CLFRAMEWORK_SHA" ]; then + CLFRAMEWORKREVISION=$CLFRAMEWORK_SHA +fi -# For pinning to a revision use this: -CLFRAMEWORKREVISION="b4a44ff3aa98d2b51f1621a7525db3f81108a1bd" # COMPMID-1995: Removed layout checks from Reduction ops git fetch https://review.mlplatform.org/ml/ComputeLibrary && git checkout ${CLFRAMEWORKREVISION} -AssertZeroExitCode +AssertZeroExitCode "Fetching and checking out ${CLFRAMEWORKREVISION} failed" # Set commit hook so we can submit reviews to gerrit (curl -Lo `git rev-parse --git-dir`/hooks/commit-msg https://review.mlplatform.org/tools/hooks/commit-msg; chmod +x `git rev-parse --git-dir`/hooks/commit-msg) -AssertZeroExitCode +AssertZeroExitCode "Setting commit hooks failed" popd > /dev/null # out of clframework popd > /dev/null # back to wherever we were when called +# Make sure the SHA of the revision that was checked out is the last line +# of output from the script... just in case we ever need it. +echo $CLFRAMEWORKREVISION exit 0 diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 528e1faefc..ec163b59c3 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -494,7 +494,8 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector supportedTypes = { DataType::Float32, DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 + DataType::QuantisedSymm16, + DataType::Float16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], @@ -526,7 +527,8 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector supportedTypes = { DataType::Float32, DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 + DataType::QuantisedSymm16, + DataType::Float16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], @@ -895,7 +897,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector supportedTypes = { DataType::Float32, DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 + DataType::QuantisedSymm16, + DataType::Float16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], @@ -926,7 +929,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons std::vector supportedTypes = { DataType::Float32, DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 + DataType::QuantisedSymm16, + DataType::Float16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 8ea923d599..b1c3ad79ac 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -27,6 +27,22 @@ std::unique_ptr RefWorkloadFactory::MakeWorkload(const QueueDescripto info); } +bool IsFloat16(const WorkloadInfo& info) +{ + auto checkFloat16 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::Float16;}; + auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkFloat16); + if (it != std::end(info.m_InputTensorInfos)) + { + return true; + } + it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkFloat16); + if (it != std::end(info.m_OutputTensorInfos)) + { + return true; + } + return false; +} + RefWorkloadFactory::RefWorkloadFactory() { } @@ -174,12 +190,20 @@ std::unique_ptr RefWorkloadFactory::CreateNormalization( std::unique_ptr RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info) const { + if (IsFloat16(info)) + { + return MakeWorkload(descriptor, info); + } return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateMultiplication( const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const { + if (IsFloat16(info)) + { + return MakeWorkload(descriptor, info); + } return std::make_unique(descriptor, info); } @@ -266,12 +290,20 @@ std::unique_ptr RefWorkloadFactory::CreateConvertFp32ToFp16( std::unique_ptr RefWorkloadFactory::CreateDivision( const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const { + if (IsFloat16(info)) + { + return MakeWorkload(descriptor, info); + } return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateSubtraction( const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const { + if (IsFloat16(info)) + { + return MakeWorkload(descriptor, info); + } return std::make_unique(descriptor, info); } -- cgit v1.2.1