diff --git a/Manifest.toml b/Manifest.toml index 2f11e11..c3237cc 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -18,6 +18,12 @@ git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "1.0.0" +[[Arpack]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra"] +git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6" +uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" +version = "0.3.1" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -35,15 +41,15 @@ version = "0.5.6" [[CSTParser]] deps = ["Tokenize"] -git-tree-sha1 = "0ff80f68f55fcde2ed98d7b24d7abaf20727f3f8" +git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "0.6.1" +version = "0.6.2" [[CodecZlib]] -deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] -git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" +deps = ["BinaryProvider", "Libdl", "TranscodingStreams"] +git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.5.2" +version = "0.6.0" [[ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -52,10 +58,10 @@ uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" version = "0.8.0" [[Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"] -git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543" +deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"] +git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.9.5" +version = "0.9.6" [[CommonSubexpressions]] deps = ["Test"] @@ -82,9 +88,9 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.0.0" [[DataAPI]] -git-tree-sha1 = "891a09f4f90361a28d0391c104a65c0202e22624" +git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.0.0" +version = "1.0.1" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] @@ -116,6 +122,12 @@ version = "0.0.10" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[Distributions]] +deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] +git-tree-sha1 = "baaf9e165ba8a2d11fb4fb3511782ee070ee3694" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.21.1" + [[FFTW]] deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515" @@ -123,10 +135,10 @@ uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" version = "0.3.0" [[FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Test"] -git-tree-sha1 = "9ab8f76758cbabba8d7f103c51dce7f73fcf8e92" +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.6.3" +version = "0.6.4" [[FixedPointNumbers]] git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b" @@ -135,7 +147,7 @@ version = "0.6.1" [[Flux]] deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] -git-tree-sha1 = "01debcc73af2428239343f5c3c0f8562529e91a0" +git-tree-sha1 = "7bd127360ae2e71abc3033e25e5be161520ec66f" repo-rev = "zygote" repo-url = "https://github.com/FluxML/Flux.jl.git" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -167,9 +179,9 @@ version = "0.21.0" [[Juno]] deps = ["Base64", "Logging", "Media", "Profile", "Test"] -git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175" +git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.7.0" +version = "0.7.2" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -227,6 +239,12 @@ git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.1.0" +[[PDMats]] +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] +git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.9.9" + [[Parsers]] deps = ["Dates", "Test"] git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9" @@ -245,6 +263,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["Printf"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +[[QuadGK]] +deps = ["DataStructures", "LinearAlgebra", "Test"] +git-tree-sha1 = "3ce467a8e76c6030d4c3786e7d3a73442017cdc0" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.0.3" + [[REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -265,6 +289,12 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "0.5.2" +[[Rmath]] +deps = ["BinaryProvider", "Libdl", "Random", "Statistics", "Test"] +git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.5.0" + [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -310,6 +340,16 @@ git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.32.0" +[[StatsFuns]] +deps = ["Rmath", "SpecialFunctions", "Test"] +git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.8.0" + +[[SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -321,9 +361,9 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.0" [[Tokenize]] -git-tree-sha1 = "c8a8b00ae44a94950814ff77850470711a360225" +git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf" uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.5" +version = "0.5.6" [[TranscodingStreams]] deps = ["Random", "Test"] diff --git a/Project.toml b/Project.toml index 91f67ad..53c6fd5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,8 +4,10 @@ authors = ["Julian P Samaroo "] version = "0.1.0" [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -14,6 +16,3 @@ julia = "≥ 1.0.0" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] diff --git a/src/Access.jl b/src/Access.jl new file mode 100644 index 0000000..57bc8a1 --- /dev/null +++ b/src/Access.jl @@ -0,0 +1,130 @@ +using Flux: softmax + +cosine_sim(u, v) = (u'v)/(norm(u)*norm(v)) + +# content-based addressing +""" + memprobdistrib(MemMat, key, keystrength) + +Defines a normalized probability distribution over the memory locations. +""" +function memprobdistrib(M, k, β) + out = [cosine_sim(k, M[i, :]) for i in 1:size(M, 1)] .* β + out = softmax(out) +end + +oneplus(x::AbstractVecOrMat) = log.(exp.(x) .+ 1) + +mutable struct Access + MemMat # R^{N*W} + LinkMat # R^{N*N} + readWts # RH element Array of [0, 1]^N + wrtWt # WH element Array of (0, 1)^{N} + usageVec # (0, 1)^{N} + precedenceWt # (0, 1)^{N} + readVecs + numWriteHeads + numReadHeads + wordSize + memorySize +end + +function Access(memorySize=128, wordSize=20, numReadHeads=1, numWriteHeads=1) + MemMat = zeros(Float32, memorySize, wordSize) + LinkMat = zeros(Float32, memorySize, memorySize) + + readWts = rand(Float32, memorySize, numReadHeads) + fill!(readWts, 1f-6) + + wrtWt = rand(Float32, memorySize, numWriteHeads) + fill!(wrtWt, 1f-6) + + usageVec = zeros(Float32, memorySize, numWriteHeads) + fill!(usageVec, 1f-6) + precedenceWt = zeros(Float32, memorySize) + + readVecs = zeros(Float32, wordSize, numReadHeads) + fill!(readVecs, 1f-6) + + Access(MemMat, LinkMat, readWts, wrtWt, usageVec, precedenceWt, readVecs, + numWriteHeads, numReadHeads, wordSize, memorySize) +end + +""" + interfacedisect(interfaceVec, writeHead, wordSize, readHeads) + +Disects the interface vector obtained as the output to obtain various memory +access controls for the DNC. +""" +function interfacedisect(interfaceVec, writeHeads, wordSize, readHeads) + demarcations = cumsum([0, # Starting Index + readHeads*wordSize, # read keys + readHeads, # read strengths + writeHeads*wordSize, # write keys + writeHeads, # write strengths + writeHeads*wordSize, # erase vectors + writeHeads*wordSize, # write vectors + readHeads, # free gates + writeHeads, # allocation gates + writeHeads, # write gates + readHeads * (1 + 2writeHeads) # read modes + ]) + + + readkeys = interfaceVec[demarcations[1]+1:demarcations[2]] + readstrengths = oneplus(interfaceVec[demarcations[2]+1:demarcations[3]]) + writekeys = interfaceVec[demarcations[3]+1:demarcations[4]] + writestrengths = oneplus(interfaceVec[demarcations[4]+1:demarcations[5]]) + eraseVec = σ.(interfaceVec[demarcations[5]+1:demarcations[6]]) + writeVec = interfaceVec[demarcations[6]+1:demarcations[7]] + freeGts = σ.(interfaceVec[demarcations[7]+1:demarcations[8]]) + allocGt = σ.(interfaceVec[demarcations[8]+1:demarcations[9]]) + writeGt = σ.(interfaceVec[demarcations[9]+1:demarcations[10]]) + readmodes = softmax(interfaceVec[demarcations[10]+1:demarcations[11]]) + + readkeys = reshape(readkeys, wordSize, readHeads) # W * RH + writekeys = reshape(writekeys, wordSize, writeHeads) # W * WH + eraseVec = reshape(eraseVec, wordSize, writeHeads) # W * WH + writeVec = reshape(writeVec, wordSize, writeHeads) # W * WH + readmodes = reshape(readmodes, (1+2writeHeads), readHeads) # (WH for backward + WH for forward + 1 for content lookup) * RH + + return (readkeys=readkeys, readstrengths=readstrengths, writekeys=writekeys, + writestrengths=writestrengths, eraseVec=eraseVec, writeVec=writeVec, + freeGts=freeGts, allocGts=allocGt, writeGts=writeGt, readmodes=readmodes) +end + +function (access::Access)(interfaceVec) + # dynamic memory allocation + + interface = interfacedisect(interfaceVec, access.numWriteHeads, access.wordSize, access.numReadHeads) + + memRetVec = prod(1 .- interface[:freeGts]' .* access.readWts, dims=2) # Memory Retention Vector = [0, 1]^{N} + access.usageVec = (access.usageVec .+ access.wrtWt .- access.usageVec .* access.wrtWt) .* memRetVec + freelist = sortperm(access.usageVec) # Z^{N} + allocWt = zeros(access.usageVec) + @. allocWt[freelist] = (1 - access.usageVec[freelist]) * cumprod([1; access.usageVec[freelist]][1:end-1]) # (0, 1)^{N} + + # writing + wrtcntWt = memprobdistrib(access.MemMat, interface[:writekeys], interface[:writestrengths]) # Write content weighting = (0, 1)^{N} + access.wrtWt .= interface[:writeGts] * (interface[:allocGts] * allocWt + (1 - interface[:allocGts])*wrtcntWt) + @. access.MemMat *= (ones(access.MemMat) - access.wrtWt*interface[:eraseVec]') # First we erase... + @. access.MemMat += access.wrtWt*interface[:writeVec]' # Then we write. + + # temporal linkage + eye = Matrix{Float32}(I, size(access.LinkMat)...) + prevlinkscale = @. 1 - access.wrtWt - access.wrtWt' + newlink = @. access.wrtWt * access.precedenceWt' + @. access.LinkMat = (1 - eye) * (prevlinkscale * access.LinkMat + newlink) + + access.precedenceWt = (1 - sum(access.wrtWt)) .* access.precedenceWt .+ access.wrtWt + + # reading + forwardWts = [access.LinkMat * readWt for readWt in access.readWts] + backwardWts = [access.LinkMat' * readWt for readWt in access.readWts] + readcntWts = memprobdistrib.([access.MemMat], readkeys, readstrengths) # Read content weightings + + access.readWts = [π[1].*b .+ π[2].*readcntWts .+ π[3].*f for (π, b, f) in (interface[:readmodes], backwardWts, forwardWts)] + access.readvecs = [access.MemMat' * W_r for W_r in access.readWts] + + return(readvecs) +end diff --git a/src/DNC.jl b/src/DNC.jl new file mode 100644 index 0000000..5e86d89 --- /dev/null +++ b/src/DNC.jl @@ -0,0 +1,44 @@ +#include("DNCLSTM.jl") +include("Access.jl") + +using Distributions: TruncatedNormal +using Flux + +mutable struct DNC + controller + access::Access + interfaceVec +end + +function DNC(memory_size=16, word_size=16, num_read_heads=4, num_write_heads=1, + hidden_size=64, output_size=4, input_size=4) + controller_input_size = input_size + word_size * num_read_heads + + # A truncated normal distribution with no elements further than 2σ of μ. + dist = TruncatedNormal(0, 0.01, -0.02, 0.02) + + # The interface vector of this dimension will only work when the number of + # write heads is equal to 1. I have yet to figure out how this value changes + # as the number of write heads change. + interface_vec_dimensions = word_size * num_read_heads + 3word_size + 5num_read_heads + 3 + controller_output_size = output_size + interface_vec_dimensions + controller = LSTM(controller_input_size, controller_output_size) + access = Access(memory_size, word_size, num_read_heads, num_write_heads) + + interface_vec = rand(dist, interface_vec_dimensions) + DNC(controller, access, interface_vec) +end + +function (dnc::DNC)(input) + readVecs = dnc.access.readVecs + + # flattening readVecs + readVecs = reshape(readVecs, readVecs |> size |> prod) + + # concatinating them with input to form controller input + controller_input = [input; readVecs] + + controller_output = dnc.controller(controller_input) + + +end diff --git a/src/DNCLSTM.jl b/src/DNCLSTM.jl index 0db42d1..cde94f3 100644 --- a/src/DNCLSTM.jl +++ b/src/DNCLSTM.jl @@ -9,7 +9,7 @@ end @treelike DNCLSTMCell function DNCLSTMCell(in::Integer, hidden::Integer; init=glorot_uniform) - DNCLSTMCell([init(hidden, in+2*hidden) for i in 1:4]..., + DNCLSTMCell([init(hidden, in) for i in 1:4]..., [zeros(hidden) for i in 1:6]...) end diff --git a/src/DifferentiableNeuralComputer.jl b/src/DifferentiableNeuralComputer.jl index 246daa7..180ce4b 100644 --- a/src/DifferentiableNeuralComputer.jl +++ b/src/DifferentiableNeuralComputer.jl @@ -1,14 +1,7 @@ module DifferentiableNeuralComputer -import Flux -import Zygote - include("DNCLSTM.jl") +include("Access.jl") +include("DNC.jl") -struct DNC end - -function (dnc::DNC)(x) - return x -end - -end # module +end # module DifferentiableNeuralComputer diff --git a/src/sandbox.jl b/src/sandbox.jl new file mode 100644 index 0000000..d38c71b --- /dev/null +++ b/src/sandbox.jl @@ -0,0 +1,18 @@ +# Prepping training data -> random memory pattern +#include("DNCLSTM.jl") + +in = 4 +#model = DNCLSTM(in, 4) + +num_seq = 10 +seq_len = 6 +seq_width = 4 +con = rand(1:seq_width, seq_len) + +seq = zeros(seq_len, seq_width) +idx = con |> enumerate .|> CartesianIndex +seq[idx] .= 1 +zer = zeros(seq_len, seq_width) + +final_i_data = hcat(seq', zer') +final_o_data = hcat(zer', seq')