Spaces:
Configuration error
Configuration error
| from collections import namedtuple | |
| from functools import cache, cached_property | |
| from io import BytesIO | |
| from os import environ | |
| from os.path import isfile, join | |
| from re import MULTILINE, escape, search, sub | |
| from subprocess import CalledProcessError, DEVNULL, TimeoutExpired | |
| from tempfile import NamedTemporaryFile, TemporaryDirectory | |
| from typing import Optional, Union | |
| import warnings | |
| from PIL import Image, ImageOps | |
| import requests | |
| import torch | |
| from torch.cuda import current_device, is_available as has_cuda | |
| from transformers import TextGenerationPipeline as TGP, TextStreamer, pipeline, ImageToTextPipeline as ITP | |
| from transformers.utils import logging | |
| from transformers.utils.hub import is_remote_url | |
| from pdf2image.pdf2image import convert_from_bytes | |
| from pdfCropMargins import crop | |
| import fitz | |
| logger = logging.get_logger("transformers") | |
| from os import killpg, getpgid | |
| from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE | |
| from signal import SIGKILL | |
| def run(*popenargs, input=None, timeout=None, check=False, **kwargs): | |
| with Popen(*popenargs, start_new_session=True, **kwargs) as process: | |
| try: | |
| stdout, stderr = process.communicate(input, timeout=timeout) | |
| except TimeoutExpired: | |
| killpg(getpgid(process.pid), SIGKILL) | |
| process.wait() | |
| raise | |
| except: | |
| killpg(getpgid(process.pid), SIGKILL) | |
| raise | |
| retcode = process.poll() | |
| if check and retcode: | |
| raise CalledProcessError(retcode, process.args, | |
| output=stdout, stderr=stderr) | |
| return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore | |
| def check_output(*popenargs, timeout=None, **kwargs): | |
| return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout | |
| class PdfDocument: | |
| def __init__(self, raw: bytes): | |
| self.raw = raw | |
| def save(self, filename): | |
| with open(filename, "wb") as f: | |
| f.write(self.raw) | |
| class TikzDocument: | |
| """ | |
| Faciliate some operations with TikZ code. To compile the images a full | |
| TeXLive installation is assumed to be on the PATH. Cropping additionally | |
| requires Ghostscript, and rasterization needs poppler (apart from the 'pdf' | |
| optional dependencies). | |
| """ | |
| # engines to try, could also try: https://tex.stackexchange.com/a/495999 | |
| engines = ["pdflatex", "lualatex", "xelatex"] | |
| Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""]) | |
| def __init__(self, code: str, timeout=120): | |
| self.code = code | |
| self.timeout = timeout | |
| def status(self) -> int: | |
| return self.compile().status | |
| def pdf(self) -> Optional[PdfDocument]: | |
| return self.compile().pdf | |
| def log(self) -> str: | |
| return self.compile().log | |
| def compiled_with_errors(self) -> bool: | |
| return self.status != 0 | |
| def has_content(self) -> bool: | |
| """true if we have an image that isn't empty""" | |
| return (img:=self.rasterize()) is not None and img.getcolors(1) is None | |
| def set_engines(cls, engines: Union[str, list]): | |
| cls.engines = [engines] if isinstance(engines, str) else engines | |
| def compile(self) -> "Output": | |
| output = dict() | |
| with TemporaryDirectory() as tmpdirname: | |
| with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile: | |
| codelines = self.code.split("\n") | |
| # make sure we don't have page numbers in compiled pdf (for cropping) | |
| codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}")) | |
| tmpfile.write("\n".join(codelines).encode()) | |
| try: | |
| # compile | |
| errorln, tmppdf, outpdf = 0, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf") | |
| open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile | |
| def try_save_last_page(): | |
| try: | |
| doc = fitz.open(tmppdf) # type: ignore | |
| doc.select([len(doc)-1]) | |
| doc.save(outpdf) | |
| except: | |
| pass | |
| for engine in self.engines: | |
| try: | |
| check_output( | |
| cwd=tmpdirname, | |
| timeout=self.timeout, | |
| stderr=DEVNULL, | |
| env=environ | dict(max_print_line="1000"), # improve formatting of log | |
| args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name] | |
| ) | |
| except (CalledProcessError, TimeoutExpired) as proc: | |
| log = getattr(proc, "output", b'').decode(errors="ignore") | |
| error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE) | |
| # only update status and log if first error occurs later than in previous engine | |
| if (linenr:=int(error.group(1)) if error else 0) > errorln: | |
| errorln = linenr | |
| output.update(status=getattr(proc, 'returncode', -1), log=log) | |
| try_save_last_page() | |
| else: | |
| output.update(status=0, log='') | |
| try_save_last_page() | |
| break | |
| # crop | |
| croppdf = f"{tmpfile.name}.crop" | |
| crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True) | |
| if isfile(croppdf): | |
| with open(croppdf, "rb") as pdf: | |
| output['pdf'] = PdfDocument(pdf.read()) | |
| except (FileNotFoundError, NameError) as e: | |
| logger.error("Missing dependencies: " + ( | |
| "Install this project with the [pdf] feature name!" if isinstance(e, NameError) | |
| else "Did you install TeX Live?" | |
| )) | |
| except RuntimeError: # pdf error during cropping | |
| pass | |
| if output.get("status") == 0 and not output.get("pdf", None): | |
| logger.warning("Could compile document but something seems to have gone wrong during cropping!") | |
| return self.Output(**output) | |
| def rasterize(self, size=336, expand_to_square=True) -> Optional[Image.Image]: | |
| if self.pdf: | |
| image = convert_from_bytes(self.pdf.raw, size=size, single_file=True)[0] | |
| if expand_to_square: | |
| image = ImageOps.pad(image, (size, size), color='white') | |
| return image | |
| def save(self, filename: str, *args, **kwargs): | |
| match filename.split(".")[-1]: | |
| case "tex": content = self.code.encode() | |
| case "pdf": content = getattr(self.pdf, "raw", bytes()) | |
| case fmt if img := self.rasterize(*args, **kwargs): | |
| img.save(imgByteArr:=BytesIO(), format=fmt) | |
| content = imgByteArr.getvalue() | |
| case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!") | |
| with open(filename, "wb") as f: | |
| f.write(content) | |
| class TikzGenerator: | |
| def __init__( | |
| self, | |
| pipe: ITP, | |
| temperature: float = 0.8, # based on "a systematic evaluation of large language models of code" | |
| top_p: float = 0.95, | |
| top_k: int = 0, | |
| stream: bool = False, | |
| expand_to_square: bool = False, | |
| clean_up_output: bool = True, | |
| ): | |
| self.expand_to_square = expand_to_square | |
| self.clean_up_output = clean_up_output | |
| self.pipeline = pipe | |
| self.default_kwargs = dict( | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=True, | |
| max_new_tokens=1024, | |
| ) | |
| # if not stream: | |
| # self.default_kwargs.pop("streamer") | |
| def generate(self, image: Image.Image, **generate_kwargs): | |
| prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:" | |
| tokenizer = self.pipeline.tokenizer | |
| text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore | |
| if self.clean_up_output: | |
| for token in reversed(tokenizer.tokenize(prompt)): # type: ignore | |
| # remove leading characters because skip_special_tokens in pipeline | |
| # adds unwanted prefix spaces if prompt ends with a special tokens | |
| if text and text[0].isspace() and token in tokenizer.all_special_tokens: # type: ignore | |
| text = text[1:] | |
| else: | |
| break | |
| # occasionally observed artifacts | |
| artifacts = { | |
| r'\bamsop\b': 'amsopn' | |
| } | |
| for artifact, replacement in artifacts.items(): | |
| text = sub(artifact, replacement, text) # type: ignore | |
| return TikzDocument(text) | |
| def __call__(self, *args, **kwargs): | |
| return self.generate(*args, **kwargs) | |