diff --git a/GNN/TemporalGAT/Data_Preprocessing/DataAquisitionPreprocessingPipeline.png b/GNN/TemporalGAT/Data_Preprocessing/DataAquisitionPreprocessingPipeline.png new file mode 100644 index 0000000..99b4ac4 Binary files /dev/null and b/GNN/TemporalGAT/Data_Preprocessing/DataAquisitionPreprocessingPipeline.png differ diff --git a/GNN/TemporalGAT/Data_Preprocessing/getBinnedBurstInfo.py b/GNN/TemporalGAT/Data_Preprocessing/getBinnedBurstInfo.py new file mode 100644 index 0000000..6e95605 --- /dev/null +++ b/GNN/TemporalGAT/Data_Preprocessing/getBinnedBurstInfo.py @@ -0,0 +1,118 @@ +''' +GETBINNEDBURSTINFO Python implementation of GETBINNEDBURSTINFO.m: + GETBINNEDBURSTINFO Identify bursts using binned data and save burst info + + This function reads the BrainGrid result (.h5 file), specifically + spikesHistory (contains spike Count for every 10ms bin; it is also the name of the dataset being read), + identify burst using burst threshold of 0.5 spikes/sec/neuron, + calculate burst information from those bins above the threshold, + and store them in allBinnedBurstInfo.csv. + + List below are those burst information and how they are calculated: + ID - the burst ID + startBin# - The starting bin number is the bin that is above the threshold. + endBin# - The ending bin number is the next bin that is above the threshold. + width(bins) - With the starting/ending bin number, the difference is calculated as the width + representing the number of bins in between each burst above the threshold. + totalSpikeCount - the sum of all the spikesPerBin between the starting and ending bin number + peakBin - startBin# + peakHeightIndex(index of the bin with the peakHeight) + 1 + peakHeight(spikes) - calculated by finding the bin with the highest spike count between each burst + Interval(bins) - difference between current peakHeight and previous peakHeight + - previous peakHeight is initialized as 0 + + + Author: Jewel Y. Lee (jewel.yh.lee@gmail.com) + Last updated: 02/10/2022 added Documentation, cleaned redundant code + Last updated by: Vu T. Tieu (vttieu1995@gmail.com) + +Syntax: getBinnedBurstInfo(h5dir) + +Input: +datasetName - Graphitti dataset the entire path can be used; for example + '/CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000' + +Output: + - - burst metadata. The columns of the csv are: + burst ID, startBin#, endBin#, width(bins), + totalSpikeCount, peakBin, + peakHeight(spikes), Interval(bins) + +Author: Marina Rosenwald (marinarosenwald@gmail.com) +Last updated: 12/16/2025 +''' + + +import h5py +import numpy as np +import csv +import os +import sys +import time + +def getBinnedBurstInfo(datasetName, adjustedBurstThreshold=0.005): + + with h5py.File(f"{datasetName}.h5", "r") as f: + if '/neuronTypes' in f: + nNeurons = f['/neuronTypes'].shape[0] + elif '/spikesProbedNeurons' in f: + nNeurons = f['/spikesProbedNeurons'].shape[0] + else: + raise KeyError("No neuronTypes or spikesProbedNeurons dataset found") + + spikesPerBin = f['/spikesHistory'][:].astype(float) + nNeurons=10000 + spikesPerNeuronPerBin = spikesPerBin / nNeurons + binsAboveThreshold = np.where(spikesPerNeuronPerBin >= adjustedBurstThreshold)[0] + + if len(binsAboveThreshold) == 0: + print("No bursts detected above threshold") + return + + burstBoundaries = np.where(np.diff(binsAboveThreshold) > 1)[0] + burstBoundaries = np.concatenate(([0], burstBoundaries, [len(binsAboveThreshold) - 1])) + nBursts = len(burstBoundaries) - 1 + + print(nBursts) + + previousPeak = 0 + out_path = os.path.join(datasetName, "allBinnedBurstInfo.csv") + with open(out_path, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow([ + "ID", "startBin#", "endBin#", "width(bins)", "totalSpikeCount", + "peakBin", "peakHeight(spikes)", "Interval(bins)" + ]) + for iBurst in range(nBursts): + startBinNum = binsAboveThreshold[burstBoundaries[iBurst]] + endBinNum = binsAboveThreshold[burstBoundaries[iBurst + 1]] + burstSlice = spikesPerBin[startBinNum:endBinNum+1] + width = (endBinNum - startBinNum) + 1 + + if burstSlice.size == 0: + continue + totalSpikeCount = np.sum(burstSlice) + peakHeightIndex = np.argmax(burstSlice) + peakHeight = burstSlice[peakHeightIndex] + peakBin = startBinNum + peakHeightIndex + burstPeakInterval = peakBin - previousPeak + writer.writerow([ + iBurst + 1, startBinNum, endBinNum, width, + int(totalSpikeCount), peakBin, int(peakHeight), burstPeakInterval + ]) + previousPeak = peakBin + + + print(f"Saved burst info to {out_path}") + +if __name__ == "__main__": + # example execution: python ./getBinnedBurstSpikes.py /CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000 + h5dir = sys.argv[1] + + start = time.time() + getBinnedBurstInfo(h5dir) + end = time.time() + + elapsed_time = end - start + + + print('Elapsed time: ' + str(elapsed_time) + ' seconds') diff --git a/GNN/TemporalGAT/Data_Preprocessing/getBinnedBurstSpikes.py b/GNN/TemporalGAT/Data_Preprocessing/getBinnedBurstSpikes.py new file mode 100644 index 0000000..9cccde5 --- /dev/null +++ b/GNN/TemporalGAT/Data_Preprocessing/getBinnedBurstSpikes.py @@ -0,0 +1,97 @@ +''' +GETBINNEDBURSTSPIKES Python implementation of GETBINNEDBURSTSPIKES.m: + GETBINNEDBURSTSPIKES Get spikes within every burst and save as flattened images of arrays + of bursts. These are the results of "spike train binning", mentioned in lee-thesis18 Figure 4.1. + Spike train binning adds up the binary values for every. + Read BrainGrid result dataset "spikesProbedNeurons" to retrieve location + of each spike that happended within a burst, and save as flatten image arrays. + + Example of flattened image arrays in a 5x5 matrix: + 1 1 2 0 1 + 1 2 2 1 1 + 1 3 2 1 0 + 2 2 2 3 1 + 0 1 3 2 2 + Each cell in a matrix this function output represents a pixel in which the number indicates how bright it is + with 5 being the highest value. The values in each cell is the result of "spike train binning" (this functions main algorithm) + This value can be understood as sum of all spike rate within 100 time step at that specific + x and y location. + + Author: Jewel Y. Lee (jewel.yh.lee@gmail.com) + Updated: 02/22/2022 added documentation and changed output filetype to a single .mat file + Updated by: Vu T. Tieu (vttieu1995@gmail.com) + Updated: May 2023 minor tweaks + Updated by: Michael Stiber + +Syntax: getBinnedBurstSpikes(h5dir) + +Input: +datasetName - Graphitti dataset the entire path can be used; for example + '/CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000' + +Output: + - allFrames.npz - collection of flattened image arrays of a burst + +Author: Marina Rosenwald (marinarosenwald@gmail.com) +Last updated: 12/16/2025 +''' + +import h5py +import numpy as np +import os +import sys +import time + +def getBinnedBurstSpikes(h5dir, n_neurons=10000, head=10, tail=0, binSize=10.0, timeStepSize=0.1): + + timeStepsPerBin = int(binSize / timeStepSize) + + print("Starting read of spikes from HDF5 file...") + with h5py.File(f"{h5dir}.h5", "r") as f: + spikesProbedNeurons = f['/spikesProbedNeurons'][:] + print("done") + + burstInfoPath = os.path.join(h5dir, "allBinnedBurstInfo.csv") + burstInfo = np.loadtxt(burstInfoPath, delimiter=",", skiprows=1, usecols=(1,2,3,4,5,6,7)) + nBursts = burstInfo.shape[0] + + allFrames = [] + + print("Starting burst analysis...") + for iBurst in range(nBursts): + burstStartBinNum = int(burstInfo[iBurst,0]) + burstEndBinNum = int(burstInfo[iBurst,1]) + width = int(burstInfo[iBurst,2]) + + startingTimeStep = (burstStartBinNum - head + 1) * timeStepsPerBin + endingTimeStep = (burstEndBinNum + tail) * timeStepsPerBin + + frame = np.zeros((n_neurons, width + head + tail), dtype=np.uint16) + + for jNeuron in range(n_neurons): + neuronSpikes = spikesProbedNeurons[:, jNeuron] + spike_mask = (neuronSpikes >= startingTimeStep) & (neuronSpikes <= endingTimeStep) + spike_times = neuronSpikes[spike_mask] + for curSpikeTime in spike_times: + bin_idx = int((curSpikeTime - startingTimeStep) / timeStepsPerBin) + if 0 <= bin_idx < frame.shape[1]: + frame[jNeuron, bin_idx] += 1 + + allFrames.append(frame) + print(f"\tdone with burst {iBurst+1}/{nBursts}") + + np.savez(os.path.join(h5dir, "allFrames.npz"), *allFrames) + print(f"Saved allFrames.npz with {len(allFrames)} bursts and {n_neurons} neurons per frame.") + +if __name__ == "__main__": + # example execution: python ./getBinnedBurstSpikes.py /CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000 + h5dir = sys.argv[1] + + start = time.time() + getBinnedBurstSpikes(h5dir) + end = time.time() + + elapsed_time = end - start + + + print('Elapsed time: ' + str(elapsed_time) + ' seconds') diff --git a/GNN/TemporalGAT/Model/TemporalGAT.py b/GNN/TemporalGAT/Model/TemporalGAT.py new file mode 100644 index 0000000..16bc6f5 --- /dev/null +++ b/GNN/TemporalGAT/Model/TemporalGAT.py @@ -0,0 +1,207 @@ +""" +TEMPORALGAT Train a Temporal Graph Attention Network on binned burst spike data + + This script constructs a windowed burst dataset from binned spike count data + produced by Graphitti simulations, trains a Temporal Graph Attention Network (Temporal GAT) + to predict neuron-level importance scores for future bursts, and saves the trained + model to disk. + + The pipeline performs the following steps: + 1. Load binned spike data (allFrames.npz), where each frame contains spike counts + for all neurons across time bins. + 2. Segment each burst frame into fixed-width temporal windows. + 3. Compute per-neuron vertex features for each window: + - mean firing rate within the window + - binary participation indicator (whether the neuron fired at least once) + - optional spatial (x, y) neuron coordinates + 4. Load a structural connectivity graph from a GraphML file and construct + edge indices and edge attributes. + 5. Generate training targets by computing each neuron's normalized total + outgoing synaptic strength in the future burst. + 6. Train a Temporal Graph Attention Network consisting of: + - a window-level Graph Attention Network (GAT) + - a temporal GRU to integrate information across windows + - a linear output layer producing neuron-level importance scores + 7. Evaluate the model after each epoch using Spearman and Pearson correlation + and Top-K precision/recall metrics. + 8. Save the trained model weights and configuration as a PyTorch checkpoint. + + Model output: + - A vector of predicted neuron importance values for each burst, where higher + values indicate greater structural influence on future bursting behavior. + +Syntax: + getBinnedBurstSpikes.py + +Input: + h5dir - Path to the Graphitti dataset directory containing: + * allFrames.npz (binned spike count data) + Example: + '/CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000' + + graphml_path - Path to the GraphML file describing the network connectivity + (edge weights represent synaptic strength). + + num_epochs - Number of training epochs for the Temporal GAT model. + +Output: + - - PyTorch checkpoint containing: + * trained model state_dict + * model architecture configuration + Saved to the input dataset directory (h5dir). + +Author: Marina Rosenwald + +Last updated: 12/16/2025 +""" + +import numpy as np +import networkx as nx +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from scipy.stats import spearmanr, pearsonr +import torch.optim as optim +import sys +import time +import os +from burstTemporalGAT import BurstTemporalGAT +from burstWindowDataset import BurstWindowDataset + + +def load_graphml_edge_index(graphml_file): + G = nx.read_graphml(graphml_file) + if len(G.edges) == 0: + print("Warning: Graph has no edges!") + return torch.empty((2,0), dtype=torch.long), torch.empty((0,1), dtype=torch.float) + + mapping = {n:i for i,n in enumerate(G.nodes())} + G = nx.relabel_nodes(G, mapping) + + edges = list(G.edges) + edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() + + edge_attr = [G[u][v].get('weight', 1.0) for u,v in edges] + edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1) + + return edge_index, edge_attr + + +def collate_fn(batch): + x_seqs, edge_indices, edge_attrs, targets = zip(*batch) + return list(x_seqs), list(edge_indices), list(edge_attrs), torch.stack(targets) + + +def evaluate(model, loader, k=50): + model.eval() + all_spearman, all_pearson, all_precisions, all_recalls = [], [], [], [] + with torch.no_grad(): + for i, (x_seq_batch, edge_index_batch, edge_attr_batch, target_batch) in enumerate(loader): + for j, (x_seq, ei_seq, ea_seq, target) in enumerate( + zip(x_seq_batch, edge_index_batch, edge_attr_batch, target_batch) + ): + preds = model(x_seq, [ei_seq]*x_seq.shape[0], [ea_seq]*x_seq.shape[0]) + preds = preds.squeeze().cpu().numpy() + target = target.cpu().numpy() + + rho, _ = spearmanr(preds, target) + r, _ = pearsonr(preds, target) + all_spearman.append(rho) + all_pearson.append(r) + + true_topk = set(target.argsort()[-k:]) + pred_topk = set(preds.argsort()[-k:]) + precision = len(true_topk & pred_topk) / len(pred_topk) + recall = len(true_topk & pred_topk) / len(true_topk) + all_precisions.append(precision) + all_recalls.append(recall) + + print(f"Spearman: {sum(all_spearman)/len(all_spearman):.3f}") + print(f"Pearson: {sum(all_pearson)/len(all_pearson):.3f}") + print(f"Top-{k} Precision: {sum(all_precisions)/len(all_precisions):.3f}") + print(f"Top-{k} Recall: {sum(all_recalls)/len(all_recalls):.3f}") + + return { + "spearman": all_spearman, + "pearson": all_pearson, + "precision": all_precisions, + "recall": all_recalls, + } + +def train(loader, dataset, num_epochs): + model = BurstTemporalGAT(vertex_in=4, edge_in=1, hid=64, heads=4, gru_hid=128) + criterion = nn.SmoothL1Loss() + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + for epoch in range(num_epochs): + model.train() + total_loss = 0 + for x_seq_batch, edge_index_batch, edge_attr_batch, target_batch in loader: + optimizer.zero_grad() + preds = [] + targets = [] + for x_seq, ei_seq, ea_seq, target in zip(x_seq_batch, edge_index_batch, edge_attr_batch, target_batch): + pred = model(x_seq, [ei_seq]*x_seq.shape[0], [ea_seq]*x_seq.shape[0]) + pred_scaled = pred.squeeze() / dataset.target_scale + target_scaled2 = target / dataset.target_scale + + preds.append(pred_scaled) + targets.append(target_scaled2) + + preds = torch.stack(preds) + targets = torch.stack(targets) + loss = criterion(preds, targets) + loss.backward() + optimizer.step() + total_loss += loss.item() + print(f"Epoch {epoch}: loss={total_loss/len(loader):.4f}") + evaluate(model, loader, k=50) + return model + +def save_model(model, path): + checkpoint = { + "model_state_dict": model.state_dict(), + "model_config": { + "vertex_in": 4, + "edge_in": 1, + "hid": 64, + "heads": 4, + "gru_hid": 128, + } + } + torch.save(checkpoint, path) + + +def main(h5dir, graphml_path, num_epochs): + data = np.load(os.path.join(h5dir, "allFrames.npz")) + allFrames = [data[key] for key in data] + + edge_index, edge_attr = load_graphml_edge_index(graphml_path) + adj_snapshots = [(edge_index, edge_attr)] * len(allFrames) + + dataset = BurstWindowDataset( + allFrames=allFrames, + adj_snapshots=adj_snapshots, + window_bins=20, + horizon=1, + include_coords=True + ) + loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) + out_model = train(loader, dataset, num_epochs) + save_model(out_model, os.path.join(h5dir, "burst_temporal_gat.pt")) + + +if __name__ == "__main__": + # example execution: python ./TemporalGAT.py /CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000 /CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000_growth_weights.graphml + # number of epochs to train model over + h5dir = sys.argv[1] + graphml_path = sys.argv[2] + num_epochs = sys.argv[3] + + start = time.time() + main(h5dir, graphml_path, int(num_epochs)) + end = time.time() + + elapsed_time = end - start + + print('Elapsed time: ' + str(elapsed_time) + ' seconds') \ No newline at end of file diff --git a/GNN/TemporalGAT/Model/burstTemporalGAT.py b/GNN/TemporalGAT/Model/burstTemporalGAT.py new file mode 100644 index 0000000..3677c4e --- /dev/null +++ b/GNN/TemporalGAT/Model/burstTemporalGAT.py @@ -0,0 +1,68 @@ +""" +BURSTTEMPORALGAT Temporal Graph Attention Network for Burst-Based Neuron Importance Modeling + + This module defines the BurstTemporalGAT model, a neural architecture designed + to model temporal bursting dynamics in large-scale spiking neural networks. + + The model operates on a sequence of graph-structured time windows corresponding + to a single burst. For each window: + 1. A Graph Attention Network (WindowGAT) computes vertex embeddings using + neuron-level features and synaptic connectivity. + 2. The sequence of window-level vertex embeddings is stacked across time. + 3. A Gated Recurrent Unit (GRU) integrates temporal information across windows. + 4. A linear output layer produces a scalar importance value for each neuron, + representing its predicted influence on future bursting behavior. + + This model is intended to be used with binned spike data and static or + quasi-static connectivity graphs generated by Graphitti simulations. + +Syntax: + model = BurstTemporalGAT( + vertex_in=, + edge_in=, + hid=, + heads=, + gru_hid= + ) + +Input: + x_seq - Tensor of shape [T, N, F], where: + T = number of temporal windows in a burst + N = number of neurons + F = number of vertex features + edge_index_seq - List (length T) of edge index tensors [2, E] + edge_attr_seq - List (length T) of edge attribute tensors [E, edge_in] + +Output: + - Tensor of shape [N, 1] containing predicted neuron importance values for + the final time window of the burst. + +Author: Marina Rosenwald + +Last updated: 12/16/2025 +""" + + +from windowGAT import WindowGAT +import torch.nn as nn +import torch + +class BurstTemporalGAT(nn.Module): + def __init__(self, vertex_in, edge_in, hid=64, heads=4, gru_hid=128): + super().__init__() + self.gat = WindowGAT(vertex_in, edge_in, hid, heads) + self.gru = nn.GRU(input_size=hid*heads, hidden_size=gru_hid, batch_first=True) + self.out = nn.Linear(gru_hid, 1) + + def forward(self, x_seq, edge_index_seq, edge_attr_seq): + T, N, F = x_seq.shape + vertex_embeds = [] + for t in range(T): + x_t = x_seq[t] + edge_index_t = edge_index_seq[t] + edge_attr_t = edge_attr_seq[t] + h_vertexs = self.gat(x_t, edge_index_t, edge_attr_t) + vertex_embeds.append(h_vertexs) + H = torch.stack(vertex_embeds, dim=1) + out_seq, _ = self.gru(H) + return self.out(out_seq[:, -1]) \ No newline at end of file diff --git a/GNN/TemporalGAT/Model/burstWindowDataset.py b/GNN/TemporalGAT/Model/burstWindowDataset.py new file mode 100644 index 0000000..3611c1a --- /dev/null +++ b/GNN/TemporalGAT/Model/burstWindowDataset.py @@ -0,0 +1,152 @@ +""" +BURSTWINDOWDATASET Windowed Burst Dataset for Temporal Graph Neural Networks + + This module defines the BurstWindowDataset class, which constructs a + windowed, graph-structured dataset from binned spike count data generated + by Graphitti simulations. + + Each sample in the dataset corresponds to a single burst and consists of: + - a sequence of temporal windows of neuron-level features + - a static graph connectivity structure (edge indices and attributes) + - a target vector representing neuron importance in a future burst + + Dataset construction pipeline: + 1. Pad all burst frames to a fixed number of neurons (N = 10,000). + 2. Segment each burst into fixed-width temporal windows. + 3. Compute per-neuron vertex features for each window: + - mean firing rate within the window + - binary participation indicator (neuron fired at least once) + - optional spatial (x, y) neuron coordinates + 4. Attach graph connectivity snapshots (edge indices and edge attributes). + 5. Generate training targets by computing each neuron’s normalized total + outgoing synaptic strength in a future burst (prediction horizon). + 6. Scale targets for numerical stability during training. + + The dataset is compatible with PyTorch DataLoader objects and is intended + for use with temporal GNN architectures such as BurstTemporalGAT. + +Syntax: + dataset = BurstWindowDataset( + allFrames=, + adj_snapshots=, + window_bins=, + horizon=, + include_coords=, + target_scale= + ) + +Input: + allFrames - List of arrays [N, T], containing binned spike counts + for each neuron across time bins for each burst. + adj_snapshots - List of graph snapshots, where each snapshot contains: + * edge_index [2, E] + * edge_attr [E, edge_in] + window_bins - Number of time bins per temporal window. + horizon - Number of bursts ahead used to construct the target. + include_coords - Whether to append spatial (x, y) neuron coordinates + to the vertex feature vectors. + target_scale - Scalar factor applied to target values for normalization. + +Output (per sample): + x_seq - Tensor [T_w, N, F] of windowed vertex features + edge_index - Tensor [2, E] of graph connectivity + edge_attr - Tensor [E, edge_in] of edge attributes + target - Tensor [N] of neuron importance values + +Author: Marina Rosenwald + +Last updated: 12/16/2025 +""" + +import torch +from torch.utils.data import Dataset +import numpy as np + +class BurstWindowDataset(Dataset): + @staticmethod + def pad_frame(frame, n_total=10000): + n_neurons, width = frame.shape + if n_neurons == n_total: + return frame + padded = np.zeros((n_total, width), dtype=frame.dtype) + padded[:n_neurons, :] = frame + return padded + + def __init__(self, allFrames, adj_snapshots, window_bins=20, horizon=1, include_coords=True, target_scale=1e6): + self.nNeurons = 10000 + self.window_bins = window_bins + self.horizon = horizon + self.nBursts = len(allFrames) + self.include_coords = include_coords + self.target_scale = target_scale + + self.allFrames = [self.pad_frame(f, n_total=self.nNeurons) for f in allFrames] + + self.windowed_frames = [] + for frame in self.allFrames: + width = frame.shape[1] + windows = [] + for start in range(0, width, self.window_bins): + end = min(start + self.window_bins, width) + win = frame[:, start:end] + + firing_counts = win.astype(np.float32) + participation = (win.sum(axis=1) > 0).astype(np.float32) + + vertex_feat = np.stack( + [firing_counts.mean(axis=1), participation], + axis=1 + ) + + if self.include_coords: + coords = np.array( + [[i % 100, i // 100] for i in range(self.nNeurons)], + dtype=np.float32 + ) + vertex_feat = np.concatenate([vertex_feat, coords], axis=1) + + windows.append(vertex_feat) + + self.windowed_frames.append(np.stack(windows)) + + self.edge_lists = [ + (edge_index, edge_attr if edge_attr.ndim == 2 else edge_attr.unsqueeze(1)) + for edge_index, edge_attr in adj_snapshots + ] + + self.targets = [] + for future_edge_index, future_edge_attr in self.edge_lists[horizon:]: + src = future_edge_index[0] + w = future_edge_attr + if w.ndim > 1: + w = w.squeeze() + + out_strength = np.zeros(self.nNeurons, dtype=np.float32) + np.add.at(out_strength, src, w.numpy()) + + out_strength /= (out_strength.max() + 1e-12) + out_strength *= self.target_scale + + self.targets.append(out_strength) + + targets_array = np.stack(self.targets) + print("\n========== GLOBAL TARGET STATISTICS ==========") + print("Total samples:", len(self.targets)) + print("min =", targets_array.min()) + print("max =", targets_array.max()) + print("mean =", targets_array.mean()) + print("median =", np.median(targets_array)) + nonzero = np.count_nonzero(targets_array) + print("nonzero count =", nonzero) + print("percent nonzero =", nonzero / targets_array.size * 100, "%") + print("==============================================\n") + + def __len__(self): + return self.nBursts - self.horizon + + def __getitem__(self, idx): + x_seq = torch.tensor(self.windowed_frames[idx], dtype=torch.float) + edge_index = torch.tensor(self.edge_lists[idx][0], dtype=torch.long) + edge_attr = torch.tensor(self.edge_lists[idx][1], dtype=torch.float) + target = torch.tensor(self.targets[idx], dtype=torch.float) + return x_seq, edge_index, edge_attr, target diff --git a/GNN/TemporalGAT/Model/modelRunPipeline.png b/GNN/TemporalGAT/Model/modelRunPipeline.png new file mode 100644 index 0000000..3b02d42 Binary files /dev/null and b/GNN/TemporalGAT/Model/modelRunPipeline.png differ diff --git a/GNN/TemporalGAT/Model/visualizeModel.py b/GNN/TemporalGAT/Model/visualizeModel.py new file mode 100644 index 0000000..86bec81 --- /dev/null +++ b/GNN/TemporalGAT/Model/visualizeModel.py @@ -0,0 +1,181 @@ +""" +VISUALIZEMODEL Load a Trained Temporal GAT and Visualize Neuron Importance + + This script loads a previously trained BurstTemporalGAT model and applies it + to windowed burst data to compute and visualize neuron-level importance scores. + + The pipeline performs the following steps: + 1. Load a trained Temporal Graph Attention Network (Temporal GAT) checkpoint. + 2. Reconstruct the burst window dataset from binned spike count data. + 3. Load the network connectivity graph from a GraphML file. + 4. Run inference on each burst to generate predicted neuron importance values. + 5. Optionally visualize predicted importance values at regular burst intervals. + 6. Compute and visualize the mean neuron importance across all bursts. + + Neuron importance values are visualized spatially using neuron (x, y) + coordinates, allowing structural and spatial patterns in burst influence + to be examined. + +Syntax: + TemporalGATInference.py + +Input: + h5dir - Path to the Graphitti dataset directory containing: + * allFrames.npz + * burst_temporal_gat.pt (trained model checkpoint) + Example: + '/CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000' + + graphml_path - Path to the GraphML file describing synaptic connectivity. + + (optional) optional_arg_50 - none value default + if 3rd argument present: return the neuron importance + at every 50 bursts + +Output: + - Spatial scatter plots showing: + * neuron importance for selected bursts (optional) + * mean neuron importance averaged across all bursts + +Author: Marina Rosenwald + +Last updated:12/16/2025 +""" + + +import torch +from burstTemporalGAT import BurstTemporalGAT +from burstWindowDataset import BurstWindowDataset +from torch.utils.data import DataLoader +import sys +import time +import os +import matplotlib.pyplot as plt +import numpy as np +import networkx as nx + +def load_model(path, device="cpu"): + checkpoint = torch.load(path, map_location=device) + + config = checkpoint["model_config"] + model = BurstTemporalGAT( + vertex_in=config["vertex_in"], + edge_in=config["edge_in"], + hid=config["hid"], + heads=config["heads"], + gru_hid=config["gru_hid"], + ) + + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(device) + model.eval() + + return model + +def load_graphml_edge_index(graphml_file): + G = nx.read_graphml(graphml_file) + if len(G.edges) == 0: + print("Warning: Graph has no edges!") + return torch.empty((2,0), dtype=torch.long), torch.empty((0,1), dtype=torch.float) + + mapping = {n:i for i,n in enumerate(G.nodes())} + G = nx.relabel_nodes(G, mapping) + + edges = list(G.edges) + edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() + + edge_attr = [G[u][v].get('weight', 1.0) for u,v in edges] + edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1) + + return edge_index, edge_attr + + +def collate_fn(batch): + x_seqs, edge_indices, edge_attrs, targets = zip(*batch) + return list(x_seqs), list(edge_indices), list(edge_attrs), torch.stack(targets) + +def get_neuron_coords(): + neuron_coords = np.array([[i % 100, i // 100] for i in range(10000)]) + return neuron_coords + +def visualizeAvg(model, loader, neuron_coords, optional_arg_50): + total_bursts = 0 + all_predictions = [] + for batch_idx, (x_seq_batch, edge_index_batch, edge_attr_batch, target_batch) in enumerate(loader): + for burst_idx, (x_seq, ei_seq, ea_seq, target) in enumerate( + zip(x_seq_batch, edge_index_batch, edge_attr_batch, target_batch) + ): + total_bursts += 1 + + preds = model( + x_seq, + [ei_seq] * x_seq.shape[0], + [ea_seq] * x_seq.shape[0] + ).squeeze().detach().cpu().numpy() + + all_predictions.append(preds) + # Plot every 50 bursts + if optional_arg_50 is not None: + if total_bursts % 50 == 0: + print( + f"Burst {total_bursts} -- min: {preds.min():.4f}, " + f"max: {preds.max():.4f}, mean: {preds.mean():.4f}" + ) + + plt.figure(figsize=(6, 6)) + sc = plt.scatter(neuron_coords[:, 0], neuron_coords[:, 1], c=preds, cmap='viridis', s=10) + plt.colorbar(sc, label='Mean Predicted Importance') + plt.title(f"Mean Neuron Importance At Burst {total_bursts}") + plt.xlabel("X Coordinate") + plt.ylabel("Y Coordinate") + # plt.show() + plt.savefig(os.path.join(h5dir, f"Mean_Neuron_Importance_{total_bursts}_Burst.png"), dpi=300, bbox_inches='tight') + all_predictions = np.vstack(all_predictions) + + mean_importance = all_predictions.mean(axis=0) + plt.figure(figsize=(6, 6)) + sc = plt.scatter( + neuron_coords[:, 0], neuron_coords[:, 1], + c=mean_importance, cmap='gray_r', s=10 + ) + plt.colorbar(sc, label='Mean Predicted Importance') + plt.title("Mean Neuron Importance Across All Bursts") + plt.xlabel("X Coordinate") + plt.ylabel("Y Coordinate") + # plt.show() + plt.savefig(os.path.join(h5dir, "Mean_Neuron_Importance_Across_All_Bursts.png"), dpi=300, bbox_inches='tight') + +def main(h5dir, graphml_path, optional_arg_50): + model = load_model(os.path.join(h5dir, "burst_temporal_gat.pt")) + + data = np.load(os.path.join(h5dir, "allFrames.npz")) + allFrames = [data[key] for key in data] + + edge_index, edge_attr = load_graphml_edge_index(graphml_path) + adj_snapshots = [(edge_index, edge_attr)] * len(allFrames) + + dataset = BurstWindowDataset( + allFrames=allFrames, + adj_snapshots=adj_snapshots, + window_bins=20, + horizon=1, + include_coords=True + ) + loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) + neuron_coords = get_neuron_coords() + visualizeAvg(model, loader, neuron_coords, optional_arg_50) + +if __name__ == "__main__": + # example execution: python ./visualizeModel.py /CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000 /CSSDIV/research/biocomputing/data/2025/tR_1.0--fE_0.90_10000_growth_weights.graphml + + h5dir = sys.argv[1] + graphml_path = sys.argv[2] + optional_arg_50 = sys.argv[3] if len(sys.argv) > 3 else None + + start = time.time() + main(h5dir, graphml_path, optional_arg_50) + end = time.time() + + elapsed_time = end - start + + print('Elapsed time: ' + str(elapsed_time) + ' seconds') \ No newline at end of file diff --git a/GNN/TemporalGAT/Model/windowGAT.py b/GNN/TemporalGAT/Model/windowGAT.py new file mode 100644 index 0000000..178a6b7 --- /dev/null +++ b/GNN/TemporalGAT/Model/windowGAT.py @@ -0,0 +1,88 @@ +""" +WINDOWGAT Window-Level Graph Attention Network for Burst-Based Neural Data + + This module defines the WindowGAT class, a graph attention mechanism used + to compute neuron level embeddings for a single temporal window of a burst. + + The WindowGAT operates on a graph where vertices represent neurons and edges + represent synaptic connections. For each edge, attention scores are computed + using both source vertex features and edge attributes, allowing the model to + weight incoming messages based on synaptic strength and vertex activity. + + Core components: + - group_softmax: a destination-wise softmax operation that normalizes + attention scores across all incoming edges for each target vertex. + - Multi-head attention mechanism that aggregates messages from source + neurons to destination neurons. + - Edge-aware attention, where edge attributes are concatenated with + source-vertex features during key and value computation. + + This module is designed to be used as a building block within higher-level + temporal models (e.g., BurstTemporalGAT), where it is applied independently + to each temporal window. + +Syntax: + gat = WindowGAT( + vertex_in=, + edge_in=, + hid=, + heads= + ) + +Input: + x - Tensor [N, F] of vertex (neuron) features + edge_index - Tensor [2, E] specifying source and destination vertices + edge_attr - Tensor [E, edge_in] of edge attributes (e.g., synaptic weights) + return_attn - Boolean flag indicating whether to return attention weights + +Output: + out - Tensor [N, hid * heads] of vertex embeddings + attn - (optional) Tensor [E, heads] of attention weights + (src, dst) - (optional) Edge index tuple corresponding to attention values + +Author: Marina Rosenwald + +Last updated: 12/16/2025 +""" + + +import torch.nn as nn +import torch +import torch.nn.functional as F + +def group_softmax(scores, dst, num_vertices): + H = scores.size(1) + max_scores = torch.full((num_vertices, H), -1e9, device=scores.device, dtype=scores.dtype) + max_scores.index_reduce_(0, dst, scores, reduce='amax') + exp_scores = torch.exp(scores - max_scores[dst]) + denom = torch.zeros(num_vertices, H, device=scores.device, dtype=scores.dtype) + denom.index_add_(0, dst, exp_scores) + return exp_scores / (denom[dst] + 1e-12) + +class WindowGAT(nn.Module): + def __init__(self, vertex_in, edge_in, hid=64, heads=4): + super().__init__() + self.q = nn.Linear(vertex_in, hid*heads, bias=False) + self.k = nn.Linear(vertex_in + edge_in, hid*heads, bias=False) + self.v = nn.Linear(vertex_in + edge_in, hid*heads, bias=False) + self.heads, self.hid = heads, hid + + def forward(self, x, edge_index, edge_attr, return_attn=False): + src, dst = edge_index + q = self.q(x[dst]).view(-1, self.heads, self.hid) + + if edge_attr.ndim == 1: + edge_attr = edge_attr.unsqueeze(-1) + + kv_in = torch.cat([x[src], edge_attr], dim=-1) + k = self.k(kv_in).view(-1, self.heads, self.hid) + v = self.v(kv_in).view(-1, self.heads, self.hid) + scores = (q * k).sum(-1) / (self.hid ** 0.5) + attn = group_softmax(scores, dst, x.size(0)) + + out = torch.zeros(x.size(0), self.heads, self.hid, device=x.device) + out.index_add_(0, dst, attn.unsqueeze(-1) * v) + out = F.elu(out.reshape(x.size(0), -1)) + if return_attn: + return out, attn, (src, dst) + return out \ No newline at end of file