Add mask download

main
ihmily 11 months ago
parent 2139ff5b3e
commit 8d566e3559

@ -1,7 +1,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys import sys
from fastapi import FastAPI, File, UploadFile, Form, Response import os
import uuid
from datetime import datetime
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi import Request from fastapi import Request
import requests import requests
import cv2 import cv2
@ -23,6 +26,12 @@ default_model = list(model_paths.keys())[0]
default_model_info = model_paths[default_model] default_model_info = model_paths[default_model]
loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])} loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])}
UPLOAD_FOLDER = "./upload"
OUTPUT_FOLDER = "./output"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
class ModelLoader: class ModelLoader:
def __init__(self): def __init__(self):
@ -62,13 +71,26 @@ async def matting(image: UploadFile = File(...), model: str = Form(default=defau
selected_model = model_loader.load_model(model) selected_model = model_loader.load_model(model)
filename = uuid.uuid4()
original_image_filename = f"original_{filename}.png"
image_filename = f"image_{filename}.png"
mask_filename = f"mask_{filename}.png"
cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)
result = selected_model(img) result = selected_model(img)
output_img = result[OutputKeys.OUTPUT_IMG] cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])
output_bytes = cv2.imencode('.png', output_img)[1].tobytes() response_data = {
"result_image_url": f"/output/{image_filename}",
"mask_image_url": f"/output/{mask_filename}",
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
return Response(content=output_bytes, media_type='image/png') return response_data
@app.post("/matting/url") @app.post("/matting/url")
@ -77,34 +99,50 @@ async def matting_url(request: Request, model: str = Form(default=default_model,
json_data = await request.json() json_data = await request.json()
image_url = json_data.get("image_url") image_url = json_data.get("image_url")
except Exception as e: except Exception as e:
return {"content": f"Error parsing JSON data: {str(e)}"}, 400 raise HTTPException(status_code=400, detail=f"Error parsing JSON data: {str(e)}")
if not image_url: if not image_url:
return {"content": "Image URL is required"}, 400 raise HTTPException(status_code=400, detail="Image URL is required")
try:
response = requests.get(image_url) response = requests.get(image_url)
if response.status_code != 200: response.raise_for_status()
return {"content": "Failed to fetch image from URL"}, 400
img_array = np.frombuffer(response.content, dtype=np.uint8) img_array = np.frombuffer(response.content, dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
except requests.RequestException as e:
raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {str(e)}")
if model not in model_paths: if model not in model_paths:
return {"content": "Invalid model selection"}, 400 raise HTTPException(status_code=400, detail="Invalid model selection")
selected_model = model_loader.load_model(model) selected_model = model_loader.load_model(model)
filename = uuid.uuid4()
original_image_filename = f"original_{filename}.png"
image_filename = f"image_{filename}.png"
mask_filename = f"mask_{filename}.png"
cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)
result = selected_model(img) result = selected_model(img)
output_img = result[OutputKeys.OUTPUT_IMG] cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])
output_bytes = cv2.imencode('.png', output_img)[1].tobytes() response_data = {
"result_image_url": f"/output/{image_filename}",
"mask_image_url": f"/output/{mask_filename}",
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
return Response(content=output_bytes, media_type='image/png') return response_data
templates = Jinja2Templates(directory="web") templates = Jinja2Templates(directory="web")
app.mount("/static", StaticFiles(directory="web/static"), name="static") app.mount("/static", StaticFiles(directory="./web/static"), name="static")
app.mount("/output", StaticFiles(directory="./output"), name="output")
app.mount("/upload", StaticFiles(directory="./upload"), name="upload")
@app.get("/") @app.get("/")
@ -115,5 +153,6 @@ async def read_index(request: Request):
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1" defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1"
uvicorn.run(app, host=defult_bind_host, port=8000) uvicorn.run(app, host=defult_bind_host, port=8000)

@ -4,54 +4,13 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Image Matting</title> <title>Simple Image Matting</title>
<style> <link rel="stylesheet" href="./static/css/style.css">
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
}
header {
background-color: #333;
color: white;
padding: 1em;
text-align: center;
}
main {
max-width: 800px;
margin: 2em auto;
background-color: white;
padding: 2em;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
text-align: center;
}
#upload-form {
text-align: center;
margin-bottom: 2em;
}
#images-container {
display: flex;
justify-content: space-between;
align-items: center;
}
#original-img,
#result-img {
max-width: 45%;
height: auto;
border: 1px solid #ccc;
}
</style>
</head> </head>
<body> <body>
<header> <header>
<h1>Image Matting</h1> <h1>Simple Image Matting</h1>
</header> </header>
<a href="https://github.com/ihmily/image-matting"><img style="position: absolute; top: 0; right: 0; border: 0;" decoding="async" width="149" height="149" src="/static/images/forkme_right_gray_6d6d6d.png" class="attachment-full size-full" alt="Fork me on GitHub" loading="lazy" data-recalc-dims="1"></a> <a href="https://github.com/ihmily/image-matting"><img style="position: absolute; top: 0; right: 0; border: 0;" decoding="async" width="149" height="149" src="/static/images/forkme_right_gray_6d6d6d.png" class="attachment-full size-full" alt="Fork me on GitHub" loading="lazy" data-recalc-dims="1"></a>
<main> <main>
@ -71,10 +30,15 @@
<div id="images-container"> <div id="images-container">
<img id="original-img" alt=" " src=""> <img id="original-img" alt=" " src="">
<img id="mask-img" alt=" " src="">
<img id="result-img" alt=" " src=""> <img id="result-img" alt=" " src="">
</div> </div>
</main> </main>
<footer>
<p>&copy; 2024 Hmily. All rights reserved.</p>
</footer>
<script> <script>
document.getElementById("image-upload").addEventListener("change", function() { document.getElementById("image-upload").addEventListener("change", function() {
const inputElement = this; const inputElement = this;
@ -84,6 +48,7 @@
const originalImgElement = document.getElementById("original-img"); const originalImgElement = document.getElementById("original-img");
originalImgElement.src = URL.createObjectURL(file); originalImgElement.src = URL.createObjectURL(file);
document.getElementById("result-img").src = ""; document.getElementById("result-img").src = "";
document.getElementById("mask-img").src = "";
} }
}); });
@ -106,9 +71,15 @@
}); });
if (response.ok) { if (response.ok) {
const resultBlob = await response.blob(); const responseData = await response.json();
const resultImgElement = document.getElementById("result-img"); const resultImgElement = document.getElementById("result-img");
resultImgElement.src = URL.createObjectURL(resultBlob); resultImgElement.src = responseData.result_image_url;
const maskImgElement = document.getElementById("mask-img");
maskImgElement.src = responseData.mask_image_url;
console.log(`Matting successful!\nOriginal Image Size: ${responseData.original_image_size.width} x ${responseData.original_image_size.height}\nGeneration Time: ${responseData.generation_time}`);
} else { } else {
alert("Matting failed. Please try again."); alert("Matting failed. Please try again.");
} }

@ -0,0 +1,53 @@
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
}
header {
background-color: #333;
color: white;
padding: 1em;
text-align: center;
}
main {
max-width: 1200px;
margin: 2em auto;
background-color: white;
padding: 2em;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
text-align: center;
}
#upload-form {
text-align: center;
margin-bottom: 2em;
}
#images-container {
display: flex;
justify-content: space-between;
align-items: center;
margin-top: 2em; /* Add spacing */
}
#original-img,
#mask-img,
#result-img {
max-width: 30%;
height: auto;
border: 1px solid #ccc;
margin: 0.5em; /* Add spacing */
}
footer {
background-color: #333;
color: white;
text-align: center;
padding: 1em;
position: fixed;
bottom: 0;
width: 100%;
}
Loading…
Cancel
Save