Spaces:
Running
Running
| import gradio as gr | |
| from transformers import pipeline, MarianMTModel, AutoTokenizer | |
| import os | |
| import azure.cognitiveservices.speech as speechsdk | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E"} | |
| # translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect") | |
| translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True) | |
| tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect") | |
| translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True) | |
| tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English") | |
| transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator") | |
| speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION')) | |
| def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT | |
| # Generate a colormap with a specified number of colors | |
| cmap = plt.cm.get_cmap(palette, num_colors) | |
| # Get the RGB values of the colors in the colormap | |
| colors_rgb = cmap(np.arange(num_colors)) | |
| # Convert the RGB values to hexadecimal color codes | |
| colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb] | |
| return colors_hex | |
| def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True): | |
| alignment = [] | |
| for i, tok in enumerate(outputs.cross_attentions[2][0][7]): | |
| alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()]) | |
| merged = [] | |
| for i in alignment: | |
| token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0] | |
| if token not in tokenizer.convert_tokens_to_ids(["</s>", "<pad>", "<unk>"]): | |
| if merged: | |
| tomerge = False | |
| # check overlap with previous entry | |
| for x in i[1]: | |
| if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "โ": | |
| tomerge = True | |
| break | |
| # if first character is not a "โ" | |
| if token[0] != "โ": | |
| tomerge = True | |
| if tomerge: | |
| merged[-1][0] += i[0] | |
| merged[-1][1] += i[1] | |
| else: | |
| merged.append(i) | |
| else: | |
| merged.append(i) | |
| colordict = {} | |
| ncolors = 0 | |
| for i in merged: | |
| src_tok = [f"src_{x}" for x in i[0]] | |
| trg_tok = [f"trg_{x}" for x in i[1]] | |
| all_tok = src_tok + trg_tok | |
| # see if any tokens in entry already have associated color | |
| newcolor = None | |
| for t in all_tok: | |
| if t in colordict: | |
| newcolor = colordict[t] | |
| break | |
| if not newcolor: | |
| newcolor = ncolors | |
| ncolors += 1 | |
| for t in all_tok: | |
| if t not in colordict: | |
| colordict[t] = newcolor | |
| colors = generate_diverging_colors(ncolors, palette="Set2") | |
| id_to_color = {i: c for i, c in enumerate(colors)} | |
| for k, v in colordict.items(): | |
| colordict[k] = id_to_color[v] | |
| tgthtml = [] | |
| for i, token in enumerate(decoder_input_ids[0]): | |
| if f"src_{i}" in colordict: | |
| label = f"src_{i}" | |
| tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| else: | |
| tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| tgthtml = "".join(tgthtml) | |
| tgthtml = tgthtml.replace("โ", " ") | |
| tgthtml = f"<span style='font-size: 30px'>{tgthtml}</span>" | |
| srchtml = [] | |
| for i, token in enumerate(encoder_input_ids[0]): | |
| if skip_first_src and i == 0: | |
| continue | |
| if f"trg_{i}" in colordict: | |
| label = f"trg_{i}" | |
| srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| else: | |
| srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| srchtml = "".join(srchtml) | |
| srchtml = srchtml.replace("โ", " ") | |
| srchtml = f"<span style='font-size: 30px'>{srchtml}</span>" | |
| return srchtml, tgthtml | |
| def translate_english(input_text, include): | |
| if not input_text: | |
| return "", "", "", "", "" | |
| inputs = [f"{val} {input_text}" for val in dialects.values()] | |
| sy, lb, eg = "SYR" in include, "LEB" in include, "EGY" in include | |
| # remove 2nd element if sy is false | |
| if not eg: | |
| inputs.pop() | |
| if not lb: | |
| inputs.pop() | |
| if not sy: | |
| inputs.pop() | |
| input_tokens = tokenizer_en2ar(inputs, return_tensors="pt").input_ids | |
| # print(input_tokens) | |
| outputs = translator_en2ar.generate(input_tokens) | |
| # print(outputs) | |
| encoder_input_ids = input_tokens[0].unsqueeze(0) | |
| decoder_input_ids = outputs[0].unsqueeze(0) | |
| decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True) | |
| # print(decoded) | |
| pal_out = decoded[0] | |
| sy_out = decoded[1] if sy else "" | |
| lb_out = decoded[1 + sy] if lb else "" | |
| eg_out = decoded[1 + sy + lb] if eg else "" | |
| if "Colorize" in include: | |
| html_outputs = translator_en2ar(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) | |
| # set dynamic threshold | |
| # print(input_tokens, input_tokens.shape) | |
| if input_tokens.shape[1] < 10: | |
| threshold = 0.4 | |
| elif input_tokens.shape[1] < 20: | |
| threshold = 0.10 | |
| else: | |
| threshold = 0.05 | |
| print("threshold", threshold) | |
| srchtml, tgthtml = align_words(html_outputs, tokenizer_en2ar, encoder_input_ids, decoder_input_ids, threshold) | |
| palhtml = f"{srchtml}<br><br><div style='direction: rtl'>{tgthtml}</div>" | |
| else: | |
| palhtml = f"<div style='font-size: 30px; direction: rtl'>{pal_out}</div>" | |
| return palhtml, pal_out, sy_out, lb_out, eg_out | |
| def translate_arabic(input_text, include=["Colorize"]): | |
| if not input_text: | |
| return "" | |
| input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids | |
| # print(input_tokens) | |
| outputs = translator_ar2en.generate(input_tokens) | |
| # print(outputs) | |
| encoder_input_ids = input_tokens[0].unsqueeze(0) | |
| decoder_input_ids = outputs[0].unsqueeze(0) | |
| decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True) | |
| # print(decoded) | |
| print(include) | |
| if "Colorize" in include: | |
| html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) | |
| # set dynamic threshold | |
| # print(input_tokens, input_tokens.shape) | |
| if input_tokens.shape[1] < 20: | |
| threshold = 0.1 | |
| elif input_tokens.shape[1] < 30: | |
| threshold = 0.01 | |
| else: | |
| threshold = 0.05 | |
| print("threshold", threshold) | |
| srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False) | |
| enhtml = f"<div style='direction: rtl'>{srchtml}</div><br><br><div>{tgthtml}</div>" | |
| else: | |
| enhtml = f"<div style='font-size: 30px;'>{decoded[0]}</div>" | |
| return enhtml | |
| def get_audio(input_text): | |
| audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav") | |
| speech_config.speech_synthesis_voice_name='ar-SY-AmanyNeural' | |
| speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config) | |
| speech_synthesis_result = speech_synthesizer.speak_text_async(input_text).get() | |
| return f"{input_text}.wav" | |
| def get_transliteration(input_text, include=["Translit."]): | |
| if "Translit." not in include: | |
| return "" | |
| result = transliterator([input_text]) | |
| return result[0]["translation_text"] | |
| bla = """ | |
| """ | |
| css = """ | |
| #liter textarea, #trans textarea { font-size: 25px;} | |
| #trans textarea { direction: rtl; } | |
| #check { border-style: none !important; } | |
| :root {--button-secondary-background-focus: #2563eb !important; | |
| --button-secondary-background-base: #2563eb !important; | |
| --button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2); | |
| --button-secondary-text-color-base: white !important; | |
| --button-secondary-text-color-hover: white !important; | |
| --button-secondary-background-focus: rgb(51 122 216 / 70%) !important; | |
| --button-secondary-text-color-focus: white !important} | |
| .dark {--button-secondary-background-base: #2563eb !important; | |
| --button-secondary-background-focus: rgb(51 122 216 / 70%) !important; | |
| --button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2)} | |
| .feather-music { stroke: #2563eb; } | |
| """ | |
| def toggle_visibility(include): | |
| outs = [gr.Textbox.update(visible=True)] * 4 | |
| if "Translit." not in include: | |
| outs[0] = gr.Textbox.update(visible=False) | |
| if "SYR" not in include: | |
| outs[1] = gr.Textbox.update(visible=False) | |
| if "LEB" not in include: | |
| outs[2] = gr.Textbox.update(visible=False) | |
| if "EGY" not in include: | |
| outs[3] = gr.Textbox.update(visible=False) | |
| return outs | |
| with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default") as demo: | |
| gr.HTML("<h2><span style='color: #2563eb'>Levantine Arabic</span> Translator</h2>") | |
| with gr.Tab('En > Ar'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Input", placeholder="Enter English text", lines=2) | |
| gr.Examples(["I wanted to go to the store yesterday, but it rained", "How are you feeling today?"], input_text) | |
| btn = gr.Button("Translate", label="Translate") | |
| with gr.Row(): | |
| include = gr.CheckboxGroup(["Translit.", "SYR", "LEB", "EGY", "Colorize"], | |
| label="Disable features to speed up translation", | |
| value=["Translit.", "EGY", "Colorize"]) | |
| gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il). Pronunciation model is specifically tailored to urban Palestinian Arabic. Text-to-speech uses Microsoft Azure's API and may provide different result from the transliterated pronunciation.") | |
| with gr.Column(): | |
| with gr.Box(label = "Palestinian"): | |
| gr.Markdown("Palestinian") | |
| with gr.Box(): | |
| pal_html = gr.HTML("<br>", visible=True, label="Palestinian", elem_id="main") | |
| pal = gr.Textbox(lines=1, label="Palestinian", elem_id="trans", visible=False) | |
| pal_translit = gr.Textbox(lines=1, label="Palestinian Pronunciation (Urban)", elem_id="liter") | |
| sy = gr.Textbox(lines=1, label="Syrian", elem_id="trans", visible=False) | |
| lb = gr.Textbox(lines=1, label="Lebanese", elem_id="trans", visible=False) | |
| eg = gr.Textbox(lines=1, label="Egyptian", elem_id="trans") | |
| # with gr.Row(): | |
| audio = gr.Audio(label="Audio - Palestinian", interactive=False) | |
| audio_button = gr.Button("Get Audio", label="Click Here to Get Audio") | |
| audio_button.click(get_audio, inputs=[pal], outputs=[audio]) | |
| btn.click(translate_english,inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg], api_name="en2ar", _js="function jump(x, y){document.getElementById('main').scrollIntoView(); return [x, y];}") | |
| input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True) | |
| pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]); | |
| include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg]) | |
| with gr.Tab('Ar > En'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1, elem_id="trans") | |
| gr.Examples(["ุฎูููุง ูุฏูุฑ ุนูู ู ุทุนู ุชุงูู", "ูุฏูุด ุญู ุงูุจูุฏูุฑุฉุ"], input_text) | |
| btn = gr.Button("Translate", label="Translate") | |
| gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).") | |
| with gr.Column(): | |
| with gr.Box(label = "English"): | |
| gr.Markdown("English") | |
| with gr.Box(): | |
| eng = gr.HTML("<br>", label="English", elem_id="main") | |
| btn.click(translate_arabic,inputs=input_text, outputs=[eng], api_name = "ar2en") | |
| with gr.Tab("Transliterate"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1) | |
| gr.Examples(["ุฎูููุง ูุฏูุฑ ุนูู ู ุทุนู ุชุงูู", "ูุฏูุด ุญู ุงูุจูุฏูุฑุฉุ"], input_text) | |
| btn = gr.Button("Transliterate", label="Transliterate") | |
| gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il)") | |
| with gr.Column(): | |
| translit = gr.Textbox(label="Transliteration", lines=1, elem_id="liter") | |
| btn.click(get_transliteration, inputs=input_text, outputs=[translit]) | |
| demo.launch() |