Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Union, Dict, Any | |
| from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer | |
| # λͺ¨λΈ λ‘λ | |
| M2M100_MODEL_NAME = "facebook/m2m100_418M" | |
| OPUS_MT_MODEL_NAME = "Helsinki-NLP/opus-mt-tc-big-en-ko" | |
| # M2M100 (λ€κ΅μ΄ λ²μ) | |
| m2m100_tokenizer = M2M100Tokenizer.from_pretrained(M2M100_MODEL_NAME) | |
| m2m100_model = M2M100ForConditionalGeneration.from_pretrained(M2M100_MODEL_NAME) | |
| # Helsinki-NLP Opus-MT (μμ΄ β νκ΅μ΄ μ μ©) | |
| opus_tokenizer = MarianTokenizer.from_pretrained(OPUS_MT_MODEL_NAME) | |
| opus_model = MarianMTModel.from_pretrained(OPUS_MT_MODEL_NAME) | |
| # CPUμμ μ€ν | |
| device = torch.device("cpu") | |
| m2m100_model.to(device) | |
| opus_model.to(device) | |
| # FastAPI μ± | |
| app = FastAPI() | |
| # μμ² λ°μ΄ν° λͺ¨λΈ | |
| class TranslationRequest(BaseModel): | |
| model: str # μ¬μ©ν λͺ¨λΈ ("m2m100" λλ "opus-mt") | |
| from_lang: str # μ λ ₯ μΈμ΄ (μ: "ko", "en", "fr") | |
| to: str # μΆλ ₯ μΈμ΄ (μ: "ko", "fr") | |
| data: Dict[str, Any] # λ²μν JSON κ°μ²΄ | |
| # M2M100 λ²μ ν¨μ (λͺ¨λ μΈμ΄ μ§μ) | |
| def translate_m2m100(text: str, src_lang: str, tgt_lang: str) -> str: | |
| if not text.strip(): | |
| return text # λΉ λ¬Έμμ΄μ΄λ©΄ κ·Έλλ‘ λ°ν | |
| m2m100_tokenizer.src_lang = src_lang | |
| encoded_text = m2m100_tokenizer(text, return_tensors="pt").to(device) | |
| generated_tokens = m2m100_model.generate( | |
| **encoded_text, forced_bos_token_id=m2m100_tokenizer.get_lang_id(tgt_lang) | |
| ) | |
| return m2m100_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| # Helsinki-NLP Opus-MT λ²μ ν¨μ (μμ΄ β νκ΅μ΄ μ μ©) | |
| def translate_opus_mt(text: str, src_lang: str, tgt_lang: str) -> str: | |
| if not text.strip(): | |
| return text # λΉ λ¬Έμμ΄μ΄λ©΄ κ·Έλλ‘ λ°ν | |
| if src_lang == "en" and tgt_lang == "ko": | |
| model_name = "Helsinki-NLP/opus-mt-en-ko" | |
| elif src_lang == "ko" and tgt_lang == "en": | |
| model_name = "Helsinki-NLP/opus-mt-ko-en" | |
| else: | |
| raise HTTPException(status_code=400, detail="Opus-MTλ μμ΄ β νκ΅μ΄λ§ μ§μν©λλ€.") | |
| tokenizer = MarianTokenizer.from_pretrained(model_name) | |
| model = MarianMTModel.from_pretrained(model_name).to(device) | |
| encoded_text = tokenizer(text, return_tensors="pt", padding=True).to(device) | |
| generated_tokens = model.generate(**encoded_text) | |
| return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| # μ¬κ·μ μΌλ‘ JSON λ²μ ν¨μ | |
| def recursive_translate(json_obj: Union[Dict[str, Any], str], src_lang: str, tgt_lang: str, model_type: str): | |
| if isinstance(json_obj, str): # λ¨μΌ λ¬Έμμ΄μ΄λ©΄ λ²μ | |
| if model_type == "m2m100": | |
| return translate_m2m100(json_obj, src_lang, tgt_lang) | |
| elif model_type == "opus-mt": | |
| return translate_opus_mt(json_obj, src_lang, tgt_lang) | |
| elif isinstance(json_obj, dict): # λμ λ리면 μ¬κ·μ μΌλ‘ λ²μ | |
| return {key: recursive_translate(value, src_lang, tgt_lang, model_type) for key, value in json_obj.items()} | |
| else: | |
| return json_obj # μ«μ, 리μ€νΈ λ±μ λ²μνμ§ μκ³ κ·Έλλ‘ λ°ν | |
| async def translate_json(request: TranslationRequest): | |
| """JSON λ°μ΄ν°λ₯Ό λ²μνλ API""" | |
| model_type = request.model # "m2m100" λλ "opus-mt" | |
| src_lang = request.from_lang | |
| tgt_lang = request.to | |
| input_data = request.data | |
| # μ§μνλ μΈμ΄ λͺ©λ‘ (M2M100μ κ±°μ λͺ¨λ μΈμ΄ μ§μ) | |
| supported_langs = ["ko", "en", "fr", "es", "ja", "zh", "de", "it"] | |
| # λͺ¨λΈ μ ν | |
| if model_type == "m2m100": | |
| if src_lang not in supported_langs or tgt_lang not in supported_langs: | |
| raise HTTPException(status_code=400, detail=f"μ§μλμ§ μλ μΈμ΄ μ½λ: {src_lang} β {tgt_lang}") | |
| elif model_type == "opus-mt": | |
| if not (src_lang in ["en", "ko"] and tgt_lang in ["en", "ko"]): | |
| raise HTTPException(status_code=400, detail="Opus-MT λͺ¨λΈμ μμ΄ β νκ΅μ΄ λ²μλ§ μ§μν©λλ€.") | |
| else: | |
| raise HTTPException(status_code=400, detail="μ§μλμ§ μλ λͺ¨λΈ μ ν") | |
| # μ¬κ·μ μΌλ‘ JSON λ²μ μ€ν | |
| translated_data = recursive_translate(input_data, src_lang, tgt_lang, model_type) | |
| return translated_data | |