Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image | |
| import tempfile | |
| import torch | |
| from ultralytics import YOLO | |
| # Load models | |
| def load_models(): | |
| license_plate_detector = YOLO('license_plate_detector.pt') | |
| vehicle_detector = YOLO('yolov8n.pt') | |
| ort_session = ort.InferenceSession("model.onnx") | |
| return license_plate_detector, vehicle_detector, ort_session | |
| def draw_border(img, top_left, bottom_right, color=(0, 255, 0), thickness=10, line_length_x=200, line_length_y=200): | |
| x1, y1 = top_left | |
| x2, y2 = bottom_right | |
| # Draw corner lines | |
| cv2.line(img, (x1, y1), (x1, y1 + line_length_y), color, thickness) # top-left | |
| cv2.line(img, (x1, y1), (x1 + line_length_x, y1), color, thickness) | |
| cv2.line(img, (x1, y2), (x1, y2 - line_length_y), color, thickness) # bottom-left | |
| cv2.line(img, (x1, y2), (x1 + line_length_x, y2), color, thickness) | |
| cv2.line(img, (x2, y1), (x2 - line_length_x, y1), color, thickness) # top-right | |
| cv2.line(img, (x2, y1), (x2, y1 + line_length_y), color, thickness) | |
| cv2.line(img, (x2, y2), (x2, y2 - line_length_y), color, thickness) # bottom-right | |
| cv2.line(img, (x2, y2), (x2 - line_length_x, y2), color, thickness) | |
| return img | |
| def process_frame(frame, license_plate_detector, vehicle_detector, ort_session): | |
| # Detect vehicles | |
| vehicle_results = vehicle_detector(frame, classes=[2, 3, 5, 7]) # cars, motorcycles, bus, trucks | |
| # Process each vehicle | |
| for vehicle in vehicle_results[0].boxes.data: | |
| x1, y1, x2, y2, score, class_id = vehicle | |
| if score > 0.5: # Confidence threshold | |
| # Draw vehicle border | |
| draw_border(frame, | |
| (int(x1), int(y1)), | |
| (int(x2), int(y2)), | |
| color=(0, 255, 0), | |
| thickness=25, | |
| line_length_x=200, | |
| line_length_y=200) | |
| # Detect license plate in vehicle region | |
| vehicle_crop = frame[int(y1):int(y2), int(x1):int(x2)] | |
| license_results = license_plate_detector(vehicle_crop) | |
| for license_plate in license_results[0].boxes.data: | |
| lp_x1, lp_y1, lp_x2, lp_y2, lp_score, _ = license_plate | |
| if lp_score > 0.5: | |
| # Adjust coordinates to full frame | |
| abs_lp_x1 = int(x1 + lp_x1) | |
| abs_lp_y1 = int(y1 + lp_y1) | |
| abs_lp_x2 = int(x1 + lp_x2) | |
| abs_lp_y2 = int(y1 + lp_y2) | |
| # Draw license plate box | |
| cv2.rectangle(frame, | |
| (abs_lp_x1, abs_lp_y1), | |
| (abs_lp_x2, abs_lp_y2), | |
| (0, 0, 255), 12) | |
| # Extract and process license plate for OCR | |
| license_crop = frame[abs_lp_y1:abs_lp_y2, abs_lp_x1:abs_lp_x2] | |
| if license_crop.size > 0: | |
| # Prepare license crop for ONNX model | |
| license_crop_resized = cv2.resize(license_crop, (640, 640)) | |
| license_crop_processed = np.transpose(license_crop_resized, (2, 0, 1)).astype(np.float32) / 255.0 | |
| license_crop_processed = np.expand_dims(license_crop_processed, axis=0) | |
| # Run OCR inference | |
| try: | |
| inputs = {ort_session.get_inputs()[0].name: license_crop_processed} | |
| outputs = ort_session.run(None, inputs) | |
| # Process OCR output (adjust based on your model's output format) | |
| # This is a placeholder - adjust based on your ONNX model's output | |
| license_number = "ABC123" # Replace with actual OCR processing | |
| # Display license plate number | |
| H, W, _ = license_crop.shape | |
| license_crop_display = cv2.resize(license_crop, (int(W * 400 / H), 400)) | |
| try: | |
| # Display license crop and number above vehicle | |
| h_crop, w_crop, _ = license_crop_display.shape | |
| center_x = int((x1 + x2) / 2) | |
| # Display license plate crop | |
| frame[int(y1) - h_crop - 100:int(y1) - 100, | |
| int(center_x - w_crop/2):int(center_x + w_crop/2)] = license_crop_display | |
| # White background for text | |
| cv2.rectangle(frame, | |
| (int(center_x - w_crop/2), int(y1) - h_crop - 400), | |
| (int(center_x + w_crop/2), int(y1) - h_crop - 100), | |
| (255, 255, 255), | |
| -1) | |
| # Draw license number | |
| (text_width, text_height), _ = cv2.getTextSize( | |
| license_number, | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 4.3, | |
| 17) | |
| cv2.putText(frame, | |
| license_number, | |
| (int(center_x - text_width/2), int(y1 - h_crop - 250 + text_height/2)), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 4.3, | |
| (0, 0, 0), | |
| 17) | |
| except Exception as e: | |
| st.error(f"Error displaying results: {str(e)}") | |
| except Exception as e: | |
| st.error(f"Error in OCR processing: {str(e)}") | |
| return frame | |
| def process_video(video_path, license_plate_detector, vehicle_detector, ort_session): | |
| cap = cv2.VideoCapture(video_path) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
| out = cv2.VideoWriter(temp_file.name, | |
| cv2.VideoWriter_fourcc(*'mp4v'), | |
| fps, | |
| (width, height)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| progress_bar = st.progress(0) | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| processed_frame = process_frame(frame, license_plate_detector, vehicle_detector, ort_session) | |
| out.write(processed_frame) | |
| frame_count += 1 | |
| progress_bar.progress(frame_count / total_frames) | |
| cap.release() | |
| out.release() | |
| progress_bar.empty() | |
| return temp_file.name | |
| # Streamlit UI | |
| st.title("Advanced Vehicle and License Plate Detection") | |
| try: | |
| license_plate_detector, vehicle_detector, ort_session = load_models() | |
| uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"]) | |
| if uploaded_file is not None: | |
| file_type = uploaded_file.type.split('/')[0] | |
| if file_type == "image": | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| if st.button("Detect"): | |
| with st.spinner("Processing image..."): | |
| # Convert PIL Image to CV2 format | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| processed_image = process_frame(image_cv, license_plate_detector, vehicle_detector, ort_session) | |
| processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB) | |
| st.image(processed_image, caption="Processed Image", use_column_width=True) | |
| elif file_type == "video": | |
| tfile = tempfile.NamedTemporaryFile(delete=False) | |
| tfile.write(uploaded_file.read()) | |
| st.video(tfile.name) | |
| if st.button("Detect"): | |
| with st.spinner("Processing video..."): | |
| processed_video = process_video(tfile.name, license_plate_detector, vehicle_detector, ort_session) | |
| st.video(processed_video) | |
| except Exception as e: | |
| st.error(f"Error loading models: {str(e)}") |