Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ Unzip the data in your working directory.

### Running the code

To train a model on the Bio data, run `python -m nqe.bio.train`.
To train a model on the Bio data, run `python -m netquery.bio.train`.
See that file for a list of possible arguments, and note that by default it assumes that the data is in a subdirectory of your working directory (i.e., "./bio_data).
By default the model will log its output and store a version of the model after training.
The train, test, and validation performance will be recorded in the log file.
If you are training with a GPU be sure to add the cuda flag, i.e., `python -m nqe.bio.train --cuda`.
If you are training with a GPU be sure to add the cuda flag, i.e., `python -m netquery.bio.train --cuda`.
The default parameters correspond to the best performing variant from the paper.

NB: Currently the training files are not-portable pickle files.
Expand Down
4 changes: 2 additions & 2 deletions netquery/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(self, features, feature_dims, cuda=False):
self.features = features
self.feat_dims = feature_dims
self.pool_matrix = {}
for mode, feat_dim in self.feat_dims.iteritems():
for mode, feat_dim in self.feat_dims.items():
self.pool_matrix[mode] = nn.Parameter(torch.FloatTensor(feat_dim, feat_dim))
init.xavier_uniform(self.pool_matrix[mode])
self.register_parameter(mode+"_pool", self.pool_matrix[mode])
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(self, features, feature_dims,
self.features = features
self.feat_dims = feature_dims
self.pool_matrix = {}
for mode, feat_dim in self.feat_dims.iteritems():
for mode, feat_dim in self.feat_dims.items():
self.pool_matrix[mode] = nn.Parameter(torch.FloatTensor(feat_dim, feat_dim))
init.xavier_uniform(self.pool_matrix[mode])
self.register_parameter(mode+"_pool", self.pool_matrix[mode])
Expand Down
23 changes: 12 additions & 11 deletions netquery/bio/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import cPickle as pickle
#import cPickle as pickle
import pickle
import torch
from collections import OrderedDict, defaultdict
from multiprocessing import Process
Expand All @@ -10,7 +11,7 @@

def load_graph(data_dir, embed_dim):
rels, adj_lists, node_maps = pickle.load(open(data_dir+"/graph_data.pkl", "rb"))
node_maps = {m : {n : i for i, n in enumerate(id_list)} for m, id_list in node_maps.iteritems()}
node_maps = {m : {n : i for i, n in enumerate(id_list)} for m, id_list in node_maps.items()}
for m in node_maps:
node_maps[m][-1] = -1
feature_dims = {m : embed_dim for m in rels}
Expand Down Expand Up @@ -39,36 +40,36 @@ def clean_test():
else:
to_keep = 10000
test_queries = load_queries_by_type("/dfs/scratch0/nqe-bio/{:s}_queries_{:d}-split.pkl".format(kind, i), keep_graph=True)
print "Loaded", i, kind
print("Loaded", i, kind)
for query_type in test_queries:
test_queries[query_type] = [q for q in test_queries[query_type] if len(q.get_edges().intersection(deleted_edges)) > 0]
test_queries[query_type] = test_queries[query_type][:to_keep]
test_queries = [q.serialize() for queries in test_queries.values() for q in queries]
pickle.dump(test_queries, open("/dfs/scratch0/nqe-bio/{:s}_queries_{:d}-clean.pkl".format(kind, i), "wb"), protocol=pickle.HIGHEST_PROTOCOL)
print "Finished", i, kind
print("Finished", i, kind)



def make_train_test_edge_data(data_dir):
print "Loading graph..."
print("Loading graph...")
graph, _, _ = load_graph(data_dir, 10)
print "Getting all edges..."
print("Getting all edges...")
edges = graph.get_all_edges()
split_point = int(0.1*len(edges))
val_test_edges = edges[:split_point]
print "Getting negative samples..."
print("Getting negative samples...")
val_test_edge_negsamples = [graph.get_negative_edge_samples(e, 100) for e in val_test_edges]
print "Making and storing test queries."
print("Making and storing test queries.")
val_test_edge_queries = [Query(("1-chain", val_test_edges[i]), val_test_edge_negsamples[i], None, 100) for i in range(split_point)]
val_split_point = int(0.1*len(val_test_edge_queries))
val_queries = val_test_edge_queries[:val_split_point]
test_queries = val_test_edge_queries[val_split_point:]
pickle.dump([q.serialize() for q in val_queries], open(data_dir+"/val_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump([q.serialize() for q in test_queries], open(data_dir+"/test_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL)

print "Removing test edges..."
print("Removing test edges...")
graph.remove_edges(val_test_edges)
print "Making and storing train queries."
print("Making and storing train queries.")
train_edges = graph.get_all_edges()
train_queries = [Query(("1-chain", e), None, None) for e in train_edges]
pickle.dump([q.serialize() for q in train_queries], open(data_dir+"/train_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL)
Expand All @@ -78,7 +79,7 @@ def _discard_negatives(file_name, small_prop=0.9):
# queries = [q if random.random() > small_prop else (q[0],[random.choice(tuple(q[1]))], None if q[2] is None else [random.choice(tuple(q[2]))]) for q in queries]
queries = [q if random.random() > small_prop else (q[0],[random.choice(list(q[1]))], None if q[2] is None else [random.choice(list(q[2]))]) for q in queries]
pickle.dump(queries, open(file_name.split(".")[0] + "-split.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL)
print "Finished", file_name
print("Finished", file_name)

def discard_negatives(data_dir):
_discard_negatives(data_dir + "/val_edges.pkl")
Expand Down
6 changes: 3 additions & 3 deletions netquery/bio/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@
parser.add_argument("--opt", type=str, default="adam")
args = parser.parse_args()

print "Loading graph data.."
print("Loading graph data..")
graph, feature_modules, node_maps = load_graph(args.data_dir, args.embed_dim)
if args.cuda:
graph.features = cudify(feature_modules, node_maps)
out_dims = {mode:args.embed_dim for mode in graph.relations}

print "Loading edge data.."
print("Loading edge data..")
train_queries = load_queries_by_formula(args.data_dir + "/train_edges.pkl")
val_queries = load_test_queries_by_formula(args.data_dir + "/val_edges.pkl")
test_queries = load_test_queries_by_formula(args.data_dir + "/test_edges.pkl")

print "Loading query data.."
print("Loading query data..")
for i in range(2,4):
train_queries.update(load_queries_by_formula(args.data_dir + "/train_queries_{:d}.pkl".format(i)))
i_val_queries = load_test_queries_by_formula(args.data_dir + "/val_queries_{:d}.pkl".format(i))
Expand Down
13 changes: 7 additions & 6 deletions netquery/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
import cPickle as pickle
#import cPickle as pickle
import pickle
from multiprocessing import Process
from netquery.graph import Query

Expand Down Expand Up @@ -45,13 +46,13 @@ def sample_clean_test(graph_loader, data_dir):
val_queries_2 = test_graph.sample_test_queries(train_graph, ["2-chain", "2-inter"], 10, 900)
val_queries_2.extend(test_graph.sample_test_queries(train_graph, ["2-chain", "2-inter"], 100, 1000))
val_queries_2 = list(set(val_queries_2)-set(test_queries_2))
print len(val_queries_2)
print(len(val_queries_2))
test_queries_3 = test_graph.sample_test_queries(train_graph, ["3-chain", "3-inter", "3-inter_chain", "3-chain_inter"], 9000, 1)
test_queries_3.extend(test_graph.sample_test_queries(train_graph, ["3-chain", "3-inter", "3-inter_chain", "3-chain_inter"], 1000, 1000))
val_queries_3 = test_graph.sample_test_queries(train_graph, ["3-chain", "3-inter", "3-inter_chain", "3-chain_inter"], 900, 1)
val_queries_3.extend(test_graph.sample_test_queries(train_graph, ["3-chain", "3-inter", "3-inter_chain", "3-chain_inter"], 100, 1000))
val_queries_3 = list(set(val_queries_3)-set(test_queries_3))
print len(val_queries_3)
print(len(val_queries_3))
pickle.dump([q.serialize() for q in test_queries_2], open(data_dir + "/test_queries_2-newclean.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump([q.serialize() for q in test_queries_3], open(data_dir + "/test_queries_3-newclean.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump([q.serialize() for q in val_queries_2], open(data_dir + "/val_queries_2-newclean.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL)
Expand All @@ -67,16 +68,16 @@ def clean_test(train_queries, test_queries):
def parallel_sample_worker(pid, num_samples, graph, data_dir, is_test, test_edges):
if not is_test:
graph.remove_edges([(q.target_node, q.formula.rels[0], q.anchor_nodes[0]) for q in test_edges])
print "Running worker", pid
print("Running worker", pid)
queries_2 = graph.sample_queries(2, num_samples, 100 if is_test else 1, verbose=True)
queries_3 = graph.sample_queries(3, num_samples, 100 if is_test else 1, verbose=True)
print "Done running worker, now saving data", pid
print("Done running worker, now saving data", pid)
pickle.dump([q.serialize() for q in queries_2], open(data_dir + "/queries_2-{:d}.pkl".format(pid), "wb"), protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump([q.serialize() for q in queries_3], open(data_dir + "/queries_3-{:d}.pkl".format(pid), "wb"), protocol=pickle.HIGHEST_PROTOCOL)

def parallel_sample(graph, num_workers, samples_per_worker, data_dir, test=False, start_ind=None):
if test:
print "Loading test/val data.."
print("Loading test/val data..")
test_edges = load_queries(data_dir + "/test_edges.pkl")
val_edges = load_queries(data_dir + "/val_edges.pkl")
else:
Expand Down
6 changes: 3 additions & 3 deletions netquery/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(self, relations, dims):
for r2 in relations[r1]:
rel = (r1, r2[1], r2[0])
self.mats[rel] = nn.Parameter(torch.FloatTensor(dims[rel[0]], dims[rel[2]]))
init.xavier_uniform(self.mats[rel])
init.xavier_uniform_(self.mats[rel])
self.register_parameter("_".join(rel), self.mats[rel])

def forward(self, embeds1, embeds2, rels):
Expand Down Expand Up @@ -279,10 +279,10 @@ def __init__(self, mode_dims, expand_dims, agg_func=torch.min):
self.agg_func = agg_func
for mode in mode_dims:
self.pre_mats[mode] = nn.Parameter(torch.FloatTensor(expand_dims[mode], mode_dims[mode]))
init.xavier_uniform(self.pre_mats[mode])
init.xavier_uniform_(self.pre_mats[mode])
self.register_parameter(mode+"_premat", self.pre_mats[mode])
self.post_mats[mode] = nn.Parameter(torch.FloatTensor(mode_dims[mode], expand_dims[mode]))
init.xavier_uniform(self.post_mats[mode])
init.xavier_uniform_(self.post_mats[mode])
self.register_parameter(mode+"_postmat", self.post_mats[mode])

def forward(self, embeds1, embeds2, mode, embeds3 = []):
Expand Down
6 changes: 3 additions & 3 deletions netquery/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, features, feature_modules):
feature_modules -- This should be a map from mode -> torch.nn.EmbeddingBag
"""
super(DirectEncoder, self).__init__()
for name, module in feature_modules.iteritems():
for name, module in feature_modules.items():
self.add_module("feat-"+name, module)
self.features = features

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, features, feature_dims,
self.adj_lists = adj_lists
self.relations = relations
self.aggregator = aggregator
for name, module in feature_modules.iteritems():
for name, module in feature_modules.items():
self.add_module("feat-"+name, module)
if base_model != None:
self.base_model = base_model
Expand All @@ -92,7 +92,7 @@ def __init__(self, features, feature_dims,
self.self_params = {}
self.compress_params = {}
self.lns = {}
for mode, feat_dim in self.feat_dims.iteritems():
for mode, feat_dim in self.feat_dims.items():
if self.layer_norm:
self.lns[mode] = LayerNorm(out_dims[mode])
self.add_module(mode+"_ln", self.lns[mode])
Expand Down
24 changes: 12 additions & 12 deletions netquery/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,18 @@ def __init__(self, features, feature_dims, relations, adj_lists):
self.full_sets = defaultdict(set)
self.full_lists = {}
self.meta_neighs = defaultdict(dict)
for rel, adjs in self.adj_lists.iteritems():
for rel, adjs in self.adj_lists.items():
full_set = set(self.adj_lists[rel].keys())
self.full_sets[rel[0]] = self.full_sets[rel[0]].union(full_set)
for mode, full_set in self.full_sets.iteritems():
for mode, full_set in self.full_sets.items():
self.full_lists[mode] = list(full_set)
self._cache_edge_counts()
self._make_flat_adj_lists()

def _make_flat_adj_lists(self):
self.flat_adj_lists = defaultdict(lambda : defaultdict(list))
for rel, adjs in self.adj_lists.iteritems():
for node, neighs in adjs.iteritems():
for rel, adjs in self.adj_lists.items():
for node, neighs in adjs.items():
self.flat_adj_lists[rel[0]][node].extend([(rel, neigh) for neigh in neighs])

def _cache_edge_counts(self):
Expand All @@ -140,10 +140,10 @@ def _cache_edge_counts(self):
self.rel_weights = OrderedDict()
self.mode_edges = defaultdict(float)
self.mode_weights = OrderedDict()
for rel, edge_count in self.rel_edges.iteritems():
for rel, edge_count in self.rel_edges.items():
self.rel_weights[rel] = edge_count / self.edges
self.mode_edges[rel[0]] += edge_count
for mode, edge_count in self.mode_edges.iteritems():
for mode, edge_count in self.mode_edges.items():
self.mode_weights[mode] = edge_count / self.edges

def remove_edges(self, edge_list):
Expand All @@ -167,10 +167,10 @@ def get_all_edges(self, seed=0, exclude_rels=set([])):
"""
edges = []
random.seed(seed)
for rel, adjs in self.adj_lists.iteritems():
for rel, adjs in self.adj_lists.items():
if rel in exclude_rels:
continue
for node, neighs in adjs.iteritems():
for node, neighs in adjs.items():
edges.extend([(node, rel, neigh) for neigh in neighs if neigh != -1])
random.shuffle(edges)
return edges
Expand All @@ -179,10 +179,10 @@ def get_all_edges_byrel(self, seed=0,
exclude_rels=set([])):
random.seed(seed)
edges = defaultdict(list)
for rel, adjs in self.adj_lists.iteritems():
for rel, adjs in self.adj_lists.items():
if rel in exclude_rels:
continue
for node, neighs in adjs.iteritems():
for node, neighs in adjs.items():
edges[(rel,)].extend([(node, neigh) for neigh in neighs if neigh != -1])

def get_negative_edge_samples(self, edge, num, rejection_sample=True):
Expand Down Expand Up @@ -216,7 +216,7 @@ def sample_test_queries(self, train_graph, q_types, samples_per_type, neg_sample
queries.append(query)
sampled += 1
if sampled % 1000 == 0 and verbose:
print "Sampled", sampled
print("Sampled", sampled)
return queries

def sample_queries(self, arity, num_samples, neg_sample_max, verbose=True):
Expand All @@ -233,7 +233,7 @@ def sample_queries(self, arity, num_samples, neg_sample_max, verbose=True):
queries.append(query)
sampled += 1
if sampled % 1000 == 0 and verbose:
print "Sampled", sampled
print("Sampled", sampled)
return queries


Expand Down
2 changes: 1 addition & 1 deletion netquery/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def margin_loss(self, nodes1, nodes2, rels):
Maximizes relationaship scores for true pairs vs negative samples.
"""
affs = self.forward(nodes1, nodes2, rels)
neg_nodes = [random.randint(1,len(self.graph.adj_lists[_reverse_relation[rels[-1]]])-1) for _ in xrange(len(nodes1))]
neg_nodes = [random.randint(1,len(self.graph.adj_lists[_reverse_relation[rels[-1]]])-1) for _ in range(len(nodes1))]
neg_affs = self.forward(nodes1, neg_nodes,
rels)
margin = 1 - (affs - neg_affs)
Expand Down
18 changes: 9 additions & 9 deletions netquery/train_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from utils import eval_auc_queries, eval_perc_queries
from netquery.utils import eval_auc_queries, eval_perc_queries
import torch

def check_conv(vals, window=2, tol=1e-6):
Expand All @@ -19,7 +19,7 @@ def update_loss(loss, losses, ema_loss, ema_alpha=0.01):
def run_eval(model, queries, iteration, logger, by_type=False):
vals = {}
def _print_by_rel(rel_aucs, logger):
for rels, auc in rel_aucs.iteritems():
for rels, auc in rel_aucs.items():
logger.info(str(rels) + "\t" + str(auc))
for query_type in queries["one_neg"]:
auc, rel_aucs = eval_auc_queries(queries["one_neg"][query_type], model)
Expand All @@ -45,15 +45,15 @@ def run_train(model, optimizer, train_queries, val_queries, test_queries, logger
vals = []
losses = []
conv_test = None
for i in xrange(max_iter):
for i in range(max_iter):

optimizer.zero_grad()
loss = run_batch(train_queries["1-chain"], model, i, batch_size)
if not edge_conv and (check_conv(vals) or len(losses) >= max_burn_in):
logger.info("Edge converged at iteration {:d}".format(i-1))
logger.info("Testing at edge conv...")
conv_test = run_eval(model, test_queries, i, logger)
conv_test = np.mean(conv_test.values())
conv_test = np.mean(list(conv_test.values()))
edge_conv = True
losses = []
ema_loss = None
Expand All @@ -74,7 +74,7 @@ def run_train(model, optimizer, train_queries, val_queries, test_queries, logger
logger.info("Fully converged at iteration {:d}".format(i))
break

losses, ema_loss = update_loss(loss.data[0], losses, ema_loss)
losses, ema_loss = update_loss(loss.item(), losses, ema_loss)
loss.backward()
optimizer.step()

Expand All @@ -84,20 +84,20 @@ def run_train(model, optimizer, train_queries, val_queries, test_queries, logger
if i >= val_every and i % val_every == 0:
v = run_eval(model, val_queries, i, logger)
if edge_conv:
vals.append(np.mean(v.values()))
vals.append(np.mean(list(v.values())))
else:
vals.append(v["1-chain"])

v = run_eval(model, test_queries, i, logger)
logger.info("Test macro-averaged val: {:f}".format(np.mean(v.values())))
logger.info("Improvement from edge conv: {:f}".format((np.mean(v.values())-conv_test)/conv_test))
logger.info("Test macro-averaged val: {:f}".format(np.mean(list(v.values()))))
logger.info("Improvement from edge conv: {:f}".format((np.mean(list(v.values()))-conv_test)/conv_test))

def run_batch(train_queries, enc_dec, iter_count, batch_size, hard_negatives=False):
num_queries = [float(len(queries)) for queries in train_queries.values()]
denom = float(sum(num_queries))
formula_index = np.argmax(np.random.multinomial(1,
np.array(num_queries)/denom))
formula = train_queries.keys()[formula_index]
formula = list(train_queries.keys())[formula_index]
n = len(train_queries[formula])
start = (iter_count * batch_size) % n
end = min(((iter_count+1) * batch_size) % n, n)
Expand Down
Loading