Spaces:
Runtime error
Runtime error
| from llama_cpp import Llama | |
| import gc | |
| import threading | |
| import logging | |
| import sys | |
| log = logging.getLogger('llm_api.backend') | |
| class LlmBackend: | |
| SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык." | |
| SYSTEM_TOKEN = 1788 | |
| USER_TOKEN = 1404 | |
| BOT_TOKEN = 9225 | |
| LINEBREAK_TOKEN = 13 | |
| ROLE_TOKENS = { | |
| "user": USER_TOKEN, | |
| "bot": BOT_TOKEN, | |
| "system": SYSTEM_TOKEN | |
| } | |
| _instance = None | |
| _model = None | |
| _model_params = None | |
| _lock = threading.Lock() | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(LlmBackend, cls).__new__(cls) | |
| return cls._instance | |
| def is_model_loaded(self): | |
| return self._model is not None | |
| def load_model(self, model_path, context_size=2000, enable_gpu=True, gpu_layer_number=35, chat_format='llama-2'): | |
| log.info('load_model - started') | |
| self._model_params = {} | |
| self._model_params['model_path'] = model_path | |
| self._model_params['context_size'] = context_size | |
| self._model_params['enable_gpu'] = enable_gpu | |
| self._model_params['gpu_layer_number'] = gpu_layer_number | |
| self._model_params['chat_format'] = chat_format | |
| if self._model is not None: | |
| self.unload_model() | |
| with self._lock: | |
| if enable_gpu: | |
| self._model = Llama( | |
| model_path=model_path, | |
| chat_format=chat_format, | |
| n_ctx=context_size, | |
| n_parts=1, | |
| #n_batch=100, | |
| logits_all=True, | |
| #n_threads=12, | |
| verbose=True, | |
| n_gpu_layers=gpu_layer_number | |
| ) | |
| log.info('load_model - finished') | |
| return self._model | |
| else: | |
| self._model = Llama( | |
| model_path=model_path, | |
| chat_format=chat_format, | |
| n_ctx=context_size, | |
| n_parts=1, | |
| #n_batch=100, | |
| logits_all=True, | |
| #n_threads=12, | |
| verbose=True | |
| ) | |
| log.info('load_model - finished') | |
| return self._model | |
| def set_system_prompt(self, prompt): | |
| with self._lock: | |
| self.SYSTEM_PROMPT = prompt | |
| def unload_model(self): | |
| log.info('unload_model - started') | |
| with self._lock: | |
| if self._model is not None: | |
| del self._model | |
| log.info('unload_model - finished') | |
| def ensure_model_is_loaded(self): | |
| log.info('ensure_model_is_loaded - started') | |
| if not self.is_model_loaded(): | |
| log.info('ensure_model_is_loaded - model reloading') | |
| if self._model_params is not None: | |
| self.load_model(**self._model_params) | |
| else: | |
| log.info('ensure_model_is_loaded - No model config found. Reloading can not be done.') | |
| log.info('ensure_model_is_loaded - finished') | |
| def generate_tokens(self, generator): | |
| log.info('generate_tokens - started') | |
| with self._lock: | |
| self.ensure_model_is_loaded() | |
| try: | |
| for token in generator: | |
| if token == self._model.token_eos(): | |
| log.info('generate_tokens - finished') | |
| yield b'' # End of chunk | |
| break | |
| token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore") | |
| yield token_str | |
| except Exception as e: | |
| log.error('generate_tokens - error') | |
| log.error(e) | |
| yield b'' # End of chunk | |
| def create_chat_completion(self, messages, stream=True): | |
| log.info('create_chat_completion called') | |
| with self._lock: | |
| log.info('create_chat_completion started') | |
| try: | |
| return self._model.create_chat_completion(messages=messages, stream=stream) | |
| except Exception as e: | |
| log.error('create_chat_completion - error') | |
| log.error(e) | |
| return None | |
| def get_message_tokens(self, role, content): | |
| log.info('get_message_tokens - started') | |
| self.ensure_model_is_loaded() | |
| message_tokens = self._model.tokenize(content.encode("utf-8")) | |
| message_tokens.insert(1, self.ROLE_TOKENS[role]) | |
| message_tokens.insert(2, self.LINEBREAK_TOKEN) | |
| message_tokens.append(self._model.token_eos()) | |
| log.info('get_message_tokens - finished') | |
| return message_tokens | |
| def get_system_tokens(self): | |
| return self.get_message_tokens(role="system", content=self.SYSTEM_PROMPT) | |
| def create_chat_generator_for_saiga(self, messages, parameters, use_system_prompt=True): | |
| log.info('create_chat_generator_for_saiga - started') | |
| with self._lock: | |
| self.ensure_model_is_loaded() | |
| tokens = self.get_system_tokens() if use_system_prompt else [] | |
| for message in messages: | |
| message_tokens = self.get_message_tokens(role=message.get("from"), content=message.get("content", "")) | |
| tokens.extend(message_tokens) | |
| tokens.extend([self._model.token_bos(), self.BOT_TOKEN, self.LINEBREAK_TOKEN]) | |
| generator = self._model.generate( | |
| tokens, | |
| top_k=parameters['top_k'], | |
| top_p=parameters['top_p'], | |
| temp=parameters['temperature'], | |
| repeat_penalty=parameters['repetition_penalty'] | |
| ) | |
| log.info('create_chat_generator_for_saiga - finished') | |
| return generator | |
| def generate_tokens(self, generator): | |
| log.info('generate_tokens - started') | |
| with self._lock: | |
| self.ensure_model_is_loaded() | |
| try: | |
| for token in generator: | |
| if token == self._model.token_eos(): | |
| yield b'' # End of chunk | |
| log.info('generate_tokens - finished') | |
| break | |
| token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore") | |
| yield token_str | |
| except Exception as e: | |
| log.error('generate_tokens - error') | |
| log.error(e) | |
| yield b'' # End of chunk |