Javedalam commited on
Commit
adfd866
·
verified ·
1 Parent(s): 9a1e742

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -16
app.py CHANGED
@@ -1,22 +1,124 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- model_id = "qvac/genesisI-model"
 
 
 
 
 
5
 
6
- tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
7
  model = AutoModelForCausalLM.from_pretrained(
8
- model_id,
9
- torch_dtype=torch.bfloat16, # or torch.float16 on T4
10
- device_map="auto"
11
  )
 
12
 
13
- prompt = "Explain precision vs. recall in one paragraph."
14
- inputs = tok(prompt, return_tensors="pt").to(model.device)
15
- out = model.generate(
16
- **inputs,
17
- max_new_tokens=256,
18
- do_sample=True,
19
- top_p=0.9,
20
- temperature=0.7,
21
- )
22
- print(tok.decode(out[0], skip_special_tokens=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import gradio as gr
4
+
5
+ MODEL_ID = "qvac/genesis-i-model" # HF repo id
6
+
7
+ print("Loading tokenizer...")
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
9
+
10
+ print("Detecting device & dtype...")
11
+ if torch.cuda.is_available():
12
+ # Prefer BF16 on modern GPUs, else fall back to FP16
13
+ try:
14
+ bf16_ok = torch.cuda.is_bf16_supported()
15
+ except AttributeError:
16
+ bf16_ok = False
17
 
18
+ torch_dtype = torch.bfloat16 if bf16_ok else torch.float16
19
+ device_map = "auto"
20
+ else:
21
+ # CPU Space or no GPU: use full precision
22
+ torch_dtype = torch.float32
23
+ device_map = "cpu"
24
 
25
+ print(f"Loading model on {device_map} with dtype={torch_dtype}...")
26
  model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_ID,
28
+ torch_dtype=torch_dtype,
29
+ device_map=device_map,
30
  )
31
+ model.eval()
32
 
33
+
34
+ def generate(
35
+ prompt: str,
36
+ temperature: float = 0.7,
37
+ top_p: float = 0.9,
38
+ max_new_tokens: int = 256,
39
+ ):
40
+ if not prompt.strip():
41
+ return "Please enter a prompt."
42
+
43
+ inputs = tokenizer(prompt, return_tensors="pt")
44
+ # Move inputs to the same device as the model
45
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
46
+
47
+ with torch.no_grad():
48
+ output_ids = model.generate(
49
+ **inputs,
50
+ max_new_tokens=max_new_tokens,
51
+ do_sample=True,
52
+ temperature=temperature,
53
+ top_p=top_p,
54
+ pad_token_id=tokenizer.eos_token_id,
55
+ )
56
+
57
+ text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
58
+ # Return ONLY the completion after the original prompt, for cleanliness
59
+ if text.startswith(prompt):
60
+ text = text[len(prompt):].lstrip()
61
+
62
+ return text
63
+
64
+
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown(
67
+ """
68
+ # QVAC Genesis I – Educational LLM Demo
69
+
70
+ Model: **qvac/genesis-i-model**
71
+ Trained on the QVAC Genesis I synthetic educational dataset (STEM-heavy).
72
+ """
73
+ )
74
+
75
+ with gr.Row():
76
+ with gr.Column(scale=3):
77
+ prompt = gr.Textbox(
78
+ label="Prompt",
79
+ placeholder="Ask a STEM question, e.g. 'Explain Gibbs free energy to a high school student.'",
80
+ lines=6,
81
+ )
82
+ temperature = gr.Slider(
83
+ minimum=0.1,
84
+ maximum=1.2,
85
+ value=0.7,
86
+ step=0.05,
87
+ label="Temperature (creativity)",
88
+ )
89
+ top_p = gr.Slider(
90
+ minimum=0.1,
91
+ maximum=1.0,
92
+ value=0.9,
93
+ step=0.05,
94
+ label="Top-p (nucleus sampling)",
95
+ )
96
+ max_new_tokens = gr.Slider(
97
+ minimum=16,
98
+ maximum=512,
99
+ value=256,
100
+ step=16,
101
+ label="Max new tokens",
102
+ )
103
+ submit = gr.Button("Generate")
104
+
105
+ with gr.Column(scale=4):
106
+ output = gr.Textbox(
107
+ label="Model output",
108
+ lines=18,
109
+ )
110
+
111
+ submit.click(
112
+ fn=generate,
113
+ inputs=[prompt, temperature, top_p, max_new_tokens],
114
+ outputs=output,
115
+ )
116
+
117
+ # Press Enter in the prompt box to generate
118
+ prompt.submit(
119
+ fn=generate,
120
+ inputs=[prompt, temperature, top_p, max_new_tokens],
121
+ outputs=output,
122
+ )
123
+
124
+ demo.queue().launch()