Spaces:
Runtime error
Runtime error
| from flask import Flask, request, Response | |
| import logging | |
| import threading | |
| from huggingface_hub import snapshot_download#, Repository | |
| import huggingface_hub | |
| import gc | |
| import os.path | |
| import xml.etree.ElementTree as ET | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from datetime import datetime, timedelta | |
| from llm_backend import LlmBackend | |
| import json | |
| import sys | |
| llm = LlmBackend() | |
| _lock = threading.Lock() | |
| SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', default="Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык.") | |
| CONTEXT_SIZE = int(os.environ.get('CONTEXT_SIZE', default='500')) | |
| HF_CACHE_DIR = os.environ.get('HF_CACHE_DIR', default='/home/user/app/.cache') | |
| USE_SYSTEM_PROMPT = os.environ.get('USE_SYSTEM_PROMPT', default='False').lower() == 'true' | |
| ENABLE_GPU = os.environ.get('ENABLE_GPU', default='False').lower() == 'true' | |
| GPU_LAYERS = int(os.environ.get('GPU_LAYERS', default='0')) | |
| CHAT_FORMAT = os.environ.get('CHAT_FORMAT', default='llama-2') | |
| REPO_NAME = os.environ.get('REPO_NAME', default='IlyaGusev/saiga2_7b_gguf') | |
| MODEL_NAME = os.environ.get('MODEL_NAME', default='model-q4_K.gguf') | |
| DATASET_REPO_URL = os.environ.get('DATASET_REPO_URL', default="https://huggingface.co/datasets/muryshev/saiga-chat") | |
| DATA_FILENAME = os.environ.get('DATA_FILENAME', default="data-saiga-cuda-release.xml") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| APP_HOST = os.environ.get('APP_HOST', default='0.0.0.0') | |
| APP_PORT = int(os.environ.get('APP_PORT', default='7860')) | |
| FLASK_THREADED = os.environ.get('FLASK_THREADED', default='False').lower() == "true" | |
| # Create a lock object | |
| lock = threading.Lock() | |
| app = Flask('llm_api') | |
| app.logger.handlers.clear() | |
| handler = logging.StreamHandler(sys.stdout) | |
| handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | |
| app.logger.addHandler(handler) | |
| app.logger.setLevel(logging.DEBUG) | |
| # Variable to store the last request time | |
| last_request_time = datetime.now() | |
| # Initialize the model when the application starts | |
| #model_path = "../models/model-q4_K.gguf" # Replace with the actual model path | |
| #MODEL_NAME = "model/ggml-model-q4_K.gguf" | |
| #REPO_NAME = "IlyaGusev/saiga2_13b_gguf" | |
| #MODEL_NAME = "model-q4_K.gguf" | |
| #epo_name = "IlyaGusev/saiga2_70b_gguf" | |
| #MODEL_NAME = "ggml-model-q4_1.gguf" | |
| local_dir = '.' | |
| if os.path.isdir('/data'): | |
| app.logger.info('Persistent storage enabled') | |
| model = None | |
| MODEL_PATH = snapshot_download(repo_id=REPO_NAME, allow_patterns=MODEL_NAME, cache_dir=HF_CACHE_DIR) + '/' + MODEL_NAME | |
| app.logger.info('Model path: ' + MODEL_PATH) | |
| DATA_FILE = os.path.join("dataset", DATA_FILENAME) | |
| app.logger.info("hfh: "+huggingface_hub.__version__) | |
| # repo = Repository( | |
| # local_dir="dataset", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
| # ) | |
| # def log(req: str = '', resp: str = ''): | |
| # if req or resp: | |
| # element = ET.Element("row", {"time": str(datetime.now()) }) | |
| # req_element = ET.SubElement(element, "request") | |
| # req_element.text = req | |
| # resp_element = ET.SubElement(element, "response") | |
| # resp_element.text = resp | |
| # with open(DATA_FILE, "ab+") as xml_file: | |
| # xml_file.write(ET.tostring(element, encoding="utf-8")) | |
| # commit_url = repo.push_to_hub() | |
| # app.logger.info(commit_url) | |
| def handler_change_context_size(): | |
| global stop_generation, model | |
| stop_generation = True | |
| new_size = int(request.args.get('size', CONTEXT_SIZE)) | |
| init_model(new_size, ENABLE_GPU, GPU_LAYERS) | |
| return Response('Size changed', content_type='text/plain') | |
| def handler_stop_generation(): | |
| global stop_generation | |
| stop_generation = True | |
| return Response('Stopped', content_type='text/plain') | |
| def generate_unknown_response(): | |
| app.logger.info('unknown method: '+request.method) | |
| try: | |
| request_payload = request.get_json() | |
| app.logger.info('payload: '+request.get_json()) | |
| except Exception as e: | |
| app.logger.info('payload empty') | |
| return Response('What do you want?', content_type='text/plain') | |
| response_tokens = bytearray() | |
| def generate_and_log_tokens(user_request, generator): | |
| global response_tokens, last_request_time | |
| for token in llm.generate_tokens(generator): | |
| if token == b'': # or (max_new_tokens is not None and i >= max_new_tokens): | |
| last_request_time = datetime.now() | |
| # log(json.dumps(user_request), response_tokens.decode("utf-8", errors="ignore")) | |
| response_tokens = bytearray() | |
| break | |
| response_tokens.extend(token) | |
| yield token | |
| def generate_response(): | |
| app.logger.info('generate_response called') | |
| data = request.get_json() | |
| app.logger.info(data) | |
| messages = data.get("messages", []) | |
| preprompt = data.get("preprompt", "") | |
| parameters = data.get("parameters", {}) | |
| # Extract parameters from the request | |
| p = { | |
| 'temperature': parameters.get("temperature", 0.01), | |
| 'truncate': parameters.get("truncate", 1000), | |
| 'max_new_tokens': parameters.get("max_new_tokens", 1024), | |
| 'top_p': parameters.get("top_p", 0.85), | |
| 'repetition_penalty': parameters.get("repetition_penalty", 1.2), | |
| 'top_k': parameters.get("top_k", 30), | |
| 'return_full_text': parameters.get("return_full_text", False) | |
| } | |
| generator = llm.create_chat_generator_for_saiga(messages=messages, parameters=p, use_system_prompt=USE_SYSTEM_PROMPT) | |
| app.logger.info('Generator created') | |
| # Use Response to stream tokens | |
| return Response(generate_and_log_tokens(user_request='1', generator=generator), content_type='text/plain', status=200, direct_passthrough=True) | |
| def init_model(): | |
| llm.load_model(model_path=MODEL_PATH, context_size=CONTEXT_SIZE, enable_gpu=ENABLE_GPU, gpu_layer_number=GPU_LAYERS) | |
| # Function to check if no requests were made in the last 5 minutes | |
| def check_last_request_time(): | |
| global last_request_time | |
| current_time = datetime.now() | |
| if (current_time - last_request_time).total_seconds() > 300: # 5 minutes in seconds | |
| llm.unload_model() | |
| app.logger.info(f"Model unloaded at {current_time}") | |
| else: | |
| app.logger.info(f"No action needed at {current_time}") | |
| if __name__ == "__main__": | |
| init_model() | |
| # scheduler = BackgroundScheduler() | |
| # scheduler.add_job(check_last_request_time, trigger='interval', minutes=1) | |
| # scheduler.start() | |
| app.run(host=APP_HOST, port=APP_PORT, debug=False, threaded=FLASK_THREADED) | |