MineROI-Net / model.py
sithuWiki's picture
upload 7 .py files
f481275 verified
raw
history blame
3.68 kB
import torch
import torch.nn as nn
import numpy as np
class SpectralFeatureExtractor(nn.Module):
def __init__(self, num_features):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(num_features, 2, dtype=torch.float32) * 0.02)
def forward(self, x):
B, L, C = x.shape
x = x.transpose(1, 2)
x_fft = torch.fft.rfft(x, dim=2, norm="ortho")
weight = torch.view_as_complex(self.complex_weight)
x_weighted = x_fft * weight.unsqueeze(0).unsqueeze(-1)
x_out = torch.fft.irfft(x_weighted, n=L, dim=2, norm="ortho")
return x_out.transpose(1, 2)
class ChannelMixing(nn.Module):
def __init__(self, num_features, reduction=4):
super().__init__()
self.fc1 = nn.Linear(num_features, num_features // reduction)
self.fc2 = nn.Linear(num_features // reduction, num_features)
self.act = nn.GELU()
def forward(self, x):
identity = x
x_pooled = x.mean(dim=1)
x_weighted = self.fc2(self.act(self.fc1(x_pooled)))
out = identity * x_weighted.unsqueeze(1)
return out
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.2):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1), :]
return self.dropout(x)
class MineROINet(nn.Module):
def __init__(self, input_dim, d_model=64, nhead=2, num_layers=2, dim_feedforward=256, dropout=0.2, num_classes=3, seq_len=30):
super().__init__()
self.spectral = SpectralFeatureExtractor(input_dim)
self.channel_mix = ChannelMixing(input_dim)
self.input_projection = nn.Linear(input_dim, d_model) if input_dim != d_model else nn.Identity()
self.pos_encoder = PositionalEncoding(d_model, max_len=seq_len, dropout=dropout)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
dropout=dropout, activation="gelu", batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.classifier = nn.Sequential(
nn.LayerNorm(d_model),
nn.Dropout(dropout),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model // 2, num_classes),
)
def forward(self, seq):
seq = self.spectral(seq)
seq = self.channel_mix(seq)
seq = self.input_projection(seq)
seq = self.pos_encoder(seq)
z = self.transformer_encoder(seq)
pooled = z.mean(dim=1)
out = self.classifier(pooled)
return out
def create_model_30day(input_dim, num_classes=3):
return MineROINet(input_dim=input_dim, d_model=64, nhead=2, num_layers=2,
dim_feedforward=256, dropout=0.2, num_classes=num_classes, seq_len=30)
# def create_model_60day(input_dim, num_classes=3):
# return MineROINet(input_dim=input_dim, d_model=64, nhead=4, num_layers=2,
# dim_feedforward=256, dropout=0.2, num_classes=num_classes, seq_len=60)