From 0747012d83db761997da46949904bd8093bacd05 Mon Sep 17 00:00:00 2001 From: Jonathan Sum <777JonathanSum@gmail.com> Date: Mon, 30 Jan 2023 07:28:32 -0800 Subject: [PATCH] update update --- transformer_unify.ipynb | 88 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/transformer_unify.ipynb b/transformer_unify.ipynb index ea63ccd..fd4894d 100644 --- a/transformer_unify.ipynb +++ b/transformer_unify.ipynb @@ -158,11 +158,90 @@ "source": [ "(l2 == y).all() # cool" ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 512, 512])\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1a = F.gelu(layer1(q))\n", + "print(q1a.shape)\n", + "torch.allclose(q1a,att)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 512, 128])\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q2a = layer2(q1a)\n", + "y = att @ v\n", + "print(q2a.shape)\n", + "torch.allclose(q2a,y)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(True)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(q2a == y).all()" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "base", "language": "python", "name": "python3" }, @@ -176,7 +255,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.12" + }, + "vscode": { + "interpreter": { + "hash": "277b78f2730b0903abf2859ae01821a113fbf907e4071225ef1d3ec9542f1da7" + } } }, "nbformat": 4,