如何部署自己的SSD检测模型到Android TFLite上

简介: TensorFlow Object Detection API 上提供了使用SSD部署到TFLite运行上去的方法, 可是这套API封装太死板, 如果你要自己实现了一套SSD的训练算法,应该怎么才能部署到TFLite上呢?   首先,抛开后处理的部分,你的SSD模型(无论是VGG-SSD和Mobilenet-SSD), 你最终的模型的输出是对class_predictions和bbo
TensorFlow Object Detection API 上提供了使用SSD部署到TFLite运行上去的方法, 可是这套API封装太死板, 如果你要自己实现了一套SSD的训练算法,应该怎么才能部署到TFLite上呢?
 
首先,抛开后处理的部分,你的SSD模型(无论是VGG-SSD和Mobilenet-SSD), 你最终的模型的输出是对class_predictions和bbox_predictions; 并且是encoded的
 
 

Encoding的方式:

class_predictions: M个Feature Layer, Feature Layer的大小(宽高)视网络结构而定; 每个Feature Layer有Num_Anchor_Depth_of_this_layer x Num_classes个channels
 
box_predictions:   M个Feature Layer; 每个Feature Layer有Num_Anchor_Depth_of_this_layer x 4个channes 这4个channel分别代表dy,dx,h,w, 即bbox中心距离anchor中心坐标的偏移量和宽高
注:通常,为了平衡loss之间的大小, 不会直接编码dy,dx,w,h的原始值,而是dy/anchor_h*scale0, dx/anchor_w*scale0, log(h/anchor_h)*scale1, log(w/anchor_w)*scale1, 也就是偏移量的绝对值除anchor宽高得到相对值,然后再乘上一个scale, 经验值 scale0取5,scale1取10; 对于h,w是对得到相对值后先取log再乘以scale, h/anchor_h的范围在1附近, 取log后可以转换到0附近;所以在解码的时候需要做对应相反的变换;
在后面TFLite_Detection_PostProcess的Op实现里就有这么一段:
 
 
然后我们需要的是做的是decode出来对每个class的confidence和location的预测值
 
 

后处理

在Object Detection API的 export_tflite_ssd_graph_lib.py文件中,你可以看到,它区别与直接freeze pb的操作就在于最后替换了后处理的部分;
 
Plain Text
Plain Text
Bash
Basic
C
C++
C#
CSS
C++
Diff
Git
go
GraphQL
HTML
HTTP
Java
JavaScript
JSON
JSX
Kotlin
Less
Makefile
Markdown
MATLAB
Nginx
Objective-C
Pascal
Perl
PHP
PowerShell
Ruby
Protobuf
Python
R
Ruby
Scala
Shell
SQL
Swift
TypeScript
VB.net
XML
YAML
KaTeX
Mermaid
PlantUML
Flow
Graphviz
AI 代码解读
frozen_graph_def = exporter.freeze_graph_with_def_protos(
AI 代码解读
input_graph_def=tf.get_default_graph().as_graph_def(),
AI 代码解读
input_saver_def=input_saver_def,
AI 代码解读
input_checkpoint=checkpoint_to_use,
AI 代码解读
output_node_names=','.join([
AI 代码解读
'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
AI 代码解读
'anchors'
AI 代码解读
]),
AI 代码解读
restore_op_name='save/restore_all',
AI 代码解读
filename_tensor_name='save/Const:0',
AI 代码解读
clear_devices=True,
AI 代码解读
output_graph='',
AI 代码解读
initializer_nodes='')
AI 代码解读
 
AI 代码解读
# Add new operation to do post processing in a custom op (TF Lite only)
AI 代码解读
if add_postprocessing_op:
AI 代码解读
transformed_graph_def = append_postprocessing_op(
AI 代码解读
frozen_graph_def, max_detections, max_classes_per_detection,
AI 代码解读
nms_score_threshold, nms_iou_threshold, num_classes, scale_values)
AI 代码解读
else:
AI 代码解读
# Return frozen without adding post-processing custom op
AI 代码解读
transformed_graph_def = frozen_graph_def
AI 代码解读
 
后处理的部分,其实看代码也很简单,就是增加了一个叫TFLite_Detection_PostProcess的node,用于解码和非极大抑制. 这个node的输入就是上面提到的box_predictions和class_predictions, 还有anchors的编码; 用这个node的目的只TFLite并不支持tf.contrib.image.non_max_surpression操作
 
 

Reshape过程:

这里需要明确,TFLite_Detection_PostProcess 这个op对raw_outputs/box_encodings, raw_outputs/class_predictions, anchors的Shape是有一个定制要求的
raw_outputs/box_encodings.shape=[1, num_anchors,4]
raw_outputs/class_predictions.shape=[1, num_anchors,num_classes+1]
anchors.shape=[1,num_anchors,4]
这里需要注意:1, 这三个都必须是3维的Tensor; 2.raw_outputs/class_predictions.shape的最后一个维度是包含background的classes, 也就是是num_classes+1; TFLite_Detection_PostProcess还有一个参数num_classes, 这个参数值是不包含background的, 所以也就导致TFLite_Detection_PostProcess的输出的class index是从0计数的;
 
 
Plain Text
Plain Text
Bash
Basic
C
C++
C#
CSS
C++
Diff
Git
go
GraphQL
HTML
HTTP
Java
JavaScript
JSON
JSX
Kotlin
Less
Makefile
Markdown
MATLAB
Nginx
Objective-C
Pascal
Perl
PHP
PowerShell
Ruby
Protobuf
Python
R
Ruby
Scala
Shell
SQL
Swift
TypeScript
VB.net
XML
YAML
KaTeX
Mermaid
PlantUML
Flow
Graphviz
AI 代码解读
with tf.variable_scope('raw_outputs'):
AI 代码解读
cls_pred = [tf.reshape(pred, [-1, num_classes]) for pred in cls_pred]
AI 代码解读
location_pred = [tf.reshape(pred, [-1, 4]) for pred in location_pred]
AI 代码解读
cls_pred = tf.concat(cls_pred, axis=0)
AI 代码解读
location_pred = tf.expand_dims(tf.concat(location_pred, axis=0),0, name='box_encodings')
AI 代码解读
 
AI 代码解读
cls_pred=tf.nn.softmax(cls_pred)
AI 代码解读
 
AI 代码解读
tf.identity(tf.expand_dims(cls_pred,0), name='class_predictions')
AI 代码解读
 
AI 代码解读
 
 
这段代码就是用来reshape成要求的输入的, 需要注意的是对class_prediction需要做依次softmax或者sigmoid, 具体选择哪种取决于你是否允许一个anchor对应多个类;
对于anchors, 这其实是一constant的值:
 
Plain Text
Plain Text
Bash
Basic
C
C++
C#
CSS
C++
Diff
Git
go
GraphQL
HTML
HTTP
Java
JavaScript
JSON
JSX
Kotlin
Less
Makefile
Markdown
MATLAB
Nginx
Objective-C
Pascal
Perl
PHP
PowerShell
Ruby
Protobuf
Python
R
Ruby
Scala
Shell
SQL
Swift
TypeScript
VB.net
XML
YAML
KaTeX
Mermaid
PlantUML
Flow
Graphviz
AI 代码解读
num_anchors = anchor_cy.get_shape().as_list()
AI 代码解读
with tf.Session() as sess:
AI 代码解读
y_out, x_out, h_out, w_out = sess.run([anchor_cy, anchor_cx, anchor_h, anchor_w])
AI 代码解读
encoded_anchors = tf.constant(
AI 代码解读
np.transpose(np.stack((y_out, x_out, h_out, w_out))),
AI 代码解读
dtype=tf.float32,
AI 代码解读
shape=[num_anchors[0], 4])
AI 代码解读
 
注意: 之前我使用tf.stack合成这个值的时候发现,TFLite只支持axis=0的时候的tf.stack, 否则就会转换是吧
 
 

导出pb

添加完后处理,既可以导出一个带有后处理功能的pb文件了; 如果你不添加后处理,把它放在CPU上后续去做,其实也可以省去不少麻烦;
 
Plain Text
Plain Text
Bash
Basic
C
C++
C#
CSS
C++
Diff
Git
go
GraphQL
HTML
HTTP
Java
JavaScript
JSON
JSX
Kotlin
Less
Makefile
Markdown
MATLAB
Nginx
Objective-C
Pascal
Perl
PHP
PowerShell
Ruby
Protobuf
Python
R
Ruby
Scala
Shell
SQL
Swift
TypeScript
VB.net
XML
YAML
KaTeX
Mermaid
PlantUML
Flow
Graphviz
AI 代码解读
binary_graph = os.path.join(output_dir, 'tflite_graph.pb')
AI 代码解读
with tf.gfile.GFile(binary_graph, 'wb') as f:
AI 代码解读
f.write(transformed_graph_def.SerializeToString())
AI 代码解读
 
AI 代码解读
txt_graph = os.path.join(output_dir, 'tflite_graph.pbtxt')
AI 代码解读
with tf.gfile.GFile(txt_graph, 'w') as f:
AI 代码解读
f.write(str(transformed_graph_def))
AI 代码解读
 
注意: 导出的pb如果包含后处理, 是没办法用正常的TF执行的,必须转成tflite执行
 
 
 

导出tflite

 
Plain Text
Plain Text
Bash
Basic
C
C++
C#
CSS
C++
Diff
Git
go
GraphQL
HTML
HTTP
Java
JavaScript
JSON
JSX
Kotlin
Less
Makefile
Markdown
MATLAB
Nginx
Objective-C
Pascal
Perl
PHP
PowerShell
Ruby
Protobuf
Python
R
Ruby
Scala
Shell
SQL
Swift
TypeScript
VB.net
XML
YAML
KaTeX
Mermaid
PlantUML
Flow
Graphviz
AI 代码解读
bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \
AI 代码解读
--input_file=$OUTPUT_DIR/tflite_graph.pb \
AI 代码解读
--output_file=$OUTPUT_DIR/detect.tflite \
AI 代码解读
--input_shapes=1,300,300,3 \
AI 代码解读
--input_arrays=normalized_input_image_tensor \
AI 代码解读
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
AI 代码解读
--inference_type=QUANTIZED_UINT8 \
AI 代码解读
--mean_values=128 \
AI 代码解读
--std_values=128 \
AI 代码解读
--change_concat_input_ranges=false \
AI 代码解读
--allow_custom_ops
AI 代码解读
 
AI 代码解读
or
AI 代码解读
 
AI 代码解读
bazel run -c opt tensorflow/lite/toco:toco -- \
AI 代码解读
--input_file=$OUTPUT_DIR/tflite_graph.pb \
AI 代码解读
--output_file=$OUTPUT_DIR/detect.tflite \
AI 代码解读
--input_shapes=1,300,300,3 \
AI 代码解读
--input_arrays=normalized_input_image_tensor \
AI 代码解读
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
AI 代码解读
--inference_type=FLOAT \
AI 代码解读
--allow_custom_ops
AI 代码解读
 
 
导出的过程中,可能遇到Converting unsupported operation: TFLite_Detection_PostProcess 这个提示, 正常如果是TF在1.10以上就忽略这个提示好了
然后你可以先用python的程序加载这个tflite去测试一下
注意: 这时候会发现一个问题, TFLite_Detection_PostProcess的NMS操作是忽略类标签的,如果你设置max_classes_per_detection=1; 但是如果你设置成>1的值, 会发现它吧background的标签也算进来了, 导致出来很多误检测的bbox;
 
 

部署Android

然后,你可以尝试部署到Android上, 在不使用NNAPI的时候正常,但是如果是NNAPI就需要自己实现相关操作了,否则会crash掉

 

烁凡
+关注
目录
打赏
0
0
0
2
49
分享
相关文章
安卓与iOS开发中的跨平台策略:一次编码,多平台部署
在移动应用开发的广阔天地中,安卓和iOS两大阵营各占一方。随着技术的发展,跨平台开发框架应运而生,它们承诺着“一次编码,到处运行”的便捷。本文将深入探讨跨平台开发的现状、挑战以及未来趋势,同时通过代码示例揭示跨平台工具的实际运用。
194 3
安卓应用开发中的内存泄漏检测与修复
【9月更文挑战第30天】在安卓应用开发过程中,内存泄漏是一个常见而又棘手的问题。它不仅会导致应用运行缓慢,还可能引发应用崩溃,严重影响用户体验。本文将深入探讨如何检测和修复内存泄漏,以提升应用性能和稳定性。我们将通过一个具体的代码示例,展示如何使用Android Studio的Memory Profiler工具来定位内存泄漏,并介绍几种常见的内存泄漏场景及其解决方案。无论你是初学者还是有经验的开发者,这篇文章都将为你提供实用的技巧和方法,帮助你打造更优质的安卓应用。
惊爆!Uno Platform 调试与性能分析终极攻略,从工具运用到代码优化,带你攻克开发难题成就完美应用
【8月更文挑战第31天】在 Uno Platform 中,调试可通过 Visual Studio 设置断点和逐步执行代码实现,同时浏览器开发者工具有助于 Web 版本调试。性能分析则利用 Visual Studio 的性能分析器检查 CPU 和内存使用情况,还可通过记录时间戳进行简单分析。优化性能涉及代码逻辑优化、资源管理和用户界面简化,综合利用平台提供的工具和技术,确保应用高效稳定运行。
146 0
打造个性化安卓应用:从设计到部署的全栈之旅
【8月更文挑战第31天】在数字化时代的浪潮中,移动应用已成为人们日常生活的一部分。本文将带你走进安卓应用的开发世界,从设计理念到实际编码,再到最终的用户手中,我们将一起探索如何将一个想法转变为现实中触手可及的应用。你将学习到如何利用安卓开发工具包(SDK)和编程语言(如Kotlin或Java),结合Material Design设计原则,创建出既美观又实用的应用。此外,我们还将讨论如何通过Google Play将应用发布给全球用户,并确保应用的安全性与维护性。无论你是初学者还是有一定经验的开发者,这篇文章都将为你提供宝贵的知识和启发。
打造个性化安卓应用:从设计到部署的全攻略
【8月更文挑战第31天】在这篇文章中,我们将一起探索如何从零开始构建一个安卓应用,并为其添加个人特色。我们将通过实际的代码示例,学习如何使用Android Studio进行开发,以及如何将应用发布到Google Play商店。无论你是编程新手还是有经验的开发者,这篇文章都将为你提供有价值的见解和技巧,帮助你打造独一无二的安卓应用。
探究Android应用开发中的内存泄漏检测与修复
在移动应用的开发过程中,优化用户体验和提升性能是至关重要的。对于Android平台而言,内存泄漏是一个常见且棘手的问题,它可能导致应用运行缓慢甚至崩溃。本文将深入探讨如何有效识别和解决内存泄漏问题,通过具体案例分析,揭示内存泄漏的成因,并提出相应的检测工具和方法。我们还将讨论一些最佳实践,帮助开发者预防内存泄漏,确保应用稳定高效地运行。
基于ssm+vue.js+uniapp小程序的安卓的微博客系统附带文章和源代码部署视频讲解等
基于ssm+vue.js+uniapp小程序的安卓的微博客系统附带文章和源代码部署视频讲解等
68 2
基于springboot+vue.js+uniapp的高校后勤网上报修系统安卓app附带文章源码部署视频讲解等
基于springboot+vue.js+uniapp的高校后勤网上报修系统安卓app附带文章源码部署视频讲解等
84 0
【转】Android线程模型(AsyncTask的使用)
【转】Android线程模型(AsyncTask的使用)
56 1
Termux安卓终端美化与开发实战:从下载到插件优化,小白也能玩转Linux
Termux是一款安卓平台上的开源终端模拟器,支持apt包管理、SSH连接及Python/Node.js/C++开发环境搭建,被誉为“手机上的Linux系统”。其特点包括零ROOT权限、跨平台开发和强大扩展性。本文详细介绍其安装准备、基础与高级环境配置、必备插件推荐、常见问题解决方法以及延伸学习资源,帮助用户充分利用Termux进行开发与学习。适用于Android 7+设备,原创内容转载请注明来源。
45 18

热门文章

最新文章

  • 1
    Android历史版本与APK文件结构
    18
  • 2
    【01】噩梦终结flutter配安卓android鸿蒙harmonyOS 以及next调试环境配鸿蒙和ios真机调试环境-flutter项目安卓环境配置-gradle-agp-ndkVersion模拟器运行真机测试环境-本地环境搭建-如何快速搭建android本地运行环境-优雅草卓伊凡-很多人在这步就被难倒了
    14
  • 3
    【03】仿站技术之python技术,看完学会再也不用去购买收费工具了-修改整体页面做好安卓下载发给客户-并且开始提交网站公安备案-作为APP下载落地页文娱产品一定要备案-包括安卓android下载(简单)-ios苹果plist下载(稍微麻烦一丢丢)-优雅草卓伊凡
    12
  • 4
    【03】微信支付商户申请下户到配置完整流程-微信开放平台创建APP应用-填写上传基础资料-生成安卓证书-获取Apk签名-申请+配置完整流程-优雅草卓伊凡
    27
  • 5
    【01】仿站技术之python技术,看完学会再也不用去购买收费工具了-用python扒一个app下载落地页-包括安卓android下载(简单)-ios苹果plist下载(稍微麻烦一丢丢)-客户的麻将软件需要下载落地页并且要做搜索引擎推广-本文用python语言快速开发爬取落地页下载-优雅草卓伊凡
    17
  • 6
    Cellebrite UFED 4PC 7.71 (Windows) - Android 和 iOS 移动设备取证软件
    5
  • 7
    【02】仿站技术之python技术,看完学会再也不用去购买收费工具了-本次找了小影-感觉页面很好看-本次是爬取vue需要用到Puppeteer库用node.js扒一个app下载落地页-包括安卓android下载(简单)-ios苹果plist下载(稍微麻烦一丢丢)-优雅草卓伊凡
    4
  • 8
    escrcpy:【技术党必看】Android开发,Escrcpy 让你无线投屏新体验!图形界面掌控 Android,30-120fps 超流畅!🔥
    6
  • 9
    Android实战经验之Kotlin中快速实现MVI架构
    8
  • 10
    即时通讯安全篇(一):正确地理解和使用Android端加密算法
    5