You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
14 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import re
import threading
import sys
import torch
import ollama
from flask import Flask, request, render_template, jsonify, send_from_directory
import os
import json
from gevent.pywsgi import WSGIServer, WSGIHandler, LoggingLogAdapter
from logging.handlers import RotatingFileHandler
import warnings
from qdrant_client import QdrantClient
warnings.filterwarnings('ignore')
import stslib
from stslib import cfg, tool
from stslib.cfg import ROOT_DIR
from faster_whisper import WhisperModel
class CustomRequestHandler(WSGIHandler):
def log_request(self):
pass
# Qdrant数据库
def query(text):
"""
执行逻辑:
首先使用openai的Embedding API将输入的文本转换为向量
然后使用Qdrant的search API进行搜索搜索结果中包含了向量和payload
"""
print("----- query start -----")
print(text)
client = QdrantClient("localhost", port=6333)
sentence_embeddings = ollama.embeddings(model="smartcreation/dmeta-embedding-zh:f16", prompt=text)
"""
搜索前六个相关摘要
"""
search_result = client.search(
collection_name=collection_name,
query_vector=sentence_embeddings["embedding"],
limit=60,
search_params={"exact": False, "hnsw_ef": 128}
)
answers = []
"""
因为提示词的长度有限每个匹配的相关摘要我在这里只取了前300个字符如果想要更多的相关摘要可以把这里的300改为更大的值
"""
resultArray = []
resultKeyArray = []
for result in search_result:
if result.score > 0.785:
# # 将 JSON 字符串转换为 Python 对象
jsonObj = json.loads(result.payload["text"])
for key in jsonObj:
if key in resultKeyArray:
continue
else:
resultKeyArray.append(key)
resultArray.append({ 'score': result.score, 'payload': result.payload })
resultString = "描述信息不充分,请修改明确描述问题。"
if(len(resultArray) >= 1):
resultString = {}
for result in resultArray:
# print(result)
# # 将 JSON 字符串转换为 Python 对象
jsonObj = json.loads(result["payload"]["text"])
for key in jsonObj:
resultString[key] = jsonObj[key]
# 将合并后的对象转换为 JSON 字符串
resultString = json.dumps(resultString)
print("---- resultString ----")
print(resultString)
return {
"answer": resultString
}
collection_name = "stark_prompt"
# 配置日志
# 禁用 Werkzeug 默认的日志处理器
log = logging.getLogger('werkzeug')
log.handlers[:] = []
log.setLevel(logging.WARNING)
app = Flask(__name__, static_folder=os.path.join(ROOT_DIR, 'static'), static_url_path='/static', template_folder=os.path.join(ROOT_DIR, 'templates'))
root_log = logging.getLogger() # Flask的根日志记录器
root_log.handlers = []
root_log.setLevel(logging.WARNING)
# 配置日志
app.logger.setLevel(logging.WARNING) # 设置日志级别为 INFO
# 创建 RotatingFileHandler 对象,设置写入的文件路径和大小限制
file_handler = RotatingFileHandler(os.path.join(ROOT_DIR, 'sts.log'), maxBytes=1024 * 1024, backupCount=5)
# 创建日志的格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# 设置文件处理器的级别和格式
file_handler.setLevel(logging.WARNING)
file_handler.setFormatter(formatter)
# 将文件处理器添加到日志记录器中
app.logger.addHandler(file_handler)
@app.route('/static/<path:filename>')
def static_files(filename):
return send_from_directory(app.config['STATIC_FOLDER'], filename)
@app.route('/')
def index():
sets=cfg.parse_ini()
return render_template("index.html",
devtype=sets.get('devtype'),
lang_code=cfg.lang_code,
language=cfg.LANG,
version=stslib.version_str,
root_dir=ROOT_DIR.replace('\\', '/'))
# 上传音频
@app.route('/upload', methods=['POST'])
def upload():
try:
# 获取上传的文件
audio_file = request.files['audio']
# 如果是mp4
noextname, ext = os.path.splitext(audio_file.filename)
ext = ext.lower()
# 如果是视频,先分离
wav_file = os.path.join(cfg.TMP_DIR, f'{noextname}.wav')
if os.path.exists(wav_file) and os.path.getsize(wav_file) > 0:
return jsonify({'code': 0, 'msg': cfg.transobj['lang1'], "data": os.path.basename(wav_file)})
msg = ""
if ext in ['.mp4', '.mov', '.avi', '.mkv', '.mpeg', '.mp3', '.flac']:
video_file = os.path.join(cfg.TMP_DIR, f'{noextname}{ext}')
audio_file.save(video_file)
params = [
"-i",
video_file,
]
if ext not in ['.mp3', '.flac']:
params.append('-vn')
params.append(wav_file)
rs = tool.runffmpeg(params)
if rs != 'ok':
return jsonify({"code": 1, "msg": rs})
msg = "," + cfg.transobj['lang9']
elif ext == '.wav':
audio_file.save(wav_file)
else:
return jsonify({"code": 1, "msg": f"{cfg.transobj['lang3']} {ext}"})
# 返回成功的响应
return jsonify({'code': 0, 'msg': cfg.transobj['lang1'] + msg, "data": os.path.basename(wav_file)})
except Exception as e:
app.logger.error(f'[upload]error: {e}')
return jsonify({'code': 2, 'msg': cfg.transobj['lang2']})
def shibie(*, wav_name=None, model=None, language=None, data_type=None, wav_file=None, key=None):
try:
sets=cfg.parse_ini()
print(f'{sets.get("initial_prompt_zh")}')
modelobj = WhisperModel(model, device=sets.get('devtype'), compute_type=sets.get('cuda_com_type'), download_root=cfg.ROOT_DIR + "/models", local_files_only=True)
cfg.progressbar[key]=0
segments,info = modelobj.transcribe(wav_file, beam_size=sets.get('beam_size'),best_of=sets.get('best_of'),temperature=0 if sets.get('temperature')==0 else [0.0,0.2,0.4,0.6,0.8,1.0],condition_on_previous_text=sets.get('condition_on_previous_text'),vad_filter=sets.get('vad'), vad_parameters=dict(min_silence_duration_ms=300),language=language, initial_prompt=None if language!='zh' else sets.get('initial_prompt_zh'))
total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps.
raw_subtitles = []
for segment in segments:
cfg.progressbar[key]=round(segment.end/total_duration, 2)
start = int(segment.start * 1000)
end = int(segment.end * 1000)
startTime = tool.ms_to_time_string(ms=start)
endTime = tool.ms_to_time_string(ms=end)
text = segment.text.strip().replace('&#39;', "'")
text = re.sub(r'&#\d+;', '', text)
# 无有效字符
if not text or re.match(r'^[,。、?‘’“”;:({}【】):;"\'\s \d`!@#$%^&*()_+=.,?/\\-]*$', text) or len(
text) <= 1:
continue
if data_type == 'json':
# 原语言字幕
raw_subtitles.append(
{"line": len(raw_subtitles) + 1, "start_time": startTime, "end_time": endTime, "text": text})
elif data_type == 'text':
raw_subtitles.append(text)
else:
raw_subtitles.append(f'{len(raw_subtitles) + 1}\n{startTime} --> {endTime}\n{text}\n')
cfg.progressbar[key]=1
if data_type != 'json':
raw_subtitles = "\n".join(raw_subtitles)
cfg.progressresult[key]=raw_subtitles
except Exception as e:
cfg.progressresult[key]=str(e)
print(str(e))
# 根据文本返回tts结果返回 name=文件名字filename=文件绝对路径
# 请求端根据需要自行选择使用哪个
# params
# wav_name:tmp下的wav文件
# model 模型名称
@app.route('/process', methods=['GET', 'POST'])
def process():
# 原始字符串
wav_name = request.form.get("wav_name").strip()
model = request.form.get("model")
# 语言
language = request.form.get("language")
# 返回格式 json txt srt
data_type = request.form.get("data_type")
wav_file = os.path.join(cfg.TMP_DIR, wav_name)
if not os.path.exists(wav_file):
return jsonify({"code": 1, "msg": f"{wav_file} {cfg.langlist['lang5']}"})
if not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-whisper-{model}/snapshots/')):
return jsonify({"code": 1, "msg": f"{model} {cfg.transobj['lang4']}"})
key=f'{wav_name}{model}{language}{data_type}'
#重设结果为none
cfg.progressresult[key]=None
# 重设进度为0
cfg.progressbar[key]=0
#新线程启动实际任务
threading.Thread(target=shibie, kwargs={"wav_name":wav_name, "model":model, "language":language, "data_type":data_type, "wav_file":wav_file, "key":key}).start()
return jsonify({"code":0, "msg":"ing"})
# 获取进度及完成后的结果
@app.route('/progressbar', methods=['GET', 'POST'])
def progressbar():
# 原始字符串
wav_name = request.form.get("wav_name").strip()
model = request.form.get("model")
# 语言
language = request.form.get("language")
# 返回格式 json txt srt
data_type = request.form.get("data_type")
key = f'{wav_name}{model}{language}{data_type}'
progressbar = cfg.progressbar[key]
if progressbar>=1:
return jsonify({"code":0, "data":progressbar, "msg":"ok", "result":cfg.progressresult[key]})
return jsonify({"code":0, "data":progressbar, "msg":"ok"})
@app.route('/api',methods=['GET','POST'])
def api():
try:
# 获取上传的文件
audio_file = request.files['file']
model = request.form.get("model")
language = request.form.get("language")
response_format = request.form.get("response_format")
print(f'{model=},{language=},{response_format=}')
if not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-whisper-{model}/snapshots/')):
return jsonify({"code": 1, "msg": f"{model} {cfg.transobj['lang4']}"})
# 如果是mp4
noextname, ext = os.path.splitext(audio_file.filename)
ext = ext.lower()
# 如果是视频,先分离
wav_file = os.path.join(cfg.TMP_DIR, f'{noextname}.wav')
if os.path.exists(wav_file):
os.remove(wav_file)
print(f'{wav_file=}')
if not os.path.exists(wav_file) or os.path.getsize(wav_file) == 0:
msg = ""
if ext in ['.mp4', '.mov', '.avi', '.mkv', '.mpeg', '.mp3', '.flac']:
video_file = os.path.join(cfg.TMP_DIR, f'{noextname}{ext}')
audio_file.save(video_file)
params = [
"-i",
video_file,
]
if ext not in ['.mp3', '.flac']:
params.append('-vn')
params.append(wav_file)
rs = tool.runffmpeg(params)
if rs != 'ok':
return jsonify({"code": 1, "msg": rs})
msg = "," + cfg.transobj['lang9']
elif ext == '.wav':
audio_file.save(wav_file)
else:
return jsonify({"code": 1, "msg": f"{cfg.transobj['lang3']} {ext}"})
print(f'{ext=}')
sets=cfg.parse_ini()
model = WhisperModel(model, device=sets.get('devtype'), compute_type=sets.get('cuda_com_type'), download_root=cfg.ROOT_DIR + "/models", local_files_only=True)
segments,_ = model.transcribe(wav_file, beam_size=sets.get('beam_size'),best_of=sets.get('best_of'),temperature=0 if sets.get('temperature')==0 else [0.0,0.2,0.4,0.6,0.8,1.0],condition_on_previous_text=sets.get('condition_on_previous_text'),vad_filter=sets.get('vad'),
vad_parameters=dict(min_silence_duration_ms=300,max_speech_duration_s=10.5),language=language,initial_prompt=None if language!='zh' else sets.get('initial_prompt_zh'))
raw_subtitles = []
for segment in segments:
start = int(segment.start * 1000)
end = int(segment.end * 1000)
startTime = tool.ms_to_time_string(ms=start)
endTime = tool.ms_to_time_string(ms=end)
text = segment.text.strip().replace('&#39;', "'")
text = re.sub(r'&#\d+;', '', text)
# 无有效字符
if not text or re.match(r'^[,。、?‘’“”;:({}【】):;"\'\s \d`!@#$%^&*()_+=.,?/\\-]*$', text) or len(text) <= 1:
continue
if response_format == 'json':
# 原语言字幕
raw_subtitles.append(
{"line": len(raw_subtitles) + 1, "start_time": startTime, "end_time": endTime, "text": text})
elif response_format == 'text':
raw_subtitles.append(text)
else:
raw_subtitles.append(f'{len(raw_subtitles) + 1}\n{startTime} --> {endTime}\n{text}\n')
if response_format != 'json':
raw_subtitles = "\n".join(raw_subtitles)
search_text = raw_subtitles[0]['text']
search = query(search_text)
return jsonify({"code": 0, "msg": 'ok', "data": search})
except Exception as e:
print(e)
app.logger.error(f'[api]error: {e}')
return jsonify({'code': 2, 'msg': str(e)})
@app.route('/checkupdate', methods=['GET', 'POST'])
def checkupdate():
return jsonify({'code': 0, "msg": cfg.updatetips})
if __name__ == '__main__':
http_server = None
try:
threading.Thread(target=tool.checkupdate).start()
try:
if cfg.devtype=='cpu':
print('\n如果设备使用英伟达显卡并且CUDA环境已正确安装可修改set.ini中\ndevtype=cpu 为 devtype=cuda, 然后重新启动以加快识别速度\n')
host = cfg.web_address.split(':')
http_server = WSGIServer((host[0], int(host[1])), app, handler_class=CustomRequestHandler)
threading.Thread(target=tool.openweb, args=(cfg.web_address,)).start()
http_server.serve_forever()
finally:
if http_server:
http_server.stop()
except Exception as e:
if http_server:
http_server.stop()
print("error:" + str(e))
app.logger.error(f"[app]start error:{str(e)}")