diff --git a/README.md b/README.md new file mode 100644 index 0000000..c22bc08 --- /dev/null +++ b/README.md @@ -0,0 +1,35 @@ +## Imgae matting + +Here are a few effects: + +![image-20240104235235787](C:\Users\96153\AppData\Roaming\Typora\typora-user-images\image-20240104235235787.png) + +![image-20240104235305391](C:\Users\96153\AppData\Roaming\Typora\typora-user-images\image-20240104235305391.png) + +  + +## How to Run + +Firstly, you need to download the project code and install the required dependencies + +``` +git clone https://github.com/ihmily/image-matting.git +cd image-matting +pip install -r requirements.txt +``` + +Next, you can use the following command to run the web interface + +``` +python app.py +``` + +Finally, you can visit http://127.0.0.1:8000 + +  + +## References + +[https://modelscope.cn/models/damo/cv_unet_universal-matting/summary](https://modelscope.cn/models/damo/cv_unet_universal-matting/summary) + +[https://modelscope.cn/models/damo/cv_unet_image-matting/summary](https://modelscope.cn/models/damo/cv_unet_image-matting/summary) \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..7b0f1e1 --- /dev/null +++ b/app.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- + +import sys +from fastapi import FastAPI, File, UploadFile, Form, Response +from fastapi import Request +import requests +import cv2 +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.outputs import OutputKeys +import numpy as np +from starlette.staticfiles import StaticFiles +from starlette.templating import Jinja2Templates + +app = FastAPI() + +model_paths = { + "universal": {'path': 'damo/cv_unet_universal-matting', 'task': Tasks.universal_matting}, + "people": {'path': 'damo/cv_unet_image-matting', 'task': Tasks.portrait_matting}, +} + +default_model = list(model_paths.keys())[0] +default_model_info = model_paths[default_model] +loaded_models = {default_model: pipeline(default_model_info['task'], model=default_model_info['path'])} + + +class ModelLoader: + def __init__(self): + self.loaded_models = {default_model: loaded_models[default_model]} + + def load_model(self, model_name): + if model_name not in self.loaded_models: + model_info = model_paths[model_name] + model_path = model_info['path'] + task_group = model_info['task'] + + self.loaded_models[model_name] = pipeline(task_group, model=model_path) + return self.loaded_models[model_name] + + +model_loader = ModelLoader() + + +@app.post("/switch_model/{new_model}") +async def switch_model(new_model: str): + if new_model not in model_paths: + return {"content": "Invalid model selection"}, 400 + model_info = model_paths[new_model] + + loaded_models[new_model] = pipeline(model_info['task'], model=model_info['path']) + model_loader.loaded_models = loaded_models + return {"content": f"Switched to model: {new_model}"}, 200 + + +@app.post("/matting") +async def matting(image: UploadFile = File(...), model: str = Form(default=default_model, alias="model")): + image_bytes = await image.read() + img = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR) + + if model not in model_paths: + return {"content": "Invalid model selection"}, 400 + + selected_model = model_loader.load_model(model) + + result = selected_model(img) + + output_img = result[OutputKeys.OUTPUT_IMG] + + output_bytes = cv2.imencode('.png', output_img)[1].tobytes() + + return Response(content=output_bytes, media_type='image/png') + + +@app.post("/matting/url") +async def matting_url(request: Request, model: str = Form(default=default_model, alias="model")): + try: + json_data = await request.json() + image_url = json_data.get("image_url") + except Exception as e: + return {"content": f"Error parsing JSON data: {str(e)}"}, 400 + + if not image_url: + return {"content": "Image URL is required"}, 400 + + response = requests.get(image_url) + if response.status_code != 200: + return {"content": "Failed to fetch image from URL"}, 400 + + img_array = np.frombuffer(response.content, dtype=np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + + if model not in model_paths: + return {"content": "Invalid model selection"}, 400 + + selected_model = model_loader.load_model(model) + + result = selected_model(img) + + output_img = result[OutputKeys.OUTPUT_IMG] + + output_bytes = cv2.imencode('.png', output_img)[1].tobytes() + + return Response(content=output_bytes, media_type='image/png') + + +templates = Jinja2Templates(directory="web") +app.mount("/static", StaticFiles(directory="web/static"), name="static") + + +@app.get("/") +async def read_index(request: Request): + return templates.TemplateResponse("index.html", {"request": request, "default_model": default_model, + "available_models": list(model_paths.keys())}) + + +if __name__ == "__main__": + import uvicorn + defult_bind_host = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1" + uvicorn.run(app, host=defult_bind_host, port=8000) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c6421a5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +fastapi==0.108.0 +opencv-python +requests +modelscope>=1.10.0 +torch==2.1.2 +transformers==4.36.2 +sentencepiece==0.1.* +tensorflow +python-multipart +uvicorn \ No newline at end of file diff --git a/web/index.html b/web/index.html new file mode 100644 index 0000000..37f9b8a --- /dev/null +++ b/web/index.html @@ -0,0 +1,145 @@ + + + + + + + Image Matting + + + + +
+

Image Matting

+
+ Fork me on GitHub +
+
+ + + + + + + +
+ +
+  +  +
+
+ + + + + + diff --git a/web/static/images/forkme_right_gray_6d6d6d.png b/web/static/images/forkme_right_gray_6d6d6d.png new file mode 100644 index 0000000..bbccd4b Binary files /dev/null and b/web/static/images/forkme_right_gray_6d6d6d.png differ