handrix commited on
Commit
ae4e2a6
·
1 Parent(s): bc7497c

Initial deployment - Toxic Detection API

Browse files
.dockerignore ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ ENV/
26
+ env/
27
+ .venv
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # OS
37
+ .DS_Store
38
+ Thumbs.db
39
+
40
+ # Git
41
+ .git/
42
+ .gitignore
43
+
44
+ # Tests
45
+ tests/
46
+ .pytest_cache/
47
+ .coverage
48
+ htmlcov/
49
+
50
+ # Documentation
51
+ *.md
52
+ !README.md
53
+
54
+ # Environment
55
+ .env
56
+ .env.local
57
+
58
+ # Logs
59
+ *.log
60
+
61
+ # Other
62
+ *.tar.gz
63
+ *.zip
.env.example ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Configuration
2
+ MODEL_NAME=vinai/phobert-base
3
+ MODEL_PATH=./models/PhoBERTFineTuned_best.pth
4
+ MAX_LENGTH=128
5
+ DEVICE=cuda
6
+
7
+ # API Configuration
8
+ API_HOST=0.0.0.0
9
+ API_PORT=8000
10
+ API_RELOAD=True
11
+
12
+ # CORS
13
+ ALLOWED_ORIGINS=*
14
+
15
+ # Logging
16
+ LOG_LEVEL=INFO
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/*.pth filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage build for smaller image size
2
+ FROM python:3.10-slim as builder
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install build dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir --user -r requirements.txt
17
+
18
+ # Final stage
19
+ FROM python:3.10-slim
20
+
21
+ # Set working directory
22
+ WORKDIR /app
23
+
24
+ # Copy Python packages from builder
25
+ COPY --from=builder /root/.local /root/.local
26
+
27
+ # Make sure scripts in .local are usable
28
+ ENV PATH=/root/.local/bin:$PATH
29
+
30
+ # Copy application code
31
+ COPY ./app ./app
32
+
33
+ # Create models directory
34
+ RUN mkdir -p ./models
35
+
36
+ # Set environment variables
37
+ ENV PYTHONUNBUFFERED=1
38
+ ENV PYTHONDONTWRITEBYTECODE=1
39
+ ENV MODEL_PATH=/app/models/PhoBERTFineTuned_best.pth
40
+
41
+ # Expose port
42
+ EXPOSE 7860
43
+
44
+ # Health check
45
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
46
+ CMD python -c "import requests; requests.get('http://localhost:7860/api/v1/health')"
47
+
48
+ # Run the application
49
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app/__init__.py ADDED
File without changes
app/api/__init__.py ADDED
File without changes
app/api/routes.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Routes
3
+ ==========
4
+ FastAPI routes (Interface Segregation)
5
+ """
6
+
7
+ from fastapi import APIRouter, Depends, HTTPException, status
8
+ from typing import Dict
9
+
10
+ from app.schemas.requests import AnalysisRequest
11
+ from app.schemas.responses import AnalysisResponse, HealthResponse, ErrorResponse
12
+ from app.services.analysis_service import analysis_service
13
+ from app.models.model_loader import model_loader
14
+ from app.core.config import settings
15
+ from app.core.exceptions import ModelNotLoadedException, AnalysisException
16
+
17
+
18
+ router = APIRouter()
19
+
20
+
21
+ @router.get(
22
+ "/",
23
+ response_model=Dict[str, str],
24
+ summary="Root endpoint",
25
+ tags=["General"]
26
+ )
27
+ async def root():
28
+ """Root endpoint - API information"""
29
+ return {
30
+ "message": "Toxic Text Detection API",
31
+ "version": settings.API_VERSION,
32
+ "docs": "/docs",
33
+ "health": "/api/v1/health"
34
+ }
35
+
36
+
37
+ @router.get(
38
+ "/health",
39
+ response_model=HealthResponse,
40
+ summary="Health check",
41
+ tags=["General"]
42
+ )
43
+ async def health_check():
44
+ """
45
+ Health check endpoint
46
+
47
+ Returns service status and model information
48
+ """
49
+ return HealthResponse(
50
+ status="healthy" if model_loader.is_loaded() else "unhealthy",
51
+ model_loaded=model_loader.is_loaded(),
52
+ device=str(model_loader.device) if model_loader.is_loaded() else "unknown",
53
+ model_name=settings.MODEL_NAME,
54
+ version=settings.API_VERSION
55
+ )
56
+
57
+
58
+ @router.post(
59
+ "/analyze",
60
+ response_model=AnalysisResponse,
61
+ responses={
62
+ 200: {"description": "Analysis successful"},
63
+ 400: {"model": ErrorResponse, "description": "Invalid input"},
64
+ 500: {"model": ErrorResponse, "description": "Analysis failed"},
65
+ 503: {"model": ErrorResponse, "description": "Model not loaded"}
66
+ },
67
+ summary="Analyze text for toxicity",
68
+ tags=["Analysis"]
69
+ )
70
+ async def analyze_text(request: AnalysisRequest):
71
+ """
72
+ Analyze text for toxic content
73
+
74
+ This endpoint analyzes Vietnamese text to detect toxic content using
75
+ a fine-tuned PhoBERT model with gradient-based explainability.
76
+
77
+ **Features:**
78
+ - Sentence-level toxicity detection
79
+ - Word-level importance scores
80
+ - HTML highlighting of toxic content
81
+ - Detailed statistics
82
+
83
+ **Parameters:**
84
+ - **text**: Text to analyze (required)
85
+ - **include_html**: Include HTML highlighting (default: true)
86
+ - **include_word_scores**: Include word-level scores (default: true)
87
+ - **include_summary_table**: Include summary table (default: false)
88
+
89
+ **Returns:**
90
+ - Overall toxicity label (toxic/clean)
91
+ - Sentence-level analysis
92
+ - Word-level scores and toxic words summary
93
+ - HTML with highlighted toxic content
94
+ - Statistical information
95
+ """
96
+ # Check if model is loaded
97
+ if not model_loader.is_loaded():
98
+ raise ModelNotLoadedException()
99
+
100
+ # Perform analysis
101
+ try:
102
+ result = analysis_service.analyze(request)
103
+ return result
104
+ except AnalysisException as e:
105
+ raise HTTPException(
106
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
107
+ detail=str(e)
108
+ )
109
+ except Exception as e:
110
+ raise HTTPException(
111
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
112
+ detail=f"Unexpected error: {str(e)}"
113
+ )
app/core/__init__.py ADDED
File without changes
app/core/config.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core Configuration
3
+ ==================
4
+ Application settings using Pydantic Settings
5
+ """
6
+
7
+ from pydantic_settings import BaseSettings
8
+ from typing import List
9
+ import torch
10
+
11
+
12
+ class Settings(BaseSettings):
13
+ """Application settings"""
14
+
15
+ # Model Configuration
16
+ MODEL_NAME: str = "vinai/phobert-base"
17
+ MODEL_PATH: str = "./models/PhoBERTFineTuned_best.pth"
18
+ MAX_LENGTH: int = 128
19
+ DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # API Configuration
22
+ API_TITLE: str = "Toxic Text Detection API"
23
+ API_VERSION: str = "1.0.0"
24
+ API_DESCRIPTION: str = "Vietnamese toxic text detection with gradient-based explainability"
25
+ API_HOST: str = "0.0.0.0"
26
+ API_PORT: int = 8000
27
+ API_RELOAD: bool = True
28
+
29
+ # CORS
30
+ ALLOWED_ORIGINS: List[str] = ["*"]
31
+
32
+ # Analysis Settings
33
+ GRADIENT_STEPS: int = 20
34
+ PERCENTILE_THRESHOLD: int = 75
35
+ MIN_WORD_LENGTH: int = 2
36
+
37
+ # Logging
38
+ LOG_LEVEL: str = "INFO"
39
+
40
+ class Config:
41
+ env_file = ".env"
42
+ case_sensitive = True
43
+
44
+
45
+ # Singleton instance
46
+ settings = Settings()
app/core/exceptions.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Exceptions
3
+ =================
4
+ Application-specific exceptions
5
+ """
6
+
7
+ from fastapi import HTTPException, status
8
+
9
+
10
+ class ToxicDetectionException(HTTPException):
11
+ """Base exception for toxic detection"""
12
+
13
+ def __init__(self, detail: str, status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR):
14
+ super().__init__(status_code=status_code, detail=detail)
15
+
16
+
17
+ class ModelNotLoadedException(ToxicDetectionException):
18
+ """Raised when model is not loaded"""
19
+
20
+ def __init__(self):
21
+ super().__init__(
22
+ detail="Model not loaded. Please check server logs.",
23
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE
24
+ )
25
+
26
+
27
+ class InvalidTextException(ToxicDetectionException):
28
+ """Raised when input text is invalid"""
29
+
30
+ def __init__(self, detail: str = "Invalid text input"):
31
+ super().__init__(
32
+ detail=detail,
33
+ status_code=status.HTTP_400_BAD_REQUEST
34
+ )
35
+
36
+
37
+ class AnalysisException(ToxicDetectionException):
38
+ """Raised when analysis fails"""
39
+
40
+ def __init__(self, detail: str = "Analysis failed"):
41
+ super().__init__(
42
+ detail=detail,
43
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
44
+ )
app/main.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main FastAPI Application
3
+ =========================
4
+ Application entry point
5
+ """
6
+
7
+ import logging
8
+ from contextlib import asynccontextmanager
9
+ from fastapi import FastAPI
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import JSONResponse
12
+
13
+ from app.core.config import settings
14
+ from app.core.exceptions import ToxicDetectionException
15
+ from app.models.model_loader import model_loader
16
+ from app.api.routes import router
17
+
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=getattr(logging, settings.LOG_LEVEL),
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @asynccontextmanager
28
+ async def lifespan(app: FastAPI):
29
+ """
30
+ Lifespan events
31
+
32
+ Startup: Load model
33
+ Shutdown: Cleanup
34
+ """
35
+ # Startup
36
+ logger.info("Starting up...")
37
+ try:
38
+ logger.info("Loading model...")
39
+ model_loader.load()
40
+ logger.info("Model loaded successfully")
41
+ except Exception as e:
42
+ logger.error(f"Failed to load model: {str(e)}")
43
+ # Continue anyway - health endpoint will show model not loaded
44
+
45
+ yield
46
+
47
+ # Shutdown
48
+ logger.info("Shutting down...")
49
+
50
+
51
+ # Create FastAPI app
52
+ app = FastAPI(
53
+ title=settings.API_TITLE,
54
+ description=settings.API_DESCRIPTION,
55
+ version=settings.API_VERSION,
56
+ lifespan=lifespan,
57
+ docs_url="/docs",
58
+ redoc_url="/redoc",
59
+ openapi_url="/openapi.json"
60
+ )
61
+
62
+ # CORS middleware
63
+ app.add_middleware(
64
+ CORSMiddleware,
65
+ allow_origins=settings.ALLOWED_ORIGINS,
66
+ allow_credentials=True,
67
+ allow_methods=["*"],
68
+ allow_headers=["*"],
69
+ )
70
+
71
+
72
+ # Exception handlers
73
+ @app.exception_handler(ToxicDetectionException)
74
+ async def toxic_detection_exception_handler(request, exc: ToxicDetectionException):
75
+ """Handle custom exceptions"""
76
+ return JSONResponse(
77
+ status_code=exc.status_code,
78
+ content={
79
+ "success": False,
80
+ "error": exc.detail,
81
+ "detail": None
82
+ }
83
+ )
84
+
85
+
86
+ @app.exception_handler(Exception)
87
+ async def general_exception_handler(request, exc: Exception):
88
+ """Handle general exceptions"""
89
+ logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
90
+ return JSONResponse(
91
+ status_code=500,
92
+ content={
93
+ "success": False,
94
+ "error": "Internal server error",
95
+ "detail": str(exc) if settings.LOG_LEVEL == "DEBUG" else None
96
+ }
97
+ )
98
+
99
+
100
+ # Include routers
101
+ app.include_router(router, prefix="/api/v1", tags=["v1"])
102
+ app.include_router(router, prefix="", tags=["root"])
103
+
104
+
105
+ # For direct run
106
+ if __name__ == "__main__":
107
+ import uvicorn
108
+
109
+ uvicorn.run(
110
+ "app.main:app",
111
+ host=settings.API_HOST,
112
+ port=settings.API_PORT,
113
+ reload=settings.API_RELOAD,
114
+ log_level=settings.LOG_LEVEL.lower()
115
+ )
app/models/__init__.py ADDED
File without changes
app/models/model_loader.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Loader
3
+ ============
4
+ Responsible for loading and initializing models (Single Responsibility)
5
+ """
6
+
7
+ import torch
8
+ from transformers import AutoModel, AutoTokenizer
9
+ from typing import Tuple
10
+ import logging
11
+
12
+ from app.models.phobert_model import PhoBERTFineTuned
13
+ from app.core.config import settings
14
+ from app.core.exceptions import ModelNotLoadedException
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ModelLoader:
21
+ """
22
+ Model loader service
23
+
24
+ Responsibilities:
25
+ - Load tokenizer
26
+ - Load base model
27
+ - Load fine-tuned weights
28
+ - Initialize model on correct device
29
+ """
30
+
31
+ def __init__(self):
32
+ self._model: PhoBERTFineTuned | None = None
33
+ self._tokenizer: AutoTokenizer | None = None
34
+ self._device: torch.device | None = None
35
+
36
+ def load(self) -> Tuple[PhoBERTFineTuned, AutoTokenizer, torch.device]:
37
+ """
38
+ Load model, tokenizer, and set device
39
+
40
+ Returns:
41
+ model: Loaded model
42
+ tokenizer: Loaded tokenizer
43
+ device: Device (CPU/CUDA)
44
+
45
+ Raises:
46
+ ModelNotLoadedException: If loading fails
47
+ """
48
+ try:
49
+ # Set device
50
+ self._device = torch.device(settings.DEVICE)
51
+ logger.info(f"Using device: {self._device}")
52
+
53
+ # Load tokenizer
54
+ logger.info(f"Loading tokenizer: {settings.MODEL_NAME}")
55
+ self._tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
56
+
57
+ # Load base model
58
+ logger.info(f"Loading base model: {settings.MODEL_NAME}")
59
+ phobert = AutoModel.from_pretrained(settings.MODEL_NAME)
60
+
61
+ # Initialize fine-tuned model
62
+ logger.info("Initializing fine-tuned model")
63
+ self._model = PhoBERTFineTuned(
64
+ embedding_model=phobert,
65
+ hidden_dim=768,
66
+ dropout=0.3,
67
+ num_classes=2,
68
+ num_layers_to_finetune=4,
69
+ pooling='mean'
70
+ )
71
+
72
+ # Load weights
73
+ logger.info(f"Loading weights from: {settings.MODEL_PATH}")
74
+ state_dict = torch.load(
75
+ settings.MODEL_PATH,
76
+ map_location=self._device
77
+ )
78
+ self._model.load_state_dict(state_dict)
79
+
80
+ # Move to device and set eval mode
81
+ self._model = self._model.to(self._device)
82
+ self._model.eval()
83
+
84
+ logger.info("Model loaded successfully")
85
+
86
+ return self._model, self._tokenizer, self._device
87
+
88
+ except Exception as e:
89
+ logger.error(f"Failed to load model: {str(e)}")
90
+ raise ModelNotLoadedException()
91
+
92
+ @property
93
+ def model(self) -> PhoBERTFineTuned:
94
+ """Get loaded model"""
95
+ if self._model is None:
96
+ raise ModelNotLoadedException()
97
+ return self._model
98
+
99
+ @property
100
+ def tokenizer(self) -> AutoTokenizer:
101
+ """Get loaded tokenizer"""
102
+ if self._tokenizer is None:
103
+ raise ModelNotLoadedException()
104
+ return self._tokenizer
105
+
106
+ @property
107
+ def device(self) -> torch.device:
108
+ """Get device"""
109
+ if self._device is None:
110
+ raise ModelNotLoadedException()
111
+ return self._device
112
+
113
+ def is_loaded(self) -> bool:
114
+ """Check if model is loaded"""
115
+ return all([
116
+ self._model is not None,
117
+ self._tokenizer is not None,
118
+ self._device is not None
119
+ ])
120
+
121
+
122
+ # Singleton instance
123
+ model_loader = ModelLoader()
app/models/phobert_model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PhoBERT Model
3
+ =============
4
+ Model architecture definition (Single Responsibility)
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Tuple, Optional
10
+
11
+
12
+ class PhoBERTFineTuned(nn.Module):
13
+ """
14
+ Fine-tuned PhoBERT model for toxic text classification
15
+
16
+ Responsibilities:
17
+ - Define model architecture
18
+ - Forward pass computation
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ embedding_model: nn.Module,
24
+ hidden_dim: int = 768,
25
+ dropout: float = 0.3,
26
+ num_classes: int = 2,
27
+ num_layers_to_finetune: int = 4,
28
+ pooling: str = 'mean'
29
+ ):
30
+ super(PhoBERTFineTuned, self).__init__()
31
+
32
+ self.embedding = embedding_model
33
+ self.pooling = pooling
34
+ self.num_layers_to_finetune = num_layers_to_finetune
35
+
36
+ # Freeze all parameters
37
+ for param in self.embedding.parameters():
38
+ param.requires_grad = False
39
+
40
+ # Unfreeze last N layers
41
+ if num_layers_to_finetune > 0:
42
+ total_layers = len(self.embedding.encoder.layer)
43
+ layers_to_train = list(range(
44
+ total_layers - num_layers_to_finetune,
45
+ total_layers
46
+ ))
47
+
48
+ for layer_idx in layers_to_train:
49
+ for param in self.embedding.encoder.layer[layer_idx].parameters():
50
+ param.requires_grad = True
51
+
52
+ if hasattr(self.embedding, 'pooler') and self.embedding.pooler is not None:
53
+ for param in self.embedding.pooler.parameters():
54
+ param.requires_grad = True
55
+
56
+ # Classification head
57
+ self.dropout = nn.Dropout(dropout)
58
+ self.fc1 = nn.Linear(hidden_dim, 256)
59
+ self.fc2 = nn.Linear(256, num_classes)
60
+ self.relu = nn.ReLU()
61
+ self.layer_norm = nn.LayerNorm(hidden_dim)
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.Tensor,
66
+ attention_mask: torch.Tensor,
67
+ return_embeddings: bool = False
68
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
69
+ """
70
+ Forward pass
71
+
72
+ Args:
73
+ input_ids: Input token IDs
74
+ attention_mask: Attention mask
75
+ return_embeddings: Whether to return embeddings
76
+
77
+ Returns:
78
+ logits: Classification logits
79
+ embeddings: Hidden states (if return_embeddings=True)
80
+ """
81
+ # Get embeddings
82
+ outputs = self.embedding(input_ids, attention_mask=attention_mask)
83
+ embeddings = outputs.last_hidden_state
84
+
85
+ # Pooling
86
+ if self.pooling == 'cls':
87
+ pooled = embeddings[:, 0, :]
88
+ elif self.pooling == 'mean':
89
+ mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
90
+ sum_embeddings = torch.sum(embeddings * mask_expanded, 1)
91
+ sum_mask = mask_expanded.sum(1)
92
+ pooled = sum_embeddings / sum_mask
93
+ else:
94
+ raise ValueError(f"Unknown pooling method: {self.pooling}")
95
+
96
+ # Classification
97
+ pooled = self.layer_norm(pooled)
98
+ out = self.dropout(pooled)
99
+ out = self.relu(self.fc1(out))
100
+ out = self.dropout(out)
101
+ logits = self.fc2(out)
102
+
103
+ if return_embeddings:
104
+ return logits, embeddings
105
+ return logits, None
app/schemas/__init__.py ADDED
File without changes
app/schemas/requests.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Request Schemas
3
+ ===============
4
+ DTOs for API requests
5
+ """
6
+
7
+ from pydantic import BaseModel, Field, field_validator
8
+
9
+
10
+ class AnalysisRequest(BaseModel):
11
+ """Request for text analysis"""
12
+
13
+ text: str = Field(
14
+ ...,
15
+ description="Text to analyze for toxicity",
16
+ min_length=1,
17
+ max_length=5000,
18
+ examples=["Đồ ngu ngốc, mất dạy!"]
19
+ )
20
+
21
+ include_html: bool = Field(
22
+ default=True,
23
+ description="Include HTML highlighting in response"
24
+ )
25
+
26
+ include_word_scores: bool = Field(
27
+ default=True,
28
+ description="Include detailed word-level scores"
29
+ )
30
+
31
+ include_summary_table: bool = Field(
32
+ default=False,
33
+ description="Include summary table of all words"
34
+ )
35
+
36
+ @field_validator('text')
37
+ @classmethod
38
+ def validate_text(cls, v: str) -> str:
39
+ """Validate text input"""
40
+ if not v or not v.strip():
41
+ raise ValueError("Text cannot be empty or only whitespace")
42
+ return v.strip()
43
+
44
+ class Config:
45
+ json_schema_extra = {
46
+ "example": {
47
+ "text": "Đồ ngu ngốc, mất dạy! Cảm ơn bạn đã chia sẻ.",
48
+ "include_html": True,
49
+ "include_word_scores": True,
50
+ "include_summary_table": False
51
+ }
52
+ }
app/schemas/responses.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Response Schemas
3
+ ================
4
+ DTOs for API responses
5
+ """
6
+
7
+ from pydantic import BaseModel, Field
8
+ from typing import List, Optional, Dict
9
+ from enum import Enum
10
+
11
+
12
+ class SentimentLabel(str, Enum):
13
+ """Sentiment labels"""
14
+ TOXIC = "toxic"
15
+ CLEAN = "clean"
16
+
17
+
18
+ class WordScore(BaseModel):
19
+ """Word-level score information"""
20
+
21
+ word: str = Field(..., description="The word")
22
+ score: float = Field(..., ge=0.0, le=1.0, description="Toxicity score (0-1)")
23
+ position: Dict[str, int] = Field(..., description="Position in text {start, end}")
24
+ is_toxic: bool = Field(..., description="Whether word is toxic")
25
+ is_stop_word: bool = Field(..., description="Whether word is a stop word")
26
+
27
+
28
+ class SentenceResult(BaseModel):
29
+ """Sentence-level analysis result"""
30
+
31
+ sentence_number: int = Field(..., description="Sentence index (1-based)")
32
+ text: str = Field(..., description="Sentence text")
33
+ label: SentimentLabel = Field(..., description="Toxic or clean")
34
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
35
+ threshold: float = Field(..., ge=0.0, le=1.0, description="Threshold used")
36
+ word_count: int = Field(..., description="Number of words")
37
+ word_scores: Optional[List[WordScore]] = Field(None, description="Word-level scores")
38
+
39
+
40
+ class ToxicWordSummary(BaseModel):
41
+ """Summary of toxic words"""
42
+
43
+ word: str = Field(..., description="Toxic word")
44
+ score: float = Field(..., ge=0.0, le=1.0, description="Maximum score")
45
+ occurrences: int = Field(..., description="Number of occurrences")
46
+ sentences: List[int] = Field(..., description="Sentence numbers containing this word")
47
+
48
+
49
+ class Statistics(BaseModel):
50
+ """Overall statistics"""
51
+
52
+ total_words: int = Field(..., description="Total number of words")
53
+ toxic_words: int = Field(..., description="Number of toxic words")
54
+ mean_score: float = Field(..., ge=0.0, le=1.0, description="Mean toxicity score")
55
+ median_score: float = Field(..., ge=0.0, le=1.0, description="Median toxicity score")
56
+ max_score: float = Field(..., ge=0.0, le=1.0, description="Maximum toxicity score")
57
+ min_score: float = Field(..., ge=0.0, le=1.0, description="Minimum toxicity score")
58
+
59
+
60
+ class AnalysisResponse(BaseModel):
61
+ """Complete analysis response"""
62
+
63
+ success: bool = Field(True, description="Whether analysis succeeded")
64
+ text: str = Field(..., description="Original input text")
65
+ overall_label: SentimentLabel = Field(..., description="Overall text sentiment")
66
+ toxic_sentence_count: int = Field(..., description="Number of toxic sentences")
67
+ clean_sentence_count: int = Field(..., description="Number of clean sentences")
68
+ total_sentences: int = Field(..., description="Total number of sentences")
69
+ sentences: List[SentenceResult] = Field(..., description="Sentence-level results")
70
+ toxic_words_summary: List[ToxicWordSummary] = Field(..., description="Summary of toxic words")
71
+ statistics: Statistics = Field(..., description="Overall statistics")
72
+ html_highlighted: Optional[str] = Field(None, description="HTML with highlighting")
73
+
74
+ class Config:
75
+ json_schema_extra = {
76
+ "example": {
77
+ "success": True,
78
+ "text": "Đồ ngu ngốc!",
79
+ "overall_label": "toxic",
80
+ "toxic_sentence_count": 1,
81
+ "clean_sentence_count": 0,
82
+ "total_sentences": 1,
83
+ "sentences": [
84
+ {
85
+ "sentence_number": 1,
86
+ "text": "Đồ ngu ngốc!",
87
+ "label": "toxic",
88
+ "confidence": 0.998,
89
+ "threshold": 0.62,
90
+ "word_count": 3,
91
+ "word_scores": [
92
+ {
93
+ "word": "Đồ",
94
+ "score": 0.902,
95
+ "position": {"start": 0, "end": 2},
96
+ "is_toxic": True,
97
+ "is_stop_word": False
98
+ }
99
+ ]
100
+ }
101
+ ],
102
+ "toxic_words_summary": [
103
+ {
104
+ "word": "ngu",
105
+ "score": 0.924,
106
+ "occurrences": 1,
107
+ "sentences": [1]
108
+ }
109
+ ],
110
+ "statistics": {
111
+ "total_words": 3,
112
+ "toxic_words": 3,
113
+ "mean_score": 0.856,
114
+ "median_score": 0.865,
115
+ "max_score": 0.924,
116
+ "min_score": 0.756
117
+ },
118
+ "html_highlighted": "<div>...</div>"
119
+ }
120
+ }
121
+
122
+
123
+ class HealthResponse(BaseModel):
124
+ """Health check response"""
125
+
126
+ status: str = Field(..., description="Service status")
127
+ model_loaded: bool = Field(..., description="Whether model is loaded")
128
+ device: str = Field(..., description="Device being used (cpu/cuda)")
129
+ model_name: str = Field(..., description="Model name")
130
+ version: str = Field(..., description="API version")
131
+
132
+
133
+ class ErrorResponse(BaseModel):
134
+ """Error response"""
135
+
136
+ success: bool = Field(False, description="Always false for errors")
137
+ error: str = Field(..., description="Error message")
138
+ detail: Optional[str] = Field(None, description="Detailed error information")
app/services/__init__.py ADDED
File without changes
app/services/analysis_service.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analysis Service
3
+ ================
4
+ Main analysis orchestrator (Dependency Inversion + Open/Closed)
5
+ """
6
+
7
+ import numpy as np
8
+ from typing import List, Dict
9
+ from collections import defaultdict
10
+
11
+ from app.models.model_loader import model_loader
12
+ from app.services.text_processor import TextProcessor
13
+ from app.services.gradient_service import GradientService
14
+ from app.services.html_generator import HTMLGenerator
15
+ from app.schemas.requests import AnalysisRequest
16
+ from app.schemas.responses import (
17
+ AnalysisResponse, SentenceResult, WordScore,
18
+ ToxicWordSummary, Statistics, SentimentLabel
19
+ )
20
+ from app.core.config import settings
21
+ from app.core.exceptions import AnalysisException
22
+
23
+
24
+ class AnalysisService:
25
+ """
26
+ Main analysis service
27
+
28
+ Responsibilities:
29
+ - Orchestrate analysis pipeline
30
+ - Coordinate between services
31
+ - Build response
32
+
33
+ Dependencies:
34
+ - TextProcessor: Text processing
35
+ - GradientService: Gradient computation
36
+ - HTMLGenerator: HTML generation
37
+ """
38
+
39
+ def __init__(self):
40
+ self.text_processor = TextProcessor()
41
+ self.gradient_service = GradientService()
42
+ self.html_generator = HTMLGenerator()
43
+
44
+ def analyze(self, request: AnalysisRequest) -> AnalysisResponse:
45
+ """
46
+ Analyze text for toxicity
47
+
48
+ Args:
49
+ request: Analysis request
50
+
51
+ Returns:
52
+ Analysis response
53
+
54
+ Raises:
55
+ AnalysisException: If analysis fails
56
+ """
57
+ try:
58
+ # 1. Split into sentences
59
+ sentences = self.text_processor.split_into_sentences(request.text)
60
+
61
+ # 2. Analyze each sentence
62
+ sentence_results = []
63
+ for i, sent_info in enumerate(sentences, 1):
64
+ sent_result = self._analyze_sentence(
65
+ sent_info,
66
+ i,
67
+ request.include_word_scores
68
+ )
69
+ sentence_results.append(sent_result)
70
+
71
+ # 3. Generate statistics
72
+ statistics = self._compute_statistics(sentence_results)
73
+
74
+ # 4. Extract toxic words summary
75
+ toxic_words_summary = self._extract_toxic_words_summary(sentence_results)
76
+
77
+ # 5. Generate HTML if requested
78
+ html_highlighted = None
79
+ if request.include_html:
80
+ html_highlighted = self.html_generator.generate_highlighted_html(
81
+ request.text,
82
+ [self._convert_to_dict(r) for r in sentence_results]
83
+ )
84
+
85
+ # 6. Determine overall label
86
+ toxic_count = sum(1 for r in sentence_results if r.label == SentimentLabel.TOXIC)
87
+ overall_label = SentimentLabel.TOXIC if toxic_count > 0 else SentimentLabel.CLEAN
88
+
89
+ # 7. Build response
90
+ return AnalysisResponse(
91
+ success=True,
92
+ text=request.text,
93
+ overall_label=overall_label,
94
+ toxic_sentence_count=toxic_count,
95
+ clean_sentence_count=len(sentences) - toxic_count,
96
+ total_sentences=len(sentences),
97
+ sentences=sentence_results,
98
+ toxic_words_summary=toxic_words_summary,
99
+ statistics=statistics,
100
+ html_highlighted=html_highlighted
101
+ )
102
+
103
+ except Exception as e:
104
+ raise AnalysisException(detail=str(e))
105
+
106
+ def _analyze_sentence(
107
+ self,
108
+ sent_info: Dict[str, any],
109
+ sent_number: int,
110
+ include_word_scores: bool
111
+ ) -> SentenceResult:
112
+ """Analyze single sentence"""
113
+ sent_text = sent_info['text']
114
+
115
+ # Extract words
116
+ words = self.text_processor.extract_words(sent_text)
117
+
118
+ if len(words) == 0:
119
+ return SentenceResult(
120
+ sentence_number=sent_number,
121
+ text=sent_text,
122
+ label=SentimentLabel.CLEAN,
123
+ confidence=0.0,
124
+ threshold=0.6,
125
+ word_count=0,
126
+ word_scores=[] if include_word_scores else None
127
+ )
128
+
129
+ # Tokenize
130
+ encoding = model_loader.tokenizer(
131
+ sent_text.lower().strip(),
132
+ add_special_tokens=True,
133
+ max_length=settings.MAX_LENGTH,
134
+ padding='max_length',
135
+ truncation=True,
136
+ return_tensors='pt'
137
+ )
138
+
139
+ # Compute gradients
140
+ gradient_scores, predicted_class, confidence = self.gradient_service.compute_integrated_gradients(
141
+ model=model_loader.model,
142
+ input_ids=encoding['input_ids'],
143
+ attention_mask=encoding['attention_mask'],
144
+ device=model_loader.device
145
+ )
146
+
147
+ # Get tokens
148
+ tokens = model_loader.tokenizer.convert_ids_to_tokens(
149
+ encoding['input_ids'][0].cpu().numpy()
150
+ )
151
+ valid_length = encoding['attention_mask'][0].sum().item()
152
+ tokens = tokens[:valid_length]
153
+
154
+ # Normalize gradients
155
+ gradient_scores_norm = self.gradient_service.normalize_scores(gradient_scores)
156
+
157
+ # Map to words
158
+ word_scores = self._map_tokens_to_words(tokens, gradient_scores_norm, words)
159
+
160
+ # Determine toxicity
161
+ is_toxic = (predicted_class == 1)
162
+ label = SentimentLabel.TOXIC if is_toxic else SentimentLabel.CLEAN
163
+
164
+ # Compute threshold
165
+ threshold = self.gradient_service.compute_threshold(word_scores, is_toxic)
166
+
167
+ # Build word scores
168
+ word_score_objects = None
169
+ if include_word_scores:
170
+ word_score_objects = []
171
+ for word_info, score in zip(words, word_scores):
172
+ word_score_objects.append(WordScore(
173
+ word=word_info['word'],
174
+ score=float(score),
175
+ position={'start': word_info['start'], 'end': word_info['end']},
176
+ is_toxic=score > threshold and not self.text_processor.is_stop_word(word_info['word']),
177
+ is_stop_word=self.text_processor.is_stop_word(word_info['word'])
178
+ ))
179
+
180
+ return SentenceResult(
181
+ sentence_number=sent_number,
182
+ text=sent_text,
183
+ label=label,
184
+ confidence=float(confidence),
185
+ threshold=float(threshold),
186
+ word_count=len(words),
187
+ word_scores=word_score_objects
188
+ )
189
+
190
+ def _map_tokens_to_words(
191
+ self,
192
+ tokens: List[str],
193
+ token_scores: np.ndarray,
194
+ original_words: List[Dict[str, any]]
195
+ ) -> np.ndarray:
196
+ """Map token scores to words"""
197
+ clean_tokens = []
198
+ clean_scores = []
199
+
200
+ for token, score in zip(tokens, token_scores):
201
+ if token not in ['<s>', '</s>', '<pad>', '<unk>']:
202
+ clean_token = token.replace('_', '').replace('@@', '').strip()
203
+ if clean_token and not self.text_processor.is_punctuation(clean_token):
204
+ clean_tokens.append(clean_token)
205
+ clean_scores.append(score)
206
+
207
+ word_scores = []
208
+ token_idx = 0
209
+
210
+ for word_info in original_words:
211
+ word = word_info['word'].lower()
212
+
213
+ matching_scores = []
214
+ temp_idx = token_idx
215
+ accumulated = ""
216
+
217
+ while temp_idx < len(clean_tokens):
218
+ accumulated += clean_tokens[temp_idx]
219
+ matching_scores.append(clean_scores[temp_idx])
220
+
221
+ if accumulated == word:
222
+ token_idx = temp_idx + 1
223
+ break
224
+ elif len(accumulated) >= len(word):
225
+ break
226
+
227
+ temp_idx += 1
228
+
229
+ word_scores.append(max(matching_scores) if matching_scores else 0.0)
230
+
231
+ return np.array(word_scores)
232
+
233
+ def _compute_statistics(self, sentence_results: List[SentenceResult]) -> Statistics:
234
+ """Compute overall statistics"""
235
+ all_scores = []
236
+ toxic_words_count = 0
237
+
238
+ for sent_result in sentence_results:
239
+ if sent_result.word_scores:
240
+ for ws in sent_result.word_scores:
241
+ all_scores.append(ws.score)
242
+ if ws.is_toxic:
243
+ toxic_words_count += 1
244
+
245
+ if len(all_scores) == 0:
246
+ return Statistics(
247
+ total_words=0,
248
+ toxic_words=0,
249
+ mean_score=0.0,
250
+ median_score=0.0,
251
+ max_score=0.0,
252
+ min_score=0.0
253
+ )
254
+
255
+ all_scores = np.array(all_scores)
256
+
257
+ return Statistics(
258
+ total_words=len(all_scores),
259
+ toxic_words=toxic_words_count,
260
+ mean_score=float(np.mean(all_scores)),
261
+ median_score=float(np.median(all_scores)),
262
+ max_score=float(np.max(all_scores)),
263
+ min_score=float(np.min(all_scores))
264
+ )
265
+
266
+ def _extract_toxic_words_summary(
267
+ self,
268
+ sentence_results: List[SentenceResult]
269
+ ) -> List[ToxicWordSummary]:
270
+ """Extract summary of toxic words"""
271
+ toxic_words_dict = defaultdict(lambda: {
272
+ 'max_score': 0.0,
273
+ 'occurrences': 0,
274
+ 'sentences': []
275
+ })
276
+
277
+ for sent_result in sentence_results:
278
+ if sent_result.word_scores:
279
+ for ws in sent_result.word_scores:
280
+ if ws.is_toxic:
281
+ word = ws.word
282
+ toxic_words_dict[word]['max_score'] = max(
283
+ toxic_words_dict[word]['max_score'],
284
+ ws.score
285
+ )
286
+ toxic_words_dict[word]['occurrences'] += 1
287
+ if sent_result.sentence_number not in toxic_words_dict[word]['sentences']:
288
+ toxic_words_dict[word]['sentences'].append(sent_result.sentence_number)
289
+
290
+ # Convert to list and sort by score
291
+ toxic_words_summary = [
292
+ ToxicWordSummary(
293
+ word=word,
294
+ score=data['max_score'],
295
+ occurrences=data['occurrences'],
296
+ sentences=sorted(data['sentences'])
297
+ )
298
+ for word, data in toxic_words_dict.items()
299
+ ]
300
+
301
+ toxic_words_summary.sort(key=lambda x: x.score, reverse=True)
302
+
303
+ return toxic_words_summary
304
+
305
+ def _convert_to_dict(self, sent_result: SentenceResult) -> Dict[str, any]:
306
+ """Convert SentenceResult to dict for HTML generator"""
307
+ return {
308
+ 'sent_start': sent_result.word_scores[0].position['start'] if sent_result.word_scores and len(sent_result.word_scores) > 0 else 0,
309
+ 'sent_end': sent_result.word_scores[-1].position['end'] if sent_result.word_scores and len(sent_result.word_scores) > 0 else len(sent_result.text),
310
+ 'is_toxic': sent_result.label == SentimentLabel.TOXIC,
311
+ 'words': [{'word': ws.word, 'start': ws.position['start'], 'end': ws.position['end']} for ws in sent_result.word_scores] if sent_result.word_scores else [],
312
+ 'scores': [ws.score for ws in sent_result.word_scores] if sent_result.word_scores else [],
313
+ 'threshold': sent_result.threshold
314
+ }
315
+
316
+
317
+ # Singleton instance
318
+ analysis_service = AnalysisService()
app/services/gradient_service.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient Service
3
+ ================
4
+ Gradient computation using Integrated Gradients (Single Responsibility)
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from typing import Tuple
11
+
12
+ from app.models.phobert_model import PhoBERTFineTuned
13
+ from app.core.config import settings
14
+
15
+
16
+ class GradientService:
17
+ """
18
+ Gradient computation service
19
+
20
+ Responsibilities:
21
+ - Compute integrated gradients
22
+ - Calculate importance scores
23
+ """
24
+
25
+ @staticmethod
26
+ def compute_integrated_gradients(
27
+ model: PhoBERTFineTuned,
28
+ input_ids: torch.Tensor,
29
+ attention_mask: torch.Tensor,
30
+ device: torch.device,
31
+ target_class: int | None = None,
32
+ steps: int | None = None
33
+ ) -> Tuple[np.ndarray, int, float]:
34
+ """
35
+ Compute integrated gradients
36
+
37
+ Args:
38
+ model: Model to analyze
39
+ input_ids: Input token IDs
40
+ attention_mask: Attention mask
41
+ device: Device
42
+ target_class: Target class (optional)
43
+ steps: Number of integration steps
44
+
45
+ Returns:
46
+ importance_scores: Token importance scores
47
+ predicted_class: Predicted class
48
+ confidence: Prediction confidence
49
+ """
50
+ if steps is None:
51
+ steps = settings.GRADIENT_STEPS
52
+
53
+ model.eval()
54
+
55
+ input_ids = input_ids.to(device)
56
+ attention_mask = attention_mask.to(device)
57
+
58
+ # Get original embeddings
59
+ with torch.no_grad():
60
+ outputs = model.embedding(input_ids, attention_mask=attention_mask)
61
+ original_hidden = outputs.last_hidden_state
62
+
63
+ baseline_hidden = torch.zeros_like(original_hidden)
64
+ integrated_grads = torch.zeros_like(original_hidden)
65
+
66
+ # Integrate gradients
67
+ for step in range(steps):
68
+ alpha = (step + 1) / steps
69
+ interpolated = baseline_hidden + alpha * (original_hidden - baseline_hidden)
70
+ interpolated = interpolated.detach().clone()
71
+ interpolated.requires_grad = True
72
+
73
+ # Forward pass through classification head
74
+ if model.pooling == 'cls':
75
+ pooled = interpolated[:, 0, :]
76
+ else:
77
+ mask_expanded = attention_mask.unsqueeze(-1).expand(interpolated.size()).float()
78
+ sum_embeddings = torch.sum(interpolated * mask_expanded, 1)
79
+ sum_mask = mask_expanded.sum(1)
80
+ pooled = sum_embeddings / sum_mask
81
+
82
+ pooled = model.layer_norm(pooled)
83
+ out = model.dropout(pooled)
84
+ out = model.relu(model.fc1(out))
85
+ out = model.dropout(out)
86
+ logits = model.fc2(out)
87
+
88
+ # Get prediction on first step
89
+ if step == 0:
90
+ probs = F.softmax(logits, dim=1)
91
+ predicted_class = torch.argmax(probs, dim=1).item()
92
+ confidence = probs[0, predicted_class].item()
93
+ if target_class is None:
94
+ target_class = predicted_class
95
+
96
+ # Backward pass
97
+ model.zero_grad()
98
+ logits[0, target_class].backward()
99
+ integrated_grads += interpolated.grad
100
+
101
+ # Average and scale
102
+ integrated_grads = integrated_grads / steps
103
+ integrated_grads = integrated_grads * (original_hidden - baseline_hidden)
104
+
105
+ # Compute importance scores
106
+ importance_scores = torch.sum(torch.abs(integrated_grads), dim=-1)
107
+ importance_scores = importance_scores[0].cpu().detach().numpy()
108
+
109
+ valid_length = attention_mask[0].sum().item()
110
+ importance_scores = importance_scores[:valid_length]
111
+
112
+ return importance_scores, predicted_class, confidence
113
+
114
+ @staticmethod
115
+ def normalize_scores(scores: np.ndarray) -> np.ndarray:
116
+ """
117
+ Normalize scores to [0, 1]
118
+
119
+ Args:
120
+ scores: Raw scores
121
+
122
+ Returns:
123
+ Normalized scores
124
+ """
125
+ min_score = scores.min()
126
+ max_score = scores.max()
127
+
128
+ if max_score - min_score < 1e-8:
129
+ return np.ones_like(scores) * 0.5
130
+
131
+ return (scores - min_score) / (max_score - min_score)
132
+
133
+ @staticmethod
134
+ def compute_threshold(
135
+ scores: np.ndarray,
136
+ is_toxic: bool,
137
+ percentile: int | None = None
138
+ ) -> float:
139
+ """
140
+ Compute threshold for toxicity
141
+
142
+ Args:
143
+ scores: Word scores
144
+ is_toxic: Whether text is toxic
145
+ percentile: Percentile for threshold
146
+
147
+ Returns:
148
+ Threshold value
149
+ """
150
+ if percentile is None:
151
+ percentile = settings.PERCENTILE_THRESHOLD
152
+
153
+ if len(scores) == 0:
154
+ return 0.6
155
+
156
+ mean_score = np.mean(scores)
157
+ percentile_score = np.percentile(scores, percentile)
158
+ threshold = 0.6 * percentile_score + 0.4 * mean_score
159
+
160
+ if is_toxic:
161
+ threshold = max(threshold, 0.55)
162
+ else:
163
+ threshold = max(threshold, 0.75)
164
+
165
+ threshold = np.clip(threshold, 0.45, 0.90)
166
+
167
+ return float(threshold)
app/services/html_generator.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HTML Generator
3
+ ==============
4
+ Generate HTML highlighting (Single Responsibility)
5
+ """
6
+
7
+ from typing import List, Dict
8
+ from app.services.text_processor import TextProcessor
9
+
10
+
11
+ class HTMLGenerator:
12
+ """
13
+ HTML generation service
14
+
15
+ Responsibilities:
16
+ - Generate HTML with highlighting
17
+ - Format toxic/clean sentences differently
18
+ """
19
+
20
+ @staticmethod
21
+ def generate_highlighted_html(
22
+ text: str,
23
+ sentence_results: List[Dict[str, any]]
24
+ ) -> str:
25
+ """
26
+ Generate HTML with highlighting
27
+
28
+ Args:
29
+ text: Original text
30
+ sentence_results: List of sentence analysis results
31
+
32
+ Returns:
33
+ HTML string with highlighting
34
+ """
35
+ html = '<div style="line-height: 2.2; font-size: 16px; font-family: Arial; max-width: 900px;">'
36
+
37
+ last_end = 0
38
+
39
+ for sent_data in sentence_results:
40
+ sent_start = sent_data['sent_start']
41
+ sent_end = sent_data['sent_end']
42
+ is_toxic = sent_data['is_toxic']
43
+ words = sent_data['words']
44
+ scores = sent_data['scores']
45
+ threshold = sent_data['threshold']
46
+
47
+ # Add space between sentences
48
+ if sent_start > last_end:
49
+ html += text[last_end:sent_start]
50
+
51
+ sent_text = text[sent_start:sent_end]
52
+
53
+ if is_toxic:
54
+ # Toxic sentence - highlight words
55
+ sent_html = HTMLGenerator._generate_toxic_sentence_html(
56
+ sent_text, sent_start, words, scores, threshold
57
+ )
58
+ html += f'<span style="border-left: 3px solid #ff6b6b; padding-left: 8px; display: inline-block; margin: 4px 0;">{sent_html}</span>'
59
+ else:
60
+ # Clean sentence - plain text
61
+ html += f'<span style="color: #444;">{sent_text}</span>'
62
+
63
+ last_end = sent_end
64
+
65
+ # Add remaining text
66
+ if last_end < len(text):
67
+ html += text[last_end:]
68
+
69
+ html += '</div>'
70
+ return html
71
+
72
+ @staticmethod
73
+ def _generate_toxic_sentence_html(
74
+ sent_text: str,
75
+ sent_start: int,
76
+ words: List[Dict[str, any]],
77
+ scores: List[float],
78
+ threshold: float
79
+ ) -> str:
80
+ """
81
+ Generate HTML for toxic sentence
82
+
83
+ Args:
84
+ sent_text: Sentence text
85
+ sent_start: Sentence start position in full text
86
+ words: List of words
87
+ scores: Word scores
88
+ threshold: Toxicity threshold
89
+
90
+ Returns:
91
+ HTML string for sentence
92
+ """
93
+ sent_html = ""
94
+ char_idx = 0
95
+ word_idx = 0
96
+
97
+ while char_idx < len(sent_text):
98
+ if word_idx < len(words):
99
+ word_info = words[word_idx]
100
+ word_start_rel = word_info['start'] - sent_start
101
+ word_end_rel = word_info['end'] - sent_start
102
+
103
+ if char_idx == word_start_rel:
104
+ word = word_info['word']
105
+ score = scores[word_idx]
106
+
107
+ if score > threshold and not TextProcessor.is_stop_word(word) and len(word) > 1:
108
+ # Toxic word - red background
109
+ color = int(255 * (1 - score))
110
+ sent_html += (
111
+ f'<span style="background-color: rgb(255, {color}, {color}); '
112
+ f'padding: 2px 4px; margin: 0 1px; border-radius: 3px; '
113
+ f'font-weight: bold;">{word}</span>'
114
+ )
115
+ else:
116
+ # Non-toxic word
117
+ if TextProcessor.is_stop_word(word):
118
+ sent_html += f'<span style="color: #aaa; font-style: italic;">{word}</span>'
119
+ else:
120
+ sent_html += f'<span style="color: #333;">{word}</span>'
121
+
122
+ char_idx = word_end_rel
123
+ word_idx += 1
124
+ continue
125
+
126
+ # Not at word - add character (punctuation, space, etc)
127
+ sent_html += sent_text[char_idx]
128
+ char_idx += 1
129
+
130
+ return sent_html
app/services/text_processor.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text Processor
3
+ ==============
4
+ Text processing utilities (Single Responsibility)
5
+ """
6
+
7
+ import re
8
+ from typing import List, Dict
9
+
10
+
11
+ class TextProcessor:
12
+ """
13
+ Text processing service
14
+
15
+ Responsibilities:
16
+ - Split text into sentences
17
+ - Extract words from text
18
+ - Identify stop words
19
+ - Identify punctuation
20
+ """
21
+
22
+ STOP_WORDS = {
23
+ 'này', 'kia', 'đó', 'ấy', 'nọ', 'đây', 'nào',
24
+ 'các', 'những', 'mọi', 'cả',
25
+ 'tôi', 'ta', 'mình', 'bạn', 'anh', 'chị', 'em',
26
+ 'nó', 'họ', 'chúng', 'ai', 'gì',
27
+ 'và', 'hoặc', 'nhưng', 'mà', 'nên', 'vì', 'nếu', 'thì', 'hay',
28
+ 'rồi', 'còn', 'cũng', 'luôn', 'đều',
29
+ 'thế', 'như',
30
+ 'của', 'cho', 'với', 'từ', 'bởi', 'về', 'trong', 'ngoài',
31
+ 'là', 'có', 'được', 'bị', 'ở', 'đang', 'sẽ', 'đã',
32
+ 'thể', 'phải', 'nên', 'muốn', 'cần', 'biết',
33
+ 'rất', 'quá', 'khá', 'hơi', 'vẫn', 'còn',
34
+ 'chỉ', 'vừa', 'mới',
35
+ 'đâu', 'sao',
36
+ 'không', 'chẳng', 'chưa',
37
+ 'nhiều', 'ít', 'vài', 'một',
38
+ 'việc', 'chuyện', 'điều', 'lúc', 'khi',
39
+ 'ra', 'vào', 'nhau', 'nhữ',
40
+ 'vậy', 'ạ', 'nhé',
41
+ }
42
+
43
+ PUNCTUATION = set('.,!?;:()[]{}"\'-/\\@#$%^&*+=<>~`|')
44
+
45
+ @staticmethod
46
+ def split_into_sentences(text: str) -> List[Dict[str, any]]:
47
+ """
48
+ Split text into sentences
49
+
50
+ Args:
51
+ text: Input text
52
+
53
+ Returns:
54
+ List of sentences with positions
55
+ """
56
+ sentence_pattern = r'([.!?]+)\s*'
57
+ parts = re.split(sentence_pattern, text)
58
+
59
+ sentences = []
60
+ current_pos = 0
61
+ i = 0
62
+
63
+ while i < len(parts):
64
+ if not parts[i].strip():
65
+ current_pos += len(parts[i])
66
+ i += 1
67
+ continue
68
+
69
+ if not re.match(r'^[.!?]+$', parts[i]):
70
+ sentence_text = parts[i]
71
+
72
+ if i + 1 < len(parts) and re.match(r'^[.!?]+$', parts[i + 1]):
73
+ sentence_text += parts[i + 1]
74
+ i += 2
75
+ else:
76
+ i += 1
77
+
78
+ if sentence_text.strip():
79
+ sentences.append({
80
+ 'text': sentence_text,
81
+ 'start': current_pos,
82
+ 'end': current_pos + len(sentence_text)
83
+ })
84
+
85
+ current_pos += len(sentence_text)
86
+ else:
87
+ current_pos += len(parts[i])
88
+ i += 1
89
+
90
+ if len(sentences) == 0:
91
+ sentences.append({'text': text, 'start': 0, 'end': len(text)})
92
+
93
+ return sentences
94
+
95
+ @staticmethod
96
+ def extract_words(text: str) -> List[Dict[str, any]]:
97
+ """
98
+ Extract words from text
99
+
100
+ Args:
101
+ text: Input text
102
+
103
+ Returns:
104
+ List of words with positions
105
+ """
106
+ pattern = r'[a-zA-Zàáảãạăắằẳẵặâấầẩẫậèéẻẽẹêếềểễệìíỉĩịòóỏõọôốồổỗộơớờởỡợùúủũụưứừửữựỳýỷỹỵđ_]+'
107
+
108
+ words = []
109
+ for match in re.finditer(pattern, text, re.IGNORECASE):
110
+ words.append({
111
+ 'word': match.group(),
112
+ 'start': match.start(),
113
+ 'end': match.end()
114
+ })
115
+
116
+ return words
117
+
118
+ @classmethod
119
+ def is_stop_word(cls, word: str) -> bool:
120
+ """
121
+ Check if word is a stop word
122
+
123
+ Args:
124
+ word: Word to check
125
+
126
+ Returns:
127
+ True if stop word
128
+ """
129
+ return word.lower().strip() in cls.STOP_WORDS
130
+
131
+ @classmethod
132
+ def is_punctuation(cls, token: str) -> bool:
133
+ """
134
+ Check if token is punctuation
135
+
136
+ Args:
137
+ token: Token to check
138
+
139
+ Returns:
140
+ True if punctuation
141
+ """
142
+ return not token or all(c in cls.PUNCTUATION for c in token)
models/PhoBERTFineTuned_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fb4d10f4754fe5c7d45d992ea7c0461f5e4e9fffc6a66fb96ed66ccddb90618
3
+ size 540876678
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
+ pydantic-settings==2.1.0
5
+ python-multipart==0.0.6
6
+ torch==2.1.0
7
+ transformers==4.35.0
8
+ numpy==1.24.3
9
+ python-dotenv==1.0.0