Spaces:
Running
Running
handrix
commited on
Commit
·
ae4e2a6
1
Parent(s):
bc7497c
Initial deployment - Toxic Detection API
Browse files- .dockerignore +63 -0
- .env.example +16 -0
- .gitattributes +1 -0
- Dockerfile +49 -0
- app/__init__.py +0 -0
- app/api/__init__.py +0 -0
- app/api/routes.py +113 -0
- app/core/__init__.py +0 -0
- app/core/config.py +46 -0
- app/core/exceptions.py +44 -0
- app/main.py +115 -0
- app/models/__init__.py +0 -0
- app/models/model_loader.py +123 -0
- app/models/phobert_model.py +105 -0
- app/schemas/__init__.py +0 -0
- app/schemas/requests.py +52 -0
- app/schemas/responses.py +138 -0
- app/services/__init__.py +0 -0
- app/services/analysis_service.py +318 -0
- app/services/gradient_service.py +167 -0
- app/services/html_generator.py +130 -0
- app/services/text_processor.py +142 -0
- models/PhoBERTFineTuned_best.pth +3 -0
- requirements.txt +9 -0
.dockerignore
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
ENV/
|
| 26 |
+
env/
|
| 27 |
+
.venv
|
| 28 |
+
|
| 29 |
+
# IDE
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
*~
|
| 35 |
+
|
| 36 |
+
# OS
|
| 37 |
+
.DS_Store
|
| 38 |
+
Thumbs.db
|
| 39 |
+
|
| 40 |
+
# Git
|
| 41 |
+
.git/
|
| 42 |
+
.gitignore
|
| 43 |
+
|
| 44 |
+
# Tests
|
| 45 |
+
tests/
|
| 46 |
+
.pytest_cache/
|
| 47 |
+
.coverage
|
| 48 |
+
htmlcov/
|
| 49 |
+
|
| 50 |
+
# Documentation
|
| 51 |
+
*.md
|
| 52 |
+
!README.md
|
| 53 |
+
|
| 54 |
+
# Environment
|
| 55 |
+
.env
|
| 56 |
+
.env.local
|
| 57 |
+
|
| 58 |
+
# Logs
|
| 59 |
+
*.log
|
| 60 |
+
|
| 61 |
+
# Other
|
| 62 |
+
*.tar.gz
|
| 63 |
+
*.zip
|
.env.example
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Configuration
|
| 2 |
+
MODEL_NAME=vinai/phobert-base
|
| 3 |
+
MODEL_PATH=./models/PhoBERTFineTuned_best.pth
|
| 4 |
+
MAX_LENGTH=128
|
| 5 |
+
DEVICE=cuda
|
| 6 |
+
|
| 7 |
+
# API Configuration
|
| 8 |
+
API_HOST=0.0.0.0
|
| 9 |
+
API_PORT=8000
|
| 10 |
+
API_RELOAD=True
|
| 11 |
+
|
| 12 |
+
# CORS
|
| 13 |
+
ALLOWED_ORIGINS=*
|
| 14 |
+
|
| 15 |
+
# Logging
|
| 16 |
+
LOG_LEVEL=INFO
|
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
models/*.pth filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-stage build for smaller image size
|
| 2 |
+
FROM python:3.10-slim as builder
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install build dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
build-essential \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Final stage
|
| 19 |
+
FROM python:3.10-slim
|
| 20 |
+
|
| 21 |
+
# Set working directory
|
| 22 |
+
WORKDIR /app
|
| 23 |
+
|
| 24 |
+
# Copy Python packages from builder
|
| 25 |
+
COPY --from=builder /root/.local /root/.local
|
| 26 |
+
|
| 27 |
+
# Make sure scripts in .local are usable
|
| 28 |
+
ENV PATH=/root/.local/bin:$PATH
|
| 29 |
+
|
| 30 |
+
# Copy application code
|
| 31 |
+
COPY ./app ./app
|
| 32 |
+
|
| 33 |
+
# Create models directory
|
| 34 |
+
RUN mkdir -p ./models
|
| 35 |
+
|
| 36 |
+
# Set environment variables
|
| 37 |
+
ENV PYTHONUNBUFFERED=1
|
| 38 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 39 |
+
ENV MODEL_PATH=/app/models/PhoBERTFineTuned_best.pth
|
| 40 |
+
|
| 41 |
+
# Expose port
|
| 42 |
+
EXPOSE 7860
|
| 43 |
+
|
| 44 |
+
# Health check
|
| 45 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 46 |
+
CMD python -c "import requests; requests.get('http://localhost:7860/api/v1/health')"
|
| 47 |
+
|
| 48 |
+
# Run the application
|
| 49 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__init__.py
ADDED
|
File without changes
|
app/api/__init__.py
ADDED
|
File without changes
|
app/api/routes.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Routes
|
| 3 |
+
==========
|
| 4 |
+
FastAPI routes (Interface Segregation)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 8 |
+
from typing import Dict
|
| 9 |
+
|
| 10 |
+
from app.schemas.requests import AnalysisRequest
|
| 11 |
+
from app.schemas.responses import AnalysisResponse, HealthResponse, ErrorResponse
|
| 12 |
+
from app.services.analysis_service import analysis_service
|
| 13 |
+
from app.models.model_loader import model_loader
|
| 14 |
+
from app.core.config import settings
|
| 15 |
+
from app.core.exceptions import ModelNotLoadedException, AnalysisException
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
router = APIRouter()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@router.get(
|
| 22 |
+
"/",
|
| 23 |
+
response_model=Dict[str, str],
|
| 24 |
+
summary="Root endpoint",
|
| 25 |
+
tags=["General"]
|
| 26 |
+
)
|
| 27 |
+
async def root():
|
| 28 |
+
"""Root endpoint - API information"""
|
| 29 |
+
return {
|
| 30 |
+
"message": "Toxic Text Detection API",
|
| 31 |
+
"version": settings.API_VERSION,
|
| 32 |
+
"docs": "/docs",
|
| 33 |
+
"health": "/api/v1/health"
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@router.get(
|
| 38 |
+
"/health",
|
| 39 |
+
response_model=HealthResponse,
|
| 40 |
+
summary="Health check",
|
| 41 |
+
tags=["General"]
|
| 42 |
+
)
|
| 43 |
+
async def health_check():
|
| 44 |
+
"""
|
| 45 |
+
Health check endpoint
|
| 46 |
+
|
| 47 |
+
Returns service status and model information
|
| 48 |
+
"""
|
| 49 |
+
return HealthResponse(
|
| 50 |
+
status="healthy" if model_loader.is_loaded() else "unhealthy",
|
| 51 |
+
model_loaded=model_loader.is_loaded(),
|
| 52 |
+
device=str(model_loader.device) if model_loader.is_loaded() else "unknown",
|
| 53 |
+
model_name=settings.MODEL_NAME,
|
| 54 |
+
version=settings.API_VERSION
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@router.post(
|
| 59 |
+
"/analyze",
|
| 60 |
+
response_model=AnalysisResponse,
|
| 61 |
+
responses={
|
| 62 |
+
200: {"description": "Analysis successful"},
|
| 63 |
+
400: {"model": ErrorResponse, "description": "Invalid input"},
|
| 64 |
+
500: {"model": ErrorResponse, "description": "Analysis failed"},
|
| 65 |
+
503: {"model": ErrorResponse, "description": "Model not loaded"}
|
| 66 |
+
},
|
| 67 |
+
summary="Analyze text for toxicity",
|
| 68 |
+
tags=["Analysis"]
|
| 69 |
+
)
|
| 70 |
+
async def analyze_text(request: AnalysisRequest):
|
| 71 |
+
"""
|
| 72 |
+
Analyze text for toxic content
|
| 73 |
+
|
| 74 |
+
This endpoint analyzes Vietnamese text to detect toxic content using
|
| 75 |
+
a fine-tuned PhoBERT model with gradient-based explainability.
|
| 76 |
+
|
| 77 |
+
**Features:**
|
| 78 |
+
- Sentence-level toxicity detection
|
| 79 |
+
- Word-level importance scores
|
| 80 |
+
- HTML highlighting of toxic content
|
| 81 |
+
- Detailed statistics
|
| 82 |
+
|
| 83 |
+
**Parameters:**
|
| 84 |
+
- **text**: Text to analyze (required)
|
| 85 |
+
- **include_html**: Include HTML highlighting (default: true)
|
| 86 |
+
- **include_word_scores**: Include word-level scores (default: true)
|
| 87 |
+
- **include_summary_table**: Include summary table (default: false)
|
| 88 |
+
|
| 89 |
+
**Returns:**
|
| 90 |
+
- Overall toxicity label (toxic/clean)
|
| 91 |
+
- Sentence-level analysis
|
| 92 |
+
- Word-level scores and toxic words summary
|
| 93 |
+
- HTML with highlighted toxic content
|
| 94 |
+
- Statistical information
|
| 95 |
+
"""
|
| 96 |
+
# Check if model is loaded
|
| 97 |
+
if not model_loader.is_loaded():
|
| 98 |
+
raise ModelNotLoadedException()
|
| 99 |
+
|
| 100 |
+
# Perform analysis
|
| 101 |
+
try:
|
| 102 |
+
result = analysis_service.analyze(request)
|
| 103 |
+
return result
|
| 104 |
+
except AnalysisException as e:
|
| 105 |
+
raise HTTPException(
|
| 106 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 107 |
+
detail=str(e)
|
| 108 |
+
)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 112 |
+
detail=f"Unexpected error: {str(e)}"
|
| 113 |
+
)
|
app/core/__init__.py
ADDED
|
File without changes
|
app/core/config.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core Configuration
|
| 3 |
+
==================
|
| 4 |
+
Application settings using Pydantic Settings
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pydantic_settings import BaseSettings
|
| 8 |
+
from typing import List
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Settings(BaseSettings):
|
| 13 |
+
"""Application settings"""
|
| 14 |
+
|
| 15 |
+
# Model Configuration
|
| 16 |
+
MODEL_NAME: str = "vinai/phobert-base"
|
| 17 |
+
MODEL_PATH: str = "./models/PhoBERTFineTuned_best.pth"
|
| 18 |
+
MAX_LENGTH: int = 128
|
| 19 |
+
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
+
# API Configuration
|
| 22 |
+
API_TITLE: str = "Toxic Text Detection API"
|
| 23 |
+
API_VERSION: str = "1.0.0"
|
| 24 |
+
API_DESCRIPTION: str = "Vietnamese toxic text detection with gradient-based explainability"
|
| 25 |
+
API_HOST: str = "0.0.0.0"
|
| 26 |
+
API_PORT: int = 8000
|
| 27 |
+
API_RELOAD: bool = True
|
| 28 |
+
|
| 29 |
+
# CORS
|
| 30 |
+
ALLOWED_ORIGINS: List[str] = ["*"]
|
| 31 |
+
|
| 32 |
+
# Analysis Settings
|
| 33 |
+
GRADIENT_STEPS: int = 20
|
| 34 |
+
PERCENTILE_THRESHOLD: int = 75
|
| 35 |
+
MIN_WORD_LENGTH: int = 2
|
| 36 |
+
|
| 37 |
+
# Logging
|
| 38 |
+
LOG_LEVEL: str = "INFO"
|
| 39 |
+
|
| 40 |
+
class Config:
|
| 41 |
+
env_file = ".env"
|
| 42 |
+
case_sensitive = True
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Singleton instance
|
| 46 |
+
settings = Settings()
|
app/core/exceptions.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Exceptions
|
| 3 |
+
=================
|
| 4 |
+
Application-specific exceptions
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from fastapi import HTTPException, status
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ToxicDetectionException(HTTPException):
|
| 11 |
+
"""Base exception for toxic detection"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, detail: str, status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR):
|
| 14 |
+
super().__init__(status_code=status_code, detail=detail)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ModelNotLoadedException(ToxicDetectionException):
|
| 18 |
+
"""Raised when model is not loaded"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__(
|
| 22 |
+
detail="Model not loaded. Please check server logs.",
|
| 23 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class InvalidTextException(ToxicDetectionException):
|
| 28 |
+
"""Raised when input text is invalid"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, detail: str = "Invalid text input"):
|
| 31 |
+
super().__init__(
|
| 32 |
+
detail=detail,
|
| 33 |
+
status_code=status.HTTP_400_BAD_REQUEST
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AnalysisException(ToxicDetectionException):
|
| 38 |
+
"""Raised when analysis fails"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, detail: str = "Analysis failed"):
|
| 41 |
+
super().__init__(
|
| 42 |
+
detail=detail,
|
| 43 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
| 44 |
+
)
|
app/main.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main FastAPI Application
|
| 3 |
+
=========================
|
| 4 |
+
Application entry point
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
from fastapi import FastAPI
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from fastapi.responses import JSONResponse
|
| 12 |
+
|
| 13 |
+
from app.core.config import settings
|
| 14 |
+
from app.core.exceptions import ToxicDetectionException
|
| 15 |
+
from app.models.model_loader import model_loader
|
| 16 |
+
from app.api.routes import router
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
logging.basicConfig(
|
| 21 |
+
level=getattr(logging, settings.LOG_LEVEL),
|
| 22 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 23 |
+
)
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@asynccontextmanager
|
| 28 |
+
async def lifespan(app: FastAPI):
|
| 29 |
+
"""
|
| 30 |
+
Lifespan events
|
| 31 |
+
|
| 32 |
+
Startup: Load model
|
| 33 |
+
Shutdown: Cleanup
|
| 34 |
+
"""
|
| 35 |
+
# Startup
|
| 36 |
+
logger.info("Starting up...")
|
| 37 |
+
try:
|
| 38 |
+
logger.info("Loading model...")
|
| 39 |
+
model_loader.load()
|
| 40 |
+
logger.info("Model loaded successfully")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"Failed to load model: {str(e)}")
|
| 43 |
+
# Continue anyway - health endpoint will show model not loaded
|
| 44 |
+
|
| 45 |
+
yield
|
| 46 |
+
|
| 47 |
+
# Shutdown
|
| 48 |
+
logger.info("Shutting down...")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Create FastAPI app
|
| 52 |
+
app = FastAPI(
|
| 53 |
+
title=settings.API_TITLE,
|
| 54 |
+
description=settings.API_DESCRIPTION,
|
| 55 |
+
version=settings.API_VERSION,
|
| 56 |
+
lifespan=lifespan,
|
| 57 |
+
docs_url="/docs",
|
| 58 |
+
redoc_url="/redoc",
|
| 59 |
+
openapi_url="/openapi.json"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# CORS middleware
|
| 63 |
+
app.add_middleware(
|
| 64 |
+
CORSMiddleware,
|
| 65 |
+
allow_origins=settings.ALLOWED_ORIGINS,
|
| 66 |
+
allow_credentials=True,
|
| 67 |
+
allow_methods=["*"],
|
| 68 |
+
allow_headers=["*"],
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Exception handlers
|
| 73 |
+
@app.exception_handler(ToxicDetectionException)
|
| 74 |
+
async def toxic_detection_exception_handler(request, exc: ToxicDetectionException):
|
| 75 |
+
"""Handle custom exceptions"""
|
| 76 |
+
return JSONResponse(
|
| 77 |
+
status_code=exc.status_code,
|
| 78 |
+
content={
|
| 79 |
+
"success": False,
|
| 80 |
+
"error": exc.detail,
|
| 81 |
+
"detail": None
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@app.exception_handler(Exception)
|
| 87 |
+
async def general_exception_handler(request, exc: Exception):
|
| 88 |
+
"""Handle general exceptions"""
|
| 89 |
+
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
|
| 90 |
+
return JSONResponse(
|
| 91 |
+
status_code=500,
|
| 92 |
+
content={
|
| 93 |
+
"success": False,
|
| 94 |
+
"error": "Internal server error",
|
| 95 |
+
"detail": str(exc) if settings.LOG_LEVEL == "DEBUG" else None
|
| 96 |
+
}
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Include routers
|
| 101 |
+
app.include_router(router, prefix="/api/v1", tags=["v1"])
|
| 102 |
+
app.include_router(router, prefix="", tags=["root"])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# For direct run
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
import uvicorn
|
| 108 |
+
|
| 109 |
+
uvicorn.run(
|
| 110 |
+
"app.main:app",
|
| 111 |
+
host=settings.API_HOST,
|
| 112 |
+
port=settings.API_PORT,
|
| 113 |
+
reload=settings.API_RELOAD,
|
| 114 |
+
log_level=settings.LOG_LEVEL.lower()
|
| 115 |
+
)
|
app/models/__init__.py
ADDED
|
File without changes
|
app/models/model_loader.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Loader
|
| 3 |
+
============
|
| 4 |
+
Responsible for loading and initializing models (Single Responsibility)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoModel, AutoTokenizer
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from app.models.phobert_model import PhoBERTFineTuned
|
| 13 |
+
from app.core.config import settings
|
| 14 |
+
from app.core.exceptions import ModelNotLoadedException
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ModelLoader:
|
| 21 |
+
"""
|
| 22 |
+
Model loader service
|
| 23 |
+
|
| 24 |
+
Responsibilities:
|
| 25 |
+
- Load tokenizer
|
| 26 |
+
- Load base model
|
| 27 |
+
- Load fine-tuned weights
|
| 28 |
+
- Initialize model on correct device
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
self._model: PhoBERTFineTuned | None = None
|
| 33 |
+
self._tokenizer: AutoTokenizer | None = None
|
| 34 |
+
self._device: torch.device | None = None
|
| 35 |
+
|
| 36 |
+
def load(self) -> Tuple[PhoBERTFineTuned, AutoTokenizer, torch.device]:
|
| 37 |
+
"""
|
| 38 |
+
Load model, tokenizer, and set device
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
model: Loaded model
|
| 42 |
+
tokenizer: Loaded tokenizer
|
| 43 |
+
device: Device (CPU/CUDA)
|
| 44 |
+
|
| 45 |
+
Raises:
|
| 46 |
+
ModelNotLoadedException: If loading fails
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
# Set device
|
| 50 |
+
self._device = torch.device(settings.DEVICE)
|
| 51 |
+
logger.info(f"Using device: {self._device}")
|
| 52 |
+
|
| 53 |
+
# Load tokenizer
|
| 54 |
+
logger.info(f"Loading tokenizer: {settings.MODEL_NAME}")
|
| 55 |
+
self._tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
| 56 |
+
|
| 57 |
+
# Load base model
|
| 58 |
+
logger.info(f"Loading base model: {settings.MODEL_NAME}")
|
| 59 |
+
phobert = AutoModel.from_pretrained(settings.MODEL_NAME)
|
| 60 |
+
|
| 61 |
+
# Initialize fine-tuned model
|
| 62 |
+
logger.info("Initializing fine-tuned model")
|
| 63 |
+
self._model = PhoBERTFineTuned(
|
| 64 |
+
embedding_model=phobert,
|
| 65 |
+
hidden_dim=768,
|
| 66 |
+
dropout=0.3,
|
| 67 |
+
num_classes=2,
|
| 68 |
+
num_layers_to_finetune=4,
|
| 69 |
+
pooling='mean'
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Load weights
|
| 73 |
+
logger.info(f"Loading weights from: {settings.MODEL_PATH}")
|
| 74 |
+
state_dict = torch.load(
|
| 75 |
+
settings.MODEL_PATH,
|
| 76 |
+
map_location=self._device
|
| 77 |
+
)
|
| 78 |
+
self._model.load_state_dict(state_dict)
|
| 79 |
+
|
| 80 |
+
# Move to device and set eval mode
|
| 81 |
+
self._model = self._model.to(self._device)
|
| 82 |
+
self._model.eval()
|
| 83 |
+
|
| 84 |
+
logger.info("Model loaded successfully")
|
| 85 |
+
|
| 86 |
+
return self._model, self._tokenizer, self._device
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.error(f"Failed to load model: {str(e)}")
|
| 90 |
+
raise ModelNotLoadedException()
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def model(self) -> PhoBERTFineTuned:
|
| 94 |
+
"""Get loaded model"""
|
| 95 |
+
if self._model is None:
|
| 96 |
+
raise ModelNotLoadedException()
|
| 97 |
+
return self._model
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def tokenizer(self) -> AutoTokenizer:
|
| 101 |
+
"""Get loaded tokenizer"""
|
| 102 |
+
if self._tokenizer is None:
|
| 103 |
+
raise ModelNotLoadedException()
|
| 104 |
+
return self._tokenizer
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def device(self) -> torch.device:
|
| 108 |
+
"""Get device"""
|
| 109 |
+
if self._device is None:
|
| 110 |
+
raise ModelNotLoadedException()
|
| 111 |
+
return self._device
|
| 112 |
+
|
| 113 |
+
def is_loaded(self) -> bool:
|
| 114 |
+
"""Check if model is loaded"""
|
| 115 |
+
return all([
|
| 116 |
+
self._model is not None,
|
| 117 |
+
self._tokenizer is not None,
|
| 118 |
+
self._device is not None
|
| 119 |
+
])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Singleton instance
|
| 123 |
+
model_loader = ModelLoader()
|
app/models/phobert_model.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PhoBERT Model
|
| 3 |
+
=============
|
| 4 |
+
Model architecture definition (Single Responsibility)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Tuple, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PhoBERTFineTuned(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Fine-tuned PhoBERT model for toxic text classification
|
| 15 |
+
|
| 16 |
+
Responsibilities:
|
| 17 |
+
- Define model architecture
|
| 18 |
+
- Forward pass computation
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
embedding_model: nn.Module,
|
| 24 |
+
hidden_dim: int = 768,
|
| 25 |
+
dropout: float = 0.3,
|
| 26 |
+
num_classes: int = 2,
|
| 27 |
+
num_layers_to_finetune: int = 4,
|
| 28 |
+
pooling: str = 'mean'
|
| 29 |
+
):
|
| 30 |
+
super(PhoBERTFineTuned, self).__init__()
|
| 31 |
+
|
| 32 |
+
self.embedding = embedding_model
|
| 33 |
+
self.pooling = pooling
|
| 34 |
+
self.num_layers_to_finetune = num_layers_to_finetune
|
| 35 |
+
|
| 36 |
+
# Freeze all parameters
|
| 37 |
+
for param in self.embedding.parameters():
|
| 38 |
+
param.requires_grad = False
|
| 39 |
+
|
| 40 |
+
# Unfreeze last N layers
|
| 41 |
+
if num_layers_to_finetune > 0:
|
| 42 |
+
total_layers = len(self.embedding.encoder.layer)
|
| 43 |
+
layers_to_train = list(range(
|
| 44 |
+
total_layers - num_layers_to_finetune,
|
| 45 |
+
total_layers
|
| 46 |
+
))
|
| 47 |
+
|
| 48 |
+
for layer_idx in layers_to_train:
|
| 49 |
+
for param in self.embedding.encoder.layer[layer_idx].parameters():
|
| 50 |
+
param.requires_grad = True
|
| 51 |
+
|
| 52 |
+
if hasattr(self.embedding, 'pooler') and self.embedding.pooler is not None:
|
| 53 |
+
for param in self.embedding.pooler.parameters():
|
| 54 |
+
param.requires_grad = True
|
| 55 |
+
|
| 56 |
+
# Classification head
|
| 57 |
+
self.dropout = nn.Dropout(dropout)
|
| 58 |
+
self.fc1 = nn.Linear(hidden_dim, 256)
|
| 59 |
+
self.fc2 = nn.Linear(256, num_classes)
|
| 60 |
+
self.relu = nn.ReLU()
|
| 61 |
+
self.layer_norm = nn.LayerNorm(hidden_dim)
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self,
|
| 65 |
+
input_ids: torch.Tensor,
|
| 66 |
+
attention_mask: torch.Tensor,
|
| 67 |
+
return_embeddings: bool = False
|
| 68 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 69 |
+
"""
|
| 70 |
+
Forward pass
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
input_ids: Input token IDs
|
| 74 |
+
attention_mask: Attention mask
|
| 75 |
+
return_embeddings: Whether to return embeddings
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
logits: Classification logits
|
| 79 |
+
embeddings: Hidden states (if return_embeddings=True)
|
| 80 |
+
"""
|
| 81 |
+
# Get embeddings
|
| 82 |
+
outputs = self.embedding(input_ids, attention_mask=attention_mask)
|
| 83 |
+
embeddings = outputs.last_hidden_state
|
| 84 |
+
|
| 85 |
+
# Pooling
|
| 86 |
+
if self.pooling == 'cls':
|
| 87 |
+
pooled = embeddings[:, 0, :]
|
| 88 |
+
elif self.pooling == 'mean':
|
| 89 |
+
mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
|
| 90 |
+
sum_embeddings = torch.sum(embeddings * mask_expanded, 1)
|
| 91 |
+
sum_mask = mask_expanded.sum(1)
|
| 92 |
+
pooled = sum_embeddings / sum_mask
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError(f"Unknown pooling method: {self.pooling}")
|
| 95 |
+
|
| 96 |
+
# Classification
|
| 97 |
+
pooled = self.layer_norm(pooled)
|
| 98 |
+
out = self.dropout(pooled)
|
| 99 |
+
out = self.relu(self.fc1(out))
|
| 100 |
+
out = self.dropout(out)
|
| 101 |
+
logits = self.fc2(out)
|
| 102 |
+
|
| 103 |
+
if return_embeddings:
|
| 104 |
+
return logits, embeddings
|
| 105 |
+
return logits, None
|
app/schemas/__init__.py
ADDED
|
File without changes
|
app/schemas/requests.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Request Schemas
|
| 3 |
+
===============
|
| 4 |
+
DTOs for API requests
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field, field_validator
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AnalysisRequest(BaseModel):
|
| 11 |
+
"""Request for text analysis"""
|
| 12 |
+
|
| 13 |
+
text: str = Field(
|
| 14 |
+
...,
|
| 15 |
+
description="Text to analyze for toxicity",
|
| 16 |
+
min_length=1,
|
| 17 |
+
max_length=5000,
|
| 18 |
+
examples=["Đồ ngu ngốc, mất dạy!"]
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
include_html: bool = Field(
|
| 22 |
+
default=True,
|
| 23 |
+
description="Include HTML highlighting in response"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
include_word_scores: bool = Field(
|
| 27 |
+
default=True,
|
| 28 |
+
description="Include detailed word-level scores"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
include_summary_table: bool = Field(
|
| 32 |
+
default=False,
|
| 33 |
+
description="Include summary table of all words"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
@field_validator('text')
|
| 37 |
+
@classmethod
|
| 38 |
+
def validate_text(cls, v: str) -> str:
|
| 39 |
+
"""Validate text input"""
|
| 40 |
+
if not v or not v.strip():
|
| 41 |
+
raise ValueError("Text cannot be empty or only whitespace")
|
| 42 |
+
return v.strip()
|
| 43 |
+
|
| 44 |
+
class Config:
|
| 45 |
+
json_schema_extra = {
|
| 46 |
+
"example": {
|
| 47 |
+
"text": "Đồ ngu ngốc, mất dạy! Cảm ơn bạn đã chia sẻ.",
|
| 48 |
+
"include_html": True,
|
| 49 |
+
"include_word_scores": True,
|
| 50 |
+
"include_summary_table": False
|
| 51 |
+
}
|
| 52 |
+
}
|
app/schemas/responses.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Response Schemas
|
| 3 |
+
================
|
| 4 |
+
DTOs for API responses
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
from typing import List, Optional, Dict
|
| 9 |
+
from enum import Enum
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SentimentLabel(str, Enum):
|
| 13 |
+
"""Sentiment labels"""
|
| 14 |
+
TOXIC = "toxic"
|
| 15 |
+
CLEAN = "clean"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class WordScore(BaseModel):
|
| 19 |
+
"""Word-level score information"""
|
| 20 |
+
|
| 21 |
+
word: str = Field(..., description="The word")
|
| 22 |
+
score: float = Field(..., ge=0.0, le=1.0, description="Toxicity score (0-1)")
|
| 23 |
+
position: Dict[str, int] = Field(..., description="Position in text {start, end}")
|
| 24 |
+
is_toxic: bool = Field(..., description="Whether word is toxic")
|
| 25 |
+
is_stop_word: bool = Field(..., description="Whether word is a stop word")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SentenceResult(BaseModel):
|
| 29 |
+
"""Sentence-level analysis result"""
|
| 30 |
+
|
| 31 |
+
sentence_number: int = Field(..., description="Sentence index (1-based)")
|
| 32 |
+
text: str = Field(..., description="Sentence text")
|
| 33 |
+
label: SentimentLabel = Field(..., description="Toxic or clean")
|
| 34 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
| 35 |
+
threshold: float = Field(..., ge=0.0, le=1.0, description="Threshold used")
|
| 36 |
+
word_count: int = Field(..., description="Number of words")
|
| 37 |
+
word_scores: Optional[List[WordScore]] = Field(None, description="Word-level scores")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ToxicWordSummary(BaseModel):
|
| 41 |
+
"""Summary of toxic words"""
|
| 42 |
+
|
| 43 |
+
word: str = Field(..., description="Toxic word")
|
| 44 |
+
score: float = Field(..., ge=0.0, le=1.0, description="Maximum score")
|
| 45 |
+
occurrences: int = Field(..., description="Number of occurrences")
|
| 46 |
+
sentences: List[int] = Field(..., description="Sentence numbers containing this word")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Statistics(BaseModel):
|
| 50 |
+
"""Overall statistics"""
|
| 51 |
+
|
| 52 |
+
total_words: int = Field(..., description="Total number of words")
|
| 53 |
+
toxic_words: int = Field(..., description="Number of toxic words")
|
| 54 |
+
mean_score: float = Field(..., ge=0.0, le=1.0, description="Mean toxicity score")
|
| 55 |
+
median_score: float = Field(..., ge=0.0, le=1.0, description="Median toxicity score")
|
| 56 |
+
max_score: float = Field(..., ge=0.0, le=1.0, description="Maximum toxicity score")
|
| 57 |
+
min_score: float = Field(..., ge=0.0, le=1.0, description="Minimum toxicity score")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AnalysisResponse(BaseModel):
|
| 61 |
+
"""Complete analysis response"""
|
| 62 |
+
|
| 63 |
+
success: bool = Field(True, description="Whether analysis succeeded")
|
| 64 |
+
text: str = Field(..., description="Original input text")
|
| 65 |
+
overall_label: SentimentLabel = Field(..., description="Overall text sentiment")
|
| 66 |
+
toxic_sentence_count: int = Field(..., description="Number of toxic sentences")
|
| 67 |
+
clean_sentence_count: int = Field(..., description="Number of clean sentences")
|
| 68 |
+
total_sentences: int = Field(..., description="Total number of sentences")
|
| 69 |
+
sentences: List[SentenceResult] = Field(..., description="Sentence-level results")
|
| 70 |
+
toxic_words_summary: List[ToxicWordSummary] = Field(..., description="Summary of toxic words")
|
| 71 |
+
statistics: Statistics = Field(..., description="Overall statistics")
|
| 72 |
+
html_highlighted: Optional[str] = Field(None, description="HTML with highlighting")
|
| 73 |
+
|
| 74 |
+
class Config:
|
| 75 |
+
json_schema_extra = {
|
| 76 |
+
"example": {
|
| 77 |
+
"success": True,
|
| 78 |
+
"text": "Đồ ngu ngốc!",
|
| 79 |
+
"overall_label": "toxic",
|
| 80 |
+
"toxic_sentence_count": 1,
|
| 81 |
+
"clean_sentence_count": 0,
|
| 82 |
+
"total_sentences": 1,
|
| 83 |
+
"sentences": [
|
| 84 |
+
{
|
| 85 |
+
"sentence_number": 1,
|
| 86 |
+
"text": "Đồ ngu ngốc!",
|
| 87 |
+
"label": "toxic",
|
| 88 |
+
"confidence": 0.998,
|
| 89 |
+
"threshold": 0.62,
|
| 90 |
+
"word_count": 3,
|
| 91 |
+
"word_scores": [
|
| 92 |
+
{
|
| 93 |
+
"word": "Đồ",
|
| 94 |
+
"score": 0.902,
|
| 95 |
+
"position": {"start": 0, "end": 2},
|
| 96 |
+
"is_toxic": True,
|
| 97 |
+
"is_stop_word": False
|
| 98 |
+
}
|
| 99 |
+
]
|
| 100 |
+
}
|
| 101 |
+
],
|
| 102 |
+
"toxic_words_summary": [
|
| 103 |
+
{
|
| 104 |
+
"word": "ngu",
|
| 105 |
+
"score": 0.924,
|
| 106 |
+
"occurrences": 1,
|
| 107 |
+
"sentences": [1]
|
| 108 |
+
}
|
| 109 |
+
],
|
| 110 |
+
"statistics": {
|
| 111 |
+
"total_words": 3,
|
| 112 |
+
"toxic_words": 3,
|
| 113 |
+
"mean_score": 0.856,
|
| 114 |
+
"median_score": 0.865,
|
| 115 |
+
"max_score": 0.924,
|
| 116 |
+
"min_score": 0.756
|
| 117 |
+
},
|
| 118 |
+
"html_highlighted": "<div>...</div>"
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class HealthResponse(BaseModel):
|
| 124 |
+
"""Health check response"""
|
| 125 |
+
|
| 126 |
+
status: str = Field(..., description="Service status")
|
| 127 |
+
model_loaded: bool = Field(..., description="Whether model is loaded")
|
| 128 |
+
device: str = Field(..., description="Device being used (cpu/cuda)")
|
| 129 |
+
model_name: str = Field(..., description="Model name")
|
| 130 |
+
version: str = Field(..., description="API version")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class ErrorResponse(BaseModel):
|
| 134 |
+
"""Error response"""
|
| 135 |
+
|
| 136 |
+
success: bool = Field(False, description="Always false for errors")
|
| 137 |
+
error: str = Field(..., description="Error message")
|
| 138 |
+
detail: Optional[str] = Field(None, description="Detailed error information")
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/analysis_service.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analysis Service
|
| 3 |
+
================
|
| 4 |
+
Main analysis orchestrator (Dependency Inversion + Open/Closed)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Dict
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
from app.models.model_loader import model_loader
|
| 12 |
+
from app.services.text_processor import TextProcessor
|
| 13 |
+
from app.services.gradient_service import GradientService
|
| 14 |
+
from app.services.html_generator import HTMLGenerator
|
| 15 |
+
from app.schemas.requests import AnalysisRequest
|
| 16 |
+
from app.schemas.responses import (
|
| 17 |
+
AnalysisResponse, SentenceResult, WordScore,
|
| 18 |
+
ToxicWordSummary, Statistics, SentimentLabel
|
| 19 |
+
)
|
| 20 |
+
from app.core.config import settings
|
| 21 |
+
from app.core.exceptions import AnalysisException
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AnalysisService:
|
| 25 |
+
"""
|
| 26 |
+
Main analysis service
|
| 27 |
+
|
| 28 |
+
Responsibilities:
|
| 29 |
+
- Orchestrate analysis pipeline
|
| 30 |
+
- Coordinate between services
|
| 31 |
+
- Build response
|
| 32 |
+
|
| 33 |
+
Dependencies:
|
| 34 |
+
- TextProcessor: Text processing
|
| 35 |
+
- GradientService: Gradient computation
|
| 36 |
+
- HTMLGenerator: HTML generation
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.text_processor = TextProcessor()
|
| 41 |
+
self.gradient_service = GradientService()
|
| 42 |
+
self.html_generator = HTMLGenerator()
|
| 43 |
+
|
| 44 |
+
def analyze(self, request: AnalysisRequest) -> AnalysisResponse:
|
| 45 |
+
"""
|
| 46 |
+
Analyze text for toxicity
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
request: Analysis request
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Analysis response
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
AnalysisException: If analysis fails
|
| 56 |
+
"""
|
| 57 |
+
try:
|
| 58 |
+
# 1. Split into sentences
|
| 59 |
+
sentences = self.text_processor.split_into_sentences(request.text)
|
| 60 |
+
|
| 61 |
+
# 2. Analyze each sentence
|
| 62 |
+
sentence_results = []
|
| 63 |
+
for i, sent_info in enumerate(sentences, 1):
|
| 64 |
+
sent_result = self._analyze_sentence(
|
| 65 |
+
sent_info,
|
| 66 |
+
i,
|
| 67 |
+
request.include_word_scores
|
| 68 |
+
)
|
| 69 |
+
sentence_results.append(sent_result)
|
| 70 |
+
|
| 71 |
+
# 3. Generate statistics
|
| 72 |
+
statistics = self._compute_statistics(sentence_results)
|
| 73 |
+
|
| 74 |
+
# 4. Extract toxic words summary
|
| 75 |
+
toxic_words_summary = self._extract_toxic_words_summary(sentence_results)
|
| 76 |
+
|
| 77 |
+
# 5. Generate HTML if requested
|
| 78 |
+
html_highlighted = None
|
| 79 |
+
if request.include_html:
|
| 80 |
+
html_highlighted = self.html_generator.generate_highlighted_html(
|
| 81 |
+
request.text,
|
| 82 |
+
[self._convert_to_dict(r) for r in sentence_results]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# 6. Determine overall label
|
| 86 |
+
toxic_count = sum(1 for r in sentence_results if r.label == SentimentLabel.TOXIC)
|
| 87 |
+
overall_label = SentimentLabel.TOXIC if toxic_count > 0 else SentimentLabel.CLEAN
|
| 88 |
+
|
| 89 |
+
# 7. Build response
|
| 90 |
+
return AnalysisResponse(
|
| 91 |
+
success=True,
|
| 92 |
+
text=request.text,
|
| 93 |
+
overall_label=overall_label,
|
| 94 |
+
toxic_sentence_count=toxic_count,
|
| 95 |
+
clean_sentence_count=len(sentences) - toxic_count,
|
| 96 |
+
total_sentences=len(sentences),
|
| 97 |
+
sentences=sentence_results,
|
| 98 |
+
toxic_words_summary=toxic_words_summary,
|
| 99 |
+
statistics=statistics,
|
| 100 |
+
html_highlighted=html_highlighted
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
raise AnalysisException(detail=str(e))
|
| 105 |
+
|
| 106 |
+
def _analyze_sentence(
|
| 107 |
+
self,
|
| 108 |
+
sent_info: Dict[str, any],
|
| 109 |
+
sent_number: int,
|
| 110 |
+
include_word_scores: bool
|
| 111 |
+
) -> SentenceResult:
|
| 112 |
+
"""Analyze single sentence"""
|
| 113 |
+
sent_text = sent_info['text']
|
| 114 |
+
|
| 115 |
+
# Extract words
|
| 116 |
+
words = self.text_processor.extract_words(sent_text)
|
| 117 |
+
|
| 118 |
+
if len(words) == 0:
|
| 119 |
+
return SentenceResult(
|
| 120 |
+
sentence_number=sent_number,
|
| 121 |
+
text=sent_text,
|
| 122 |
+
label=SentimentLabel.CLEAN,
|
| 123 |
+
confidence=0.0,
|
| 124 |
+
threshold=0.6,
|
| 125 |
+
word_count=0,
|
| 126 |
+
word_scores=[] if include_word_scores else None
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Tokenize
|
| 130 |
+
encoding = model_loader.tokenizer(
|
| 131 |
+
sent_text.lower().strip(),
|
| 132 |
+
add_special_tokens=True,
|
| 133 |
+
max_length=settings.MAX_LENGTH,
|
| 134 |
+
padding='max_length',
|
| 135 |
+
truncation=True,
|
| 136 |
+
return_tensors='pt'
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Compute gradients
|
| 140 |
+
gradient_scores, predicted_class, confidence = self.gradient_service.compute_integrated_gradients(
|
| 141 |
+
model=model_loader.model,
|
| 142 |
+
input_ids=encoding['input_ids'],
|
| 143 |
+
attention_mask=encoding['attention_mask'],
|
| 144 |
+
device=model_loader.device
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Get tokens
|
| 148 |
+
tokens = model_loader.tokenizer.convert_ids_to_tokens(
|
| 149 |
+
encoding['input_ids'][0].cpu().numpy()
|
| 150 |
+
)
|
| 151 |
+
valid_length = encoding['attention_mask'][0].sum().item()
|
| 152 |
+
tokens = tokens[:valid_length]
|
| 153 |
+
|
| 154 |
+
# Normalize gradients
|
| 155 |
+
gradient_scores_norm = self.gradient_service.normalize_scores(gradient_scores)
|
| 156 |
+
|
| 157 |
+
# Map to words
|
| 158 |
+
word_scores = self._map_tokens_to_words(tokens, gradient_scores_norm, words)
|
| 159 |
+
|
| 160 |
+
# Determine toxicity
|
| 161 |
+
is_toxic = (predicted_class == 1)
|
| 162 |
+
label = SentimentLabel.TOXIC if is_toxic else SentimentLabel.CLEAN
|
| 163 |
+
|
| 164 |
+
# Compute threshold
|
| 165 |
+
threshold = self.gradient_service.compute_threshold(word_scores, is_toxic)
|
| 166 |
+
|
| 167 |
+
# Build word scores
|
| 168 |
+
word_score_objects = None
|
| 169 |
+
if include_word_scores:
|
| 170 |
+
word_score_objects = []
|
| 171 |
+
for word_info, score in zip(words, word_scores):
|
| 172 |
+
word_score_objects.append(WordScore(
|
| 173 |
+
word=word_info['word'],
|
| 174 |
+
score=float(score),
|
| 175 |
+
position={'start': word_info['start'], 'end': word_info['end']},
|
| 176 |
+
is_toxic=score > threshold and not self.text_processor.is_stop_word(word_info['word']),
|
| 177 |
+
is_stop_word=self.text_processor.is_stop_word(word_info['word'])
|
| 178 |
+
))
|
| 179 |
+
|
| 180 |
+
return SentenceResult(
|
| 181 |
+
sentence_number=sent_number,
|
| 182 |
+
text=sent_text,
|
| 183 |
+
label=label,
|
| 184 |
+
confidence=float(confidence),
|
| 185 |
+
threshold=float(threshold),
|
| 186 |
+
word_count=len(words),
|
| 187 |
+
word_scores=word_score_objects
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def _map_tokens_to_words(
|
| 191 |
+
self,
|
| 192 |
+
tokens: List[str],
|
| 193 |
+
token_scores: np.ndarray,
|
| 194 |
+
original_words: List[Dict[str, any]]
|
| 195 |
+
) -> np.ndarray:
|
| 196 |
+
"""Map token scores to words"""
|
| 197 |
+
clean_tokens = []
|
| 198 |
+
clean_scores = []
|
| 199 |
+
|
| 200 |
+
for token, score in zip(tokens, token_scores):
|
| 201 |
+
if token not in ['<s>', '</s>', '<pad>', '<unk>']:
|
| 202 |
+
clean_token = token.replace('_', '').replace('@@', '').strip()
|
| 203 |
+
if clean_token and not self.text_processor.is_punctuation(clean_token):
|
| 204 |
+
clean_tokens.append(clean_token)
|
| 205 |
+
clean_scores.append(score)
|
| 206 |
+
|
| 207 |
+
word_scores = []
|
| 208 |
+
token_idx = 0
|
| 209 |
+
|
| 210 |
+
for word_info in original_words:
|
| 211 |
+
word = word_info['word'].lower()
|
| 212 |
+
|
| 213 |
+
matching_scores = []
|
| 214 |
+
temp_idx = token_idx
|
| 215 |
+
accumulated = ""
|
| 216 |
+
|
| 217 |
+
while temp_idx < len(clean_tokens):
|
| 218 |
+
accumulated += clean_tokens[temp_idx]
|
| 219 |
+
matching_scores.append(clean_scores[temp_idx])
|
| 220 |
+
|
| 221 |
+
if accumulated == word:
|
| 222 |
+
token_idx = temp_idx + 1
|
| 223 |
+
break
|
| 224 |
+
elif len(accumulated) >= len(word):
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
temp_idx += 1
|
| 228 |
+
|
| 229 |
+
word_scores.append(max(matching_scores) if matching_scores else 0.0)
|
| 230 |
+
|
| 231 |
+
return np.array(word_scores)
|
| 232 |
+
|
| 233 |
+
def _compute_statistics(self, sentence_results: List[SentenceResult]) -> Statistics:
|
| 234 |
+
"""Compute overall statistics"""
|
| 235 |
+
all_scores = []
|
| 236 |
+
toxic_words_count = 0
|
| 237 |
+
|
| 238 |
+
for sent_result in sentence_results:
|
| 239 |
+
if sent_result.word_scores:
|
| 240 |
+
for ws in sent_result.word_scores:
|
| 241 |
+
all_scores.append(ws.score)
|
| 242 |
+
if ws.is_toxic:
|
| 243 |
+
toxic_words_count += 1
|
| 244 |
+
|
| 245 |
+
if len(all_scores) == 0:
|
| 246 |
+
return Statistics(
|
| 247 |
+
total_words=0,
|
| 248 |
+
toxic_words=0,
|
| 249 |
+
mean_score=0.0,
|
| 250 |
+
median_score=0.0,
|
| 251 |
+
max_score=0.0,
|
| 252 |
+
min_score=0.0
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
all_scores = np.array(all_scores)
|
| 256 |
+
|
| 257 |
+
return Statistics(
|
| 258 |
+
total_words=len(all_scores),
|
| 259 |
+
toxic_words=toxic_words_count,
|
| 260 |
+
mean_score=float(np.mean(all_scores)),
|
| 261 |
+
median_score=float(np.median(all_scores)),
|
| 262 |
+
max_score=float(np.max(all_scores)),
|
| 263 |
+
min_score=float(np.min(all_scores))
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def _extract_toxic_words_summary(
|
| 267 |
+
self,
|
| 268 |
+
sentence_results: List[SentenceResult]
|
| 269 |
+
) -> List[ToxicWordSummary]:
|
| 270 |
+
"""Extract summary of toxic words"""
|
| 271 |
+
toxic_words_dict = defaultdict(lambda: {
|
| 272 |
+
'max_score': 0.0,
|
| 273 |
+
'occurrences': 0,
|
| 274 |
+
'sentences': []
|
| 275 |
+
})
|
| 276 |
+
|
| 277 |
+
for sent_result in sentence_results:
|
| 278 |
+
if sent_result.word_scores:
|
| 279 |
+
for ws in sent_result.word_scores:
|
| 280 |
+
if ws.is_toxic:
|
| 281 |
+
word = ws.word
|
| 282 |
+
toxic_words_dict[word]['max_score'] = max(
|
| 283 |
+
toxic_words_dict[word]['max_score'],
|
| 284 |
+
ws.score
|
| 285 |
+
)
|
| 286 |
+
toxic_words_dict[word]['occurrences'] += 1
|
| 287 |
+
if sent_result.sentence_number not in toxic_words_dict[word]['sentences']:
|
| 288 |
+
toxic_words_dict[word]['sentences'].append(sent_result.sentence_number)
|
| 289 |
+
|
| 290 |
+
# Convert to list and sort by score
|
| 291 |
+
toxic_words_summary = [
|
| 292 |
+
ToxicWordSummary(
|
| 293 |
+
word=word,
|
| 294 |
+
score=data['max_score'],
|
| 295 |
+
occurrences=data['occurrences'],
|
| 296 |
+
sentences=sorted(data['sentences'])
|
| 297 |
+
)
|
| 298 |
+
for word, data in toxic_words_dict.items()
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
toxic_words_summary.sort(key=lambda x: x.score, reverse=True)
|
| 302 |
+
|
| 303 |
+
return toxic_words_summary
|
| 304 |
+
|
| 305 |
+
def _convert_to_dict(self, sent_result: SentenceResult) -> Dict[str, any]:
|
| 306 |
+
"""Convert SentenceResult to dict for HTML generator"""
|
| 307 |
+
return {
|
| 308 |
+
'sent_start': sent_result.word_scores[0].position['start'] if sent_result.word_scores and len(sent_result.word_scores) > 0 else 0,
|
| 309 |
+
'sent_end': sent_result.word_scores[-1].position['end'] if sent_result.word_scores and len(sent_result.word_scores) > 0 else len(sent_result.text),
|
| 310 |
+
'is_toxic': sent_result.label == SentimentLabel.TOXIC,
|
| 311 |
+
'words': [{'word': ws.word, 'start': ws.position['start'], 'end': ws.position['end']} for ws in sent_result.word_scores] if sent_result.word_scores else [],
|
| 312 |
+
'scores': [ws.score for ws in sent_result.word_scores] if sent_result.word_scores else [],
|
| 313 |
+
'threshold': sent_result.threshold
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# Singleton instance
|
| 318 |
+
analysis_service = AnalysisService()
|
app/services/gradient_service.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradient Service
|
| 3 |
+
================
|
| 4 |
+
Gradient computation using Integrated Gradients (Single Responsibility)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
from app.models.phobert_model import PhoBERTFineTuned
|
| 13 |
+
from app.core.config import settings
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GradientService:
|
| 17 |
+
"""
|
| 18 |
+
Gradient computation service
|
| 19 |
+
|
| 20 |
+
Responsibilities:
|
| 21 |
+
- Compute integrated gradients
|
| 22 |
+
- Calculate importance scores
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def compute_integrated_gradients(
|
| 27 |
+
model: PhoBERTFineTuned,
|
| 28 |
+
input_ids: torch.Tensor,
|
| 29 |
+
attention_mask: torch.Tensor,
|
| 30 |
+
device: torch.device,
|
| 31 |
+
target_class: int | None = None,
|
| 32 |
+
steps: int | None = None
|
| 33 |
+
) -> Tuple[np.ndarray, int, float]:
|
| 34 |
+
"""
|
| 35 |
+
Compute integrated gradients
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model: Model to analyze
|
| 39 |
+
input_ids: Input token IDs
|
| 40 |
+
attention_mask: Attention mask
|
| 41 |
+
device: Device
|
| 42 |
+
target_class: Target class (optional)
|
| 43 |
+
steps: Number of integration steps
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
importance_scores: Token importance scores
|
| 47 |
+
predicted_class: Predicted class
|
| 48 |
+
confidence: Prediction confidence
|
| 49 |
+
"""
|
| 50 |
+
if steps is None:
|
| 51 |
+
steps = settings.GRADIENT_STEPS
|
| 52 |
+
|
| 53 |
+
model.eval()
|
| 54 |
+
|
| 55 |
+
input_ids = input_ids.to(device)
|
| 56 |
+
attention_mask = attention_mask.to(device)
|
| 57 |
+
|
| 58 |
+
# Get original embeddings
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
outputs = model.embedding(input_ids, attention_mask=attention_mask)
|
| 61 |
+
original_hidden = outputs.last_hidden_state
|
| 62 |
+
|
| 63 |
+
baseline_hidden = torch.zeros_like(original_hidden)
|
| 64 |
+
integrated_grads = torch.zeros_like(original_hidden)
|
| 65 |
+
|
| 66 |
+
# Integrate gradients
|
| 67 |
+
for step in range(steps):
|
| 68 |
+
alpha = (step + 1) / steps
|
| 69 |
+
interpolated = baseline_hidden + alpha * (original_hidden - baseline_hidden)
|
| 70 |
+
interpolated = interpolated.detach().clone()
|
| 71 |
+
interpolated.requires_grad = True
|
| 72 |
+
|
| 73 |
+
# Forward pass through classification head
|
| 74 |
+
if model.pooling == 'cls':
|
| 75 |
+
pooled = interpolated[:, 0, :]
|
| 76 |
+
else:
|
| 77 |
+
mask_expanded = attention_mask.unsqueeze(-1).expand(interpolated.size()).float()
|
| 78 |
+
sum_embeddings = torch.sum(interpolated * mask_expanded, 1)
|
| 79 |
+
sum_mask = mask_expanded.sum(1)
|
| 80 |
+
pooled = sum_embeddings / sum_mask
|
| 81 |
+
|
| 82 |
+
pooled = model.layer_norm(pooled)
|
| 83 |
+
out = model.dropout(pooled)
|
| 84 |
+
out = model.relu(model.fc1(out))
|
| 85 |
+
out = model.dropout(out)
|
| 86 |
+
logits = model.fc2(out)
|
| 87 |
+
|
| 88 |
+
# Get prediction on first step
|
| 89 |
+
if step == 0:
|
| 90 |
+
probs = F.softmax(logits, dim=1)
|
| 91 |
+
predicted_class = torch.argmax(probs, dim=1).item()
|
| 92 |
+
confidence = probs[0, predicted_class].item()
|
| 93 |
+
if target_class is None:
|
| 94 |
+
target_class = predicted_class
|
| 95 |
+
|
| 96 |
+
# Backward pass
|
| 97 |
+
model.zero_grad()
|
| 98 |
+
logits[0, target_class].backward()
|
| 99 |
+
integrated_grads += interpolated.grad
|
| 100 |
+
|
| 101 |
+
# Average and scale
|
| 102 |
+
integrated_grads = integrated_grads / steps
|
| 103 |
+
integrated_grads = integrated_grads * (original_hidden - baseline_hidden)
|
| 104 |
+
|
| 105 |
+
# Compute importance scores
|
| 106 |
+
importance_scores = torch.sum(torch.abs(integrated_grads), dim=-1)
|
| 107 |
+
importance_scores = importance_scores[0].cpu().detach().numpy()
|
| 108 |
+
|
| 109 |
+
valid_length = attention_mask[0].sum().item()
|
| 110 |
+
importance_scores = importance_scores[:valid_length]
|
| 111 |
+
|
| 112 |
+
return importance_scores, predicted_class, confidence
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def normalize_scores(scores: np.ndarray) -> np.ndarray:
|
| 116 |
+
"""
|
| 117 |
+
Normalize scores to [0, 1]
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
scores: Raw scores
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Normalized scores
|
| 124 |
+
"""
|
| 125 |
+
min_score = scores.min()
|
| 126 |
+
max_score = scores.max()
|
| 127 |
+
|
| 128 |
+
if max_score - min_score < 1e-8:
|
| 129 |
+
return np.ones_like(scores) * 0.5
|
| 130 |
+
|
| 131 |
+
return (scores - min_score) / (max_score - min_score)
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def compute_threshold(
|
| 135 |
+
scores: np.ndarray,
|
| 136 |
+
is_toxic: bool,
|
| 137 |
+
percentile: int | None = None
|
| 138 |
+
) -> float:
|
| 139 |
+
"""
|
| 140 |
+
Compute threshold for toxicity
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
scores: Word scores
|
| 144 |
+
is_toxic: Whether text is toxic
|
| 145 |
+
percentile: Percentile for threshold
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Threshold value
|
| 149 |
+
"""
|
| 150 |
+
if percentile is None:
|
| 151 |
+
percentile = settings.PERCENTILE_THRESHOLD
|
| 152 |
+
|
| 153 |
+
if len(scores) == 0:
|
| 154 |
+
return 0.6
|
| 155 |
+
|
| 156 |
+
mean_score = np.mean(scores)
|
| 157 |
+
percentile_score = np.percentile(scores, percentile)
|
| 158 |
+
threshold = 0.6 * percentile_score + 0.4 * mean_score
|
| 159 |
+
|
| 160 |
+
if is_toxic:
|
| 161 |
+
threshold = max(threshold, 0.55)
|
| 162 |
+
else:
|
| 163 |
+
threshold = max(threshold, 0.75)
|
| 164 |
+
|
| 165 |
+
threshold = np.clip(threshold, 0.45, 0.90)
|
| 166 |
+
|
| 167 |
+
return float(threshold)
|
app/services/html_generator.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HTML Generator
|
| 3 |
+
==============
|
| 4 |
+
Generate HTML highlighting (Single Responsibility)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
from app.services.text_processor import TextProcessor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class HTMLGenerator:
|
| 12 |
+
"""
|
| 13 |
+
HTML generation service
|
| 14 |
+
|
| 15 |
+
Responsibilities:
|
| 16 |
+
- Generate HTML with highlighting
|
| 17 |
+
- Format toxic/clean sentences differently
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
def generate_highlighted_html(
|
| 22 |
+
text: str,
|
| 23 |
+
sentence_results: List[Dict[str, any]]
|
| 24 |
+
) -> str:
|
| 25 |
+
"""
|
| 26 |
+
Generate HTML with highlighting
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
text: Original text
|
| 30 |
+
sentence_results: List of sentence analysis results
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
HTML string with highlighting
|
| 34 |
+
"""
|
| 35 |
+
html = '<div style="line-height: 2.2; font-size: 16px; font-family: Arial; max-width: 900px;">'
|
| 36 |
+
|
| 37 |
+
last_end = 0
|
| 38 |
+
|
| 39 |
+
for sent_data in sentence_results:
|
| 40 |
+
sent_start = sent_data['sent_start']
|
| 41 |
+
sent_end = sent_data['sent_end']
|
| 42 |
+
is_toxic = sent_data['is_toxic']
|
| 43 |
+
words = sent_data['words']
|
| 44 |
+
scores = sent_data['scores']
|
| 45 |
+
threshold = sent_data['threshold']
|
| 46 |
+
|
| 47 |
+
# Add space between sentences
|
| 48 |
+
if sent_start > last_end:
|
| 49 |
+
html += text[last_end:sent_start]
|
| 50 |
+
|
| 51 |
+
sent_text = text[sent_start:sent_end]
|
| 52 |
+
|
| 53 |
+
if is_toxic:
|
| 54 |
+
# Toxic sentence - highlight words
|
| 55 |
+
sent_html = HTMLGenerator._generate_toxic_sentence_html(
|
| 56 |
+
sent_text, sent_start, words, scores, threshold
|
| 57 |
+
)
|
| 58 |
+
html += f'<span style="border-left: 3px solid #ff6b6b; padding-left: 8px; display: inline-block; margin: 4px 0;">{sent_html}</span>'
|
| 59 |
+
else:
|
| 60 |
+
# Clean sentence - plain text
|
| 61 |
+
html += f'<span style="color: #444;">{sent_text}</span>'
|
| 62 |
+
|
| 63 |
+
last_end = sent_end
|
| 64 |
+
|
| 65 |
+
# Add remaining text
|
| 66 |
+
if last_end < len(text):
|
| 67 |
+
html += text[last_end:]
|
| 68 |
+
|
| 69 |
+
html += '</div>'
|
| 70 |
+
return html
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def _generate_toxic_sentence_html(
|
| 74 |
+
sent_text: str,
|
| 75 |
+
sent_start: int,
|
| 76 |
+
words: List[Dict[str, any]],
|
| 77 |
+
scores: List[float],
|
| 78 |
+
threshold: float
|
| 79 |
+
) -> str:
|
| 80 |
+
"""
|
| 81 |
+
Generate HTML for toxic sentence
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
sent_text: Sentence text
|
| 85 |
+
sent_start: Sentence start position in full text
|
| 86 |
+
words: List of words
|
| 87 |
+
scores: Word scores
|
| 88 |
+
threshold: Toxicity threshold
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
HTML string for sentence
|
| 92 |
+
"""
|
| 93 |
+
sent_html = ""
|
| 94 |
+
char_idx = 0
|
| 95 |
+
word_idx = 0
|
| 96 |
+
|
| 97 |
+
while char_idx < len(sent_text):
|
| 98 |
+
if word_idx < len(words):
|
| 99 |
+
word_info = words[word_idx]
|
| 100 |
+
word_start_rel = word_info['start'] - sent_start
|
| 101 |
+
word_end_rel = word_info['end'] - sent_start
|
| 102 |
+
|
| 103 |
+
if char_idx == word_start_rel:
|
| 104 |
+
word = word_info['word']
|
| 105 |
+
score = scores[word_idx]
|
| 106 |
+
|
| 107 |
+
if score > threshold and not TextProcessor.is_stop_word(word) and len(word) > 1:
|
| 108 |
+
# Toxic word - red background
|
| 109 |
+
color = int(255 * (1 - score))
|
| 110 |
+
sent_html += (
|
| 111 |
+
f'<span style="background-color: rgb(255, {color}, {color}); '
|
| 112 |
+
f'padding: 2px 4px; margin: 0 1px; border-radius: 3px; '
|
| 113 |
+
f'font-weight: bold;">{word}</span>'
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
# Non-toxic word
|
| 117 |
+
if TextProcessor.is_stop_word(word):
|
| 118 |
+
sent_html += f'<span style="color: #aaa; font-style: italic;">{word}</span>'
|
| 119 |
+
else:
|
| 120 |
+
sent_html += f'<span style="color: #333;">{word}</span>'
|
| 121 |
+
|
| 122 |
+
char_idx = word_end_rel
|
| 123 |
+
word_idx += 1
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
# Not at word - add character (punctuation, space, etc)
|
| 127 |
+
sent_html += sent_text[char_idx]
|
| 128 |
+
char_idx += 1
|
| 129 |
+
|
| 130 |
+
return sent_html
|
app/services/text_processor.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text Processor
|
| 3 |
+
==============
|
| 4 |
+
Text processing utilities (Single Responsibility)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
from typing import List, Dict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TextProcessor:
|
| 12 |
+
"""
|
| 13 |
+
Text processing service
|
| 14 |
+
|
| 15 |
+
Responsibilities:
|
| 16 |
+
- Split text into sentences
|
| 17 |
+
- Extract words from text
|
| 18 |
+
- Identify stop words
|
| 19 |
+
- Identify punctuation
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
STOP_WORDS = {
|
| 23 |
+
'này', 'kia', 'đó', 'ấy', 'nọ', 'đây', 'nào',
|
| 24 |
+
'các', 'những', 'mọi', 'cả',
|
| 25 |
+
'tôi', 'ta', 'mình', 'bạn', 'anh', 'chị', 'em',
|
| 26 |
+
'nó', 'họ', 'chúng', 'ai', 'gì',
|
| 27 |
+
'và', 'hoặc', 'nhưng', 'mà', 'nên', 'vì', 'nếu', 'thì', 'hay',
|
| 28 |
+
'rồi', 'còn', 'cũng', 'luôn', 'đều',
|
| 29 |
+
'thế', 'như',
|
| 30 |
+
'của', 'cho', 'với', 'từ', 'bởi', 'về', 'trong', 'ngoài',
|
| 31 |
+
'là', 'có', 'được', 'bị', 'ở', 'đang', 'sẽ', 'đã',
|
| 32 |
+
'thể', 'phải', 'nên', 'muốn', 'cần', 'biết',
|
| 33 |
+
'rất', 'quá', 'khá', 'hơi', 'vẫn', 'còn',
|
| 34 |
+
'chỉ', 'vừa', 'mới',
|
| 35 |
+
'đâu', 'sao',
|
| 36 |
+
'không', 'chẳng', 'chưa',
|
| 37 |
+
'nhiều', 'ít', 'vài', 'một',
|
| 38 |
+
'việc', 'chuyện', 'điều', 'lúc', 'khi',
|
| 39 |
+
'ra', 'vào', 'nhau', 'nhữ',
|
| 40 |
+
'vậy', 'ạ', 'nhé',
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
PUNCTUATION = set('.,!?;:()[]{}"\'-/\\@#$%^&*+=<>~`|')
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def split_into_sentences(text: str) -> List[Dict[str, any]]:
|
| 47 |
+
"""
|
| 48 |
+
Split text into sentences
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
text: Input text
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List of sentences with positions
|
| 55 |
+
"""
|
| 56 |
+
sentence_pattern = r'([.!?]+)\s*'
|
| 57 |
+
parts = re.split(sentence_pattern, text)
|
| 58 |
+
|
| 59 |
+
sentences = []
|
| 60 |
+
current_pos = 0
|
| 61 |
+
i = 0
|
| 62 |
+
|
| 63 |
+
while i < len(parts):
|
| 64 |
+
if not parts[i].strip():
|
| 65 |
+
current_pos += len(parts[i])
|
| 66 |
+
i += 1
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
if not re.match(r'^[.!?]+$', parts[i]):
|
| 70 |
+
sentence_text = parts[i]
|
| 71 |
+
|
| 72 |
+
if i + 1 < len(parts) and re.match(r'^[.!?]+$', parts[i + 1]):
|
| 73 |
+
sentence_text += parts[i + 1]
|
| 74 |
+
i += 2
|
| 75 |
+
else:
|
| 76 |
+
i += 1
|
| 77 |
+
|
| 78 |
+
if sentence_text.strip():
|
| 79 |
+
sentences.append({
|
| 80 |
+
'text': sentence_text,
|
| 81 |
+
'start': current_pos,
|
| 82 |
+
'end': current_pos + len(sentence_text)
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
current_pos += len(sentence_text)
|
| 86 |
+
else:
|
| 87 |
+
current_pos += len(parts[i])
|
| 88 |
+
i += 1
|
| 89 |
+
|
| 90 |
+
if len(sentences) == 0:
|
| 91 |
+
sentences.append({'text': text, 'start': 0, 'end': len(text)})
|
| 92 |
+
|
| 93 |
+
return sentences
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def extract_words(text: str) -> List[Dict[str, any]]:
|
| 97 |
+
"""
|
| 98 |
+
Extract words from text
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
text: Input text
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
List of words with positions
|
| 105 |
+
"""
|
| 106 |
+
pattern = r'[a-zA-Zàáảãạăắằẳẵặâấầẩẫậèéẻẽẹêếềểễệìíỉĩịòóỏõọôốồổỗộơớờởỡợùúủũụưứừửữựỳýỷỹỵđ_]+'
|
| 107 |
+
|
| 108 |
+
words = []
|
| 109 |
+
for match in re.finditer(pattern, text, re.IGNORECASE):
|
| 110 |
+
words.append({
|
| 111 |
+
'word': match.group(),
|
| 112 |
+
'start': match.start(),
|
| 113 |
+
'end': match.end()
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
return words
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def is_stop_word(cls, word: str) -> bool:
|
| 120 |
+
"""
|
| 121 |
+
Check if word is a stop word
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
word: Word to check
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
True if stop word
|
| 128 |
+
"""
|
| 129 |
+
return word.lower().strip() in cls.STOP_WORDS
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def is_punctuation(cls, token: str) -> bool:
|
| 133 |
+
"""
|
| 134 |
+
Check if token is punctuation
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
token: Token to check
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
True if punctuation
|
| 141 |
+
"""
|
| 142 |
+
return not token or all(c in cls.PUNCTUATION for c in token)
|
models/PhoBERTFineTuned_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fb4d10f4754fe5c7d45d992ea7c0461f5e4e9fffc6a66fb96ed66ccddb90618
|
| 3 |
+
size 540876678
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn[standard]==0.24.0
|
| 3 |
+
pydantic==2.5.0
|
| 4 |
+
pydantic-settings==2.1.0
|
| 5 |
+
python-multipart==0.0.6
|
| 6 |
+
torch==2.1.0
|
| 7 |
+
transformers==4.35.0
|
| 8 |
+
numpy==1.24.3
|
| 9 |
+
python-dotenv==1.0.0
|