import torch import torch.nn as nn import numpy as np class SpectralFeatureExtractor(nn.Module): def __init__(self, num_features): super().__init__() self.complex_weight = nn.Parameter(torch.randn(num_features, 2, dtype=torch.float32) * 0.02) def forward(self, x): B, L, C = x.shape x = x.transpose(1, 2) x_fft = torch.fft.rfft(x, dim=2, norm="ortho") weight = torch.view_as_complex(self.complex_weight) x_weighted = x_fft * weight.unsqueeze(0).unsqueeze(-1) x_out = torch.fft.irfft(x_weighted, n=L, dim=2, norm="ortho") return x_out.transpose(1, 2) class ChannelMixing(nn.Module): def __init__(self, num_features, reduction=4): super().__init__() self.fc1 = nn.Linear(num_features, num_features // reduction) self.fc2 = nn.Linear(num_features // reduction, num_features) self.act = nn.GELU() def forward(self, x): identity = x x_pooled = x.mean(dim=1) x_weighted = self.fc2(self.act(self.fc1(x_pooled))) out = identity * x_weighted.unsqueeze(1) return out class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000, dropout=0.2): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[:, : x.size(1), :] return self.dropout(x) class MineROINet(nn.Module): def __init__(self, input_dim, d_model=64, nhead=2, num_layers=2, dim_feedforward=256, dropout=0.2, num_classes=3, seq_len=30): super().__init__() self.spectral = SpectralFeatureExtractor(input_dim) self.channel_mix = ChannelMixing(input_dim) self.input_projection = nn.Linear(input_dim, d_model) if input_dim != d_model else nn.Identity() self.pos_encoder = PositionalEncoding(d_model, max_len=seq_len, dropout=dropout) encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation="gelu", batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.classifier = nn.Sequential( nn.LayerNorm(d_model), nn.Dropout(dropout), nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model // 2, num_classes), ) def forward(self, seq): seq = self.spectral(seq) seq = self.channel_mix(seq) seq = self.input_projection(seq) seq = self.pos_encoder(seq) z = self.transformer_encoder(seq) pooled = z.mean(dim=1) out = self.classifier(pooled) return out def create_model_30day(input_dim, num_classes=3): return MineROINet(input_dim=input_dim, d_model=64, nhead=2, num_layers=2, dim_feedforward=256, dropout=0.2, num_classes=num_classes, seq_len=30) # def create_model_60day(input_dim, num_classes=3): # return MineROINet(input_dim=input_dim, d_model=64, nhead=4, num_layers=2, # dim_feedforward=256, dropout=0.2, num_classes=num_classes, seq_len=60)