|
|
""" |
|
|
Quick test script for B2NL v6.1.2 app functionality |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
import torch |
|
|
|
|
|
|
|
|
parent_dir = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(parent_dir / 'intelligent-tokenizer_v6.1.2')) |
|
|
|
|
|
from core.unified_model import IntelligentTokenizerModelV61 |
|
|
from core.byte_tokenizer_v6 import ByteTokenizerV6 |
|
|
|
|
|
def test_model(): |
|
|
device = torch.device('cpu') |
|
|
tokenizer = ByteTokenizerV6(max_seq_len=64) |
|
|
model = IntelligentTokenizerModelV61(vocab_size=260, max_seq_len=64).to(device) |
|
|
|
|
|
|
|
|
checkpoint_path = parent_dir / 'intelligent-tokenizer_v6.1.2' / 'checkpoints' / 'v612_compression_first' / 'best_model.pt' |
|
|
|
|
|
if checkpoint_path.exists(): |
|
|
print(f"Loading checkpoint from {checkpoint_path}") |
|
|
checkpoint = torch.load(str(checkpoint_path), map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f"[OK] Loaded checkpoint: Epoch {checkpoint.get('epoch', 'N/A')}") |
|
|
model.eval() |
|
|
|
|
|
|
|
|
test_text = "μλ
νμΈμ. μ€λ λ μ¨κ° μ’λ€μ." |
|
|
print(f"\nTest text: {test_text}") |
|
|
|
|
|
|
|
|
byte_seq = list(test_text.encode('utf-8'))[:62] |
|
|
print(f"Bytes: {len(byte_seq)}") |
|
|
|
|
|
|
|
|
input_ids = torch.tensor([[tokenizer.BOS] + byte_seq + [tokenizer.EOS]], dtype=torch.long).to(device) |
|
|
if input_ids.size(1) < 64: |
|
|
padding = torch.full((1, 64 - input_ids.size(1)), tokenizer.PAD, dtype=torch.long).to(device) |
|
|
input_ids = torch.cat([input_ids, padding], dim=1) |
|
|
|
|
|
attention_mask = (input_ids != tokenizer.PAD).float() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=input_ids, |
|
|
epoch=233, |
|
|
use_cross_attention=True |
|
|
) |
|
|
|
|
|
print(f"\n[OK] Model outputs available: {list(outputs.keys())}") |
|
|
|
|
|
|
|
|
if 'eojeol_boundaries' in outputs: |
|
|
boundaries = torch.argmax(outputs['eojeol_boundaries'], dim=-1)[0] |
|
|
num_groups = torch.sum(boundaries == 1).item() + 1 |
|
|
compression = len(byte_seq) / num_groups |
|
|
print(f"[OK] Compression: {len(byte_seq)} bytes -> {num_groups} tokens = {compression:.1f}:1") |
|
|
|
|
|
|
|
|
groups = [] |
|
|
current_group = [] |
|
|
boundaries_np = boundaries.cpu().numpy() |
|
|
|
|
|
for i in range(min(len(byte_seq), len(boundaries_np))): |
|
|
is_boundary = (i == 0) or (boundaries_np[i] == 1) |
|
|
|
|
|
if is_boundary and current_group: |
|
|
try: |
|
|
group_text = bytes(current_group).decode('utf-8', errors='replace') |
|
|
groups.append(f"<{group_text}>") |
|
|
except: |
|
|
groups.append(f"<{len(current_group)}B>") |
|
|
current_group = [] |
|
|
|
|
|
if i < len(byte_seq): |
|
|
current_group.append(byte_seq[i]) |
|
|
|
|
|
if current_group: |
|
|
try: |
|
|
group_text = bytes(current_group).decode('utf-8', errors='replace') |
|
|
groups.append(f"<{group_text}>") |
|
|
except: |
|
|
groups.append(f"<{len(current_group)}B>") |
|
|
|
|
|
print(f"[OK] Groups: {' '.join(groups)}") |
|
|
|
|
|
|
|
|
if 'encoder_hidden_states' in outputs: |
|
|
|
|
|
last_hidden = outputs['encoder_hidden_states'][-1] if isinstance(outputs['encoder_hidden_states'], tuple) else outputs['encoder_hidden_states'] |
|
|
embeddings = last_hidden[0, 0, :20] |
|
|
emb_values = embeddings.cpu().numpy() |
|
|
print(f"\n[OK] Embeddings (first 20 dims):") |
|
|
for i in range(0, len(emb_values), 5): |
|
|
dims = emb_values[i:min(i+5, len(emb_values))] |
|
|
dim_strs = [f'{v:7.4f}' for v in dims] |
|
|
print(f" Dim {i:2d}-{min(i+4, len(emb_values)-1):2d}: [{', '.join(dim_strs)}]") |
|
|
print(f"\n Stats - Mean: {emb_values.mean():.4f}, Std: {emb_values.std():.4f}, Min: {emb_values.min():.4f}, Max: {emb_values.max():.4f}") |
|
|
|
|
|
|
|
|
if 'logits' in outputs: |
|
|
pred_ids = outputs['logits'].argmax(dim=-1)[0] |
|
|
|
|
|
valid_length = 64 |
|
|
for i in range(1, len(pred_ids)): |
|
|
if pred_ids[i] == 256 or pred_ids[i] == 258: |
|
|
valid_length = i |
|
|
break |
|
|
|
|
|
pred_ids = pred_ids[1:valid_length] |
|
|
pred_ids = pred_ids[pred_ids < 256] |
|
|
|
|
|
if len(pred_ids) > 0: |
|
|
try: |
|
|
reconstructed = bytes(pred_ids.cpu().numpy()).decode('utf-8', errors='ignore') |
|
|
print(f"\n[OK] Reconstructed: {reconstructed}") |
|
|
|
|
|
|
|
|
orig_text = test_text[:len(reconstructed)] |
|
|
matches = sum(1 for o, r in zip(orig_text, reconstructed) if o == r) |
|
|
accuracy = (matches / len(orig_text)) * 100 |
|
|
print(f"[OK] Accuracy: {accuracy:.1f}%") |
|
|
except: |
|
|
print("[ERROR] Reconstruction decode error") |
|
|
|
|
|
print("\n[SUCCESS] All tests passed!") |
|
|
|
|
|
else: |
|
|
print(f"[ERROR] Checkpoint not found at {checkpoint_path}") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("="*60) |
|
|
print("B2NL v6.1.2 App Test") |
|
|
print("="*60) |
|
|
|
|
|
success = test_model() |
|
|
|
|
|
if success: |
|
|
print("\n[READY] Ready to run the Gradio app!") |
|
|
print("Run: python app.py") |
|
|
else: |
|
|
print("\n[WARNING] Please check the checkpoint path") |