WIP Branching (#29)

This commit is contained in:
Aidan Bleser 2025-06-19 12:16:13 -05:00 committed by GitHub
parent 13b3449d82
commit c969d44dd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 317 additions and 48 deletions

View file

@ -147,6 +147,65 @@ export const createAndAddMessage = mutation({
}, },
}); });
export const createBranched = mutation({
args: {
conversation_id: v.id('conversations'),
from_message_id: v.id('messages'),
session_token: v.string(),
},
handler: async (ctx, args): Promise<Id<'conversations'>> => {
const session = await ctx.runQuery(api.betterAuth.publicGetSession, {
session_token: args.session_token,
});
if (!session) throw new Error('Unauthorized');
const existingConversation = await ctx.db.get(args.conversation_id);
console.log(existingConversation);
if (!existingConversation) throw new Error('Conversation not found');
if (existingConversation.user_id !== session.userId && !existingConversation.public)
throw new Error('Unauthorized');
const messages = await ctx.db
.query('messages')
.withIndex('by_conversation', (q) => q.eq('conversation_id', args.conversation_id))
.collect();
const messageIndex = messages.findIndex((m) => m._id === args.from_message_id);
const newMessages = messages.slice(0, messageIndex + 1);
const newConversationId = await ctx.db.insert('conversations', {
title: existingConversation.title,
branched_from: existingConversation._id,
user_id: session.userId,
updated_at: Date.now(),
generating: false,
public: false,
cost_usd: newMessages.reduce((acc, m) => acc + (m.cost_usd ?? 0), 0),
});
console.log(newConversationId);
await Promise.all(
newMessages.map((m) => {
const newMessage = {
...m,
_id: undefined,
_creationTime: undefined,
conversation_id: newConversationId,
};
return ctx.db.insert('messages', newMessage);
})
);
return newConversationId;
},
});
export const updateTitle = mutation({ export const updateTitle = mutation({
args: { args: {
conversation_id: v.id('conversations'), conversation_id: v.id('conversations'),

View file

@ -54,6 +54,7 @@ export default defineSchema({
generating: v.optional(v.boolean()), generating: v.optional(v.boolean()),
cost_usd: v.optional(v.number()), cost_usd: v.optional(v.number()),
public: v.optional(v.boolean()), public: v.optional(v.boolean()),
branched_from: v.optional(v.id('conversations')),
}).index('by_user', ['user_id']), }).index('by_user', ['user_id']),
messages: defineTable({ messages: defineTable({
conversation_id: v.string(), conversation_id: v.string(),

View file

@ -19,6 +19,7 @@
import XIcon from '~icons/lucide/x'; import XIcon from '~icons/lucide/x';
import { Button } from './ui/button'; import { Button } from './ui/button';
import { callModal } from './ui/modal/global-modal.svelte'; import { callModal } from './ui/modal/global-modal.svelte';
import SplitIcon from '~icons/lucide/split';
let { searchModalOpen = $bindable(false) }: { searchModalOpen: boolean } = $props(); let { searchModalOpen = $bindable(false) }: { searchModalOpen: boolean } = $props();
@ -198,6 +199,9 @@
)} )}
> >
<p class="truncate rounded-lg py-2 pr-4 pl-3 whitespace-nowrap"> <p class="truncate rounded-lg py-2 pr-4 pl-3 whitespace-nowrap">
{#if conversation.branched_from}
<SplitIcon class="text-muted-foreground/50 mr-1 inline size-4" />
{/if}
<span>{conversation.title}</span> <span>{conversation.title}</span>
</p> </p>
<div class="pr-2"> <div class="pr-2">

View file

@ -0,0 +1,52 @@
<script lang="ts">
import type { Props } from '.';
import { cn } from '$lib/utils/utils';
let { class: className, ...rest }: Props = $props();
</script>
<svg
class={cn('size-4 fill-current', className)}
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
{...rest}
>
<path
d="M2.96132 21.09V4.55M2.96132 4.55V2.64M2.96132 4.55V5.78C2.96132 7.39 4.17132 8.89 6.16132 9.72L10.7413 11.64C12.7213 12.47 13.9413 13.96 13.9413 15.58V19.42"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M17.4713 17.59L14.0613 21.25L10.4013 17.84"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M20.8382 6.76662C20.8382 5.5946 20.3727 4.47058 19.5439 3.64183C18.7152 2.81309 17.5911 2.3475 16.4191 2.3475C15.1837 2.35215 13.9979 2.83421 13.1097 3.69288L12 4.80257"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M12 2.3475V4.80257H14.4551"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M12 6.76662C12 7.93865 12.4656 9.06267 13.2943 9.89141C14.1231 10.7202 15.2471 11.1857 16.4191 11.1857C17.6545 11.1811 18.8403 10.699 19.7285 9.84037L20.8382 8.73068"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M18.3832 8.73067H20.8382V11.1857"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>

View file

@ -0,0 +1,28 @@
<script lang="ts">
import type { Props } from '.';
import { cn } from '$lib/utils/utils';
let { class: className, ...rest }: Props = $props();
</script>
<svg
class={cn('size-4 fill-current', className)}
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
{...rest}
>
<path
d="M6.02 21.09V4.55M6.02 4.55V2.64M6.02 4.55V5.78C6.02 7.39 7.23 8.89 9.22 9.72L13.8 11.64C15.78 12.47 17 13.96 17 15.58V19.42"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M20.53 17.59L17.12 21.25L13.46 17.84"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>

View file

@ -2,6 +2,8 @@ import type { HTMLAttributes } from 'svelte/elements';
import GitHub from './github.svelte'; import GitHub from './github.svelte';
import TypeScript from './typescript.svelte'; import TypeScript from './typescript.svelte';
import Svelte from './svelte.svelte'; import Svelte from './svelte.svelte';
import Branch from './branch.svelte';
import BranchAndRegen from './branch-and-regen.svelte';
export interface Props extends HTMLAttributes<SVGElement> { export interface Props extends HTMLAttributes<SVGElement> {
class?: string; class?: string;
@ -9,4 +11,4 @@ export interface Props extends HTMLAttributes<SVGElement> {
height?: number; height?: number;
} }
export { GitHub, TypeScript, Svelte }; export { GitHub, TypeScript, Svelte, Branch, BranchAndRegen };

View file

@ -93,6 +93,7 @@
this={href ? 'a' : 'button'} this={href ? 'a' : 'button'}
{...rest} {...rest}
data-slot="button" data-slot="button"
data-loading={loading}
type={href ? undefined : type} type={href ? undefined : type}
href={href && !disabled ? href : undefined} href={href && !disabled ? href : undefined}
disabled={href ? undefined : disabled || loading} disabled={href ? undefined : disabled || loading}

View file

@ -18,23 +18,34 @@ import { parseMessageForRules } from '$lib/utils/rules.js';
// Set to true to enable debug logging // Set to true to enable debug logging
const ENABLE_LOGGING = true; const ENABLE_LOGGING = true;
const reqBodySchema = z.object({ const reqBodySchema = z
message: z.string(), .object({
model_id: z.string(), message: z.string().optional(),
model_id: z.string(),
session_token: z.string(), session_token: z.string(),
conversation_id: z.string().optional(), conversation_id: z.string().optional(),
web_search_enabled: z.boolean().optional(), web_search_enabled: z.boolean().optional(),
images: z images: z
.array( .array(
z.object({ z.object({
url: z.string(), url: z.string(),
storage_id: z.string(), storage_id: z.string(),
fileName: z.string().optional(), fileName: z.string().optional(),
}) })
) )
.optional(), .optional(),
}); })
.refine(
(data) => {
if (data.conversation_id === undefined && data.message === undefined) return false;
return true;
},
{
message: 'You must provide a message when creating a new conversation',
}
);
export type GenerateMessageRequestBody = z.infer<typeof reqBodySchema>; export type GenerateMessageRequestBody = z.infer<typeof reqBodySchema>;
@ -688,6 +699,11 @@ export const POST: RequestHandler = async ({ request }) => {
let conversationId = args.conversation_id; let conversationId = args.conversation_id;
if (!conversationId) { if (!conversationId) {
// technically zod should catch this but just in case
if (args.message === undefined) {
return error(400, 'You must provide a message when creating a new conversation');
}
const convMessageResult = await ResultAsync.fromPromise( const convMessageResult = await ResultAsync.fromPromise(
client.mutation(api.conversations.createAndAddMessage, { client.mutation(api.conversations.createAndAddMessage, {
content: args.message, content: args.message,
@ -722,25 +738,28 @@ export const POST: RequestHandler = async ({ request }) => {
); );
} else { } else {
log('Using existing conversation', startTime); log('Using existing conversation', startTime);
const userMessageResult = await ResultAsync.fromPromise(
client.mutation(api.messages.create, {
conversation_id: conversationId as Id<'conversations'>,
content: args.message,
session_token: args.session_token,
model_id: args.model_id,
role: 'user',
images: args.images,
web_search_enabled: args.web_search_enabled,
}),
(e) => `Failed to create user message: ${e}`
);
if (userMessageResult.isErr()) { if (args.message) {
log(`User message creation failed: ${userMessageResult.error}`, startTime); const userMessageResult = await ResultAsync.fromPromise(
return error(500, 'Failed to create user message'); client.mutation(api.messages.create, {
conversation_id: conversationId as Id<'conversations'>,
content: args.message,
session_token: args.session_token,
model_id: args.model_id,
role: 'user',
images: args.images,
web_search_enabled: args.web_search_enabled,
}),
(e) => `Failed to create user message: ${e}`
);
if (userMessageResult.isErr()) {
log(`User message creation failed: ${userMessageResult.error}`, startTime);
return error(500, 'Failed to create user message');
}
log('User message created', startTime);
} }
log('User message created', startTime);
} }
// Set generating status to true before starting background generation // Set generating status to true before starting background generation

View file

@ -61,7 +61,7 @@
<title>{conversation.data?.title} | thom.chat</title> <title>{conversation.data?.title} | thom.chat</title>
</svelte:head> </svelte:head>
<div class="flex h-full flex-1 flex-col py-4"> <div class="flex h-full flex-1 flex-col py-4 pt-6">
{#if !conversation.data && !conversation.isLoading} {#if !conversation.data && !conversation.isLoading}
<div class="flex flex-1 flex-col items-center justify-center gap-4 pt-[25svh]"> <div class="flex flex-1 flex-col items-center justify-center gap-4 pt-[25svh]">
<div> <div>

View file

@ -1,7 +1,7 @@
<script lang="ts"> <script lang="ts">
import { cn } from '$lib/utils/utils'; import { cn } from '$lib/utils/utils';
import { tv } from 'tailwind-variants'; import { tv } from 'tailwind-variants';
import type { Doc } from '$lib/backend/convex/_generated/dataModel'; import type { Doc, Id } from '$lib/backend/convex/_generated/dataModel';
import { CopyButton } from '$lib/components/ui/copy-button'; import { CopyButton } from '$lib/components/ui/copy-button';
import '../../../markdown.css'; import '../../../markdown.css';
import MarkdownRenderer from './markdown-renderer.svelte'; import MarkdownRenderer from './markdown-renderer.svelte';
@ -9,6 +9,15 @@
import { sanitizeHtml } from '$lib/utils/markdown-it'; import { sanitizeHtml } from '$lib/utils/markdown-it';
import { on } from 'svelte/events'; import { on } from 'svelte/events';
import { isHtmlElement } from '$lib/utils/is'; import { isHtmlElement } from '$lib/utils/is';
import { Button } from '$lib/components/ui/button';
import Tooltip from '$lib/components/ui/tooltip.svelte';
import { useConvexClient } from 'convex-svelte';
import { api } from '$lib/backend/convex/_generated/api';
import { session } from '$lib/state/session.svelte';
import { ResultAsync } from 'neverthrow';
import { goto } from '$app/navigation';
import { callGenerateMessage } from '../../api/generate-message/call';
import * as Icons from '$lib/components/icons';
const style = tv({ const style = tv({
base: 'prose rounded-xl p-2 max-w-full', base: 'prose rounded-xl p-2 max-w-full',
@ -24,6 +33,8 @@
message: Doc<'messages'>; message: Doc<'messages'>;
}; };
const client = useConvexClient();
let { message }: Props = $props(); let { message }: Props = $props();
let imageModal = $state<{ open: boolean; imageUrl: string; fileName: string }>({ let imageModal = $state<{ open: boolean; imageUrl: string; fileName: string }>({
@ -39,6 +50,57 @@
fileName, fileName,
}; };
} }
async function createBranchedConversation() {
const res = await ResultAsync.fromPromise(
client.mutation(api.conversations.createBranched, {
conversation_id: message.conversation_id as Id<'conversations'>,
from_message_id: message._id,
session_token: session.current?.session.token ?? '',
}),
(e) => e
);
if (res.isErr()) {
console.error(res.error);
return;
}
await goto(`/chat/${res.value}`);
}
async function branchAndGenerate() {
const res = await ResultAsync.fromPromise(
client.mutation(api.conversations.createBranched, {
conversation_id: message.conversation_id as Id<'conversations'>,
from_message_id: message._id,
session_token: session.current?.session.token ?? '',
}),
(e) => e
);
if (res.isErr()) {
console.error(res.error);
return;
}
const cid = res.value;
const generateRes = await callGenerateMessage({
session_token: session.current?.session.token ?? '',
conversation_id: cid,
model_id: message.model_id!,
images: message.images,
web_search_enabled: message.web_search_enabled,
});
if (generateRes.isErr()) {
// TODO: add error toast
return;
}
await goto(`/chat/${cid}`);
}
</script> </script>
{#if message.role !== 'system' && !(message.role === 'assistant' && message.content.length === 0 && !message.error)} {#if message.role !== 'system' && !(message.role === 'assistant' && message.content.length === 0 && !message.error)}
@ -101,26 +163,67 @@
</div> </div>
<div <div
class={cn( class={cn(
'flex place-items-center gap-2 opacity-0 transition-opacity group-hover:opacity-100', 'flex place-items-center gap-2 md:opacity-0 transition-opacity group-hover:opacity-100',
{ {
'justify-end': message.role === 'user', 'justify-end': message.role === 'user',
} }
)} )}
> >
<Tooltip>
{#snippet trigger(tooltip)}
<Button
size="icon"
variant="ghost"
class={cn('group order-2 size-7', { 'order-1': message.role === 'user' })}
onClickPromise={createBranchedConversation}
{...tooltip.trigger}
>
<Icons.Branch class="group-data-[loading=true]:opacity-0" />
</Button>
{/snippet}
Branch off this message
</Tooltip>
{#if message.role === 'user'}
<Tooltip>
{#snippet trigger(tooltip)}
<Button
size="icon"
variant="ghost"
class={cn('group order-0 size-7')}
onClickPromise={branchAndGenerate}
{...tooltip.trigger}
>
<Icons.BranchAndRegen class="group-data-[loading=true]:opacity-0" />
</Button>
{/snippet}
Branch and regenerate
</Tooltip>
{/if}
{#if message.content.length > 0} {#if message.content.length > 0}
<CopyButton class="size-7" text={message.content} /> <Tooltip>
{/if} {#snippet trigger(tooltip)}
{#if message.model_id !== undefined} <CopyButton
<span class="text-muted-foreground text-xs">{message.model_id}</span> class={cn('order-1 size-7', { 'order-2': message.role === 'user' })}
{/if} text={message.content}
{#if message.web_search_enabled} {...tooltip.trigger}
<span class="text-muted-foreground text-xs"> Web search enabled </span> />
{/snippet}
Copy
</Tooltip>
{/if} {/if}
{#if message.role === 'assistant'}
{#if message.model_id !== undefined}
<span class="text-muted-foreground text-xs">{message.model_id}</span>
{/if}
{#if message.web_search_enabled}
<span class="text-muted-foreground text-xs"> Web search enabled </span>
{/if}
{#if message.cost_usd !== undefined} {#if message.cost_usd !== undefined}
<span class="text-muted-foreground text-xs"> <span class="text-muted-foreground text-xs">
${message.cost_usd.toFixed(6)} ${message.cost_usd.toFixed(6)}
</span> </span>
{/if}
{/if} {/if}
</div> </div>
</div> </div>