Spaces:
Running
Running
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import easyocr | |
| import pytesseract | |
| import keras_ocr | |
| import pandas as pd | |
| from PIL import Image | |
| import io | |
| import re | |
| from typing import List, Tuple, Union | |
| from datetime import datetime | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| import torch | |
| from datetime import datetime | |
| import time | |
| from paddleocr import PaddleOCR | |
| # Initialisation of models | |
| def load_models(): | |
| global model_vehicle, model_plate, reader_easyocr, pipeline_kerasocr, processor_trocr, model_trocr, ocr_paddle | |
| model_vehicle = YOLO('models/yolov8n.pt') | |
| model_plate = YOLO('models/best.pt') | |
| reader_easyocr = easyocr.Reader(['en'], gpu=False) | |
| pipeline_kerasocr = keras_ocr.pipeline.Pipeline() | |
| processor_trocr = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
| model_trocr = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
| ocr_paddle = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False) | |
| load_models() | |
| # patterns plate layouts europe | |
| EUROPEAN_PATTERNS = { | |
| 'FR': r'^(?:[A-Z]{2}-\d{3}-[A-Z]{2}|\d{2,4}\s?[A-Z]{2,3}\s?\d{2,4})$', # France | |
| 'DE': r'^[A-Z]{1,3}-[A-Z]{1,2}\s?\d{1,4}[EH]?$', # Germany | |
| 'ES': r'^(\d{4}[A-Z]{3}|[A-Z]{1,2}\d{4}[A-Z]{2,3})$', # Spain | |
| 'IT': r'^[A-Z]{2}\s?\d{3}\s?[A-Z]{2}$', # Italy | |
| 'GB': r'^[A-Z]{2}\d{2}\s?[A-Z]{3}$', # Great-Britain | |
| 'NL': r'^[A-Z]{2}-\d{3}-[A-Z]$', # Netherlands | |
| 'BE': r'^(1-[A-Z]{3}-\d{3}|\d-[A-Z]{3}-\d{3})$', # Belgium | |
| 'PL': r'^[A-Z]{2,3}\s?\d{4,5}$', # Poland | |
| 'SE': r'^[A-Z]{3}\s?\d{3}$', # Sweden | |
| 'NO': r'^[A-Z]{2}\s?\d{5}$', # Norway | |
| 'FI': r'^[A-Z]{3}-\d{3}$', # Finland | |
| 'DK': r'^[A-Z]{2}\s?\d{2}\s?\d{3}$', # Denmark | |
| 'CH': r'^[A-Z]{2}\s?\d{1,6}$', # Switzerland | |
| 'AT': r'^[A-Z]{1,2}\s?\d{1,5}[A-Z]$', # Austria | |
| 'PT': r'^[A-Z]{2}-\d{2}-[A-Z]{2}$', # Portugal | |
| 'EU': r'^[A-Z0-9]{2,4}[-\s]?[A-Z0-9]{1,4}[-\s]?[A-Z0-9]{1,4}$' # Generic European plate | |
| } | |
| def preprocess_image(image): | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| blur = cv2.GaussianBlur(gray, (5, 5), 0) | |
| thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] | |
| return cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB) | |
| def trocr_ocr(image): | |
| pixel_values = processor_trocr(image, return_tensors="pt").pixel_values | |
| generated_ids = model_trocr.generate(pixel_values) | |
| return processor_trocr.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| def read_license_plate(license_plate_crop, ocr_engine='easyocr'): | |
| if ocr_engine == 'easyocr': | |
| detections_raw = reader_easyocr.readtext(license_plate_crop) | |
| detections_preprocessed = reader_easyocr.readtext(preprocess_image(license_plate_crop)) | |
| elif ocr_engine == 'pytesseract': | |
| text_raw = pytesseract.image_to_string(license_plate_crop, config='--psm 8') | |
| text_preprocessed = pytesseract.image_to_string(preprocess_image(license_plate_crop), config='--psm 8') | |
| detections_raw = [(None, text_raw.strip(), None)] | |
| detections_preprocessed = [(None, text_preprocessed.strip(), None)] | |
| elif ocr_engine == 'kerasocr': | |
| if len(license_plate_crop.shape) == 2 or license_plate_crop.shape[2] == 1: | |
| license_plate_crop = cv2.cvtColor(license_plate_crop, cv2.COLOR_GRAY2RGB) | |
| detection_results_raw = pipeline_kerasocr.recognize([license_plate_crop])[0] | |
| detection_results_preprocessed = pipeline_kerasocr.recognize([preprocess_image(license_plate_crop)])[0] | |
| detections_raw = [(None, ''.join([text for text, box in detection_results_raw]), None)] | |
| detections_preprocessed = [(None, ''.join([text for text, box in detection_results_preprocessed]), None)] | |
| elif ocr_engine == 'trocr': | |
| text_raw = trocr_ocr(license_plate_crop) | |
| text_preprocessed = trocr_ocr(preprocess_image(license_plate_crop)) | |
| detections_raw = [(None, text_raw.strip(), None)] | |
| detections_preprocessed = [(None, text_preprocessed.strip(), None)] | |
| elif ocr_engine == 'paddleocr': | |
| preprocessed_image = preprocess_image(license_plate_crop) # Assurez-vous que cette ligne est incluse | |
| result_raw = ocr_paddle.ocr(license_plate_crop) | |
| result_preprocessed = ocr_paddle.ocr(preprocessed_image) | |
| # Vérifiez si les résultats ne sont pas vides avant de les utiliser | |
| if result_raw and result_raw[0]: | |
| detections_raw = [(None, result_raw[0][0][1][0], result_raw[0][0][1][1])] | |
| else: | |
| detections_raw = [(None, '', 0.0)] | |
| if result_preprocessed and result_preprocessed[0]: | |
| detections_preprocessed = [(None, result_preprocessed[0][0][1][0], result_preprocessed[0][0][1][1])] | |
| else: | |
| detections_preprocessed = [(None, '', 0.0)] | |
| else: | |
| raise ValueError(f"OCR engine '{ocr_engine}' not supported.") | |
| def extract_text(detections): | |
| plate = [] | |
| for detection in detections: | |
| _, text, _ = detection | |
| text = text.upper().replace(' ', '') | |
| plate.append(text) | |
| return " ".join(plate) if plate else None | |
| return extract_text(detections_raw), extract_text(detections_preprocessed) | |
| def clean_plate_text(text): | |
| if text is None: | |
| return '' | |
| cleaned = re.sub(r'[^A-Z0-9\-\s]', '', text.upper()) | |
| cleaned = re.sub(r'\s+', '', cleaned).strip() | |
| return cleaned | |
| def validate_european_plate(text): | |
| for country, pattern in EUROPEAN_PATTERNS.items(): | |
| if re.match(pattern, text): | |
| return text, country | |
| return None, None | |
| def post_process_ocr(raw_text, preprocessed_text): | |
| cleaned_raw = clean_plate_text(raw_text) | |
| validated_raw, country_raw = validate_european_plate(cleaned_raw) | |
| cleaned_preprocessed = clean_plate_text(preprocessed_text) | |
| validated_preprocessed, country_preprocessed = validate_european_plate(cleaned_preprocessed) | |
| if validated_raw: | |
| return validated_raw, country_raw, True | |
| elif validated_preprocessed: | |
| return validated_preprocessed, country_preprocessed, True | |
| return cleaned_raw, 'Unknown', False | |
| def detect_and_recognize_plates(image, ocr_engine='easyocr', confidence_threshold=0.5): | |
| results_vehicle = model_vehicle(image) | |
| plates_detected = [] | |
| cropped_plates = [] | |
| vehicles_found = False | |
| for result in results_vehicle: | |
| for bbox in result.boxes.data.tolist(): | |
| x1, y1, x2, y2, score, class_id = bbox | |
| if score < confidence_threshold: | |
| continue # Skip detections below the confidence threshold | |
| if int(class_id) == 2: # Class ID 2 represents cars in COCO dataset | |
| vehicles_found = True | |
| vehicle = image[int(y1):int(y2), int(x1):int(x2)] | |
| results_plate = model_plate(vehicle) | |
| for result_plate in results_plate: | |
| for bbox_plate in result_plate.boxes.data.tolist(): | |
| px1, py1, px2, py2, pscore, pclass_id = bbox_plate | |
| if pscore < confidence_threshold: | |
| continue # Skip detections below the confidence threshold | |
| plate = vehicle[int(py1):int(py2), int(px1):int(px2)] | |
| cropped_plates.append(plate) # Save the cropped plate | |
| raw_result, preprocessed_result = read_license_plate(plate, ocr_engine=ocr_engine) | |
| if raw_result or preprocessed_result: | |
| validated_text, country, is_valid = post_process_ocr(raw_result, preprocessed_result) | |
| plates_detected.append({ | |
| 'raw_text': raw_result, | |
| 'preprocessed_text': preprocessed_result, | |
| 'validated_text': validated_text, | |
| 'country': country, | |
| 'is_valid': is_valid, | |
| 'bbox': [int(x1+px1), int(y1+py1), int(x1+px2), int(y1+py2)] | |
| }) | |
| # Annotate the image | |
| cv2.rectangle(image, (int(x1+px1), int(y1+py1)), (int(x1+px2), int(y1+py2)), (0, 255, 0), 2) | |
| if validated_text: | |
| cv2.putText(image, f"{validated_text} ({country})", (int(x1+px1), int(y1+py1)-10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
| if not vehicles_found: | |
| results_plate = model_plate(image) | |
| for result_plate in results_plate: | |
| for bbox_plate in result_plate.boxes.data.tolist(): | |
| px1, py1, px2, py2, pscore, pclass_id = bbox_plate | |
| if pscore < confidence_threshold: | |
| continue # Skip detections below the confidence threshold | |
| plate = image[int(py1):int(py2), int(px1):int(px2)] | |
| cropped_plates.append(plate) # Save the cropped plate | |
| raw_result, preprocessed_result = read_license_plate(plate, ocr_engine=ocr_engine) | |
| if raw_result or preprocessed_result: | |
| validated_text, country, is_valid = post_process_ocr(raw_result, preprocessed_result) | |
| plates_detected.append({ | |
| 'raw_text': raw_result, | |
| 'preprocessed_text': preprocessed_result, | |
| 'validated_text': validated_text, | |
| 'country': country, | |
| 'is_valid': is_valid, | |
| 'bbox': [int(px1), int(py1), int(px2), int(py2)] | |
| }) | |
| # Annotate the image | |
| cv2.rectangle(image, (int(px1), int(py1)), (int(px2), int(py2)), (0, 255, 0), 2) | |
| if validated_text: | |
| cv2.putText(image, f"{validated_text} ({country})", (int(px1), int(py1)-10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
| return image, plates_detected, cropped_plates | |
| def process_image(input_image, ocr_engine='easyocr', confidence_threshold=0.5) -> Tuple[Union[np.ndarray, None], pd.DataFrame, List[np.ndarray]]: | |
| try: | |
| # Convert Gradio image to numpy array | |
| if isinstance(input_image, np.ndarray): | |
| image_np = input_image | |
| elif isinstance(input_image, Image.Image): | |
| image_np = np.array(input_image) | |
| else: | |
| raise ValueError("Unsupported image type") | |
| # Detect and recognize plates | |
| annotated_image, plates, cropped_plates = detect_and_recognize_plates(image_np, ocr_engine=ocr_engine, confidence_threshold=confidence_threshold) | |
| # Prepare the result as a pandas DataFrame | |
| results = [] | |
| for i, plate in enumerate(plates): | |
| results.append({ | |
| "Plate Number": i + 1, | |
| "Validated Text": plate['validated_text'], | |
| "Country": plate['country'], | |
| "Valid": "Yes" if plate['is_valid'] else "No", | |
| "Raw OCR": plate['raw_text'], | |
| "Preprocessed OCR": plate['preprocessed_text'], | |
| }) | |
| df = pd.DataFrame(results) if results else pd.DataFrame({"Message": ["No license plates detected"]}) | |
| return annotated_image, df, cropped_plates | |
| except Exception as e: | |
| print(f"An error occurred: {str(e)}") | |
| return None, pd.DataFrame({"Error": [str(e)]}), [] | |
| def compare_ocr_engines(image): | |
| ocr_engines = ['easyocr', 'pytesseract', 'kerasocr', 'trocr'] | |
| results = {} | |
| for engine in ocr_engines: | |
| start_time = time.time() | |
| _, df, _ = process_image(image, ocr_engine=engine) | |
| end_time = time.time() | |
| results[engine] = { | |
| 'processing_time': end_time - start_time, | |
| 'plates_detected': len(df) if 'Plate Number' in df.columns else 0, | |
| 'texts': df['Validated Text'].tolist() if 'Validated Text' in df.columns else [] | |
| } | |
| comparison_df = pd.DataFrame({ | |
| 'OCR Engine': ocr_engines, | |
| 'Processing Time (s)': [results[engine]['processing_time'] for engine in ocr_engines], | |
| 'Plates Detected': [results[engine]['plates_detected'] for engine in ocr_engines], | |
| 'Detected Texts': [', '.join(results[engine]['texts']) for engine in ocr_engines] | |
| }) | |
| return comparison_df | |
| # gradio app | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🚗 ALPR YOLOv8 and Multi-OCR 🚗 | |
| Test this ALPR solution using YOLOv8 and various OCR engines! | |
| > Better results with high quality images, plate aligned horizontally, clearly visible. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("Single Image Processing"): | |
| with gr.Accordion("How It Works", open=False): | |
| gr.Markdown( | |
| """ | |
| This ALPR (Automatic License Plate Recognition) system works in several steps: | |
| 1. Vehicle Detection: Uses YOLOv8 to detect vehicles in the image with pretrained model on MS-COCO dataset. | |
| 2. License Plate Detection: Applies a custom YOLOv8 model to locate license plates region within detected vehicles to crop it. | |
| 3. Add preprocessing on the cropped plate that can help to give better results in some situation. | |
| 4. OCR: Employs various OCR engines to read the text on the cropped license plates. | |
| 5. Post-processing: Cleans and validates the detected text against known license plate patterns. | |
| """ | |
| ) | |
| with gr.Accordion("OCR Engines", open=False): | |
| gr.Markdown( | |
| """ | |
| The system supports multiple OCR engines: | |
| - [EasyOCR](https://github.com/JaidedAI/EasyOCR): General-purpose OCR library with good accuracy. | |
| - [Pytesseract](https://github.com/madmaze/pytesseract): Open-source OCR engine based on Tesseract. | |
| - [Keras-OCR](https://github.com/faustomorales/keras-ocr): Deep learning-based OCR solution. | |
| - [TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr): Transformer-based OCR model for handwritten and printed text. | |
| Each engine has its strengths and may perform differently depending on the image quality and license plate style. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="numpy", label="Input image") | |
| ocr_selector = gr.Radio(choices=['easyocr', 'paddleocr', 'pytesseract', 'kerasocr', 'trocr'], value='easyocr', label="Select OCR Engine") | |
| confidence_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Detection Confidence Threshold") | |
| submit_btn = gr.Button("Detect License Plates", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(type="numpy", label="Annotated image") | |
| cropped_plate_gallery = gr.Gallery(label="Cropped plates") | |
| output_table = gr.Dataframe(label="Detection results") | |
| with gr.Accordion("Understanding the Results", open=False): | |
| gr.Markdown( | |
| """ | |
| The results table provides the following information: | |
| - Plate Number: Sequential number assigned to each detected plate. | |
| - Validated Text: The final, cleaned, and validated license plate text. | |
| - Country: Estimated country of origin based on the plate format. | |
| - Valid: Indicates whether the plate matches a known format. | |
| - Raw OCR: The initial text detected by the OCR engine. | |
| - Preprocessed OCR: Text detected after image preprocessing. | |
| The confidence threshold determines the minimum confidence score for a detection to be considered valid. | |
| """ | |
| ) | |
| with gr.TabItem("OCR Engine Comparison"): | |
| with gr.Row(): | |
| comparison_input = gr.Image(type="numpy", label="Input Image for Comparison") | |
| compare_btn = gr.Button("Compare OCR Engines") | |
| comparison_output = gr.Dataframe(label="OCR Engine Comparison Results") | |
| # Event handlers | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, ocr_selector, confidence_slider], | |
| outputs=[output_image, output_table, cropped_plate_gallery] | |
| ) | |
| compare_btn.click( | |
| fn=compare_ocr_engines, | |
| inputs=[comparison_input], | |
| outputs=[comparison_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |