Skip to content
76 changes: 58 additions & 18 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.0.0"

[[Arpack]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6"
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
version = "0.3.1"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

Expand All @@ -35,15 +41,15 @@ version = "0.5.6"

[[CSTParser]]
deps = ["Tokenize"]
git-tree-sha1 = "0ff80f68f55fcde2ed98d7b24d7abaf20727f3f8"
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.1"
version = "0.6.2"

[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.2"
version = "0.6.0"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
Expand All @@ -52,10 +58,10 @@ uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.8.0"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
version = "0.9.6"

[[CommonSubexpressions]]
deps = ["Test"]
Expand All @@ -82,9 +88,9 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[DataAPI]]
git-tree-sha1 = "891a09f4f90361a28d0391c104a65c0202e22624"
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.0.0"
version = "1.0.1"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -116,17 +122,23 @@ version = "0.0.10"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "baaf9e165ba8a2d11fb4fb3511782ee070ee3694"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.21.1"

[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "0.3.0"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Test"]
git-tree-sha1 = "9ab8f76758cbabba8d7f103c51dce7f73fcf8e92"
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.6.3"
version = "0.6.4"

[[FixedPointNumbers]]
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
Expand All @@ -135,7 +147,7 @@ version = "0.6.1"

[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"]
git-tree-sha1 = "01debcc73af2428239343f5c3c0f8562529e91a0"
git-tree-sha1 = "7bd127360ae2e71abc3033e25e5be161520ec66f"
repo-rev = "zygote"
repo-url = "https://github.com/FluxML/Flux.jl.git"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down Expand Up @@ -167,9 +179,9 @@ version = "0.21.0"

[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175"
git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.7.0"
version = "0.7.2"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
Expand Down Expand Up @@ -227,6 +239,12 @@ git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"

[[PDMats]]
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.9"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
Expand All @@ -245,6 +263,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra", "Test"]
git-tree-sha1 = "3ce467a8e76c6030d4c3786e7d3a73442017cdc0"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.0.3"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Expand All @@ -265,6 +289,12 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"

[[Rmath]]
deps = ["BinaryProvider", "Libdl", "Random", "Statistics", "Test"]
git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.5.0"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

Expand Down Expand Up @@ -310,6 +340,16 @@ git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.32.0"

[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions", "Test"]
git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.8.0"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -321,9 +361,9 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Tokenize]]
git-tree-sha1 = "c8a8b00ae44a94950814ff77850470711a360225"
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.5"
version = "0.5.6"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand Down
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>"]
version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -14,6 +16,3 @@ julia = "≥ 1.0.0"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
130 changes: 130 additions & 0 deletions src/Access.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
using Flux: softmax

cosine_sim(u, v) = (u'v)/(norm(u)*norm(v))

# content-based addressing
"""
memprobdistrib(MemMat, key, keystrength)

Defines a normalized probability distribution over the memory locations.
"""
function memprobdistrib(M, k, β)
out = [cosine_sim(k, M[i, :]) for i in 1:size(M, 1)] .* β
out = softmax(out)
end

oneplus(x::AbstractVecOrMat) = log.(exp.(x) .+ 1)

mutable struct Access
MemMat # R^{N*W}
LinkMat # R^{N*N}
readWts # RH element Array of [0, 1]^N
wrtWt # WH element Array of (0, 1)^{N}
usageVec # (0, 1)^{N}
precedenceWt # (0, 1)^{N}
readVecs
numWriteHeads
numReadHeads
wordSize
memorySize
end

function Access(memorySize=128, wordSize=20, numReadHeads=1, numWriteHeads=1)
MemMat = zeros(Float32, memorySize, wordSize)
LinkMat = zeros(Float32, memorySize, memorySize)

readWts = rand(Float32, memorySize, numReadHeads)
fill!(readWts, 1f-6)

wrtWt = rand(Float32, memorySize, numWriteHeads)
fill!(wrtWt, 1f-6)

usageVec = zeros(Float32, memorySize, numWriteHeads)
fill!(usageVec, 1f-6)
precedenceWt = zeros(Float32, memorySize)

readVecs = zeros(Float32, wordSize, numReadHeads)
fill!(readVecs, 1f-6)

Access(MemMat, LinkMat, readWts, wrtWt, usageVec, precedenceWt, readVecs,
numWriteHeads, numReadHeads, wordSize, memorySize)
end

"""
interfacedisect(interfaceVec, writeHead, wordSize, readHeads)

Disects the interface vector obtained as the output to obtain various memory
access controls for the DNC.
"""
function interfacedisect(interfaceVec, writeHeads, wordSize, readHeads)
demarcations = cumsum([0, # Starting Index
readHeads*wordSize, # read keys
readHeads, # read strengths
writeHeads*wordSize, # write keys
writeHeads, # write strengths
writeHeads*wordSize, # erase vectors
writeHeads*wordSize, # write vectors
readHeads, # free gates
writeHeads, # allocation gates
writeHeads, # write gates
readHeads * (1 + 2writeHeads) # read modes
])


readkeys = interfaceVec[demarcations[1]+1:demarcations[2]]
readstrengths = oneplus(interfaceVec[demarcations[2]+1:demarcations[3]])
writekeys = interfaceVec[demarcations[3]+1:demarcations[4]]
writestrengths = oneplus(interfaceVec[demarcations[4]+1:demarcations[5]])
eraseVec = σ.(interfaceVec[demarcations[5]+1:demarcations[6]])
writeVec = interfaceVec[demarcations[6]+1:demarcations[7]]
freeGts = σ.(interfaceVec[demarcations[7]+1:demarcations[8]])
allocGt = σ.(interfaceVec[demarcations[8]+1:demarcations[9]])
writeGt = σ.(interfaceVec[demarcations[9]+1:demarcations[10]])
readmodes = softmax(interfaceVec[demarcations[10]+1:demarcations[11]])

readkeys = reshape(readkeys, wordSize, readHeads) # W * RH
writekeys = reshape(writekeys, wordSize, writeHeads) # W * WH
eraseVec = reshape(eraseVec, wordSize, writeHeads) # W * WH
writeVec = reshape(writeVec, wordSize, writeHeads) # W * WH
readmodes = reshape(readmodes, (1+2writeHeads), readHeads) # (WH for backward + WH for forward + 1 for content lookup) * RH

return (readkeys=readkeys, readstrengths=readstrengths, writekeys=writekeys,
writestrengths=writestrengths, eraseVec=eraseVec, writeVec=writeVec,
freeGts=freeGts, allocGts=allocGt, writeGts=writeGt, readmodes=readmodes)
end

function (access::Access)(interfaceVec)
# dynamic memory allocation

interface = interfacedisect(interfaceVec, access.numWriteHeads, access.wordSize, access.numReadHeads)

memRetVec = prod(1 .- interface[:freeGts]' .* access.readWts, dims=2) # Memory Retention Vector = [0, 1]^{N}
access.usageVec = (access.usageVec .+ access.wrtWt .- access.usageVec .* access.wrtWt) .* memRetVec
freelist = sortperm(access.usageVec) # Z^{N}
allocWt = zeros(access.usageVec)
@. allocWt[freelist] = (1 - access.usageVec[freelist]) * cumprod([1; access.usageVec[freelist]][1:end-1]) # (0, 1)^{N}

# writing
wrtcntWt = memprobdistrib(access.MemMat, interface[:writekeys], interface[:writestrengths]) # Write content weighting = (0, 1)^{N}
access.wrtWt .= interface[:writeGts] * (interface[:allocGts] * allocWt + (1 - interface[:allocGts])*wrtcntWt)
@. access.MemMat *= (ones(access.MemMat) - access.wrtWt*interface[:eraseVec]') # First we erase...
@. access.MemMat += access.wrtWt*interface[:writeVec]' # Then we write.

# temporal linkage
eye = Matrix{Float32}(I, size(access.LinkMat)...)
prevlinkscale = @. 1 - access.wrtWt - access.wrtWt'
newlink = @. access.wrtWt * access.precedenceWt'
@. access.LinkMat = (1 - eye) * (prevlinkscale * access.LinkMat + newlink)

access.precedenceWt = (1 - sum(access.wrtWt)) .* access.precedenceWt .+ access.wrtWt

# reading
forwardWts = [access.LinkMat * readWt for readWt in access.readWts]
backwardWts = [access.LinkMat' * readWt for readWt in access.readWts]
readcntWts = memprobdistrib.([access.MemMat], readkeys, readstrengths) # Read content weightings

access.readWts = [π[1].*b .+ π[2].*readcntWts .+ π[3].*f for (π, b, f) in (interface[:readmodes], backwardWts, forwardWts)]
access.readvecs = [access.MemMat' * W_r for W_r in access.readWts]

return(readvecs)
end
44 changes: 44 additions & 0 deletions src/DNC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include("DNCLSTM.jl")
include("Access.jl")

using Distributions: TruncatedNormal
using Flux

mutable struct DNC
controller
access::Access
interfaceVec
end

function DNC(memory_size=16, word_size=16, num_read_heads=4, num_write_heads=1,
hidden_size=64, output_size=4, input_size=4)
controller_input_size = input_size + word_size * num_read_heads

# A truncated normal distribution with no elements further than 2σ of μ.
dist = TruncatedNormal(0, 0.01, -0.02, 0.02)

# The interface vector of this dimension will only work when the number of
# write heads is equal to 1. I have yet to figure out how this value changes
# as the number of write heads change.
interface_vec_dimensions = word_size * num_read_heads + 3word_size + 5num_read_heads + 3
controller_output_size = output_size + interface_vec_dimensions
controller = LSTM(controller_input_size, controller_output_size)
access = Access(memory_size, word_size, num_read_heads, num_write_heads)

interface_vec = rand(dist, interface_vec_dimensions)
DNC(controller, access, interface_vec)
end

function (dnc::DNC)(input)
readVecs = dnc.access.readVecs

# flattening readVecs
readVecs = reshape(readVecs, readVecs |> size |> prod)

# concatinating them with input to form controller input
controller_input = [input; readVecs]

controller_output = dnc.controller(controller_input)


end
2 changes: 1 addition & 1 deletion src/DNCLSTM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
@treelike DNCLSTMCell

function DNCLSTMCell(in::Integer, hidden::Integer; init=glorot_uniform)
DNCLSTMCell([init(hidden, in+2*hidden) for i in 1:4]...,
DNCLSTMCell([init(hidden, in) for i in 1:4]...,
[zeros(hidden) for i in 1:6]...)
end

Expand Down
13 changes: 3 additions & 10 deletions src/DifferentiableNeuralComputer.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
module DifferentiableNeuralComputer

import Flux
import Zygote

include("DNCLSTM.jl")
include("Access.jl")
include("DNC.jl")

struct DNC end

function (dnc::DNC)(x)
return x
end

end # module
end # module DifferentiableNeuralComputer
Loading