File size: 428 Bytes
c69c4af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
"""Task class definition """
from dataclasses import dataclass
from typing import List, Type
import torch.nn as nn
@dataclass
class Task:
"""Encapsulates all configuration for a single task."""
name: str
class_labels: List[str]
criterion: Type[nn.Module]
weight: float = 1.0
use_weighted_loss: bool = False
@property
def num_classes(self) -> int:
return len(self.class_labels)
|