From 78f92976e6e7bf516e315c078819ed00d8613519 Mon Sep 17 00:00:00 2001 From: Joe <34472403+joeat1@users.noreply.github.com> Date: Sun, 5 Jul 2020 21:26:14 +0800 Subject: [PATCH] Update utils.py --- graphsage/utils.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/graphsage/utils.py b/graphsage/utils.py index ff05072d..823328bd 100644 --- a/graphsage/utils.py +++ b/graphsage/utils.py @@ -5,13 +5,9 @@ import json import sys import os - +import pdb import networkx as nx from networkx.readwrite import json_graph -version_info = list(map(int, nx.__version__.split('.'))) -major = version_info[0] -minor = version_info[1] -assert (major <= 1) and (minor <= 11), "networkx major version > 1.11" WALK_LEN=5 N_WALKS=50 @@ -19,9 +15,12 @@ def load_data(prefix, normalize=True, load_walks=False): G_data = json.load(open(prefix + "-G.json")) G = json_graph.node_link_graph(G_data) - if isinstance(G.nodes()[0], int): - conversion = lambda n : int(n) - else: + try: + if isinstance(G.nodes()[0], dict): + conversion = lambda n : int(n) + else: + print("Something wrong when load graph") + except: conversion = lambda n : n if os.path.exists(prefix + "-feats.npy"): @@ -44,7 +43,7 @@ def load_data(prefix, normalize=True, load_walks=False): ## (necessary because of networkx weirdness with the Reddit data) broken_count = 0 for node in G.nodes(): - if not 'val' in G.node[node] or not 'test' in G.node[node]: + if not 'val' in G.nodes[node] or not 'test' in G.nodes[node]: G.remove_node(node) broken_count += 1 print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count)) @@ -53,15 +52,15 @@ def load_data(prefix, normalize=True, load_walks=False): ## (some datasets might already have this..) print("Loaded data.. now preprocessing..") for edge in G.edges(): - if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or - G.node[edge[0]]['test'] or G.node[edge[1]]['test']): + if (G.nodes[edge[0]]['val'] or G.nodes[edge[1]]['val'] or + G.nodes[edge[0]]['test'] or G.nodes[edge[1]]['test']): G[edge[0]][edge[1]]['train_removed'] = True else: G[edge[0]][edge[1]]['train_removed'] = False if normalize and not feats is None: from sklearn.preprocessing import StandardScaler - train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]) + train_ids = np.array([id_map[n] for n in G.nodes() if not G.nodes[n]['val'] and not G.nodes[n]['test']]) train_feats = feats[train_ids] scaler = StandardScaler() scaler.fit(train_feats) @@ -97,7 +96,7 @@ def run_random_walks(G, nodes, num_walks=N_WALKS): out_file = sys.argv[2] G_data = json.load(open(graph_file)) G = json_graph.node_link_graph(G_data) - nodes = [n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]] + nodes = [n for n in G.nodes() if not G.nodes[n]["val"] and not G.nodes[n]["test"]] G = G.subgraph(nodes) pairs = run_random_walks(G, nodes) with open(out_file, "w") as fp: