From d3d56719f120be7d46ffabbc96cf62dc125fb42a Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 2 Apr 2016 20:32:09 +0200 Subject: [PATCH 1/7] Trying to fix graphgen with parallel containers --- graphgen.lua | 22 +++++++++++++++++++-- models.lua | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/graphgen.lua b/graphgen.lua index 7f07a26..7c1f7d9 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -102,6 +102,7 @@ local function generateGraph(net, input, opts) local storageHash = {} local nodes = {} + local trickyNodes = {} local g = graph.Graph() @@ -168,7 +169,17 @@ local function generateGraph(net, input, opts) nodes[toPtr] = nodes[toPtr] or createNode(name,to) - assert(nodes[fromPtr], 'Parent node inexistant for module '.. name) + --assert(nodes[fromPtr], 'Parent node inexistant for module '.. name) + if not nodes[fromPtr] then + --[[ + print('Printing debug') + print(debug.getinfo(2)) + --]] + + nodes[fromPtr] = createNode('oups',from) + table.insert(trickyNodes, fromPtr) + trickyNodes[fromPtr] = nodes[fromPtr] + end -- insert edge g:add(graph.Edge(nodes[fromPtr],nodes[toPtr])) @@ -199,7 +210,14 @@ local function generateGraph(net, input, opts) -- those containers effectively do some computation, so they have their -- place in the graph for i,branch in ipairs(m.modules) do - local last_module = branch:get(branch:size()) + --local last_module = branch:get(branch:size()) + local last_module + if branch.modues then + last_module = branch:get(#branch.modules) + else + last_module = branch + end + local out = last_module.output local ptr = torch.pointer(out) diff --git a/models.lua b/models.lua index 8269c71..fd05903 100644 --- a/models.lua +++ b/models.lua @@ -72,6 +72,62 @@ models.siamese = function() return m, input end +models.siamese_parallel = function() + local fSize = {1, 32, 64} + local featuresOut = 128 + + local desc = nn.Sequential() + desc:add(nn.Reshape(1,64,64)) + desc:add(nn.SpatialAveragePooling(2,2,2,2)) + desc:add(nn.SpatialConvolution(fSize[1], fSize[2], 7,7)) + desc:add(nn.ReLU()) + desc:add(nn.SpatialMaxPooling(2,2,2,2)) + desc:add(nn.SpatialConvolution(fSize[2], fSize[3], 6,6)) + desc:add(nn.ReLU()) + desc:add(nn.View(-1):setNumInputDims(3)) + desc:add(nn.Linear(4096, 128)) + desc:add(nn.Contiguous()) + + local siamese = nn.Parallel(2,2) + local siam = desc:clone() + desc:share(siam, 'weight', 'bias', 'gradWeight', 'gradBias') + siamese:add(desc) + siamese:add(siam) + + local top = nn.Sequential() + top:add(nn.Linear(featuresOut*2, featuresOut*2)) + top:add(nn.ReLU()) + top:add(nn.Linear(featuresOut*2, 1)) + + local model = nn.Sequential():add(siamese):add(top) + + local input = torch.rand(1,2,64,64) + + return model, input +end + +models.basic_parallel_middle = function() + local model = nn.Sequential():add(nn.Linear(2,2)) + local prl = nn.Parallel(2,1) + prl:add(nn.Linear(2,2)) + prl:add(nn.Linear(2,2)) + model:add(prl) + local input = torch.rand(2,2) + return model, input +end + +models.basic_splitTable = function() + local model = nn.Sequential():add(nn.Linear(2,2)) + model:add(nn.SplitTable(2)) + local prl = nn.ParallelTable() + prl:add(nn.ReLU()) + prl:add(nn.Sigmoid()) + model:add(prl) + model:add(nn.JoinTable(1)) + local input = torch.rand(2,2) + return model, input +end + models.basic_concat = function() local m = nn.Sequential() local cat = nn.ConcatTable() From 48a65c874893ad73d524cc7b9af697078b607b60 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 7 Apr 2016 08:07:47 +0200 Subject: [PATCH 2/7] More changes --- graphgen.lua | 43 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/graphgen.lua b/graphgen.lua index 7c1f7d9..a141d32 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -34,6 +34,7 @@ local colorNames = { "goldenrod","goldenrod1","goldenrod2","goldenrod3","goldenrod4" } + -- some modules exist only for constructing -- the flow of information, and should not -- have their place in the computation graph @@ -159,6 +160,32 @@ local function generateGraph(net, input, opts) end end + --local oldFuncs = {DoubleTensor={},FloatTensor={}} + local oldF + local idx = 1 + + local function hackTorch() + --oldFuncs.DoubleTensor.__index = torch.DoubleTensor.__index + oldF = torch.DoubleTensor.__index + torch.DoubleTensor.__index = function(...) + local r = oldF(...) + trickyNodes[torch.pointer(r)] = idx + idx = idx+1 + return r + end + end + + local function unhackTorch() + torch.DoubleTensor.__index = oldF + end + + local m_idx = {} + net:apply(function(x) + table.insert(m_idx, x) + end) + + + -- create edge "from" -> "to", creating "to" on the way with "name" -- the edges can be seen as linking modules, but in fact it links the output -- tensor of each module @@ -175,12 +202,16 @@ local function generateGraph(net, input, opts) print('Printing debug') print(debug.getinfo(2)) --]] + local n = trickyNodes[fromPtr] + print(n) + local n2 = torch.typename(m_idx[n]) - nodes[fromPtr] = createNode('oups',from) - table.insert(trickyNodes, fromPtr) - trickyNodes[fromPtr] = nodes[fromPtr] + nodes[fromPtr] = createNode(n2,from) + --nodes[fromPtr] = createNode('oups',from) + --table.insert(trickyNodes, fromPtr) + --trickyNodes[fromPtr] = nodes[fromPtr] end - + -- insert edge g:add(graph.Edge(nodes[fromPtr],nodes[toPtr])) elseif torch.isTensor(from) then @@ -234,11 +265,13 @@ local function generateGraph(net, input, opts) -- fill the states from each tensor net:forward(input) - + -- overwriting the standard functions to generate our graph net:apply(apply_func) -- generate the graph + hackTorch() net:forward(input) + unhackTorch() if opts.addOutputNode then -- add dummy output node and link the last module to it From ae8a9eb0867c44d539a7d6a89bf8c2f913ceaf1d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 12 Apr 2016 00:32:14 +0200 Subject: [PATCH 3/7] Seems to work! --- graphgen.lua | 58 +++++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/graphgen.lua b/graphgen.lua index a141d32..a81db06 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -161,16 +161,21 @@ local function generateGraph(net, input, opts) end --local oldFuncs = {DoubleTensor={},FloatTensor={}} + local origTorchFuncs = {DoubleTensor={},FloatTensor={},CudaTensor={}} + local hackableTorchFuncs = {'select'} local oldF - local idx = 1 + local idx_m = {__input=input} local function hackTorch() --oldFuncs.DoubleTensor.__index = torch.DoubleTensor.__index - oldF = torch.DoubleTensor.__index - torch.DoubleTensor.__index = function(...) + --oldF = torch.DoubleTensor.__index + oldF = torch.DoubleTensor.select + --torch.DoubleTensor.__index = function(...) + torch.DoubleTensor.select = function(...) local r = oldF(...) - trickyNodes[torch.pointer(r)] = idx - idx = idx+1 + if r then + trickyNodes[torch.pointer(r)] = {idx_m,'select'} + end return r end end @@ -179,13 +184,6 @@ local function generateGraph(net, input, opts) torch.DoubleTensor.__index = oldF end - local m_idx = {} - net:apply(function(x) - table.insert(m_idx, x) - end) - - - -- create edge "from" -> "to", creating "to" on the way with "name" -- the edges can be seen as linking modules, but in fact it links the output -- tensor of each module @@ -196,20 +194,14 @@ local function generateGraph(net, input, opts) nodes[toPtr] = nodes[toPtr] or createNode(name,to) - --assert(nodes[fromPtr], 'Parent node inexistant for module '.. name) if not nodes[fromPtr] then - --[[ - print('Printing debug') - print(debug.getinfo(2)) - --]] local n = trickyNodes[fromPtr] - print(n) - local n2 = torch.typename(m_idx[n]) + --print(n) + --print(torch.typename(n[1])) + local n2 = n[2] - nodes[fromPtr] = createNode(n2,from) - --nodes[fromPtr] = createNode('oups',from) - --table.insert(trickyNodes, fromPtr) - --trickyNodes[fromPtr] = nodes[fromPtr] + local pfrom = n[1].__input + addEdge(pfrom,from,n2) end -- insert edge @@ -225,11 +217,19 @@ local function generateGraph(net, input, opts) end end + local old_idx_m = {}-- = {{__input=input}} + -- go over the network keeping track of the input/output for each module -- we overwrite the updateOutput for that. local function apply_func(m) local basefunc = m.updateOutput m.updateOutput = function(self, input) + self.__input = input + table.insert(old_idx_m,idx_m) + idx_m = self + local output = basefunc(self, input) + idx_m = table.remove(old_idx_m) + --idx_m = self if isSingleOperationModule(m) then local name = tostring(m) if m.inplace then -- handle it differently ? @@ -257,20 +257,26 @@ local function generateGraph(net, input, opts) addEdge(out, self.output, torch.typename(m)) end end - return basefunc(self, input) + --idx = idx + 1 + return output end + --print(idx) + end createBoundaryNode(input, 'Input') -- fill the states from each tensor - net:forward(input) + --net:forward(input) + hackTorch() -- overwriting the standard functions to generate our graph net:apply(apply_func) -- generate the graph - hackTorch() net:forward(input) + + --print(trickyNodes) + print(old_idx_m) unhackTorch() if opts.addOutputNode then From 49142d54ab558e7708af32c6e590b88661d8f401 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 12 Apr 2016 01:00:25 +0200 Subject: [PATCH 4/7] Cleaning up --- graphgen.lua | 61 ++++++++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/graphgen.lua b/graphgen.lua index a81db06..5ec4f64 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -160,28 +160,35 @@ local function generateGraph(net, input, opts) end end - --local oldFuncs = {DoubleTensor={},FloatTensor={}} - local origTorchFuncs = {DoubleTensor={},FloatTensor={},CudaTensor={}} + local origTorchFuncs = {DoubleTensor={},FloatTensor={}} + if package.loaded.cutorch then + origTorchFuncs.CudaTensor = {} + end local hackableTorchFuncs = {'select'} - local oldF - local idx_m = {__input=input} + local current_module = {__input=input} local function hackTorch() - --oldFuncs.DoubleTensor.__index = torch.DoubleTensor.__index - --oldF = torch.DoubleTensor.__index - oldF = torch.DoubleTensor.select - --torch.DoubleTensor.__index = function(...) - torch.DoubleTensor.select = function(...) - local r = oldF(...) - if r then - trickyNodes[torch.pointer(r)] = {idx_m,'select'} + for torchType, t in pairs(origTorchFuncs) do + for _, func in ipairs(hackableTorchFuncs) do + oldFunc = torch[torchType][func] + t[func] = oldFunc + torch[torchType][func] = function(...) + local res = oldFunc(...) + if res then + trickyNodes[torch.pointer(res)] = {current_module, 'torch.'..func} + end + return res + end end - return r end end local function unhackTorch() - torch.DoubleTensor.__index = oldF + for torchType, t in pairs(origTorchFuncs) do + for _, func in ipairs(hackableTorchFuncs) do + torch[torchType][func] = t[func] + end + end end -- create edge "from" -> "to", creating "to" on the way with "name" @@ -195,13 +202,12 @@ local function generateGraph(net, input, opts) nodes[toPtr] = nodes[toPtr] or createNode(name,to) if not nodes[fromPtr] then - local n = trickyNodes[fromPtr] - --print(n) - --print(torch.typename(n[1])) - local n2 = n[2] + local trickyNode = trickyNodes[fromPtr] + assert(trickyNode, "Could't handle previous node to "..name) + local trickyNodeName = trickyNode[2] - local pfrom = n[1].__input - addEdge(pfrom,from,n2) + local trickyParentFrom = trickyNode[1].__input + addEdge(trickyParentFrom,from,trickyNodeName) end -- insert edge @@ -217,7 +223,7 @@ local function generateGraph(net, input, opts) end end - local old_idx_m = {}-- = {{__input=input}} + local stack_visited_modules = {} -- go over the network keeping track of the input/output for each module -- we overwrite the updateOutput for that. @@ -225,11 +231,10 @@ local function generateGraph(net, input, opts) local basefunc = m.updateOutput m.updateOutput = function(self, input) self.__input = input - table.insert(old_idx_m,idx_m) - idx_m = self + table.insert(stack_visited_modules, current_module) + current_module = self local output = basefunc(self, input) - idx_m = table.remove(old_idx_m) - --idx_m = self + current_module = table.remove(stack_visited_modules) if isSingleOperationModule(m) then local name = tostring(m) if m.inplace then -- handle it differently ? @@ -257,11 +262,8 @@ local function generateGraph(net, input, opts) addEdge(out, self.output, torch.typename(m)) end end - --idx = idx + 1 return output end - --print(idx) - end createBoundaryNode(input, 'Input') @@ -275,8 +277,6 @@ local function generateGraph(net, input, opts) -- generate the graph net:forward(input) - --print(trickyNodes) - print(old_idx_m) unhackTorch() if opts.addOutputNode then @@ -302,6 +302,7 @@ local function generateGraph(net, input, opts) -- clean up the modified function net:apply(function(x) x.updateOutput = nil + x.__input = nil end) return g From 018f199a904bfbcb4bb16ba375b2c297bdeb6437 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 12 Apr 2016 08:16:18 +0200 Subject: [PATCH 5/7] Add comments --- graphgen.lua | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/graphgen.lua b/graphgen.lua index 5ec4f64..7a3c9ea 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -104,6 +104,8 @@ local function generateGraph(net, input, opts) local storageHash = {} local nodes = {} local trickyNodes = {} + local current_module = {__input=input} + local stack_visited_modules = {} local g = graph.Graph() @@ -161,12 +163,19 @@ local function generateGraph(net, input, opts) end local origTorchFuncs = {DoubleTensor={},FloatTensor={}} + -- also hack the cuda counter-parts if cutorch is loaded if package.loaded.cutorch then origTorchFuncs.CudaTensor = {} end - local hackableTorchFuncs = {'select'} - local current_module = {__input=input} - + -- list of functions to hack. seems that can't extend due to stack + -- overflow reasons + local hackableTorchFuncs = {'select','__index'} + + -- we will temporarily overwrite torch functions to keep track + -- of all created tensors during the forward call. This will + -- allow us to handle some corner cases where the input tensor is + -- not part of the state of a module (i.e., it's not the output + -- of another module) local function hackTorch() for torchType, t in pairs(origTorchFuncs) do for _, func in ipairs(hackableTorchFuncs) do @@ -201,6 +210,10 @@ local function generateGraph(net, input, opts) nodes[toPtr] = nodes[toPtr] or createNode(name,to) + -- if "from" tensor is not present in "nodes" table, this means that + -- "from" is not the output of a module, and was created on the fly + -- during for example a slicing of a tensor. "trickyNodes" contains + -- all tensors that were generated on the fly if not nodes[fromPtr] then local trickyNode = trickyNodes[fromPtr] assert(trickyNode, "Could't handle previous node to "..name) @@ -223,8 +236,6 @@ local function generateGraph(net, input, opts) end end - local stack_visited_modules = {} - -- go over the network keeping track of the input/output for each module -- we overwrite the updateOutput for that. local function apply_func(m) From 6bd137451fbac247f6d5d7678db7867c3e8d4caa Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 12 Apr 2016 08:23:17 +0200 Subject: [PATCH 6/7] Minor tweaks --- graphgen.lua | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/graphgen.lua b/graphgen.lua index 7a3c9ea..f8b0b6d 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -34,7 +34,6 @@ local colorNames = { "goldenrod","goldenrod1","goldenrod2","goldenrod3","goldenrod4" } - -- some modules exist only for constructing -- the flow of information, and should not -- have their place in the computation graph @@ -179,11 +178,12 @@ local function generateGraph(net, input, opts) local function hackTorch() for torchType, t in pairs(origTorchFuncs) do for _, func in ipairs(hackableTorchFuncs) do - oldFunc = torch[torchType][func] + local oldFunc = torch[torchType][func] t[func] = oldFunc torch[torchType][func] = function(...) local res = oldFunc(...) if res then + -- heavy use of upvalues trickyNodes[torch.pointer(res)] = {current_module, 'torch.'..func} end return res @@ -241,11 +241,14 @@ local function generateGraph(net, input, opts) local function apply_func(m) local basefunc = m.updateOutput m.updateOutput = function(self, input) + -- add input to self to help keep track of it self.__input = input + -- keeps a stack of visited modules table.insert(stack_visited_modules, current_module) current_module = self local output = basefunc(self, input) current_module = table.remove(stack_visited_modules) + -- add edges to the graph according to the node type if isSingleOperationModule(m) then local name = tostring(m) if m.inplace then -- handle it differently ? @@ -257,10 +260,9 @@ local function generateGraph(net, input, opts) -- those containers effectively do some computation, so they have their -- place in the graph for i,branch in ipairs(m.modules) do - --local last_module = branch:get(branch:size()) local last_module - if branch.modues then - last_module = branch:get(#branch.modules) + if branch.modules then + last_module = branch:get(branch:size()) else last_module = branch end From 0ed02cbfdb97229174ccea2bf0356c55191c66a4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 12 Apr 2016 08:31:02 +0200 Subject: [PATCH 7/7] No need of prefilling forward --- graphgen.lua | 3 --- 1 file changed, 3 deletions(-) diff --git a/graphgen.lua b/graphgen.lua index f8b0b6d..f75c069 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -281,9 +281,6 @@ local function generateGraph(net, input, opts) createBoundaryNode(input, 'Input') - -- fill the states from each tensor - --net:forward(input) - hackTorch() -- overwriting the standard functions to generate our graph net:apply(apply_func)