From c969d44dd412441f0a739207776ffd6036dd7022 Mon Sep 17 00:00:00 2001 From: Aidan Bleser <117548273+ieedan@users.noreply.github.com> Date: Thu, 19 Jun 2025 12:16:13 -0500 Subject: [PATCH] WIP Branching (#29) --- src/lib/backend/convex/conversations.ts | 59 ++++++++ src/lib/backend/convex/schema.ts | 1 + src/lib/components/app-sidebar.svelte | 4 + .../components/icons/branch-and-regen.svelte | 52 +++++++ src/lib/components/icons/branch.svelte | 28 ++++ src/lib/components/icons/index.ts | 4 +- src/lib/components/ui/button/button.svelte | 1 + src/routes/api/generate-message/+server.ts | 85 +++++++----- src/routes/chat/[id]/+page.svelte | 2 +- src/routes/chat/[id]/message.svelte | 129 ++++++++++++++++-- 10 files changed, 317 insertions(+), 48 deletions(-) create mode 100644 src/lib/components/icons/branch-and-regen.svelte create mode 100644 src/lib/components/icons/branch.svelte diff --git a/src/lib/backend/convex/conversations.ts b/src/lib/backend/convex/conversations.ts index f5f34f2..8f09b1f 100644 --- a/src/lib/backend/convex/conversations.ts +++ b/src/lib/backend/convex/conversations.ts @@ -147,6 +147,65 @@ export const createAndAddMessage = mutation({ }, }); +export const createBranched = mutation({ + args: { + conversation_id: v.id('conversations'), + from_message_id: v.id('messages'), + session_token: v.string(), + }, + handler: async (ctx, args): Promise> => { + const session = await ctx.runQuery(api.betterAuth.publicGetSession, { + session_token: args.session_token, + }); + + if (!session) throw new Error('Unauthorized'); + + const existingConversation = await ctx.db.get(args.conversation_id); + + console.log(existingConversation); + + if (!existingConversation) throw new Error('Conversation not found'); + if (existingConversation.user_id !== session.userId && !existingConversation.public) + throw new Error('Unauthorized'); + + const messages = await ctx.db + .query('messages') + .withIndex('by_conversation', (q) => q.eq('conversation_id', args.conversation_id)) + .collect(); + + const messageIndex = messages.findIndex((m) => m._id === args.from_message_id); + + const newMessages = messages.slice(0, messageIndex + 1); + + const newConversationId = await ctx.db.insert('conversations', { + title: existingConversation.title, + branched_from: existingConversation._id, + user_id: session.userId, + updated_at: Date.now(), + generating: false, + public: false, + cost_usd: newMessages.reduce((acc, m) => acc + (m.cost_usd ?? 0), 0), + }); + + console.log(newConversationId); + + await Promise.all( + newMessages.map((m) => { + const newMessage = { + ...m, + _id: undefined, + _creationTime: undefined, + conversation_id: newConversationId, + }; + + return ctx.db.insert('messages', newMessage); + }) + ); + + return newConversationId; + }, +}); + export const updateTitle = mutation({ args: { conversation_id: v.id('conversations'), diff --git a/src/lib/backend/convex/schema.ts b/src/lib/backend/convex/schema.ts index 6c823a2..97aef42 100644 --- a/src/lib/backend/convex/schema.ts +++ b/src/lib/backend/convex/schema.ts @@ -54,6 +54,7 @@ export default defineSchema({ generating: v.optional(v.boolean()), cost_usd: v.optional(v.number()), public: v.optional(v.boolean()), + branched_from: v.optional(v.id('conversations')), }).index('by_user', ['user_id']), messages: defineTable({ conversation_id: v.string(), diff --git a/src/lib/components/app-sidebar.svelte b/src/lib/components/app-sidebar.svelte index 13460b6..9b08808 100644 --- a/src/lib/components/app-sidebar.svelte +++ b/src/lib/components/app-sidebar.svelte @@ -19,6 +19,7 @@ import XIcon from '~icons/lucide/x'; import { Button } from './ui/button'; import { callModal } from './ui/modal/global-modal.svelte'; + import SplitIcon from '~icons/lucide/split'; let { searchModalOpen = $bindable(false) }: { searchModalOpen: boolean } = $props(); @@ -198,6 +199,9 @@ )} >

+ {#if conversation.branched_from} + + {/if} {conversation.title}

diff --git a/src/lib/components/icons/branch-and-regen.svelte b/src/lib/components/icons/branch-and-regen.svelte new file mode 100644 index 0000000..d51de23 --- /dev/null +++ b/src/lib/components/icons/branch-and-regen.svelte @@ -0,0 +1,52 @@ + + + + + + + + + + diff --git a/src/lib/components/icons/branch.svelte b/src/lib/components/icons/branch.svelte new file mode 100644 index 0000000..4783025 --- /dev/null +++ b/src/lib/components/icons/branch.svelte @@ -0,0 +1,28 @@ + + + + + + diff --git a/src/lib/components/icons/index.ts b/src/lib/components/icons/index.ts index 61e6d6a..bfed7f2 100644 --- a/src/lib/components/icons/index.ts +++ b/src/lib/components/icons/index.ts @@ -2,6 +2,8 @@ import type { HTMLAttributes } from 'svelte/elements'; import GitHub from './github.svelte'; import TypeScript from './typescript.svelte'; import Svelte from './svelte.svelte'; +import Branch from './branch.svelte'; +import BranchAndRegen from './branch-and-regen.svelte'; export interface Props extends HTMLAttributes { class?: string; @@ -9,4 +11,4 @@ export interface Props extends HTMLAttributes { height?: number; } -export { GitHub, TypeScript, Svelte }; +export { GitHub, TypeScript, Svelte, Branch, BranchAndRegen }; diff --git a/src/lib/components/ui/button/button.svelte b/src/lib/components/ui/button/button.svelte index 37c5887..838907e 100644 --- a/src/lib/components/ui/button/button.svelte +++ b/src/lib/components/ui/button/button.svelte @@ -93,6 +93,7 @@ this={href ? 'a' : 'button'} {...rest} data-slot="button" + data-loading={loading} type={href ? undefined : type} href={href && !disabled ? href : undefined} disabled={href ? undefined : disabled || loading} diff --git a/src/routes/api/generate-message/+server.ts b/src/routes/api/generate-message/+server.ts index df1c354..6a331ec 100644 --- a/src/routes/api/generate-message/+server.ts +++ b/src/routes/api/generate-message/+server.ts @@ -18,23 +18,34 @@ import { parseMessageForRules } from '$lib/utils/rules.js'; // Set to true to enable debug logging const ENABLE_LOGGING = true; -const reqBodySchema = z.object({ - message: z.string(), - model_id: z.string(), +const reqBodySchema = z + .object({ + message: z.string().optional(), + model_id: z.string(), - session_token: z.string(), - conversation_id: z.string().optional(), - web_search_enabled: z.boolean().optional(), - images: z - .array( - z.object({ - url: z.string(), - storage_id: z.string(), - fileName: z.string().optional(), - }) - ) - .optional(), -}); + session_token: z.string(), + conversation_id: z.string().optional(), + web_search_enabled: z.boolean().optional(), + images: z + .array( + z.object({ + url: z.string(), + storage_id: z.string(), + fileName: z.string().optional(), + }) + ) + .optional(), + }) + .refine( + (data) => { + if (data.conversation_id === undefined && data.message === undefined) return false; + + return true; + }, + { + message: 'You must provide a message when creating a new conversation', + } + ); export type GenerateMessageRequestBody = z.infer; @@ -688,6 +699,11 @@ export const POST: RequestHandler = async ({ request }) => { let conversationId = args.conversation_id; if (!conversationId) { + // technically zod should catch this but just in case + if (args.message === undefined) { + return error(400, 'You must provide a message when creating a new conversation'); + } + const convMessageResult = await ResultAsync.fromPromise( client.mutation(api.conversations.createAndAddMessage, { content: args.message, @@ -722,25 +738,28 @@ export const POST: RequestHandler = async ({ request }) => { ); } else { log('Using existing conversation', startTime); - const userMessageResult = await ResultAsync.fromPromise( - client.mutation(api.messages.create, { - conversation_id: conversationId as Id<'conversations'>, - content: args.message, - session_token: args.session_token, - model_id: args.model_id, - role: 'user', - images: args.images, - web_search_enabled: args.web_search_enabled, - }), - (e) => `Failed to create user message: ${e}` - ); - if (userMessageResult.isErr()) { - log(`User message creation failed: ${userMessageResult.error}`, startTime); - return error(500, 'Failed to create user message'); + if (args.message) { + const userMessageResult = await ResultAsync.fromPromise( + client.mutation(api.messages.create, { + conversation_id: conversationId as Id<'conversations'>, + content: args.message, + session_token: args.session_token, + model_id: args.model_id, + role: 'user', + images: args.images, + web_search_enabled: args.web_search_enabled, + }), + (e) => `Failed to create user message: ${e}` + ); + + if (userMessageResult.isErr()) { + log(`User message creation failed: ${userMessageResult.error}`, startTime); + return error(500, 'Failed to create user message'); + } + + log('User message created', startTime); } - - log('User message created', startTime); } // Set generating status to true before starting background generation diff --git a/src/routes/chat/[id]/+page.svelte b/src/routes/chat/[id]/+page.svelte index d06fe1b..5c8da57 100644 --- a/src/routes/chat/[id]/+page.svelte +++ b/src/routes/chat/[id]/+page.svelte @@ -61,7 +61,7 @@ {conversation.data?.title} | thom.chat -
+
{#if !conversation.data && !conversation.isLoading}
diff --git a/src/routes/chat/[id]/message.svelte b/src/routes/chat/[id]/message.svelte index a863fa9..07ab0ad 100644 --- a/src/routes/chat/[id]/message.svelte +++ b/src/routes/chat/[id]/message.svelte @@ -1,7 +1,7 @@ {#if message.role !== 'system' && !(message.role === 'assistant' && message.content.length === 0 && !message.error)} @@ -101,26 +163,67 @@
+ + {#snippet trigger(tooltip)} + + {/snippet} + Branch off this message + + {#if message.role === 'user'} + + {#snippet trigger(tooltip)} + + {/snippet} + Branch and regenerate + + {/if} {#if message.content.length > 0} - - {/if} - {#if message.model_id !== undefined} - {message.model_id} - {/if} - {#if message.web_search_enabled} - Web search enabled + + {#snippet trigger(tooltip)} + + {/snippet} + Copy + {/if} + {#if message.role === 'assistant'} + {#if message.model_id !== undefined} + {message.model_id} + {/if} + {#if message.web_search_enabled} + Web search enabled + {/if} - {#if message.cost_usd !== undefined} - - ${message.cost_usd.toFixed(6)} - + {#if message.cost_usd !== undefined} + + ${message.cost_usd.toFixed(6)} + + {/if} {/if}