Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,8 @@ dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egypt
|
|
| 10 |
# translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
|
| 11 |
translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
|
| 12 |
tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
|
| 13 |
-
translator_ar2en =
|
|
|
|
| 14 |
transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
|
| 15 |
|
| 16 |
speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
|
|
@@ -28,7 +29,7 @@ def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT
|
|
| 28 |
return colors_hex
|
| 29 |
|
| 30 |
|
| 31 |
-
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4):
|
| 32 |
alignment = []
|
| 33 |
for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
|
| 34 |
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
|
|
@@ -93,7 +94,7 @@ def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, thresh
|
|
| 93 |
|
| 94 |
srchtml = []
|
| 95 |
for i, token in enumerate(encoder_input_ids[0]):
|
| 96 |
-
if i == 0:
|
| 97 |
continue
|
| 98 |
if f"trg_{i}" in colordict:
|
| 99 |
label = f"trg_{i}"
|
|
@@ -158,13 +159,42 @@ def translate_english(input_text, include):
|
|
| 158 |
|
| 159 |
return palhtml, pal_out, sy_out, lb_out, eg_out
|
| 160 |
|
| 161 |
-
def translate_arabic(input_text):
|
| 162 |
if not input_text:
|
| 163 |
return ""
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
def get_audio(input_text):
|
| 170 |
audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
|
|
@@ -244,6 +274,7 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default")
|
|
| 244 |
input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
|
| 245 |
pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
|
| 246 |
include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
|
|
|
|
| 247 |
with gr.Tab('Ar > En'):
|
| 248 |
with gr.Row():
|
| 249 |
with gr.Column():
|
|
@@ -252,8 +283,12 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default")
|
|
| 252 |
btn = gr.Button("Translate", label="Translate")
|
| 253 |
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
|
| 254 |
with gr.Column():
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
| 256 |
btn.click(translate_arabic,inputs=input_text, outputs=[eng])
|
|
|
|
| 257 |
with gr.Tab("Transliterate"):
|
| 258 |
with gr.Row():
|
| 259 |
with gr.Column():
|
|
|
|
| 10 |
# translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
|
| 11 |
translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
|
| 12 |
tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
|
| 13 |
+
translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True)
|
| 14 |
+
tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English")
|
| 15 |
transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
|
| 16 |
|
| 17 |
speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
|
|
|
|
| 29 |
return colors_hex
|
| 30 |
|
| 31 |
|
| 32 |
+
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True):
|
| 33 |
alignment = []
|
| 34 |
for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
|
| 35 |
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
|
|
|
|
| 94 |
|
| 95 |
srchtml = []
|
| 96 |
for i, token in enumerate(encoder_input_ids[0]):
|
| 97 |
+
if skip_first_src and i == 0:
|
| 98 |
continue
|
| 99 |
if f"trg_{i}" in colordict:
|
| 100 |
label = f"trg_{i}"
|
|
|
|
| 159 |
|
| 160 |
return palhtml, pal_out, sy_out, lb_out, eg_out
|
| 161 |
|
| 162 |
+
def translate_arabic(input_text, include=["Colorize"]):
|
| 163 |
if not input_text:
|
| 164 |
return ""
|
| 165 |
|
| 166 |
+
input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids
|
| 167 |
+
# print(input_tokens)
|
| 168 |
+
outputs = translator_ar2en.generate(input_tokens)
|
| 169 |
+
# print(outputs)
|
| 170 |
+
|
| 171 |
+
encoder_input_ids = input_tokens[0].unsqueeze(0)
|
| 172 |
+
decoder_input_ids = outputs[0].unsqueeze(0)
|
| 173 |
|
| 174 |
+
decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True)
|
| 175 |
+
# print(decoded)
|
| 176 |
+
|
| 177 |
+
print(include)
|
| 178 |
+
if "Colorize" in include:
|
| 179 |
+
html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)
|
| 180 |
+
|
| 181 |
+
# set dynamic threshold
|
| 182 |
+
# print(input_tokens, input_tokens.shape)
|
| 183 |
+
if input_tokens.shape[1] < 20:
|
| 184 |
+
threshold = 0.1
|
| 185 |
+
elif input_tokens.shape[1] < 30:
|
| 186 |
+
threshold = 0.01
|
| 187 |
+
else:
|
| 188 |
+
threshold = 0.05
|
| 189 |
+
|
| 190 |
+
print("threshold", threshold)
|
| 191 |
+
|
| 192 |
+
srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False)
|
| 193 |
+
enhtml = f"<div style='direction: rtl'>{srchtml}</div><br><br><div>{tgthtml}</div>"
|
| 194 |
+
else:
|
| 195 |
+
enhtml = f"<div style='font-size: 30px;'>{decoded[0]}</div>"
|
| 196 |
+
|
| 197 |
+
return enhtml
|
| 198 |
|
| 199 |
def get_audio(input_text):
|
| 200 |
audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
|
|
|
|
| 274 |
input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
|
| 275 |
pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
|
| 276 |
include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
|
| 277 |
+
|
| 278 |
with gr.Tab('Ar > En'):
|
| 279 |
with gr.Row():
|
| 280 |
with gr.Column():
|
|
|
|
| 283 |
btn = gr.Button("Translate", label="Translate")
|
| 284 |
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
|
| 285 |
with gr.Column():
|
| 286 |
+
with gr.Box(label = "English"):
|
| 287 |
+
gr.Markdown("English")
|
| 288 |
+
with gr.Box():
|
| 289 |
+
eng = gr.HTML("<br>", label="English", elem_id="main")
|
| 290 |
btn.click(translate_arabic,inputs=input_text, outputs=[eng])
|
| 291 |
+
|
| 292 |
with gr.Tab("Transliterate"):
|
| 293 |
with gr.Row():
|
| 294 |
with gr.Column():
|