Pytorch Flask服务部署图片识别(学习笔记) 电脑版发表于:2024/1/5 14:16 ![](https://img.tnblog.net/arcimg/hb/21f086c80c5d4afda1bc1029dadd8f3a.png) >#Pytorch Flask服务部署图片识别(学习笔记) [TOC] ## Flask 简介 tn2>Flask是一个用Python编写的轻量级Web应用框架。 它简单易用,但同时也足够灵活和强大,能够支持复杂的Web应用。 由于其轻量级的特性,Flask非常适合用作在Web上部署机器学习模型的工具。 tn>简单来讲:启动一个服务,根据传上来的东西进行预测并返回结果。 ## 安装Flask ```python python -m pip install flask ``` ## 实践目录 ![](https://img.tnblog.net/arcimg/hb/46fd7cddc24b4568a29fd34886586152.png) | 文件或文件夹 | 描述 | | ------------ | ------------ | | `flower_data` | 训练的图像数据 | | `best.pth` | 训练好的模型 | | `flask_server.py` | 服务器端代码 | | `flask_predict.py` | 客户端请求代码 | ## 服务器端 tn2>服务器对需要预处理的图片流程如下图所示: ![](https://img.tnblog.net/arcimg/hb/7a2167bd6e97406e9b6ef14875ad5005.png) tn2>`flask_server.py`代码如下所示: ```python import io import json # flask 服务 import flask import torch import torch import torch.nn.functional as F from PIL import Image from torch import nn #from torchvision import transforms as T from torchvision import transforms, models, datasets from torch.autograd import Variable # 初始化Flask app app = flask.Flask(__name__) model = None use_gpu = False # 加载模型进来 def load_model(): """Load the pre-trained model, you can use your model just as easily. """ # 定义一个全局变量 global model #这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息 model = models.resnet18() num_ftrs = model.fc.in_features model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 102类的分类任务 #print(model) 加载模型 checkpoint = torch.load('best.pth') # 加载权重参数 model.load_state_dict(checkpoint['state_dict']) #将模型指定为测试格式 model.eval() #是否使用gpu if use_gpu: model.cuda() # 数据预处理 def prepare_image(image, target_size): """Do image preprocessing before prediction on any data. :param image: original image :param target_size: target image size :return: preprocessed image """ #针对不同模型,image的格式不同,但需要统一至RGB格式 if image.mode != 'RGB': image = image.convert("RGB") # Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor) # 图片与训练尺寸大小一致 image = transforms.Resize(target_size)(image) # 转tensor格式 image = transforms.ToTensor()(image) # Convert to Torch.Tensor and normalize. mean与std (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致 # 设置均值和标准差 image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) # Add batch_size axis.增加一个维度,用于按batch测试 本次这里一次测试一张 # 举例:1*3*64*64 image = image[None] if use_gpu: image = image.cuda() return Variable(image, volatile=True) #不需要求导 # 开启服务 这里的predict是API路径、使用POST请求 @app.route("/predict", methods=["POST"]) def predict(): # Initialize the data dictionary that will be returned from the view. #做一个标志,刚开始无图像传入时为false,传入图像时为true data = {"success": False} # 如果收到请求 if flask.request.method == 'POST': #判断是否为图像 if flask.request.files.get("image"): # Read the image in PIL format # 将收到的图像进行读取 image = flask.request.files["image"].read() image = Image.open(io.BytesIO(image)) #二进制数据 # 利用上面的预处理函数将读入的图像进行预处理 image = prepare_image(image, target_size=(64, 64)) # 放入模型中进行预测,softmax得到各个类别的概率 preds = F.softmax(model(image), dim=1) # k找出类别前3高的 results = torch.topk(preds.cpu().data, k=3, dim=1) # 结果转成cpu最后转成numpy results = (results[0].cpu().numpy(), results[1].cpu().numpy()) #将data字典增加一个key,value,其中value为list格式 data['predictions'] = list() # 遍历每一个预测结果 for prob, label in zip(results[0][0], results[1][0]): #label_name = idx2label[str(label)] # label真实值,和probability概率值 r = {"label": str(label), "probability": float(prob)} # 将预测结果添加至data字典 data['predictions'].append(r) # Indicate that the request was a success. data["success"] = True # 将最终结果以json格式文件传出 return flask.jsonify(data) """ test_json = { "status_code": 200, "success": { "message": "image uploaded", "code": 200 }, "video":{ "video_name":opt['source'].split('/')[-1], "video_path":opt['source'], "description":"1", "length": str(hour)+','+str(minute)+','+str(round(second,4)), "model_object_completed":model_flag } "status_txt": "OK" } response = requests.post( 'http://xxx.xxx.xxx.xxx:8090/api/ObjectToKafka/',, data={'json': str(test_json)}) """ if __name__ == '__main__': print("Loading PyTorch model and Flask starting server ...") print("Please wait until server has fully started") #先加载模型 load_model() #再开启服务 app.run(port='5012') ``` tn2>这里我开放的端口是`5012`,通过请求`/predict`链接,通过执行如下命令将程序跑起来: ```python python flask_server.py ``` ![](https://img.tnblog.net/arcimg/hb/9ff98d7c5d80431485ced42218f59147.png) tn>只要把Flask关了模型就没了,如果Flask一直开着的模型就一直都在跑。 ## 客户端 tn2>客户端主要是上传一张`image_06998.jpg`的图片到服务器中去预测,代码如下: ```python import requests import argparse # url和端口携程自己的 flask_url = 'http://127.0.0.1:5012/predict' def predict_result(image_path): #传入本地图片 image = open(image_path, 'rb').read() payload = {'image': image} #request发给server. r = requests.post(flask_url, files=payload).json() # 成功的话在返回. if r['success']: # 输出结果. for (i, result) in enumerate(r['predictions']): print('{}. {}: {:.4f}'.format(i + 1, result['label'], result['probability'])) # 失败了就打印. else: print('Request failed') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Classification demo') # 添加参数 parser.add_argument('--file', default='./flower_data/train_filelist/image_06998.jpg', type=str, help='test image file') args = parser.parse_args() # 开始请求 predict_result(args.file) ``` ```bash python flask_predict.py ``` tn2>预测结果如下所示: ![](https://img.tnblog.net/arcimg/hb/5d69d414c6d54d05bf372795de3bee4b.png) tn2>我们可以看到预测得最相似的label是`34`,准确率`97%`,我们去图片数据中找找这张图片的训练集验证一下。 ![](https://img.tnblog.net/arcimg/hb/e9edae8bec7c47d88497857ac7be15d3.png) tn2>训练的结果与预期的结果一致。 tn><a href="https://download.tnblog.net/resource/index/a6f1480a5ea54461854604818dea347a">代码链接</a>