diff --git a/models/gat.py b/models/gat.py index f444714..74ff93a 100644 --- a/models/gat.py +++ b/models/gat.py @@ -18,8 +18,10 @@ def __init__(self, num_features_xd=78, n_output=1, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc_xt1 = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(256, 1024) @@ -46,12 +48,16 @@ def forward(self, data): # protein input feed-forward: target = data.target embedded_xt = self.embedding_xt(target) - conv_xt = self.conv_xt1(embedded_xt) - conv_xt = self.relu(conv_xt) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) - # flatten - xt = conv_xt.view(-1, 32 * 121) - xt = self.fc_xt1(xt) + conv_xt = self.conv_xt_1(embedded_xt) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] + xt = self.fc1_xt(xt) # concat xc = torch.cat((x, xt), 1) diff --git a/models/gat_gcn.py b/models/gat_gcn.py index d99d1b4..2240710 100644 --- a/models/gat_gcn.py +++ b/models/gat_gcn.py @@ -23,8 +23,10 @@ def __init__(self, n_output=1, num_features_xd=78, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc1_xt = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(256, 1024) @@ -46,9 +48,15 @@ def forward(self, data): x = self.fc_g2(x) embedded_xt = self.embedding_xt(target) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) + conv_xt = self.conv_xt_1(embedded_xt) - # flatten - xt = conv_xt.view(-1, 32 * 121) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) # concat diff --git a/models/gcn.py b/models/gcn.py index 6c583c3..ec27562 100644 --- a/models/gcn.py +++ b/models/gcn.py @@ -22,8 +22,10 @@ def __init__(self, n_output=1, n_filters=32, embed_dim=128,num_features_xd=78, n # protein sequence branch (1d conv) self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc1_xt = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(2*output_dim, 1024) @@ -54,9 +56,15 @@ def forward(self, data): # 1d conv layers embedded_xt = self.embedding_xt(target) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) + conv_xt = self.conv_xt_1(embedded_xt) - # flatten - xt = conv_xt.view(-1, 32 * 121) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) # concat diff --git a/models/ginconv.py b/models/ginconv.py index bd37f4c..8d06aed 100644 --- a/models/ginconv.py +++ b/models/ginconv.py @@ -41,8 +41,10 @@ def __init__(self, n_output=1,num_features_xd=78, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc1_xt = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(256, 1024) @@ -68,9 +70,15 @@ def forward(self, data): x = F.dropout(x, p=0.2, training=self.training) embedded_xt = self.embedding_xt(target) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) + conv_xt = self.conv_xt_1(embedded_xt) - # flatten - xt = conv_xt.view(-1, 32 * 121) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) # concat