import logging import re import threading import sys import torch import ollama from flask import Flask, request, render_template, jsonify, send_from_directory import os import json from gevent.pywsgi import WSGIServer, WSGIHandler, LoggingLogAdapter from logging.handlers import RotatingFileHandler import warnings from qdrant_client import QdrantClient warnings.filterwarnings('ignore') import stslib from stslib import cfg, tool from stslib.cfg import ROOT_DIR from faster_whisper import WhisperModel class CustomRequestHandler(WSGIHandler): def log_request(self): pass # Qdrant数据库 def query(text): """ 执行逻辑: 首先使用openai的Embedding API将输入的文本转换为向量 然后使用Qdrant的search API进行搜索,搜索结果中包含了向量和payload """ print("----- query start -----") print(text) client = QdrantClient("localhost", port=6333) sentence_embeddings = ollama.embeddings(model="smartcreation/dmeta-embedding-zh:f16", prompt=text) """ 搜索前六个相关摘要 """ search_result = client.search( collection_name=collection_name, query_vector=sentence_embeddings["embedding"], limit=60, search_params={"exact": False, "hnsw_ef": 128} ) answers = [] """ 因为提示词的长度有限,每个匹配的相关摘要我在这里只取了前300个字符,如果想要更多的相关摘要,可以把这里的300改为更大的值 """ resultArray = [] resultKeyArray = [] for result in search_result: if result.score > 0.785: # # 将 JSON 字符串转换为 Python 对象 jsonObj = json.loads(result.payload["text"]) for key in jsonObj: if key in resultKeyArray: continue else: resultKeyArray.append(key) resultArray.append({ 'score': result.score, 'payload': result.payload }) resultString = "描述信息不充分,请修改明确描述问题。" if(len(resultArray) >= 1): resultString = {} for result in resultArray: # print(result) # # 将 JSON 字符串转换为 Python 对象 jsonObj = json.loads(result["payload"]["text"]) for key in jsonObj: resultString[key] = jsonObj[key] # 将合并后的对象转换为 JSON 字符串 resultString = json.dumps(resultString) print("---- resultString ----") print(resultString) return { "answer": resultString } collection_name = "stark_prompt" # 配置日志 # 禁用 Werkzeug 默认的日志处理器 log = logging.getLogger('werkzeug') log.handlers[:] = [] log.setLevel(logging.WARNING) app = Flask(__name__, static_folder=os.path.join(ROOT_DIR, 'static'), static_url_path='/static', template_folder=os.path.join(ROOT_DIR, 'templates')) root_log = logging.getLogger() # Flask的根日志记录器 root_log.handlers = [] root_log.setLevel(logging.WARNING) # 配置日志 app.logger.setLevel(logging.WARNING) # 设置日志级别为 INFO # 创建 RotatingFileHandler 对象,设置写入的文件路径和大小限制 file_handler = RotatingFileHandler(os.path.join(ROOT_DIR, 'sts.log'), maxBytes=1024 * 1024, backupCount=5) # 创建日志的格式 formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') # 设置文件处理器的级别和格式 file_handler.setLevel(logging.WARNING) file_handler.setFormatter(formatter) # 将文件处理器添加到日志记录器中 app.logger.addHandler(file_handler) @app.route('/static/') def static_files(filename): return send_from_directory(app.config['STATIC_FOLDER'], filename) @app.route('/') def index(): sets=cfg.parse_ini() return render_template("index.html", devtype=sets.get('devtype'), lang_code=cfg.lang_code, language=cfg.LANG, version=stslib.version_str, root_dir=ROOT_DIR.replace('\\', '/')) # 上传音频 @app.route('/upload', methods=['POST']) def upload(): try: # 获取上传的文件 audio_file = request.files['audio'] # 如果是mp4 noextname, ext = os.path.splitext(audio_file.filename) ext = ext.lower() # 如果是视频,先分离 wav_file = os.path.join(cfg.TMP_DIR, f'{noextname}.wav') if os.path.exists(wav_file) and os.path.getsize(wav_file) > 0: return jsonify({'code': 0, 'msg': cfg.transobj['lang1'], "data": os.path.basename(wav_file)}) msg = "" if ext in ['.mp4', '.mov', '.avi', '.mkv', '.mpeg', '.mp3', '.flac']: video_file = os.path.join(cfg.TMP_DIR, f'{noextname}{ext}') audio_file.save(video_file) params = [ "-i", video_file, ] if ext not in ['.mp3', '.flac']: params.append('-vn') params.append(wav_file) rs = tool.runffmpeg(params) if rs != 'ok': return jsonify({"code": 1, "msg": rs}) msg = "," + cfg.transobj['lang9'] elif ext == '.wav': audio_file.save(wav_file) else: return jsonify({"code": 1, "msg": f"{cfg.transobj['lang3']} {ext}"}) # 返回成功的响应 return jsonify({'code': 0, 'msg': cfg.transobj['lang1'] + msg, "data": os.path.basename(wav_file)}) except Exception as e: app.logger.error(f'[upload]error: {e}') return jsonify({'code': 2, 'msg': cfg.transobj['lang2']}) def shibie(*, wav_name=None, model=None, language=None, data_type=None, wav_file=None, key=None): try: sets=cfg.parse_ini() print(f'{sets.get("initial_prompt_zh")}') modelobj = WhisperModel(model, device=sets.get('devtype'), compute_type=sets.get('cuda_com_type'), download_root=cfg.ROOT_DIR + "/models", local_files_only=True) cfg.progressbar[key]=0 segments,info = modelobj.transcribe(wav_file, beam_size=sets.get('beam_size'),best_of=sets.get('best_of'),temperature=0 if sets.get('temperature')==0 else [0.0,0.2,0.4,0.6,0.8,1.0],condition_on_previous_text=sets.get('condition_on_previous_text'),vad_filter=sets.get('vad'), vad_parameters=dict(min_silence_duration_ms=300),language=language, initial_prompt=None if language!='zh' else sets.get('initial_prompt_zh')) total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps. raw_subtitles = [] for segment in segments: cfg.progressbar[key]=round(segment.end/total_duration, 2) start = int(segment.start * 1000) end = int(segment.end * 1000) startTime = tool.ms_to_time_string(ms=start) endTime = tool.ms_to_time_string(ms=end) text = segment.text.strip().replace(''', "'") text = re.sub(r'&#\d+;', '', text) # 无有效字符 if not text or re.match(r'^[,。、?‘’“”;:({}【】):;"\'\s \d`!@#$%^&*()_+=.,?/\\-]*$', text) or len( text) <= 1: continue if data_type == 'json': # 原语言字幕 raw_subtitles.append( {"line": len(raw_subtitles) + 1, "start_time": startTime, "end_time": endTime, "text": text}) elif data_type == 'text': raw_subtitles.append(text) else: raw_subtitles.append(f'{len(raw_subtitles) + 1}\n{startTime} --> {endTime}\n{text}\n') cfg.progressbar[key]=1 if data_type != 'json': raw_subtitles = "\n".join(raw_subtitles) cfg.progressresult[key]=raw_subtitles except Exception as e: cfg.progressresult[key]=str(e) print(str(e)) # 根据文本返回tts结果,返回 name=文件名字,filename=文件绝对路径 # 请求端根据需要自行选择使用哪个 # params # wav_name:tmp下的wav文件 # model 模型名称 @app.route('/process', methods=['GET', 'POST']) def process(): # 原始字符串 wav_name = request.form.get("wav_name").strip() model = request.form.get("model") # 语言 language = request.form.get("language") # 返回格式 json txt srt data_type = request.form.get("data_type") wav_file = os.path.join(cfg.TMP_DIR, wav_name) if not os.path.exists(wav_file): return jsonify({"code": 1, "msg": f"{wav_file} {cfg.langlist['lang5']}"}) if not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-whisper-{model}/snapshots/')): return jsonify({"code": 1, "msg": f"{model} {cfg.transobj['lang4']}"}) key=f'{wav_name}{model}{language}{data_type}' #重设结果为none cfg.progressresult[key]=None # 重设进度为0 cfg.progressbar[key]=0 #新线程启动实际任务 threading.Thread(target=shibie, kwargs={"wav_name":wav_name, "model":model, "language":language, "data_type":data_type, "wav_file":wav_file, "key":key}).start() return jsonify({"code":0, "msg":"ing"}) # 获取进度及完成后的结果 @app.route('/progressbar', methods=['GET', 'POST']) def progressbar(): # 原始字符串 wav_name = request.form.get("wav_name").strip() model = request.form.get("model") # 语言 language = request.form.get("language") # 返回格式 json txt srt data_type = request.form.get("data_type") key = f'{wav_name}{model}{language}{data_type}' progressbar = cfg.progressbar[key] if progressbar>=1: return jsonify({"code":0, "data":progressbar, "msg":"ok", "result":cfg.progressresult[key]}) return jsonify({"code":0, "data":progressbar, "msg":"ok"}) @app.route('/api',methods=['GET','POST']) def api(): try: # 获取上传的文件 audio_file = request.files['file'] model = request.form.get("model") language = request.form.get("language") response_format = request.form.get("response_format") print(f'{model=},{language=},{response_format=}') if not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-whisper-{model}/snapshots/')): return jsonify({"code": 1, "msg": f"{model} {cfg.transobj['lang4']}"}) # 如果是mp4 noextname, ext = os.path.splitext(audio_file.filename) ext = ext.lower() # 如果是视频,先分离 wav_file = os.path.join(cfg.TMP_DIR, f'{noextname}.wav') if os.path.exists(wav_file): os.remove(wav_file) print(f'{wav_file=}') if not os.path.exists(wav_file) or os.path.getsize(wav_file) == 0: msg = "" if ext in ['.mp4', '.mov', '.avi', '.mkv', '.mpeg', '.mp3', '.flac']: video_file = os.path.join(cfg.TMP_DIR, f'{noextname}{ext}') audio_file.save(video_file) params = [ "-i", video_file, ] if ext not in ['.mp3', '.flac']: params.append('-vn') params.append(wav_file) rs = tool.runffmpeg(params) if rs != 'ok': return jsonify({"code": 1, "msg": rs}) msg = "," + cfg.transobj['lang9'] elif ext == '.wav': audio_file.save(wav_file) else: return jsonify({"code": 1, "msg": f"{cfg.transobj['lang3']} {ext}"}) print(f'{ext=}') sets=cfg.parse_ini() model = WhisperModel(model, device=sets.get('devtype'), compute_type=sets.get('cuda_com_type'), download_root=cfg.ROOT_DIR + "/models", local_files_only=True) segments,_ = model.transcribe(wav_file, beam_size=sets.get('beam_size'),best_of=sets.get('best_of'),temperature=0 if sets.get('temperature')==0 else [0.0,0.2,0.4,0.6,0.8,1.0],condition_on_previous_text=sets.get('condition_on_previous_text'),vad_filter=sets.get('vad'), vad_parameters=dict(min_silence_duration_ms=300,max_speech_duration_s=10.5),language=language,initial_prompt=None if language!='zh' else sets.get('initial_prompt_zh')) raw_subtitles = [] for segment in segments: start = int(segment.start * 1000) end = int(segment.end * 1000) startTime = tool.ms_to_time_string(ms=start) endTime = tool.ms_to_time_string(ms=end) text = segment.text.strip().replace(''', "'") text = re.sub(r'&#\d+;', '', text) # 无有效字符 if not text or re.match(r'^[,。、?‘’“”;:({}【】):;"\'\s \d`!@#$%^&*()_+=.,?/\\-]*$', text) or len(text) <= 1: continue if response_format == 'json': # 原语言字幕 raw_subtitles.append( {"line": len(raw_subtitles) + 1, "start_time": startTime, "end_time": endTime, "text": text}) elif response_format == 'text': raw_subtitles.append(text) else: raw_subtitles.append(f'{len(raw_subtitles) + 1}\n{startTime} --> {endTime}\n{text}\n') if response_format != 'json': raw_subtitles = "\n".join(raw_subtitles) search_text = raw_subtitles[0]['text'] search = query(search_text) return jsonify({"code": 0, "msg": 'ok', "data": search}) except Exception as e: print(e) app.logger.error(f'[api]error: {e}') return jsonify({'code': 2, 'msg': str(e)}) @app.route('/checkupdate', methods=['GET', 'POST']) def checkupdate(): return jsonify({'code': 0, "msg": cfg.updatetips}) if __name__ == '__main__': http_server = None try: threading.Thread(target=tool.checkupdate).start() try: if cfg.devtype=='cpu': print('\n如果设备使用英伟达显卡并且CUDA环境已正确安装,可修改set.ini中\ndevtype=cpu 为 devtype=cuda, 然后重新启动以加快识别速度\n') host = cfg.web_address.split(':') http_server = WSGIServer((host[0], int(host[1])), app, handler_class=CustomRequestHandler) threading.Thread(target=tool.openweb, args=(cfg.web_address,)).start() http_server.serve_forever() finally: if http_server: http_server.stop() except Exception as e: if http_server: http_server.stop() print("error:" + str(e)) app.logger.error(f"[app]start error:{str(e)}")