ggunio's picture
Update to B2NL v6.1.2 POC - 18.6:1 compression with 6 languages (Korean, English, Chinese, Japanese, Spanish, Arabic)
13c2c77
"""
Quick test script for B2NL v6.1.2 app functionality
"""
import sys
from pathlib import Path
import torch
# Add path
parent_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(parent_dir / 'intelligent-tokenizer_v6.1.2'))
from core.unified_model import IntelligentTokenizerModelV61
from core.byte_tokenizer_v6 import ByteTokenizerV6
def test_model():
device = torch.device('cpu')
tokenizer = ByteTokenizerV6(max_seq_len=64)
model = IntelligentTokenizerModelV61(vocab_size=260, max_seq_len=64).to(device)
# Load checkpoint
checkpoint_path = parent_dir / 'intelligent-tokenizer_v6.1.2' / 'checkpoints' / 'v612_compression_first' / 'best_model.pt'
if checkpoint_path.exists():
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(str(checkpoint_path), map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"[OK] Loaded checkpoint: Epoch {checkpoint.get('epoch', 'N/A')}")
model.eval()
# Test Korean text
test_text = "μ•ˆλ…•ν•˜μ„Έμš”. 였늘 날씨가 μ’‹λ„€μš”."
print(f"\nTest text: {test_text}")
# Encode
byte_seq = list(test_text.encode('utf-8'))[:62]
print(f"Bytes: {len(byte_seq)}")
# Prepare input
input_ids = torch.tensor([[tokenizer.BOS] + byte_seq + [tokenizer.EOS]], dtype=torch.long).to(device)
if input_ids.size(1) < 64:
padding = torch.full((1, 64 - input_ids.size(1)), tokenizer.PAD, dtype=torch.long).to(device)
input_ids = torch.cat([input_ids, padding], dim=1)
attention_mask = (input_ids != tokenizer.PAD).float()
# Forward pass - v6.1.2 production mode
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
epoch=233, # Match checkpoint epoch for best performance
use_cross_attention=True # Enable cross-attention for better reconstruction
)
print(f"\n[OK] Model outputs available: {list(outputs.keys())}")
# Check boundaries for groups
if 'eojeol_boundaries' in outputs:
boundaries = torch.argmax(outputs['eojeol_boundaries'], dim=-1)[0]
num_groups = torch.sum(boundaries == 1).item() + 1
compression = len(byte_seq) / num_groups
print(f"[OK] Compression: {len(byte_seq)} bytes -> {num_groups} tokens = {compression:.1f}:1")
# Visualize groups
groups = []
current_group = []
boundaries_np = boundaries.cpu().numpy()
for i in range(min(len(byte_seq), len(boundaries_np))):
is_boundary = (i == 0) or (boundaries_np[i] == 1)
if is_boundary and current_group:
try:
group_text = bytes(current_group).decode('utf-8', errors='replace')
groups.append(f"<{group_text}>")
except:
groups.append(f"<{len(current_group)}B>")
current_group = []
if i < len(byte_seq):
current_group.append(byte_seq[i])
if current_group:
try:
group_text = bytes(current_group).decode('utf-8', errors='replace')
groups.append(f"<{group_text}>")
except:
groups.append(f"<{len(current_group)}B>")
print(f"[OK] Groups: {' '.join(groups)}")
# Check embeddings
if 'encoder_hidden_states' in outputs:
# encoder_hidden_states is a tuple of all layer outputs
last_hidden = outputs['encoder_hidden_states'][-1] if isinstance(outputs['encoder_hidden_states'], tuple) else outputs['encoder_hidden_states']
embeddings = last_hidden[0, 0, :20] # First token, first 20 dims
emb_values = embeddings.cpu().numpy()
print(f"\n[OK] Embeddings (first 20 dims):")
for i in range(0, len(emb_values), 5):
dims = emb_values[i:min(i+5, len(emb_values))]
dim_strs = [f'{v:7.4f}' for v in dims]
print(f" Dim {i:2d}-{min(i+4, len(emb_values)-1):2d}: [{', '.join(dim_strs)}]")
print(f"\n Stats - Mean: {emb_values.mean():.4f}, Std: {emb_values.std():.4f}, Min: {emb_values.min():.4f}, Max: {emb_values.max():.4f}")
# Check reconstruction
if 'logits' in outputs:
pred_ids = outputs['logits'].argmax(dim=-1)[0]
# Find valid length
valid_length = 64
for i in range(1, len(pred_ids)):
if pred_ids[i] == 256 or pred_ids[i] == 258:
valid_length = i
break
pred_ids = pred_ids[1:valid_length]
pred_ids = pred_ids[pred_ids < 256]
if len(pred_ids) > 0:
try:
reconstructed = bytes(pred_ids.cpu().numpy()).decode('utf-8', errors='ignore')
print(f"\n[OK] Reconstructed: {reconstructed}")
# Calculate accuracy
orig_text = test_text[:len(reconstructed)]
matches = sum(1 for o, r in zip(orig_text, reconstructed) if o == r)
accuracy = (matches / len(orig_text)) * 100
print(f"[OK] Accuracy: {accuracy:.1f}%")
except:
print("[ERROR] Reconstruction decode error")
print("\n[SUCCESS] All tests passed!")
else:
print(f"[ERROR] Checkpoint not found at {checkpoint_path}")
return False
return True
if __name__ == "__main__":
print("="*60)
print("B2NL v6.1.2 App Test")
print("="*60)
success = test_model()
if success:
print("\n[READY] Ready to run the Gradio app!")
print("Run: python app.py")
else:
print("\n[WARNING] Please check the checkpoint path")