Apollo-1 / app.py
Spestly's picture
Update app.py
415417b verified
import gradio as gr
import spaces
from transformers import pipeline
import torch
from typing import List, Dict, Optional
# Global variable to store pipelines
model_cache = {}
# Available models
AVAILABLE_MODELS = {
"Apollo-1-4B": "Loom-Labs/Apollo-1-4B",
"Apollo-1-8B": "Loom-Labs/Apollo-1-8B",
"Apollo-1-2B": "Loom-Labs/Apollo-1-2B",
"Daedalus-1-2B": "Loom-Labs/Daedalus-1-2B",
"Daedalus-1-8B": "Loom-Labs/Daedalus-1-8B",
}
@spaces.GPU
def initialize_model(model_name):
global model_cache
if model_name not in AVAILABLE_MODELS:
raise ValueError(f"Model {model_name} not found in available models")
model_id = AVAILABLE_MODELS[model_name]
# Check if model is already cached
if model_id not in model_cache:
try:
model_cache[model_id] = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
except Exception as e:
# Fallback to CPU if GPU fails
model_cache[model_id] = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True
)
return model_cache[model_id]
@spaces.GPU
def generate_response(message, history, model_name, max_length=512, temperature=0.7, top_p=0.9):
"""Generate response using the selected model"""
# Initialize model inside the GPU-decorated function
try:
model_pipe = initialize_model(model_name)
except Exception as e:
return f"Error loading model {model_name}: {str(e)}"
# Format the conversation history
messages = []
# Add conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Generate response
try:
# Some models may not support the messages format, so we'll try different approaches
try:
# Try with messages format first
response = model_pipe(
messages,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=model_pipe.tokenizer.eos_token_id,
return_full_text=False
)
except:
# Fallback to simple text format
conversation_text = ""
for msg in messages:
if msg["role"] == "user":
conversation_text += f"User: {msg['content']}\n"
else:
conversation_text += f"Assistant: {msg['content']}\n"
conversation_text += "Assistant:"
response = model_pipe(
conversation_text,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=model_pipe.tokenizer.eos_token_id,
return_full_text=False
)
# Extract the generated text
if isinstance(response, list) and len(response) > 0:
generated_text = response[0]['generated_text']
else:
generated_text = str(response)
# Clean up the response
if isinstance(generated_text, list):
assistant_response = generated_text[-1]['content']
else:
# Remove the prompt and extract assistant response
assistant_response = str(generated_text).strip()
if "Assistant:" in assistant_response:
assistant_response = assistant_response.split("Assistant:")[-1].strip()
return assistant_response
except Exception as e:
return f"Error generating response: {str(e)}"
@spaces.GPU
def generate(
model: str,
user_input: str,
history: Optional[str] = "",
temperature: float = 0.7,
system_prompt: Optional[str] = "",
max_tokens: int = 512
):
"""
API endpoint for LLM generation
Args:
model: Model name to use (Nous-1-2B, Nous-1-4B, or Nous-1-8B)
user_input: Current user message/input
history: JSON string of conversation history in format [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
temperature: Temperature for generation (0.1-2.0)
system_prompt: System prompt to guide the model
max_tokens: Maximum tokens to generate (1-8192)
Returns:
Generated response from the model
"""
# Validate model
if model not in AVAILABLE_MODELS:
return f"Error: Model {model} not available. Available models: {list(AVAILABLE_MODELS.keys())}"
# Initialize model
try:
model_pipe = initialize_model(model)
except Exception as e:
return f"Error loading model {model}: {str(e)}"
# Parse history if provided and convert to gradio format
gradio_history = []
if history and history.strip():
try:
import json
history_list = json.loads(history)
current_pair = [None, None]
for msg in history_list:
if isinstance(msg, dict) and "role" in msg and "content" in msg:
if msg["role"] == "user":
if current_pair[0] is not None:
gradio_history.append([current_pair[0], current_pair[1]])
current_pair = [msg["content"], None]
elif msg["role"] == "assistant":
current_pair[1] = msg["content"]
if current_pair[0] is not None:
gradio_history.append([current_pair[0], current_pair[1]])
except:
# If history parsing fails, continue without history
pass
# Add system prompt to user input if provided
final_user_input = user_input
if system_prompt and system_prompt.strip():
final_user_input = f"System: {system_prompt}\n\nUser: {user_input}"
# Use the original generate_response function
return generate_response(final_user_input, gradio_history, model, max_tokens, temperature, 0.9)
# Create the Gradio interface
def create_interface():
with gr.Blocks(title="Multi-Model Chat") as demo:
gr.Markdown("""
# πŸš€ Loom Labs Model Chat Interface
Chat with the models by Loom Labs.
**Available Models:**
- Apollo-1-4B (4 billion parameters)
- Apollo-1-8B (8 billion parameters)
- Apollo-1-2B (2 billion parameters)
- Daedalus-1-2B (2 billion parameters)
- Daedalus-1-8B (8 billion parameters)
""")
with gr.Row():
model_selector = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="Apollo-1-4B",
label="Select Model",
info="Choose which model to use for generation"
)
chatbot = gr.Chatbot(
height=400,
placeholder="Select a model and start chatting...",
label="Chat"
)
msg = gr.Textbox(
placeholder="Type your message here...",
label="Message",
lines=2
)
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear Chat", variant="secondary")
with gr.Accordion("Advanced Settings", open=False):
max_length = gr.Slider(
minimum=200,
maximum=8192,
value=2048,
step=50,
label="Max Length",
info="Maximum length of generated response"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Controls randomness in generation"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top P",
info="Controls diversity via nucleus sampling"
)
# Event handlers
def user_message(message, history):
return "", history + [[message, None]]
def bot_response(history, model_name, max_len, temp, top_p):
if history:
user_message = history[-1][0]
bot_message = generate_response(
user_message,
history[:-1],
model_name,
max_len,
temp,
top_p
)
history[-1][1] = bot_message
return history
def model_changed(model_name):
return gr.update(placeholder=f"Chat with {model_name}...")
# Wire up the events
msg.submit(user_message, [msg, chatbot], [msg, chatbot]).then(
bot_response, [chatbot, model_selector, max_length, temperature, top_p], chatbot
)
submit_btn.click(user_message, [msg, chatbot], [msg, chatbot]).then(
bot_response, [chatbot, model_selector, max_length, temperature, top_p], chatbot
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
model_selector.change(model_changed, model_selector, chatbot)
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
# Enable API and launch
demo.launch(share=True)