| from langchain.llms.base import LLM | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from transformers import GPT2TokenizerFast | |
| from langchain.schema.messages import get_buffer_string | |
| def get_num_tokens(text): | |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
| return len(tokenizer.tokenize(text)) | |
| def get_memory_num_tokens(memory): | |
| buffer = memory.chat_memory.messages | |
| return sum([get_num_tokens(get_buffer_string([m])) for m in buffer]) | |
| def validate_memory_len(memory, max_token_limit=2000): | |
| buffer = memory.chat_memory.messages | |
| curr_buffer_length = get_memory_num_tokens(memory) | |
| if curr_buffer_length > max_token_limit: | |
| while curr_buffer_length > max_token_limit: | |
| buffer.pop(0) | |
| curr_buffer_length = get_memory_num_tokens(memory) | |
| return memory | |
| if __name__ == '__main__': | |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
| text = '''Hi''' | |
| print(len(tokenizer.tokenize(text))) |