添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
模型部署3/3-手把手实现利用flask深度学习模型部署

模型部署3/3-手把手实现利用flask深度学习模型部署

2 年前 · 来自专栏 机器学习&深度学习模型应用和部署

什么是Flask?

Flask是基于Python编写的轻量级Web应用框架 ,用他可以实现 深度学习模型部署为web应用 ,阅读完本文最终能够实现:

首先一个 简单的例子 了解一下 flask的基本框架

1. 下载flask

在anaconda prompt输入:

python -m pip --default-timeout=100 install -i https://pypi.tuna.tsinghua.edu.cn/simple flask

此方法来自于 《万能pip install大法》

2. 复制以下代码:

from flask import Flask
# 1. 定义app
app = Flask(__name__)
# 2. 定义函数
@app.route('/')
def hello_world():
 return 'hello,word!'
# 3. 定义ip和端口
if __name__ == "__main__":
    app.run(host='127.0.0.1', port=8080)

3. 运行后复制该地址到chrome浏览器

web端展示结果:

了解完flask后,下面将进行 基于flask的MNIST手写数字预测模型部署

完成后终端会创建以下 文件结构 。这些文件包含:

  • 前端web 页面的框架结构(.html)
  • 前端 页面的样式(.css)
  • 前端 页面按钮的交互 eg.predict和clear按钮 (.js)
  • 后台 flask应用程序(.py)
  • keras 深度学习模型文件(.h5)

STEP 1:复制model.h5文件到py文件所在文件夹

Model.h5文件 是在 模型部署1/3-构建MNIST手写字深度学习模型 生成的,同时也包含模型建立的过程。

更多的 keras模型保存和加载 的方式查看 模型部署2/3-保存和加载Keras模型的三种方式

STEP 2:创建index.js文件

创建Js文件到static文件夹, js文件能够解决前端交互的问题 ,复制以下代码。

(function() {
  var canvas = document.querySelector("#canvas");
  var context = canvas.getContext("2d");
  canvas.width = 280;
  canvas.height = 280;
  var Mouse = { x: 0, y: 0 };
  var lastMouse = { x: 0, y: 0 };
  context.fillStyle = "black";
  context.fillRect(0, 0, canvas.width, canvas.height);
  context.color = "white";
  context.lineWidth = 15;
  context.lineJoin = context.lineCap = "round";
clearCanvas();
canvas.addEventListener( "mousemove",
  function(e) {
    lastMouse.x = Mouse.x;
    lastMouse.y = Mouse.y;
    Mouse.x = e.pageX - this.offsetLeft;
    Mouse.y = e.pageY - this.offsetTop;
}, false);
canvas.addEventListener("mousedown",
  function(e) {
    canvas.addEventListener("mousemove", onDraw, false);
}, false);
canvas.addEventListener("mouseup",
  function() {
    canvas.removeEventListener("mousemove", onDraw, false);
}, false);
/* Canvas Draw */
var onDraw = function() {
  context.lineWidth = context.lineWidth;
  context.lineJoin = "round";
  context.lineCap = "round";
  context.strokeStyle = context.color;
  context.beginPath();
  context.moveTo(lastMouse.x, lastMouse.y);
  context.lineTo(Mouse.x, Mouse.y);
  context.closePath();
  context.stroke();
/* This function clears the box */
function clearCanvas() {
  var clearButton = $("#clearButton");
  clearButton.on("click", function() {
    context.clearRect(0, 0, 280, 280);
    context.fillStyle = "black";
    context.fillRect(0, 0, canvas.width, canvas.height);
/* Slider control */
var slider = document.getElementById("myRange");
var output = document.getElementById("sliderValue");
output.innerHTML = slider.value;
slider.oninput = function() {
  output.innerHTML = this.value;
  context.lineWidth = $(this).val();
$("#lineWidth").change(function() {
  context.lineWidth = $(this).val();
}})();

STEP 3:创建style.css

创建css文件到static文件夹, css文件创建了前端页面的元素 ,复制以下代码。

 body {
  padding-top: 20px;
  padding-bottom: 20px;
.header, .footer {
  padding-right: 15px;
  padding-left: 15px;
.header {
  padding-bottom: 20px;
  border-bottom: 1px solid #e5e5e5;
.header h3 {
  margin-top: 0;
  margin-bottom: 0;
  line-height: 40px;
.footer {
  padding-top: 19px;
  color: #777;
  border-top: 1px solid #e5e5e5;
@media (min-width: 768px) {
  .container {
    max-width: 730px;
.container-narrow > hr {
  margin: 30px 0;
.jumbotron {
  text-align: center;
  border-bottom: 1px solid #e5e5e5;
  padding-top: 20px;
  padding-bottom: 20px;
.bodyDiv{
  text-align: center;
@media screen and (min-width: 768px) {
  .header,
  .footer {
    padding-right: 0;
    padding-left: 0;
.header {
  margin-bottom: 30px;
.jumbotron {
  border-bottom: 0;
@media screen and (max-width: 500px) {
  .slidecontainer{
    display: none;
.slidecontainer{
  float: left;
  width: 30%;
.jumbotronHeading{
  margin-bottom: 7vh;
.canvasDiv{
  display: flow-root;
  text-align: center;
}

STEP 4:创建index.html文件

创建html文件到templates文件夹, html文件创建了web的基本框架 ,复制以下代码。

更多的 前端知识 查看 第一章 产品经理必懂的前端技术

 <!DOCTYPE html>
<html lang="en">
<meta charset="utf-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<!-- The above 3 meta tags *must* come first in the head; any other head content must come *after* these tags -->
<title>MNIST Handwritten text recognition using keras</title>
<!-- Bootstrap core CSS -->
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous"/>
<link rel="stylesheet" href="{{ url_for('static',filename='style.css') }}"/>
</head>
    <div class="container">
      <div class="header clearfix">
        <h3 class="text-muted">MNIST Handwritten CNN</h3>
      <div class="jumbotron">
        <h3 class="jumbotronHeading">Draw the digit inside this Box   </h3>
        <div class="slidecontainer">
          <p>Drag the slider to change the line width.</p>
          <input type="range" min="10" max="50" value="15" id="myRange" />
          <p>Value: <span id="sliderValue"></span></p>
        <div class="canvasDiv">
          <canvas id="canvas" width="280" height="280" style="padding-bottom: 20px">
          </canvas>
          <br />
          <p style="text-align:center;">
          <a class="btn btn-success myButton" href="#" role="button">Predict</a>
          <a class="btn btn-primary" href="#" id="clearButton" role="button">Clear</a>
      <div class="jumbotron">
        <p id="result">Get your prediction here!!!</p>
      <footer class="footer">
        <p>Keras MNIST</p>
      </footer>
<!-- /container -->
<script src="http://cdnjs.cloudflare.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script src="{{ url_for('static',filename='index.js') }}"></script><script type="text/javascript">
$(".myButton").click(function() {
  var $SCRIPT_ROOT = {{ request.script_root|tojson|safe }};
  var canvasObj = document.getElementById("canvas");
  var img = canvasObj.toDataURL();
  $.ajax({
    type: "POST",
    url: $SCRIPT_ROOT + "/predict/",
    data: img,
    success: function(data){
      $('#result').text(' Predicted Output: '+data);
</script>
</body>
</html>

STEP 5:创建keras_flask.py文件

该文件包含了 调用flask函数 ,具体的查看代码注释。

from flask import Flask, render_template, request
from scipy.misc import imread, imresize, imsave
import tensorflow as tf
import numpy as np
import re
import base64
from tensorflow.keras.models import load_model
from tensorflow.python.keras.backend import set_session
# 1. 初始化 flask app
app = Flask(__name__)
# 2. 初始化global variables
sess = tf.Session()
graph = tf.get_default_graph()
# 3. 将用户画的图输出成output.png
def convertImage(imgData1):
  imgstr = re.search(r'base64,(.*)', str(imgData1)).group(1)
  with open('output.png', 'wb') as output:
    output.write(base64.b64decode(imgstr))
# 4. 搭建前端框架
@app.route('/')
def index():
  return render_template("index.html")
# 5. 定义预测函数
@app.route('/predict/', methods=['GET', 'POST'])
def predict():
  # 这个函数会在用户点击‘predict’按钮时触发
  # 会将输出的output.png放入模型中进行预测
  # 同时在页面上输出预测结果
  imgData = request.get_data()
  convertImage(imgData)
  # 读取图片
  x = imread('output.png', mode='L')
  # 设置图片的规格
  x = imresize(x, (28, 28))/255
  # 可以保存最终处理好的图片
  imsave('final_image.jpg', x)
  x = x.reshape(1, 28, 28, 1)
  # 调用训练好的模型和并进行预测
  global graph
  global sess
  with graph.as_default():
    set_session(sess)
    model = load_model('model.h5')
    out = model.predict(x)
    response = np.argmax(out, axis=1)