#! /usr/bin/env python from flask import Flask, render_template, request, Response import numpy as np from binascii import a2b_base64 import imageio from PIL import Image import io import time import ast import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms
global model_states, nb_epoch #to have access later model_states = ['Not Trained'] nb_epoch=5
app = Flask(name)
model =None #page to_train @app.route('/') def to_train(): return render_template('to_train.html', nb_epoch=nb_epoch)
#train the model @app.route("/loadmodel/", methods=['GET']) def load(): global model class NN(nn.Module): def init(self): super(NN, self).init() self.conv1L = nn.Conv2d(1, 20, 3, 1) self.conv2L = nn.Conv2d(20, 50, 3, 1) self.FC1 = nn.Linear(5550, 500) self.FC2 = nn.Linear(500 ,10)
def forward(self, x):
x = F.relu(self.conv1L(x)) #20 x 26
x = F.max_pool2d(x, (2,2)) #20 x 13
x = F.relu(self.conv2L(x)) #50 x 11
x = F.max_pool2d(x, (2,2)) #50 x 5x 5
x = x.view(-1, 50*5*5) #flatten
x = self.FC1(x) #500
x = self.FC2(x) #10
return F.log_softmax(x, dim=1)
checkpoint = torch.load("Meetup_MNIST.pt")
model = NN()
model.load_state_dict(checkpoint)
print("model loaded")
return "Loading done"
#page where you draw the number @app.route('/index/', methods=['GET','POST']) def index(): prediction='?' if request.method == 'POST':
dataURL = request.get_data()
drawURL_clean = dataURL[22:]
binary_data=a2b_base64(drawURL_clean)
img = Image.open(io.BytesIO(binary_data))
img.thumbnail((28,28))
img.save("data_img/draw.png")
return render_template('index.html', prediction=prediction)
#display prediction @app.route('/result/') def result(): time.sleep(0.2) img = Image.open("data_img/draw.png").convert("1") transform=transforms.Compose([transforms.ToTensor()]) img = transform(img) img = torch.unsqueeze(img , 0) prediction = inference(model , img) print(prediction) return render_template("index.html",prediction=prediction)
def inference(model , img): output = model(img) output = torch.exp(output) top_prob,top_class=output.topk(1,dim=1) return top_class.item()
if name == "main": app.run(debug=True, threaded=True)
|