Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import re | |
| import uuid | |
| import tempfile | |
| import json | |
| from argparse import ArgumentParser | |
| from threading import Thread | |
| from queue import Queue | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| import whisper | |
| from transformers import ( | |
| WhisperFeatureExtractor, | |
| AutoTokenizer, | |
| AutoModel, | |
| AutoModelForCausalLM | |
| ) | |
| from transformers.generation.streamers import BaseStreamer | |
| from speech_tokenizer.modeling_whisper import WhisperVQEncoder | |
| from speech_tokenizer.utils import extract_speech_token | |
| # Add local paths | |
| sys.path.insert(0, "./cosyvoice") | |
| sys.path.insert(0, "./third_party/Matcha-TTS") | |
| from flow_inference import AudioDecoder | |
| # RAG imports | |
| from langchain_community.document_loaders import * | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores.faiss import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from tqdm import tqdm | |
| import joblib | |
| import spaces | |
| # Token streamer for generation | |
| class TokenStreamer(BaseStreamer): | |
| def __init__(self, skip_prompt: bool = False, timeout=None): | |
| self.skip_prompt = skip_prompt | |
| self.token_queue = Queue() | |
| self.stop_signal = None | |
| self.next_tokens_are_prompt = True | |
| self.timeout = timeout | |
| def put(self, value): | |
| if len(value.shape) > 1 and value.shape[0] > 1: | |
| raise ValueError("TextStreamer only supports batch size 1") | |
| elif len(value.shape) > 1: | |
| value = value[0] | |
| if self.skip_prompt and self.next_tokens_are_prompt: | |
| self.next_tokens_are_prompt = False | |
| return | |
| for token in value.tolist(): | |
| self.token_queue.put(token) | |
| def end(self): | |
| self.token_queue.put(self.stop_signal) | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| value = self.token_queue.get(timeout=self.timeout) | |
| if value == self.stop_signal: | |
| raise StopIteration() | |
| else: | |
| return value | |
| # File loader mapping | |
| LOADER_MAPPING = { | |
| '.pdf': PyPDFLoader, | |
| '.txt': TextLoader, | |
| '.md': UnstructuredMarkdownLoader, | |
| '.csv': CSVLoader, | |
| '.jpg': UnstructuredImageLoader, | |
| '.jpeg': UnstructuredImageLoader, | |
| '.png': UnstructuredImageLoader, | |
| '.json': JSONLoader, | |
| '.html': BSHTMLLoader, | |
| '.htm': BSHTMLLoader | |
| } | |
| def load_single_file(file_path): | |
| _, ext = os.path.splitext(file_path) | |
| ext = ext.lower() | |
| loader_class = LOADER_MAPPING.get(ext) | |
| if not loader_class: | |
| print(f"Unsupported file type: {ext}") | |
| return None | |
| loader = loader_class(file_path) | |
| docs = list(loader.lazy_load()) | |
| return docs | |
| def load_files(file_paths: list): | |
| if not file_paths: | |
| return [] | |
| docs = [] | |
| for file_path in tqdm(file_paths): | |
| print("Loading docs:", file_path) | |
| loaded_docs = load_single_file(file_path) | |
| if loaded_docs: | |
| docs.extend(loaded_docs) | |
| return docs | |
| def split_text(txt, chunk_size=200, overlap=20): | |
| if not txt: | |
| return None | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap) | |
| docs = splitter.split_documents(txt) | |
| return docs | |
| def create_embedding_model(model_file): | |
| embedding = HuggingFaceEmbeddings(model_name=model_file, model_kwargs={'trust_remote_code': True}) | |
| return embedding | |
| def save_file_paths(store_path, file_paths): | |
| joblib.dump(file_paths, f'{store_path}/file_paths.pkl') | |
| def load_file_paths(store_path): | |
| file_paths_file = f'{store_path}/file_paths.pkl' | |
| if os.path.exists(file_paths_file): | |
| return joblib.load(file_paths_file) | |
| return None | |
| def file_paths_match(store_path, file_paths): | |
| saved_file_paths = load_file_paths(store_path) | |
| return saved_file_paths == file_paths | |
| def create_vector_store(docs, store_file, embeddings): | |
| vector_store = FAISS.from_documents(docs, embeddings) | |
| vector_store.save_local(store_file) | |
| return vector_store | |
| def load_vector_store(store_path, embeddings): | |
| if os.path.exists(store_path): | |
| vector_store = FAISS.load_local(store_path, embeddings, allow_dangerous_deserialization=True) | |
| return vector_store | |
| else: | |
| return None | |
| def load_or_create_store(store_path, file_paths, embeddings): | |
| if os.path.exists(store_path) and file_paths_match(store_path, file_paths): | |
| print("Vector database is consistent with last use, no need to rewrite") | |
| vector_store = load_vector_store(store_path, embeddings) | |
| if vector_store: | |
| return vector_store | |
| print("Rewriting database") | |
| pages = load_files(file_paths) | |
| docs = split_text(pages) | |
| vector_store = create_vector_store(docs, store_path, embeddings) | |
| save_file_paths(store_path, file_paths) | |
| return vector_store | |
| def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8): | |
| retriever = vector_store.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={"score_threshold": relevance_threshold, "k": k} | |
| ) | |
| similar_docs = retriever.invoke(query) | |
| context = [doc.page_content for doc in similar_docs] | |
| return context | |
| class ModelWorker: | |
| def __init__(self, model_path, device='cuda'): | |
| self.device = device | |
| self.glm_model = AutoModel.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| device=device | |
| ).to(device).eval() | |
| self.glm_tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True | |
| ) | |
| def generate_stream(self, params): | |
| prompt = params["prompt"] | |
| temperature = float(params.get("temperature", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| max_new_tokens = int(params.get("max_new_tokens", 256)) | |
| inputs = self.glm_tokenizer([prompt], return_tensors="pt") | |
| inputs = inputs.to(self.device) | |
| streamer = TokenStreamer(skip_prompt=True) | |
| thread = Thread( | |
| target=self.glm_model.generate, | |
| kwargs=dict( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| streamer=streamer | |
| ) | |
| ) | |
| thread.start() | |
| for token_id in streamer: | |
| yield token_id | |
| def generate_stream_gate(self, params): | |
| try: | |
| for x in self.generate_stream(params): | |
| yield x | |
| except Exception as e: | |
| print("Caught Unknown Error", e) | |
| ret = "Server Error" | |
| yield ret | |
| def initialize_embedding_model_and_vector_store(Embedding_Model, store_path, file_paths): | |
| embedding_model = create_embedding_model(Embedding_Model) | |
| vector_store = load_or_create_store(store_path, file_paths, embedding_model) | |
| return vector_store, embedding_model | |
| def handle_file_upload(files): | |
| if not files: | |
| return None | |
| file_paths = [file.name for file in files] | |
| return file_paths | |
| def reinitialize_database(files, progress=gr.Progress()): | |
| global vector_store, embedding_model | |
| if not files: | |
| return "No files uploaded. Please upload files first." | |
| file_paths = [file.name for file in files] | |
| progress(0, desc="Initializing embedding model...") | |
| embedding_model = create_embedding_model(Embedding_Model) | |
| progress(0.3, desc="Loading documents...") | |
| pages = load_files(file_paths) | |
| progress(0.5, desc="Splitting text...") | |
| docs = split_text(pages) | |
| progress(0.7, desc="Creating vector store...") | |
| vector_store = create_vector_store(docs, store_path, embedding_model) | |
| save_file_paths(store_path, file_paths) | |
| return "Database reinitialized successfully!" | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default="7860") | |
| parser.add_argument("--flow-path", type=str, default="THUDM/glm-4-voice-decoder") | |
| parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b") | |
| parser.add_argument("--tokenizer-path", type=str, default="THUDM/glm-4-voice-tokenizer") | |
| parser.add_argument("--whisper_model", type=str, default="base") | |
| parser.add_argument("--share", action='store_true') | |
| args = parser.parse_args() | |
| # Define model configurations | |
| flow_config = os.path.join(args.flow_path, "config.yaml") | |
| flow_checkpoint = os.path.join(args.flow_path, 'flow.pt') | |
| hift_checkpoint = os.path.join(args.flow_path, 'hift.pt') | |
| device = "cuda" | |
| # Global variables | |
| audio_decoder = None | |
| whisper_model = None | |
| feature_extractor = None | |
| glm_model = None | |
| glm_tokenizer = None | |
| vector_store = None | |
| embedding_model = None | |
| whisper_transcribe_model = None | |
| model_worker = None | |
| # RAG configuration | |
| Embedding_Model = '/root/autodl-tmp/rag/multilingual-e5-large-instruct' | |
| file_paths = ['/root/autodl-tmp/rag/me.txt', "/root/autodl-tmp/rag/2024-Wealth-Outlook-MidYear-Edition.pdf"] | |
| store_path = '/root/autodl-tmp/rag/me.faiss' | |
| def initialize_fn(): | |
| global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer | |
| global vector_store, embedding_model, whisper_transcribe_model, model_worker | |
| if audio_decoder is not None: | |
| return | |
| model_worker = ModelWorker(args.model_path, device) | |
| glm_tokenizer = model_worker.glm_tokenizer | |
| audio_decoder = AudioDecoder( | |
| config_path=flow_config, | |
| flow_ckpt_path=flow_checkpoint, | |
| hift_ckpt_path=hift_checkpoint, | |
| device=device | |
| ) | |
| whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device) | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path) | |
| embedding_model = create_embedding_model(Embedding_Model) | |
| vector_store = load_or_create_store(store_path, file_paths, embedding_model) | |
| whisper_transcribe_model = whisper.load_model("/root/autodl-tmp/whisper/base/base.pt") | |
| def clear_fn(): | |
| return [], [], '', '', '', None, None | |
| def inference_fn( | |
| temperature: float, | |
| top_p: float, | |
| max_new_token: int, | |
| input_mode, | |
| audio_path: str | None, | |
| input_text: str | None, | |
| history: list[dict], | |
| previous_input_tokens: str, | |
| previous_completion_tokens: str, | |
| ): | |
| global whisper_transcribe_model, vector_store | |
| using_context = False | |
| if input_mode == "audio": | |
| assert audio_path is not None | |
| history.append({"role": "user", "content": {"path": audio_path}}) | |
| audio_tokens = extract_speech_token( | |
| whisper_model, feature_extractor, [audio_path] | |
| )[0] | |
| if len(audio_tokens) == 0: | |
| raise gr.Error("No audio tokens extracted") | |
| audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens]) | |
| audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>" | |
| user_input = audio_tokens | |
| system_prompt = "User will provide you with a speech instruction. Do it step by step." | |
| whisper_result = whisper_transcribe_model.transcribe(audio_path) | |
| transcribed_text = whisper_result['text'] | |
| context = query_vector_store(vector_store, transcribed_text, 4, 0.7) | |
| else: | |
| assert input_text is not None | |
| history.append({"role": "user", "content": input_text}) | |
| user_input = input_text | |
| system_prompt = "User will provide you with a text instruction. Do it step by step." | |
| context = query_vector_store(vector_store, input_text, 4, 0.7) | |
| if context is not None: | |
| using_context = True | |
| inputs = previous_input_tokens + previous_completion_tokens | |
| inputs = inputs.strip() | |
| if "<|system|>" not in inputs: | |
| inputs += f"<|system|>\n{system_prompt}" | |
| if ("<|context|>" not in inputs) and (using_context == True): | |
| inputs += f"<|context|> According to the following content: {context}, Please answer the question" | |
| if "<|context|>" not in inputs and context is not None: | |
| inputs += f"<|context|>\n{context}" | |
| inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n" | |
| with torch.no_grad(): | |
| text_tokens, audio_tokens = [], [] | |
| audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>') | |
| end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>') | |
| complete_tokens = [] | |
| prompt_speech_feat = torch.zeros(1, 0, 80).to(device) | |
| flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device) | |
| this_uuid = str(uuid.uuid4()) | |
| tts_speechs = [] | |
| tts_mels = [] | |
| prev_mel = None | |
| is_finalize = False | |
| block_size = 10 | |
| # Generate tokens using ModelWorker directly instead of API | |
| for token_id in model_worker.generate_stream_gate({ | |
| "prompt": inputs, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_new_tokens": max_new_token, | |
| }): | |
| if isinstance(token_id, str): # Error case | |
| yield history, inputs, '', token_id, None, None | |
| return | |
| if token_id == end_token_id: | |
| is_finalize = True | |
| if len(audio_tokens) >= block_size or (is_finalize and audio_tokens): | |
| block_size = 20 | |
| tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0) | |
| if prev_mel is not None: | |
| prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2) | |
| tts_speech, tts_mel = audio_decoder.token2wav( | |
| tts_token, | |
| uuid=this_uuid, | |
| prompt_token=flow_prompt_speech_token.to(device), | |
| prompt_feat=prompt_speech_feat.to(device), | |
| finalize=is_finalize | |
| ) | |
| prev_mel = tts_mel | |
| tts_speechs.append(tts_speech.squeeze()) | |
| tts_mels.append(tts_mel) | |
| yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None | |
| flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1) | |
| audio_tokens = [] | |
| if not is_finalize: | |
| complete_tokens.append(token_id) | |
| if token_id >= audio_offset: | |
| audio_tokens.append(token_id - audio_offset) | |
| else: | |
| text_tokens.append(token_id) | |
| # Generate final audio and save | |
| tts_speech = torch.cat(tts_speechs, dim=-1).cpu() | |
| complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav") | |
| history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}}) | |
| history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)}) | |
| yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy()) | |
| def update_input_interface(input_mode): | |
| if input_mode == "audio": | |
| return [gr.update(visible=True), gr.update(visible=False)] | |
| else: | |
| return [gr.update(visible=False), gr.update(visible=True)] | |
| # Create Gradio interface with new layout | |
| with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo: | |
| with gr.Row(): | |
| # Left column for chat interface | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Chat Interface") | |
| with gr.Row(): | |
| temperature = gr.Number(label="Temperature", value=0.2, minimum=0, maximum=1) | |
| top_p = gr.Number(label="Top p", value=0.8, minimum=0, maximum=1) | |
| max_new_token = gr.Number(label="Max new tokens", value=2000, minimum=1) | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| type="messages", | |
| scale=1, | |
| height=500 | |
| ) | |
| with gr.Row(): | |
| input_mode = gr.Radio( | |
| ["audio", "text"], | |
| label="Input Mode", | |
| value="audio" | |
| ) | |
| with gr.Row(): | |
| audio = gr.Audio( | |
| label="Input audio", | |
| type='filepath', | |
| show_download_button=True, | |
| visible=True | |
| ) | |
| text_input = gr.Textbox( | |
| label="Input text", | |
| placeholder="Enter your text here...", | |
| lines=2, | |
| visible=False | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| reset_btn = gr.Button("Clear") | |
| output_audio = gr.Audio( | |
| label="Play", | |
| streaming=True, | |
| autoplay=True, | |
| show_download_button=False | |
| ) | |
| complete_audio = gr.Audio( | |
| label="Last Output Audio (If Any)", | |
| show_download_button=True | |
| ) | |
| # Right column for database management | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Database Management") | |
| file_upload = gr.Files( | |
| label="Upload Database Files", | |
| file_types=[".txt", ".pdf", ".md", ".csv", ".json", ".html", ".htm"], | |
| file_count="multiple" | |
| ) | |
| reinit_btn = gr.Button("Reinitialize Database", variant="secondary") | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| history_state = gr.State([]) | |
| # Setup interaction handlers | |
| respond = submit_btn.click( | |
| inference_fn, | |
| inputs=[ | |
| temperature, | |
| top_p, | |
| max_new_token, | |
| input_mode, | |
| audio, | |
| text_input, | |
| history_state, | |
| ], | |
| outputs=[ | |
| history_state, | |
| output_audio, | |
| complete_audio | |
| ] | |
| ) | |
| respond.then(lambda s: s, [history_state], chatbot) | |
| reset_btn.click( | |
| clear_fn, | |
| outputs=[ | |
| chatbot, | |
| history_state, | |
| output_audio, | |
| complete_audio | |
| ] | |
| ) | |
| input_mode.change( | |
| update_input_interface, | |
| inputs=[input_mode], | |
| outputs=[audio, text_input] | |
| ) | |
| # Database reinitialization handler | |
| reinit_btn.click( | |
| reinitialize_database, | |
| inputs=[file_upload], | |
| outputs=[status_text] | |
| ) | |
| # Initialize models and launch interface | |
| initialize_fn() | |
| demo.launch( | |
| server_port=args.port, | |
| server_name=args.host, | |
| share=args.share | |
| ) |