|
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
import threading
|
|
|
|
|
import sys
|
|
|
|
|
import torch
|
|
|
|
|
import ollama
|
|
|
|
|
from flask import Flask, request, render_template, jsonify, send_from_directory
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
from gevent.pywsgi import WSGIServer, WSGIHandler, LoggingLogAdapter
|
|
|
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
|
import warnings
|
|
|
|
|
from qdrant_client import QdrantClient
|
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
import stslib
|
|
|
|
|
from stslib import cfg, tool
|
|
|
|
|
from stslib.cfg import ROOT_DIR
|
|
|
|
|
from faster_whisper import WhisperModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomRequestHandler(WSGIHandler):
|
|
|
|
|
def log_request(self):
|
|
|
|
|
pass
|
|
|
|
|
# Qdrant数据库
|
|
|
|
|
def query(text):
|
|
|
|
|
"""
|
|
|
|
|
执行逻辑:
|
|
|
|
|
首先使用openai的Embedding API将输入的文本转换为向量
|
|
|
|
|
然后使用Qdrant的search API进行搜索,搜索结果中包含了向量和payload
|
|
|
|
|
"""
|
|
|
|
|
print("----- query start -----")
|
|
|
|
|
print(text)
|
|
|
|
|
client = QdrantClient("localhost", port=6333)
|
|
|
|
|
sentence_embeddings = ollama.embeddings(model="smartcreation/dmeta-embedding-zh:f16", prompt=text)
|
|
|
|
|
"""
|
|
|
|
|
搜索前六个相关摘要
|
|
|
|
|
"""
|
|
|
|
|
search_result = client.search(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
query_vector=sentence_embeddings["embedding"],
|
|
|
|
|
limit=60,
|
|
|
|
|
search_params={"exact": False, "hnsw_ef": 128}
|
|
|
|
|
)
|
|
|
|
|
answers = []
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
因为提示词的长度有限,每个匹配的相关摘要我在这里只取了前300个字符,如果想要更多的相关摘要,可以把这里的300改为更大的值
|
|
|
|
|
"""
|
|
|
|
|
resultArray = []
|
|
|
|
|
resultKeyArray = []
|
|
|
|
|
for result in search_result:
|
|
|
|
|
if result.score > 0.785:
|
|
|
|
|
# # 将 JSON 字符串转换为 Python 对象
|
|
|
|
|
jsonObj = json.loads(result.payload["text"])
|
|
|
|
|
for key in jsonObj:
|
|
|
|
|
if key in resultKeyArray:
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
resultKeyArray.append(key)
|
|
|
|
|
resultArray.append({ 'score': result.score, 'payload': result.payload })
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resultString = "描述信息不充分,请修改明确描述问题。"
|
|
|
|
|
if(len(resultArray) >= 1):
|
|
|
|
|
resultString = {}
|
|
|
|
|
for result in resultArray:
|
|
|
|
|
# print(result)
|
|
|
|
|
# # 将 JSON 字符串转换为 Python 对象
|
|
|
|
|
jsonObj = json.loads(result["payload"]["text"])
|
|
|
|
|
for key in jsonObj:
|
|
|
|
|
resultString[key] = jsonObj[key]
|
|
|
|
|
|
|
|
|
|
# 将合并后的对象转换为 JSON 字符串
|
|
|
|
|
resultString = json.dumps(resultString)
|
|
|
|
|
|
|
|
|
|
print("---- resultString ----")
|
|
|
|
|
print(resultString)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"answer": resultString
|
|
|
|
|
}
|
|
|
|
|
collection_name = "stark_prompt"
|
|
|
|
|
# 配置日志
|
|
|
|
|
# 禁用 Werkzeug 默认的日志处理器
|
|
|
|
|
log = logging.getLogger('werkzeug')
|
|
|
|
|
log.handlers[:] = []
|
|
|
|
|
log.setLevel(logging.WARNING)
|
|
|
|
|
app = Flask(__name__, static_folder=os.path.join(ROOT_DIR, 'static'), static_url_path='/static', template_folder=os.path.join(ROOT_DIR, 'templates'))
|
|
|
|
|
root_log = logging.getLogger() # Flask的根日志记录器
|
|
|
|
|
root_log.handlers = []
|
|
|
|
|
root_log.setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
|
app.logger.setLevel(logging.WARNING) # 设置日志级别为 INFO
|
|
|
|
|
# 创建 RotatingFileHandler 对象,设置写入的文件路径和大小限制
|
|
|
|
|
file_handler = RotatingFileHandler(os.path.join(ROOT_DIR, 'sts.log'), maxBytes=1024 * 1024, backupCount=5)
|
|
|
|
|
# 创建日志的格式
|
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
|
# 设置文件处理器的级别和格式
|
|
|
|
|
file_handler.setLevel(logging.WARNING)
|
|
|
|
|
file_handler.setFormatter(formatter)
|
|
|
|
|
# 将文件处理器添加到日志记录器中
|
|
|
|
|
app.logger.addHandler(file_handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/static/<path:filename>')
|
|
|
|
|
def static_files(filename):
|
|
|
|
|
return send_from_directory(app.config['STATIC_FOLDER'], filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/')
|
|
|
|
|
def index():
|
|
|
|
|
sets=cfg.parse_ini()
|
|
|
|
|
return render_template("index.html",
|
|
|
|
|
devtype=sets.get('devtype'),
|
|
|
|
|
lang_code=cfg.lang_code,
|
|
|
|
|
language=cfg.LANG,
|
|
|
|
|
version=stslib.version_str,
|
|
|
|
|
root_dir=ROOT_DIR.replace('\\', '/'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 上传音频
|
|
|
|
|
@app.route('/upload', methods=['POST'])
|
|
|
|
|
def upload():
|
|
|
|
|
try:
|
|
|
|
|
# 获取上传的文件
|
|
|
|
|
audio_file = request.files['audio']
|
|
|
|
|
# 如果是mp4
|
|
|
|
|
noextname, ext = os.path.splitext(audio_file.filename)
|
|
|
|
|
ext = ext.lower()
|
|
|
|
|
# 如果是视频,先分离
|
|
|
|
|
wav_file = os.path.join(cfg.TMP_DIR, f'{noextname}.wav')
|
|
|
|
|
if os.path.exists(wav_file) and os.path.getsize(wav_file) > 0:
|
|
|
|
|
return jsonify({'code': 0, 'msg': cfg.transobj['lang1'], "data": os.path.basename(wav_file)})
|
|
|
|
|
msg = ""
|
|
|
|
|
if ext in ['.mp4', '.mov', '.avi', '.mkv', '.mpeg', '.mp3', '.flac']:
|
|
|
|
|
video_file = os.path.join(cfg.TMP_DIR, f'{noextname}{ext}')
|
|
|
|
|
audio_file.save(video_file)
|
|
|
|
|
params = [
|
|
|
|
|
"-i",
|
|
|
|
|
video_file,
|
|
|
|
|
]
|
|
|
|
|
if ext not in ['.mp3', '.flac']:
|
|
|
|
|
params.append('-vn')
|
|
|
|
|
params.append(wav_file)
|
|
|
|
|
rs = tool.runffmpeg(params)
|
|
|
|
|
if rs != 'ok':
|
|
|
|
|
return jsonify({"code": 1, "msg": rs})
|
|
|
|
|
msg = "," + cfg.transobj['lang9']
|
|
|
|
|
elif ext == '.wav':
|
|
|
|
|
audio_file.save(wav_file)
|
|
|
|
|
else:
|
|
|
|
|
return jsonify({"code": 1, "msg": f"{cfg.transobj['lang3']} {ext}"})
|
|
|
|
|
|
|
|
|
|
# 返回成功的响应
|
|
|
|
|
return jsonify({'code': 0, 'msg': cfg.transobj['lang1'] + msg, "data": os.path.basename(wav_file)})
|
|
|
|
|
except Exception as e:
|
|
|
|
|
app.logger.error(f'[upload]error: {e}')
|
|
|
|
|
return jsonify({'code': 2, 'msg': cfg.transobj['lang2']})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shibie(*, wav_name=None, model=None, language=None, data_type=None, wav_file=None, key=None):
|
|
|
|
|
try:
|
|
|
|
|
sets=cfg.parse_ini()
|
|
|
|
|
print(f'{sets.get("initial_prompt_zh")}')
|
|
|
|
|
modelobj = WhisperModel(model, device=sets.get('devtype'), compute_type=sets.get('cuda_com_type'), download_root=cfg.ROOT_DIR + "/models", local_files_only=True)
|
|
|
|
|
cfg.progressbar[key]=0
|
|
|
|
|
segments,info = modelobj.transcribe(wav_file, beam_size=sets.get('beam_size'),best_of=sets.get('best_of'),temperature=0 if sets.get('temperature')==0 else [0.0,0.2,0.4,0.6,0.8,1.0],condition_on_previous_text=sets.get('condition_on_previous_text'),vad_filter=sets.get('vad'), vad_parameters=dict(min_silence_duration_ms=300),language=language, initial_prompt=None if language!='zh' else sets.get('initial_prompt_zh'))
|
|
|
|
|
total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps.
|
|
|
|
|
|
|
|
|
|
raw_subtitles = []
|
|
|
|
|
for segment in segments:
|
|
|
|
|
cfg.progressbar[key]=round(segment.end/total_duration, 2)
|
|
|
|
|
start = int(segment.start * 1000)
|
|
|
|
|
end = int(segment.end * 1000)
|
|
|
|
|
startTime = tool.ms_to_time_string(ms=start)
|
|
|
|
|
endTime = tool.ms_to_time_string(ms=end)
|
|
|
|
|
text = segment.text.strip().replace(''', "'")
|
|
|
|
|
text = re.sub(r'&#\d+;', '', text)
|
|
|
|
|
|
|
|
|
|
# 无有效字符
|
|
|
|
|
if not text or re.match(r'^[,。、?‘’“”;:({}【】):;"\'\s \d`!@#$%^&*()_+=.,?/\\-]*$', text) or len(
|
|
|
|
|
text) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
if data_type == 'json':
|
|
|
|
|
# 原语言字幕
|
|
|
|
|
raw_subtitles.append(
|
|
|
|
|
{"line": len(raw_subtitles) + 1, "start_time": startTime, "end_time": endTime, "text": text})
|
|
|
|
|
elif data_type == 'text':
|
|
|
|
|
raw_subtitles.append(text)
|
|
|
|
|
else:
|
|
|
|
|
raw_subtitles.append(f'{len(raw_subtitles) + 1}\n{startTime} --> {endTime}\n{text}\n')
|
|
|
|
|
cfg.progressbar[key]=1
|
|
|
|
|
if data_type != 'json':
|
|
|
|
|
raw_subtitles = "\n".join(raw_subtitles)
|
|
|
|
|
cfg.progressresult[key]=raw_subtitles
|
|
|
|
|
except Exception as e:
|
|
|
|
|
cfg.progressresult[key]=str(e)
|
|
|
|
|
print(str(e))
|
|
|
|
|
|
|
|
|
|
# 根据文本返回tts结果,返回 name=文件名字,filename=文件绝对路径
|
|
|
|
|
# 请求端根据需要自行选择使用哪个
|
|
|
|
|
# params
|
|
|
|
|
# wav_name:tmp下的wav文件
|
|
|
|
|
# model 模型名称
|
|
|
|
|
@app.route('/process', methods=['GET', 'POST'])
|
|
|
|
|
def process():
|
|
|
|
|
# 原始字符串
|
|
|
|
|
wav_name = request.form.get("wav_name").strip()
|
|
|
|
|
model = request.form.get("model")
|
|
|
|
|
# 语言
|
|
|
|
|
language = request.form.get("language")
|
|
|
|
|
# 返回格式 json txt srt
|
|
|
|
|
data_type = request.form.get("data_type")
|
|
|
|
|
wav_file = os.path.join(cfg.TMP_DIR, wav_name)
|
|
|
|
|
if not os.path.exists(wav_file):
|
|
|
|
|
return jsonify({"code": 1, "msg": f"{wav_file} {cfg.langlist['lang5']}"})
|
|
|
|
|
if not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-whisper-{model}/snapshots/')):
|
|
|
|
|
return jsonify({"code": 1, "msg": f"{model} {cfg.transobj['lang4']}"})
|
|
|
|
|
key=f'{wav_name}{model}{language}{data_type}'
|
|
|
|
|
#重设结果为none
|
|
|
|
|
cfg.progressresult[key]=None
|
|
|
|
|
# 重设进度为0
|
|
|
|
|
cfg.progressbar[key]=0
|
|
|
|
|
#新线程启动实际任务
|
|
|
|
|
threading.Thread(target=shibie, kwargs={"wav_name":wav_name, "model":model, "language":language, "data_type":data_type, "wav_file":wav_file, "key":key}).start()
|
|
|
|
|
return jsonify({"code":0, "msg":"ing"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取进度及完成后的结果
|
|
|
|
|
@app.route('/progressbar', methods=['GET', 'POST'])
|
|
|
|
|
def progressbar():
|
|
|
|
|
# 原始字符串
|
|
|
|
|
wav_name = request.form.get("wav_name").strip()
|
|
|
|
|
model = request.form.get("model")
|
|
|
|
|
# 语言
|
|
|
|
|
language = request.form.get("language")
|
|
|
|
|
# 返回格式 json txt srt
|
|
|
|
|
data_type = request.form.get("data_type")
|
|
|
|
|
key = f'{wav_name}{model}{language}{data_type}'
|
|
|
|
|
progressbar = cfg.progressbar[key]
|
|
|
|
|
if progressbar>=1:
|
|
|
|
|
return jsonify({"code":0, "data":progressbar, "msg":"ok", "result":cfg.progressresult[key]})
|
|
|
|
|
return jsonify({"code":0, "data":progressbar, "msg":"ok"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/api',methods=['GET','POST'])
|
|
|
|
|
def api():
|
|
|
|
|
try:
|
|
|
|
|
# 获取上传的文件
|
|
|
|
|
audio_file = request.files['file']
|
|
|
|
|
model = request.form.get("model")
|
|
|
|
|
language = request.form.get("language")
|
|
|
|
|
response_format = request.form.get("response_format")
|
|
|
|
|
print(f'{model=},{language=},{response_format=}')
|
|
|
|
|
if not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-whisper-{model}/snapshots/')):
|
|
|
|
|
return jsonify({"code": 1, "msg": f"{model} {cfg.transobj['lang4']}"})
|
|
|
|
|
|
|
|
|
|
# 如果是mp4
|
|
|
|
|
noextname, ext = os.path.splitext(audio_file.filename)
|
|
|
|
|
ext = ext.lower()
|
|
|
|
|
# 如果是视频,先分离
|
|
|
|
|
wav_file = os.path.join(cfg.TMP_DIR, f'{noextname}.wav')
|
|
|
|
|
if os.path.exists(wav_file):
|
|
|
|
|
os.remove(wav_file)
|
|
|
|
|
print(f'{wav_file=}')
|
|
|
|
|
if not os.path.exists(wav_file) or os.path.getsize(wav_file) == 0:
|
|
|
|
|
msg = ""
|
|
|
|
|
if ext in ['.mp4', '.mov', '.avi', '.mkv', '.mpeg', '.mp3', '.flac']:
|
|
|
|
|
video_file = os.path.join(cfg.TMP_DIR, f'{noextname}{ext}')
|
|
|
|
|
audio_file.save(video_file)
|
|
|
|
|
params = [
|
|
|
|
|
"-i",
|
|
|
|
|
video_file,
|
|
|
|
|
]
|
|
|
|
|
if ext not in ['.mp3', '.flac']:
|
|
|
|
|
params.append('-vn')
|
|
|
|
|
params.append(wav_file)
|
|
|
|
|
rs = tool.runffmpeg(params)
|
|
|
|
|
if rs != 'ok':
|
|
|
|
|
return jsonify({"code": 1, "msg": rs})
|
|
|
|
|
msg = "," + cfg.transobj['lang9']
|
|
|
|
|
elif ext == '.wav':
|
|
|
|
|
audio_file.save(wav_file)
|
|
|
|
|
else:
|
|
|
|
|
return jsonify({"code": 1, "msg": f"{cfg.transobj['lang3']} {ext}"})
|
|
|
|
|
print(f'{ext=}')
|
|
|
|
|
sets=cfg.parse_ini()
|
|
|
|
|
model = WhisperModel(model, device=sets.get('devtype'), compute_type=sets.get('cuda_com_type'), download_root=cfg.ROOT_DIR + "/models", local_files_only=True)
|
|
|
|
|
|
|
|
|
|
segments,_ = model.transcribe(wav_file, beam_size=sets.get('beam_size'),best_of=sets.get('best_of'),temperature=0 if sets.get('temperature')==0 else [0.0,0.2,0.4,0.6,0.8,1.0],condition_on_previous_text=sets.get('condition_on_previous_text'),vad_filter=sets.get('vad'),
|
|
|
|
|
vad_parameters=dict(min_silence_duration_ms=300,max_speech_duration_s=10.5),language=language,initial_prompt=None if language!='zh' else sets.get('initial_prompt_zh'))
|
|
|
|
|
raw_subtitles = []
|
|
|
|
|
for segment in segments:
|
|
|
|
|
start = int(segment.start * 1000)
|
|
|
|
|
end = int(segment.end * 1000)
|
|
|
|
|
startTime = tool.ms_to_time_string(ms=start)
|
|
|
|
|
endTime = tool.ms_to_time_string(ms=end)
|
|
|
|
|
text = segment.text.strip().replace(''', "'")
|
|
|
|
|
text = re.sub(r'&#\d+;', '', text)
|
|
|
|
|
|
|
|
|
|
# 无有效字符
|
|
|
|
|
if not text or re.match(r'^[,。、?‘’“”;:({}【】):;"\'\s \d`!@#$%^&*()_+=.,?/\\-]*$', text) or len(text) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
if response_format == 'json':
|
|
|
|
|
# 原语言字幕
|
|
|
|
|
raw_subtitles.append(
|
|
|
|
|
{"line": len(raw_subtitles) + 1, "start_time": startTime, "end_time": endTime, "text": text})
|
|
|
|
|
elif response_format == 'text':
|
|
|
|
|
raw_subtitles.append(text)
|
|
|
|
|
else:
|
|
|
|
|
raw_subtitles.append(f'{len(raw_subtitles) + 1}\n{startTime} --> {endTime}\n{text}\n')
|
|
|
|
|
if response_format != 'json':
|
|
|
|
|
raw_subtitles = "\n".join(raw_subtitles)
|
|
|
|
|
search_text = raw_subtitles[0]['text']
|
|
|
|
|
search = query(search_text)
|
|
|
|
|
return jsonify({"code": 0, "msg": 'ok', "data": search})
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
app.logger.error(f'[api]error: {e}')
|
|
|
|
|
return jsonify({'code': 2, 'msg': str(e)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/checkupdate', methods=['GET', 'POST'])
|
|
|
|
|
def checkupdate():
|
|
|
|
|
return jsonify({'code': 0, "msg": cfg.updatetips})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
http_server = None
|
|
|
|
|
try:
|
|
|
|
|
threading.Thread(target=tool.checkupdate).start()
|
|
|
|
|
try:
|
|
|
|
|
if cfg.devtype=='cpu':
|
|
|
|
|
print('\n如果设备使用英伟达显卡并且CUDA环境已正确安装,可修改set.ini中\ndevtype=cpu 为 devtype=cuda, 然后重新启动以加快识别速度\n')
|
|
|
|
|
host = cfg.web_address.split(':')
|
|
|
|
|
http_server = WSGIServer((host[0], int(host[1])), app, handler_class=CustomRequestHandler)
|
|
|
|
|
threading.Thread(target=tool.openweb, args=(cfg.web_address,)).start()
|
|
|
|
|
http_server.serve_forever()
|
|
|
|
|
finally:
|
|
|
|
|
if http_server:
|
|
|
|
|
http_server.stop()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if http_server:
|
|
|
|
|
http_server.stop()
|
|
|
|
|
print("error:" + str(e))
|
|
|
|
|
app.logger.error(f"[app]start error:{str(e)}")
|