Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
"start": "next start"
},
"dependencies": {
"tiktoken": "^1.0.11",
"@radix-ui/react-checkbox": "^1.0.3",
"@radix-ui/react-popover": "^1.0.5",
"@radix-ui/react-select": "^1.2.0",
"@tanstack/react-query": "^4.20.2",
"@tiptap/core": "^2.3.0",
"@tiptap/extension-document": "^2.3.0",
"@tiptap/extension-paragraph": "^2.3.0",
"@tiptap/extension-text": "^2.3.0",
"@tiptap/pm": "^2.3.0",
"@tiptap/react": "^2.3.0",
"@trpc/client": "^10.9.0",
"@trpc/next": "^10.9.0",
"@trpc/react-query": "^10.9.0",
Expand All @@ -31,6 +36,7 @@
"superjson": "1.9.1",
"tailwind-merge": "^1.10.0",
"tailwindcss-animate": "^1.0.5",
"tiktoken": "^1.0.11",
"zod": "^3.20.6"
},
"devDependencies": {
Expand Down
227 changes: 227 additions & 0 deletions src/components/RichEditor.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import Document from "@tiptap/extension-document";
import Paragraph from "@tiptap/extension-paragraph";
import Text from "@tiptap/extension-text";
import {
EditorContent,
Extension,
useEditor,
type JSONContent,
} from "@tiptap/react";
import { useEffect, useMemo } from "react";
import { Plugin, PluginKey } from "@tiptap/pm/state";

import { type Node } from "@tiptap/pm/model";
import { Decoration, DecorationSet } from "@tiptap/pm/view";
import Graphemer from "graphemer";
import { cn } from "~/utils/cn";
import { type UserModelChoice, getUserSelectedEncoder } from "~/utils/model";

const COLORS = [
"bg-sky-200",
"bg-amber-200",
"bg-blue-200",
"bg-green-200",
"bg-orange-200",
"bg-cyan-200",
"bg-gray-200",
"bg-purple-200",
"bg-indigo-200",
"bg-lime-200",
"bg-rose-200",
"bg-violet-200",
"bg-yellow-200",
"bg-emerald-200",
"bg-zinc-200",
"bg-red-200",
"bg-fuchsia-200",
"bg-pink-200",
"bg-teal-200",
];

function convertTextToJSONContent(content: string | null | undefined) {
if (content == null) return [];
const lines = content.split("\n");
return lines.map(
(line): JSONContent => ({
type: "paragraph",
content: line ? [{ type: "text", text: line }] : [],
})
);
}

function convertMessageToJSONContent(
content: string | null | undefined
): JSONContent {
if (typeof content === "string") {
return {
type: "doc",
content: convertTextToJSONContent(content),
};
}

return { type: "doc", content: [] };
}

const key = new PluginKey("tiktokenizer");

function binarySearch(haystack: number[], needle: number) {
let left = 0;
let right = haystack.length - 1;

while (left <= right) {
const cursor = (left + right) >> 1;
if (haystack[cursor]! <= needle) left = cursor + 1;
else right = cursor - 1;
}

return needle + left;
}

export const TokenHighlighter = Extension.create<{ model: UserModelChoice }>({
name: "colorHighlighter",

addProseMirrorPlugins() {
const encoder = getUserSelectedEncoder(this.options.model);

const graphemer = new Graphemer();
const textDecoder = new TextDecoder();

function getTokenDecorations(doc: Node): DecorationSet {
let text = "";
const bounds: number[] = [];

doc.descendants((node, position) => {
const insert = (position > 0 ? `\n` : "") + node.textContent;
bounds.push(text.length);
text += insert;

return false;
});

const encoding = encoder.encode(text, "all");

let textAcc = 0;
let byteAcc: number[] = [];
let tokenAcc: number[] = [];
let inputGraphemes = graphemer.splitGraphemes(text);

const decorations: Decoration[] = [];
for (let idx = 0; idx < encoding.length; idx++) {
const token = encoding[idx]!;
byteAcc.push(...encoder.decode_single_token_bytes(token));
tokenAcc.push(token);

const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
const graphemes = graphemer.splitGraphemes(segmentText);

if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
decorations.push(
Decoration.inline(
binarySearch(bounds, textAcc),
binarySearch(bounds, textAcc + segmentText.length),
{ class: cn(COLORS[decorations.length % COLORS.length]) }
)
);

textAcc += segmentText.length;

byteAcc = [];
tokenAcc = [];
inputGraphemes = inputGraphemes.slice(graphemes.length);
}
}

return DecorationSet.create(doc, decorations);
}

return [
new Plugin({
key,
state: {
init: (_, { doc }) => getTokenDecorations(doc),
apply: (transaction, oldState) =>
transaction.docChanged
? getTokenDecorations(transaction.doc)
: oldState,
},
props: {
decorations(state) {
return this.getState(state);
},
},
}),
];
},
});

export function RichEditor(props: {
value: string;
onChange: (value: string) => void;
model: UserModelChoice;
}) {
const content = useMemo(
() => convertMessageToJSONContent(props.value),
[props.value]
);

const editor = useEditor(
{
extensions: [
Document,
Paragraph,
Text,
TokenHighlighter.configure({ model: props.model }),
],
editorProps: {
attributes: { class: cn("outline-none p-3 border rounded-md") },
},

content,
onUpdate: ({ editor }) => {
const json = editor.getJSON();
const values: Array<{ type: "text"; text: string }> = [];

// assume root is type="doc"
for (const child of json.content ?? []) {
switch (child.type) {
case "paragraph": {
const text = (
child.content
?.map((i) => i.text)
.filter((x): x is string => x != null) ?? []
).join("");

values.push({ type: "text" as const, text });
break;
}
case "text": {
if (child.text)
values.push({ type: "text" as const, text: child.text });
break;
}
}
}

let result: string = values
.filter((i): i is { type: "text"; text: string } => i.type === "text")
.map((i) => i.text)
.join("\n");
props.onChange?.(result);
},
},
[props.model]
);

useEffect(() => {
if (!editor) return;
let { from, to } = editor.state.selection;
editor.commands.setContent(content, false, { preserveWhitespace: "full" });
editor.commands.setTextSelection({ from, to });
}, [editor, content]);

return (
<div className="grid w-full">
<EditorContent editor={editor} />
</div>
);
}
38 changes: 5 additions & 33 deletions src/pages/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,11 @@ import {
} from "~/sections/EncoderSelect";
import { TokenViewer } from "~/sections/TokenViewer";
import { TextArea } from "~/components/Input";
import {
encoding_for_model,
get_encoding,
type TiktokenModel,
type TiktokenEncoding,
} from "tiktoken";
import { type TiktokenModel, type TiktokenEncoding } from "tiktoken";
import { getSegments } from "~/utils/segments";
import { useRouter } from "next/router";

function getUserSelectedEncoder(
params: { model: TiktokenModel } | { encoder: TiktokenEncoding }
) {
if ("model" in params) {
if (
params.model === "gpt-4" ||
params.model === "gpt-4-32k" ||
params.model === "gpt-3.5-turbo" ||
params.model === "gpt-4-1106-preview"
) {
return encoding_for_model(params.model, {
"<|im_start|>": 100264,
"<|im_end|>": 100265,
"<|im_sep|>": 100266,
});
}

return encoding_for_model(params.model);
}

if ("encoder" in params) {
return get_encoding(params.encoder);
}

throw new Error("Invalid params");
}
import { RichEditor } from "~/components/RichEditor";
import { getUserSelectedEncoder } from "../utils/model";

function isChatModel(
params: { model: TiktokenModel } | { encoder: TiktokenEncoding }
Expand Down Expand Up @@ -130,6 +100,8 @@ const Home: NextPage<InferGetServerSidePropsType<typeof getServerSideProps>> = (
/>
</div>

<RichEditor model={params} value={inputText} onChange={setInputText} />

<div className="grid gap-4 md:grid-cols-2">
<section className="flex flex-col gap-4">
{isChatModel(params) && (
Expand Down
35 changes: 35 additions & 0 deletions src/utils/model.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import {
encoding_for_model,
get_encoding,
type TiktokenModel,
type TiktokenEncoding,
} from "tiktoken";

export type UserModelChoice =
| { model: TiktokenModel }
| { encoder: TiktokenEncoding };

export function getUserSelectedEncoder(params: UserModelChoice) {
if ("model" in params) {
if (
params.model === "gpt-4" ||
params.model === "gpt-4-32k" ||
params.model === "gpt-3.5-turbo" ||
params.model === "gpt-4-1106-preview"
) {
return encoding_for_model(params.model, {
"<|im_start|>": 100264,
"<|im_end|>": 100265,
"<|im_sep|>": 100266,
});
}

return encoding_for_model(params.model);
}

if ("encoder" in params) {
return get_encoding(params.encoder);
}

throw new Error("Invalid params");
}
Loading