Add mask download

main
ihmily 11 months ago
parent 2139ff5b3e
commit 8d566e3559

@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-
import sys
from fastapi import FastAPI, File, UploadFile, Form, Response
import os
import uuid
from datetime import datetime
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi import Request
import requests
import cv2
@ -23,6 +26,12 @@ default_model = list(model_paths.keys())[0]
default_model_info = model_paths[default_model]
loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])}
UPLOAD_FOLDER = "./upload"
OUTPUT_FOLDER = "./output"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
class ModelLoader:
def __init__(self):
@ -62,13 +71,26 @@ async def matting(image: UploadFile = File(...), model: str = Form(default=defau
selected_model = model_loader.load_model(model)
filename = uuid.uuid4()
original_image_filename = f"original_{filename}.png"
image_filename = f"image_{filename}.png"
mask_filename = f"mask_{filename}.png"
cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)
result = selected_model(img)
output_img = result[OutputKeys.OUTPUT_IMG]
cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])
output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
response_data = {
"result_image_url": f"/output/{image_filename}",
"mask_image_url": f"/output/{mask_filename}",
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
return Response(content=output_bytes, media_type='image/png')
return response_data
@app.post("/matting/url")
@ -77,34 +99,50 @@ async def matting_url(request: Request, model: str = Form(default=default_model,
json_data = await request.json()
image_url = json_data.get("image_url")
except Exception as e:
return {"content": f"Error parsing JSON data: {str(e)}"}, 400
raise HTTPException(status_code=400, detail=f"Error parsing JSON data: {str(e)}")
if not image_url:
return {"content": "Image URL is required"}, 400
raise HTTPException(status_code=400, detail="Image URL is required")
try:
response = requests.get(image_url)
if response.status_code != 200:
return {"content": "Failed to fetch image from URL"}, 400
response.raise_for_status()
img_array = np.frombuffer(response.content, dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
except requests.RequestException as e:
raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {str(e)}")
if model not in model_paths:
return {"content": "Invalid model selection"}, 400
raise HTTPException(status_code=400, detail="Invalid model selection")
selected_model = model_loader.load_model(model)
filename = uuid.uuid4()
original_image_filename = f"original_{filename}.png"
image_filename = f"image_{filename}.png"
mask_filename = f"mask_{filename}.png"
cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)
result = selected_model(img)
output_img = result[OutputKeys.OUTPUT_IMG]
cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])
output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
response_data = {
"result_image_url": f"/output/{image_filename}",
"mask_image_url": f"/output/{mask_filename}",
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
return Response(content=output_bytes, media_type='image/png')
return response_data
templates = Jinja2Templates(directory="web")
app.mount("/static", StaticFiles(directory="web/static"), name="static")
app.mount("/static", StaticFiles(directory="./web/static"), name="static")
app.mount("/output", StaticFiles(directory="./output"), name="output")
app.mount("/upload", StaticFiles(directory="./upload"), name="upload")
@app.get("/")
@ -115,5 +153,6 @@ async def read_index(request: Request):
if __name__ == "__main__":
import uvicorn
defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1"
uvicorn.run(app, host=defult_bind_host, port=8000)

@ -4,54 +4,13 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Image Matting</title>
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
}
header {
background-color: #333;
color: white;
padding: 1em;
text-align: center;
}
main {
max-width: 800px;
margin: 2em auto;
background-color: white;
padding: 2em;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
text-align: center;
}
#upload-form {
text-align: center;
margin-bottom: 2em;
}
#images-container {
display: flex;
justify-content: space-between;
align-items: center;
}
#original-img,
#result-img {
max-width: 45%;
height: auto;
border: 1px solid #ccc;
}
</style>
<title>Simple Image Matting</title>
<link rel="stylesheet" href="./static/css/style.css">
</head>
<body>
<header>
<h1>Image Matting</h1>
<h1>Simple Image Matting</h1>
</header>
<a href="https://github.com/ihmily/image-matting"><img style="position: absolute; top: 0; right: 0; border: 0;" decoding="async" width="149" height="149" src="/static/images/forkme_right_gray_6d6d6d.png" class="attachment-full size-full" alt="Fork me on GitHub" loading="lazy" data-recalc-dims="1"></a>
<main>
@ -71,10 +30,15 @@
<div id="images-container">
<img id="original-img" alt=" " src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7">
<img id="mask-img" alt=" " src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7">
<img id="result-img" alt=" " src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7">
</div>
</main>
<footer>
<p>&copy; 2024 Hmily. All rights reserved.</p>
</footer>
<script>
document.getElementById("image-upload").addEventListener("change", function() {
const inputElement = this;
@ -84,6 +48,7 @@
const originalImgElement = document.getElementById("original-img");
originalImgElement.src = URL.createObjectURL(file);
document.getElementById("result-img").src = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7";
document.getElementById("mask-img").src = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7";
}
});
@ -106,9 +71,15 @@
});
if (response.ok) {
const resultBlob = await response.blob();
const responseData = await response.json();
const resultImgElement = document.getElementById("result-img");
resultImgElement.src = URL.createObjectURL(resultBlob);
resultImgElement.src = responseData.result_image_url;
const maskImgElement = document.getElementById("mask-img");
maskImgElement.src = responseData.mask_image_url;
console.log(`Matting successful!\nOriginal Image Size: ${responseData.original_image_size.width} x ${responseData.original_image_size.height}\nGeneration Time: ${responseData.generation_time}`);
} else {
alert("Matting failed. Please try again.");
}

@ -0,0 +1,53 @@
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
}
header {
background-color: #333;
color: white;
padding: 1em;
text-align: center;
}
main {
max-width: 1200px;
margin: 2em auto;
background-color: white;
padding: 2em;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
text-align: center;
}
#upload-form {
text-align: center;
margin-bottom: 2em;
}
#images-container {
display: flex;
justify-content: space-between;
align-items: center;
margin-top: 2em; /* Add spacing */
}
#original-img,
#mask-img,
#result-img {
max-width: 30%;
height: auto;
border: 1px solid #ccc;
margin: 0.5em; /* Add spacing */
}
footer {
background-color: #333;
color: white;
text-align: center;
padding: 1em;
position: fixed;
bottom: 0;
width: 100%;
}
Loading…
Cancel
Save