Skip to content

exponentialXP/smrnn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This new architecture is based off the RNN architecture, with almost the same speed and memory usage, now has identical loss to the transformer architecture, but O(n^2)->O(n) time complexity!

In almost all modern language models, the transformer architecture is used and is the beating heart of the model. However, the transformer architecture is extremely inefficient when used for long sequences. I will introduce a new RNN-like architecture which I have called matmul and scale RNNs, which has linear time-complexity, opposite to transformer’s quadratic time-complexity.
Like the traditional RNNs, the first step is to run the tokens through an embedding layer. Then for each timestep, we update our hidden state using a Scale(token) layer and Matmul(h, token) layer, initializing h at 0. The result is H[t] = For each layer: h’ = Silu(h + MatmulLayer(concat(h, scale(token) * H[t-1]))) Then, like any other network we convert them to logits (like a reverse embedding) and apply softmax.

Scale Layer:
Linear(in_dim=emb dim, out_dim=emb dim),
LayerNorm(shape=emb dim)

Matmul Layer:
Linear(in_dim=emb dimx2, out_dim=emb dim),
LayerNorm(shape=emb dim)

WHY THIS ARCHITECTURE?
First of all, transformer’s memory requirements and computational time dramatically increase depending on context length O(n^2) while RNNs have a time complexity of O(n). Secondly, this architecture’s performance is in a different ballpark to RNNs and LSTMs and even comparable to transformers while keeping half or more of the efficiency as vanilla RNNs.

image

Hyperparameters: Learning rate = 3e-4, Batch Size = 32, Sequence Length = 128, Optimizer = AdamW, Emb Dim / State Dim = 256, Layers = 4, Heads (transformer only) = 4, Tokenizer = GPT2 (tiktoken), Gradient Clipping (On) Max Norm = 1

Architecture Validation Loss (FineWeb-Edu) Non-Embedding Parameters
SMRNN 5.5 ~800,000
Transformer 5.4 ~3,200,000
Vanilla RNN 7.7 ~600,000

CONCLUSION
These results show that the model has comparable performance to transformers even with its O(n) compression and having x4 less parameters. However, it is unknown if this architecture will scale up as well as transformers due to the limited amount of compute in this experiment. More scaling, longer seq len and more must be compared to really find out if this architecture can outperform transformers practically in the long run.

About

Almost SOTA LLM architecture, with O(n) time complexity

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages