grammar-correction-api / local_test.py
Enoch Jason J
Modified app.py
caaf797
import torch
import re
import os
import textract
from fpdf import FPDF
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# --- Configuration ---
# All paths are now local
INPUT_DOC_PATH = "Doreen.doc"
OUTPUT_PDF_PATH = "Doreen_DeFio_Report_Local_Test.pdf"
# --- Model Paths (loading from local Hugging Face cache) ---
GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
BASE_MODEL_PATH = "unsloth/gemma-2b-it"
# FIX: This now points to the local folder containing your fine-tuned model.
LORA_ADAPTER_PATH = "gemma-grammar-lora"
# --- Global variables for models ---
grammar_model = None
grammar_tokenizer = None
gender_model = None
gender_tokenizer = None
device = "cpu"
# --- 1. Model Loading Logic (from main.py) ---
def load_all_models():
"""Loads all AI models into memory."""
global grammar_model, grammar_tokenizer, gender_model, gender_tokenizer
print("--- Starting Model Loading ---")
try:
print(f"Loading gender model from cache: {GENDER_MODEL_PATH}")
gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH)
gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device)
print("βœ… Gender verifier model loaded successfully!")
print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH, dtype=torch.float32
).to(device)
grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
print(f"Applying LoRA adapter from local folder: {LORA_ADAPTER_PATH}")
grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device)
print("βœ… Grammar correction model loaded successfully!")
if grammar_tokenizer.pad_token is None:
grammar_tokenizer.pad_token = grammar_tokenizer.eos_token
if gender_tokenizer.pad_token is None:
gender_tokenizer.pad_token = gender_tokenizer.eos_token
except Exception as e:
print(f"❌ Critical error during model loading: {e}")
return False
print("--- Model Loading Complete ---")
return True
# --- 2. Correction Functions (adapted from main.py) ---
def run_grammar_correction(text: str) -> str:
"""Corrects grammar using the loaded LoRA model."""
if not grammar_model: return text
input_text = f"Prompt: {text}\nResponse:"
inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Cleaning logic
if "Response:" in output_text:
parts = output_text.split("Response:")
if len(parts) > 1: return parts[1].strip()
return output_text.strip()
def run_gender_correction(text: str) -> str:
"""Corrects gender using the loaded gender model and regex."""
if not gender_model: return text
input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{text}\nResponse:"
inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = gender_model.generate(
**inputs, max_new_tokens=64, temperature=0.0,
do_sample=False, eos_token_id=gender_tokenizer.eos_token_id
)
output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Cleaning logic
if "Response:" in output_text:
parts = output_text.split("Response:")
if len(parts) > 1: output_text = parts[1].strip()
cleaned_text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', output_text, flags=re.IGNORECASE).strip().strip('"')
# Regex safety net
corrections = {
r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
}
for pattern, replacement in corrections.items():
cleaned_text = re.sub(pattern, replacement, cleaned_text, flags=re.IGNORECASE)
return cleaned_text
# --- 3. Document Processing Logic (from document_pipeline.py) ---
def extract_text_from_doc(filepath):
"""Extracts all text using textract."""
try:
text_bytes = textract.process(filepath)
return text_bytes.decode('utf-8')
except Exception as e:
print(f"Error reading document with textract: {e}")
return None
def parse_and_correct_text(raw_text):
"""Parses text and calls the local correction functions."""
structured_data = {}
key_value_pattern = re.compile(r'^\s*(Client Name|Date of Exam|...):s*(.*)', re.IGNORECASE | re.DOTALL) # Abridged for brevity
# This is the key change: we call the local functions directly
# instead of making API requests.
for line in raw_text.split('\n'):
# ... (parsing logic) ...
# Example of calling the function directly:
# corrected_value = run_grammar_correction(value)
# final_corrected = run_gender_correction(grammar_corrected)
pass # Placeholder for the full parsing logic from your script
# Dummy data to demonstrate PDF generation
structured_data['Client Name'] = run_grammar_correction("Morgan & Morgan")
structured_data['Intake'] = run_gender_correction(run_grammar_correction("The IME physician asked the examinee if he has any issues sleeping. The examinee replied yes."))
return structured_data
class PDF(FPDF):
"""Custom PDF class with Unicode font support."""
def header(self):
self.add_font('DejaVu', 'B', 'DejaVuSans-Bold.ttf', uni=True)
self.set_font('DejaVu', 'B', 15)
self.cell(0, 10, 'IME WatchDog Report', 0, 1, 'C')
self.ln(10)
def footer(self):
self.set_y(-15)
self.set_font('Helvetica', 'I', 8)
self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C')
def generate_pdf(data, output_path):
"""Generates the final PDF report."""
pdf = PDF()
pdf.add_font('DejaVu', '', 'DejaVuSans.ttf', uni=True)
pdf.add_page()
pdf.set_font('DejaVu', '', 12)
for key, value in data.items():
pdf.set_font('DejaVu', 'B', 12)
pdf.multi_cell(0, 8, f"{key}:")
pdf.set_font('DejaVu', '', 12)
pdf.multi_cell(0, 8, str(value))
pdf.ln(4)
pdf.output(output_path)
print(f"βœ… Successfully generated PDF report at: {output_path}")
# --- Main Execution ---
if __name__ == "__main__":
print("--- Starting Local Test Pipeline ---")
# 1. Pre-requisite: Make sure models are downloaded.
# It's assumed you've run download_models.py script locally first.
# 2. Load the models into memory
if load_all_models():
# 3. Extract raw text from the input document
raw_text = extract_text_from_doc(INPUT_DOC_PATH)
if raw_text:
# 4. Parse and correct the text
corrected_data = parse_and_correct_text(raw_text)
# 5. Generate the final PDF report
generate_pdf(corrected_data, OUTPUT_PDF_PATH)
print("--- Pipeline Finished ---")