-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
94 lines (70 loc) · 2.95 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from flask import Flask, request, jsonify, render_template
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image
import os
#curl -X POST -F "file=@/Path/to_image" http://127.0.0.1:5000/predict
app = Flask(__name__)
import os
# Base project directory (you can dynamically fetch this if needed)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Path to the model relative to the project directory
MODEL_PATH = os.path.join(BASE_DIR, "model", "2024-11-21_vgg16_finetuned-2.pth")
# Define and load the model
model = models.vgg16()
# Customize the classifier from the one we did
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 2)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) # Map to appropriate device
model.to(device) # Send the model to the appropriate device
model.eval() # Set to evaluation mode
# Create a folder for uploaded images
UPLOAD_FOLDER = 'uploads'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# Define image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)), # Match your model input size
transforms.ToTensor()
])
@app.route('/')
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
# Check if the request contains a file
if 'image' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
# Get the file from the request
file = request.files['image']
try:
# Save the uploaded file to the server
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) # Define the path to save the file
file.save(file_path) # Save the file to the defined path
# Open the saved image file and preprocess it
# RGB is the most commonly used format
image = Image.open(file_path).convert('RGB')
# Apply preprocessing (resize, convert to tensor) and add batch dimension
image = transform(image).unsqueeze(0).to(device)
# Run the image through the model to get predictions
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
result = predicted.item()
# Clean up the uploaded file after prediction
os.remove(file_path)
# Return the prediction result as JSON
return jsonify({"result": int(result)})
# Handle any exceptions
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
# I put the Flask app in debug mode to test, we can change to false
app.run(debug=True)