|
|
from tqdm import tqdm |
|
|
import network |
|
|
import utils |
|
|
import os |
|
|
import random |
|
|
import argparse |
|
|
import numpy as np |
|
|
import json |
|
|
|
|
|
from torch.utils import data |
|
|
from datasets import VOCSegmentation, Cityscapes |
|
|
from utils import ext_transforms as et |
|
|
from metrics import StreamSegMetrics |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from PIL import Image |
|
|
import matplotlib |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
def get_argparser(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--out_dir", type=str, default="run_0") |
|
|
|
|
|
|
|
|
parser.add_argument("--data_root", type=str, default='', |
|
|
help="path to Dataset") |
|
|
parser.add_argument("--dataset", type=str, default='voc', |
|
|
choices=['voc'], help='Name of dataset') |
|
|
parser.add_argument("--num_classes", type=int, default=None, |
|
|
help="num classes (default: None)") |
|
|
|
|
|
|
|
|
parser.add_argument("--model", type=str, default='deeplabv3plus_resnet101', |
|
|
choices=['deeplabv3plus_resnet101', 'deeplabv3plus_resnet50', 'deeplabv3plus_mobilenet', |
|
|
'deeplabv3plus_xception', 'deeplabv3plus_hrnetv2_48', 'deeplabv3plus_hrnetv2_32', |
|
|
'deeplabv3_resnet101', 'deeplabv3_resnet50', 'deeplabv3_mobilenet', |
|
|
'deeplabv3_xception', 'deeplabv3_hrnetv2_48', 'deeplabv3_hrnetv2_32'], |
|
|
help='model name') |
|
|
parser.add_argument("--separable_conv", action='store_true', default=False, |
|
|
help="apply separable conv to decoder and aspp") |
|
|
parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16]) |
|
|
|
|
|
|
|
|
parser.add_argument("--use_eoaNet", action='store_true', default=True, |
|
|
help="Use Entropy-Optimized Attention Network") |
|
|
parser.add_argument("--no_eoaNet", action='store_false', dest='use_eoaNet', |
|
|
help="Disable Entropy-Optimized Attention Network") |
|
|
parser.add_argument("--msa_scales", nargs='+', type=int, default=[1, 2, 4], |
|
|
help="Scales for Multi-Scale Attention") |
|
|
parser.add_argument("--eog_beta", type=float, default=0.3, |
|
|
help="Entropy threshold for Entropy-Optimized Gating") |
|
|
|
|
|
|
|
|
parser.add_argument("--test_only", action='store_true', default=False) |
|
|
parser.add_argument("--save_val_results", action='store_true', default=False, |
|
|
help="save segmentation results to \"./results\"") |
|
|
parser.add_argument("--total_itrs", type=int, default=30e3, |
|
|
help="epoch number (default: 30k 30e3)") |
|
|
parser.add_argument("--lr", type=float, default=0.02, |
|
|
help="learning rate (default: 0.01)") |
|
|
parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'], |
|
|
help="learning rate scheduler policy") |
|
|
parser.add_argument("--step_size", type=int, default=10000) |
|
|
parser.add_argument("--crop_val", action='store_true', default=True, |
|
|
help='crop validation (default: False)') |
|
|
parser.add_argument("--batch_size", type=int, default=32, |
|
|
help='batch size (default: 16)') |
|
|
parser.add_argument("--val_batch_size", type=int, default=4, |
|
|
help='batch size for validation (default: 4)') |
|
|
parser.add_argument("--crop_size", type=int, default=513) |
|
|
|
|
|
parser.add_argument("--ckpt", default=None, type=str, |
|
|
help="restore from checkpoint") |
|
|
parser.add_argument("--continue_training", action='store_true', default=False) |
|
|
|
|
|
parser.add_argument("--loss_type", type=str, default='cross_entropy', |
|
|
choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)") |
|
|
parser.add_argument("--gpu_id", type=str, default='0,1', |
|
|
help="GPU ID") |
|
|
parser.add_argument("--weight_decay", type=float, default=1e-4, |
|
|
help='weight decay (default: 1e-4)') |
|
|
parser.add_argument("--random_seed", type=int, default=1, |
|
|
help="random seed (default: 1)") |
|
|
parser.add_argument("--print_interval", type=int, default=10, |
|
|
help="print interval of loss (default: 10)") |
|
|
parser.add_argument("--val_interval", type=int, default=100, |
|
|
help="epoch interval for eval (default: 100)") |
|
|
parser.add_argument("--download", action='store_true', default=False, |
|
|
help="download datasets") |
|
|
|
|
|
|
|
|
parser.add_argument("--year", type=str, default='2012_aug', |
|
|
choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC') |
|
|
return parser |
|
|
|
|
|
|
|
|
def get_dataset(opts): |
|
|
""" Dataset And Augmentation |
|
|
""" |
|
|
if opts.dataset == 'voc': |
|
|
train_transform = et.ExtCompose([ |
|
|
|
|
|
et.ExtRandomScale((0.5, 2.0)), |
|
|
et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True), |
|
|
et.ExtRandomHorizontalFlip(), |
|
|
et.ExtToTensor(), |
|
|
et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
if opts.crop_val: |
|
|
val_transform = et.ExtCompose([ |
|
|
et.ExtResize(opts.crop_size), |
|
|
et.ExtCenterCrop(opts.crop_size), |
|
|
et.ExtToTensor(), |
|
|
et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
else: |
|
|
val_transform = et.ExtCompose([ |
|
|
et.ExtToTensor(), |
|
|
et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, |
|
|
image_set='train', download=opts.download, transform=train_transform) |
|
|
val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, |
|
|
image_set='val', download=False, transform=val_transform) |
|
|
|
|
|
if opts.dataset == 'cityscapes': |
|
|
train_transform = et.ExtCompose([ |
|
|
|
|
|
et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), |
|
|
et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), |
|
|
et.ExtRandomHorizontalFlip(), |
|
|
et.ExtToTensor(), |
|
|
et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
val_transform = et.ExtCompose([ |
|
|
|
|
|
et.ExtToTensor(), |
|
|
et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
train_dst = Cityscapes(root=opts.data_root, |
|
|
split='train', transform=train_transform) |
|
|
val_dst = Cityscapes(root=opts.data_root, |
|
|
split='val', transform=val_transform) |
|
|
return train_dst, val_dst |
|
|
|
|
|
|
|
|
def validate(opts, model, loader, device, metrics, ret_samples_ids=None): |
|
|
"""Do validation and return specified samples""" |
|
|
metrics.reset() |
|
|
ret_samples = [] |
|
|
if opts.save_val_results: |
|
|
if not os.path.exists('results'): |
|
|
os.mkdir('results') |
|
|
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
img_id = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i, (images, labels) in tqdm(enumerate(loader)): |
|
|
|
|
|
images = images.to(device, dtype=torch.float32) |
|
|
labels = labels.to(device, dtype=torch.long) |
|
|
|
|
|
outputs = model(images) |
|
|
preds = outputs.detach().max(dim=1)[1].cpu().numpy() |
|
|
targets = labels.cpu().numpy() |
|
|
|
|
|
metrics.update(targets, preds) |
|
|
if ret_samples_ids is not None and i in ret_samples_ids: |
|
|
ret_samples.append( |
|
|
(images[0].detach().cpu().numpy(), targets[0], preds[0])) |
|
|
|
|
|
if opts.save_val_results: |
|
|
for i in range(len(images)): |
|
|
image = images[i].detach().cpu().numpy() |
|
|
target = targets[i] |
|
|
pred = preds[i] |
|
|
|
|
|
image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8) |
|
|
target = loader.dataset.decode_target(target).astype(np.uint8) |
|
|
pred = loader.dataset.decode_target(pred).astype(np.uint8) |
|
|
|
|
|
Image.fromarray(image).save('results/%d_image.png' % img_id) |
|
|
Image.fromarray(target).save('results/%d_target.png' % img_id) |
|
|
Image.fromarray(pred).save('results/%d_pred.png' % img_id) |
|
|
|
|
|
fig = plt.figure() |
|
|
plt.imshow(image) |
|
|
plt.axis('off') |
|
|
plt.imshow(pred, alpha=0.7) |
|
|
ax = plt.gca() |
|
|
ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) |
|
|
ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator()) |
|
|
plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0) |
|
|
plt.close() |
|
|
img_id += 1 |
|
|
|
|
|
score = metrics.get_results() |
|
|
return score, ret_samples |
|
|
|
|
|
def main(opts): |
|
|
if opts.dataset.lower() == 'voc': |
|
|
opts.num_classes = 21 |
|
|
elif opts.dataset.lower() == 'cityscapes': |
|
|
opts.num_classes = 19 |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print("Device: %s" % device) |
|
|
|
|
|
|
|
|
torch.manual_seed(opts.random_seed) |
|
|
np.random.seed(opts.random_seed) |
|
|
random.seed(opts.random_seed) |
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir='logs') |
|
|
|
|
|
|
|
|
if opts.dataset == 'voc' and not opts.crop_val: |
|
|
opts.val_batch_size = 1 |
|
|
|
|
|
train_dst, val_dst = get_dataset(opts) |
|
|
|
|
|
|
|
|
effective_batch_size = min(opts.batch_size, len(train_dst)) |
|
|
effective_val_batch_size = min(opts.val_batch_size, len(val_dst)) |
|
|
|
|
|
if effective_batch_size < opts.batch_size: |
|
|
print(f"Warning: Reducing batch size from {opts.batch_size} to {effective_batch_size} due to small dataset") |
|
|
|
|
|
train_loader = data.DataLoader( |
|
|
train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2, |
|
|
drop_last=False) |
|
|
val_loader = data.DataLoader( |
|
|
val_dst, batch_size=effective_val_batch_size, shuffle=True, num_workers=2) |
|
|
print("Dataset: %s, Train set: %d, Val set: %d" % |
|
|
(opts.dataset, len(train_dst), len(val_dst))) |
|
|
|
|
|
|
|
|
model = network.modeling.__dict__[opts.model]( |
|
|
num_classes=opts.num_classes, |
|
|
output_stride=opts.output_stride, |
|
|
use_eoaNet=opts.use_eoaNet, |
|
|
msa_scales=opts.msa_scales, |
|
|
eog_beta=opts.eog_beta |
|
|
) |
|
|
if opts.separable_conv and 'plus' in opts.model: |
|
|
network.convert_to_separable_conv(model.classifier) |
|
|
utils.set_bn_momentum(model.backbone, momentum=0.01) |
|
|
|
|
|
|
|
|
metrics = StreamSegMetrics(opts.num_classes) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.SGD(params=[ |
|
|
{'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr}, |
|
|
{'params': model.classifier.parameters(), 'lr': opts.lr}, |
|
|
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) |
|
|
|
|
|
|
|
|
if opts.lr_policy == 'poly': |
|
|
scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) |
|
|
elif opts.lr_policy == 'step': |
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) |
|
|
|
|
|
|
|
|
|
|
|
if opts.loss_type == 'focal_loss': |
|
|
criterion = utils.FocalLoss(ignore_index=255, size_average=True) |
|
|
elif opts.loss_type == 'cross_entropy': |
|
|
criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') |
|
|
|
|
|
def save_ckpt(path): |
|
|
""" save current model |
|
|
""" |
|
|
torch.save({ |
|
|
"cur_itrs": cur_itrs, |
|
|
"model_state": model.module.state_dict(), |
|
|
"optimizer_state": optimizer.state_dict(), |
|
|
"scheduler_state": scheduler.state_dict(), |
|
|
"best_score": best_score, |
|
|
}, path) |
|
|
print("Model saved as %s" % path) |
|
|
|
|
|
if not os.path.exists('checkpoints'): |
|
|
os.mkdir('checkpoints') |
|
|
|
|
|
|
|
|
best_score = 0.0 |
|
|
cur_itrs = 0 |
|
|
cur_epochs = 0 |
|
|
|
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
if opts.ckpt is not None and os.path.isfile(opts.ckpt): |
|
|
|
|
|
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) |
|
|
model.load_state_dict(checkpoint["model_state"]) |
|
|
model = nn.DataParallel(model) |
|
|
model.to(device) |
|
|
if opts.continue_training: |
|
|
optimizer.load_state_dict(checkpoint["optimizer_state"]) |
|
|
scheduler.load_state_dict(checkpoint["scheduler_state"]) |
|
|
cur_itrs = checkpoint["cur_itrs"] |
|
|
best_score = checkpoint['best_score'] |
|
|
print("Training state restored from %s" % opts.ckpt) |
|
|
print("Model restored from %s" % opts.ckpt) |
|
|
del checkpoint |
|
|
else: |
|
|
print("[!] Retrain") |
|
|
model = nn.DataParallel(model) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
|
|
if opts.test_only: |
|
|
model.eval() |
|
|
val_score, ret_samples = validate( |
|
|
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) |
|
|
print(metrics.to_str(val_score)) |
|
|
writer.close() |
|
|
return |
|
|
|
|
|
interval_loss = 0 |
|
|
latest_checkpoints = [] |
|
|
if not os.path.exists(f'checkpoints'): |
|
|
os.mkdir(f'checkpoints') |
|
|
while True: |
|
|
|
|
|
model.train() |
|
|
cur_epochs += 1 |
|
|
for (images, labels) in train_loader: |
|
|
cur_itrs += 1 |
|
|
|
|
|
images = images.to(device, dtype=torch.float32) |
|
|
labels = labels.to(device, dtype=torch.long) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(images) |
|
|
loss = criterion(outputs, labels) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
np_loss = loss.detach().cpu().numpy() |
|
|
interval_loss += np_loss |
|
|
|
|
|
writer.add_scalar('Loss/train', np_loss, cur_itrs) |
|
|
|
|
|
if (cur_itrs) % 10 == 0: |
|
|
interval_loss = interval_loss / 10 |
|
|
print("Epoch %d, Itrs %d/%d, Loss=%f" % |
|
|
(cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) |
|
|
interval_loss = 0.0 |
|
|
|
|
|
if (cur_itrs) % opts.val_interval == 0: |
|
|
ckpt_path = f'checkpoints/latest_{cur_itrs}_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth' |
|
|
save_ckpt(ckpt_path) |
|
|
latest_checkpoints.append(ckpt_path) |
|
|
|
|
|
if len(latest_checkpoints) > 2: |
|
|
|
|
|
oldest_ckpt_path = latest_checkpoints.pop(0) |
|
|
try: |
|
|
|
|
|
os.remove(oldest_ckpt_path) |
|
|
print(f"Successfully removed old checkpoint: {oldest_ckpt_path}") |
|
|
except FileNotFoundError: |
|
|
|
|
|
print(f"Warning: Could not remove checkpoint because it was not found: {oldest_ckpt_path}") |
|
|
except OSError as e: |
|
|
|
|
|
print(f"Error removing checkpoint {oldest_ckpt_path}: {e}") |
|
|
|
|
|
print("validation...") |
|
|
model.eval() |
|
|
val_score, ret_samples = validate( |
|
|
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) |
|
|
print(metrics.to_str(val_score)) |
|
|
|
|
|
writer.add_scalar('Metrics/Mean_IoU', val_score['Mean IoU'], cur_itrs) |
|
|
writer.add_scalar('Metrics/Overall_Acc', val_score['Overall Acc'], cur_itrs) |
|
|
writer.add_scalar('Metrics/Mean_Acc', val_score['Mean Acc'], cur_itrs) |
|
|
|
|
|
if val_score['Mean IoU'] > best_score: |
|
|
best_score = val_score['Mean IoU'] |
|
|
save_ckpt(f'checkpoints/best_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth') |
|
|
with open(f'checkpoints/best_score.txt', 'a') as f: |
|
|
f.write(f"iter:{cur_itrs}\n{str(best_score)}\n") |
|
|
with open(f"final_info.json", "w") as f: |
|
|
final_info = { |
|
|
"voc12_aug": { |
|
|
"means": { |
|
|
"mIoU": val_score['Mean IoU'], |
|
|
"OA": val_score['Overall Acc'], |
|
|
"mAcc": val_score['Mean IoU'] |
|
|
} |
|
|
} |
|
|
} |
|
|
json.dump(final_info, f, indent=4) |
|
|
|
|
|
model.train() |
|
|
scheduler.step() |
|
|
|
|
|
if cur_itrs >= opts.total_itrs: |
|
|
writer.close() |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = get_argparser().parse_args() |
|
|
try: |
|
|
main(args) |
|
|
except Exception as e: |
|
|
import traceback |
|
|
print("Original error in subprocess:", flush=True) |
|
|
traceback.print_exc(file=open("traceback.log", "w")) |
|
|
raise |
|
|
|