|
|
@ -1,7 +1,10 @@
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
from fastapi import FastAPI, File, UploadFile, Form, Response
|
|
|
|
import os
|
|
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
|
|
|
from fastapi import Request
|
|
|
|
from fastapi import Request
|
|
|
|
import requests
|
|
|
|
import requests
|
|
|
|
import cv2
|
|
|
|
import cv2
|
|
|
@ -23,6 +26,12 @@ default_model = list(model_paths.keys())[0]
|
|
|
|
default_model_info = model_paths[default_model]
|
|
|
|
default_model_info = model_paths[default_model]
|
|
|
|
loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])}
|
|
|
|
loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UPLOAD_FOLDER = "./upload"
|
|
|
|
|
|
|
|
OUTPUT_FOLDER = "./output"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelLoader:
|
|
|
|
class ModelLoader:
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
@ -62,13 +71,26 @@ async def matting(image: UploadFile = File(...), model: str = Form(default=defau
|
|
|
|
|
|
|
|
|
|
|
|
selected_model = model_loader.load_model(model)
|
|
|
|
selected_model = model_loader.load_model(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filename = uuid.uuid4()
|
|
|
|
|
|
|
|
original_image_filename = f"original_{filename}.png"
|
|
|
|
|
|
|
|
image_filename = f"image_{filename}.png"
|
|
|
|
|
|
|
|
mask_filename = f"mask_{filename}.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)
|
|
|
|
|
|
|
|
|
|
|
|
result = selected_model(img)
|
|
|
|
result = selected_model(img)
|
|
|
|
|
|
|
|
|
|
|
|
output_img = result[OutputKeys.OUTPUT_IMG]
|
|
|
|
cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
|
|
|
|
|
|
|
|
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])
|
|
|
|
|
|
|
|
|
|
|
|
output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
|
|
|
|
response_data = {
|
|
|
|
|
|
|
|
"result_image_url": f"/output/{image_filename}",
|
|
|
|
|
|
|
|
"mask_image_url": f"/output/{mask_filename}",
|
|
|
|
|
|
|
|
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
|
|
|
|
|
|
|
|
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return Response(content=output_bytes, media_type='image/png')
|
|
|
|
return response_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/matting/url")
|
|
|
|
@app.post("/matting/url")
|
|
|
@ -77,34 +99,50 @@ async def matting_url(request: Request, model: str = Form(default=default_model,
|
|
|
|
json_data = await request.json()
|
|
|
|
json_data = await request.json()
|
|
|
|
image_url = json_data.get("image_url")
|
|
|
|
image_url = json_data.get("image_url")
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
return {"content": f"Error parsing JSON data: {str(e)}"}, 400
|
|
|
|
raise HTTPException(status_code=400, detail=f"Error parsing JSON data: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
if not image_url:
|
|
|
|
if not image_url:
|
|
|
|
return {"content": "Image URL is required"}, 400
|
|
|
|
raise HTTPException(status_code=400, detail="Image URL is required")
|
|
|
|
|
|
|
|
|
|
|
|
response = requests.get(image_url)
|
|
|
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
|
|
|
return {"content": "Failed to fetch image from URL"}, 400
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_array = np.frombuffer(response.content, dtype=np.uint8)
|
|
|
|
try:
|
|
|
|
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|
|
|
response = requests.get(image_url)
|
|
|
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
|
|
img_array = np.frombuffer(response.content, dtype=np.uint8)
|
|
|
|
|
|
|
|
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|
|
|
|
|
|
|
except requests.RequestException as e:
|
|
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
if model not in model_paths:
|
|
|
|
if model not in model_paths:
|
|
|
|
return {"content": "Invalid model selection"}, 400
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid model selection")
|
|
|
|
|
|
|
|
|
|
|
|
selected_model = model_loader.load_model(model)
|
|
|
|
selected_model = model_loader.load_model(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filename = uuid.uuid4()
|
|
|
|
|
|
|
|
original_image_filename = f"original_{filename}.png"
|
|
|
|
|
|
|
|
image_filename = f"image_{filename}.png"
|
|
|
|
|
|
|
|
mask_filename = f"mask_{filename}.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv2.imwrite(os.path.join(UPLOAD_FOLDER, original_image_filename), img)
|
|
|
|
|
|
|
|
|
|
|
|
result = selected_model(img)
|
|
|
|
result = selected_model(img)
|
|
|
|
|
|
|
|
|
|
|
|
output_img = result[OutputKeys.OUTPUT_IMG]
|
|
|
|
cv2.imwrite(os.path.join(OUTPUT_FOLDER, image_filename), result[OutputKeys.OUTPUT_IMG])
|
|
|
|
|
|
|
|
cv2.imwrite(os.path.join(OUTPUT_FOLDER, mask_filename), result['output_img'][:, :, 3])
|
|
|
|
|
|
|
|
|
|
|
|
output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
|
|
|
|
response_data = {
|
|
|
|
|
|
|
|
"result_image_url": f"/output/{image_filename}",
|
|
|
|
|
|
|
|
"mask_image_url": f"/output/{mask_filename}",
|
|
|
|
|
|
|
|
"original_image_size": {"width": img.shape[1], "height": img.shape[0]},
|
|
|
|
|
|
|
|
"generation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return Response(content=output_bytes, media_type='image/png')
|
|
|
|
return response_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
templates = Jinja2Templates(directory="web")
|
|
|
|
templates = Jinja2Templates(directory="web")
|
|
|
|
app.mount("/static", StaticFiles(directory="web/static"), name="static")
|
|
|
|
app.mount("/static", StaticFiles(directory="./web/static"), name="static")
|
|
|
|
|
|
|
|
app.mount("/output", StaticFiles(directory="./output"), name="output")
|
|
|
|
|
|
|
|
app.mount("/upload", StaticFiles(directory="./upload"), name="upload")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
|
@app.get("/")
|
|
|
@ -115,5 +153,6 @@ async def read_index(request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import uvicorn
|
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
|
|
|
defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1"
|
|
|
|
defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1"
|
|
|
|
uvicorn.run(app, host=defult_bind_host, port=8000)
|
|
|
|
uvicorn.run(app, host=defult_bind_host, port=8000)
|
|
|
|