You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import sys
|
|
from fastapi import FastAPI, File, UploadFile, Form, Response
|
|
from fastapi import Request
|
|
import requests
|
|
import cv2
|
|
from modelscope.pipelines import pipeline
|
|
from modelscope.utils.constant import Tasks
|
|
from modelscope.outputs import OutputKeys
|
|
import numpy as np
|
|
from starlette.staticfiles import StaticFiles
|
|
from starlette.templating import Jinja2Templates
|
|
|
|
app = FastAPI()
|
|
|
|
model_paths = {
|
|
"universal": {'path': 'damo/cv_unet_universal-matting', 'task': Tasks.universal_matting},
|
|
"people": {'path': 'damo/cv_unet_image-matting', 'task': Tasks.portrait_matting},
|
|
}
|
|
|
|
default_model = list(model_paths.keys())[0]
|
|
default_model_info = model_paths[default_model]
|
|
loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])}
|
|
|
|
|
|
class ModelLoader:
|
|
def __init__(self):
|
|
self.loaded_models = {default_model: loaded_models[default_model]}
|
|
|
|
def load_model(self, model_name):
|
|
if model_name not in self.loaded_models:
|
|
model_info = model_paths[model_name]
|
|
model_path = model_info['path']
|
|
task_group = model_info['task']
|
|
|
|
self.loaded_models[model_name] = pipeline(task_group, model=model_path)
|
|
return self.loaded_models[model_name]
|
|
|
|
|
|
model_loader = ModelLoader()
|
|
|
|
|
|
@app.post("/switch_model/{new_model}")
|
|
async def switch_model(new_model: str):
|
|
if new_model not in model_paths:
|
|
return {"content": "Invalid model selection"}, 400
|
|
model_info = model_paths[new_model]
|
|
|
|
loaded_models[new_model] = pipeline(model_info['task'], model=model_info['path'])
|
|
model_loader.loaded_models = loaded_models
|
|
return {"content": f"Switched to model: {new_model}"}, 200
|
|
|
|
|
|
@app.post("/matting")
|
|
async def matting(image: UploadFile = File(...), model: str = Form(default=default_model, alias="model")):
|
|
image_bytes = await image.read()
|
|
img = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
|
|
|
|
if model not in model_paths:
|
|
return {"content": "Invalid model selection"}, 400
|
|
|
|
selected_model = model_loader.load_model(model)
|
|
|
|
result = selected_model(img)
|
|
|
|
output_img = result[OutputKeys.OUTPUT_IMG]
|
|
|
|
output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
|
|
|
|
return Response(content=output_bytes, media_type='image/png')
|
|
|
|
|
|
@app.post("/matting/url")
|
|
async def matting_url(request: Request, model: str = Form(default=default_model, alias="model")):
|
|
try:
|
|
json_data = await request.json()
|
|
image_url = json_data.get("image_url")
|
|
except Exception as e:
|
|
return {"content": f"Error parsing JSON data: {str(e)}"}, 400
|
|
|
|
if not image_url:
|
|
return {"content": "Image URL is required"}, 400
|
|
|
|
response = requests.get(image_url)
|
|
if response.status_code != 200:
|
|
return {"content": "Failed to fetch image from URL"}, 400
|
|
|
|
img_array = np.frombuffer(response.content, dtype=np.uint8)
|
|
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|
|
|
if model not in model_paths:
|
|
return {"content": "Invalid model selection"}, 400
|
|
|
|
selected_model = model_loader.load_model(model)
|
|
|
|
result = selected_model(img)
|
|
|
|
output_img = result[OutputKeys.OUTPUT_IMG]
|
|
|
|
output_bytes = cv2.imencode('.png', output_img)[1].tobytes()
|
|
|
|
return Response(content=output_bytes, media_type='image/png')
|
|
|
|
|
|
templates = Jinja2Templates(directory="web")
|
|
app.mount("/static", StaticFiles(directory="web/static"), name="static")
|
|
|
|
|
|
@app.get("/")
|
|
async def read_index(request: Request):
|
|
return templates.TemplateResponse("index.html", {"request": request, "default_model": default_model,
|
|
"available_models": list(model_paths.keys())})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1"
|
|
uvicorn.run(app, host=defult_bind_host, port=8000)
|