WIP Branching (#29)

This commit is contained in:
Aidan Bleser 2025-06-19 12:16:13 -05:00 committed by GitHub
parent 13b3449d82
commit c969d44dd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 317 additions and 48 deletions

View file

@ -147,6 +147,65 @@ export const createAndAddMessage = mutation({
},
});
export const createBranched = mutation({
args: {
conversation_id: v.id('conversations'),
from_message_id: v.id('messages'),
session_token: v.string(),
},
handler: async (ctx, args): Promise<Id<'conversations'>> => {
const session = await ctx.runQuery(api.betterAuth.publicGetSession, {
session_token: args.session_token,
});
if (!session) throw new Error('Unauthorized');
const existingConversation = await ctx.db.get(args.conversation_id);
console.log(existingConversation);
if (!existingConversation) throw new Error('Conversation not found');
if (existingConversation.user_id !== session.userId && !existingConversation.public)
throw new Error('Unauthorized');
const messages = await ctx.db
.query('messages')
.withIndex('by_conversation', (q) => q.eq('conversation_id', args.conversation_id))
.collect();
const messageIndex = messages.findIndex((m) => m._id === args.from_message_id);
const newMessages = messages.slice(0, messageIndex + 1);
const newConversationId = await ctx.db.insert('conversations', {
title: existingConversation.title,
branched_from: existingConversation._id,
user_id: session.userId,
updated_at: Date.now(),
generating: false,
public: false,
cost_usd: newMessages.reduce((acc, m) => acc + (m.cost_usd ?? 0), 0),
});
console.log(newConversationId);
await Promise.all(
newMessages.map((m) => {
const newMessage = {
...m,
_id: undefined,
_creationTime: undefined,
conversation_id: newConversationId,
};
return ctx.db.insert('messages', newMessage);
})
);
return newConversationId;
},
});
export const updateTitle = mutation({
args: {
conversation_id: v.id('conversations'),

View file

@ -54,6 +54,7 @@ export default defineSchema({
generating: v.optional(v.boolean()),
cost_usd: v.optional(v.number()),
public: v.optional(v.boolean()),
branched_from: v.optional(v.id('conversations')),
}).index('by_user', ['user_id']),
messages: defineTable({
conversation_id: v.string(),

View file

@ -19,6 +19,7 @@
import XIcon from '~icons/lucide/x';
import { Button } from './ui/button';
import { callModal } from './ui/modal/global-modal.svelte';
import SplitIcon from '~icons/lucide/split';
let { searchModalOpen = $bindable(false) }: { searchModalOpen: boolean } = $props();
@ -198,6 +199,9 @@
)}
>
<p class="truncate rounded-lg py-2 pr-4 pl-3 whitespace-nowrap">
{#if conversation.branched_from}
<SplitIcon class="text-muted-foreground/50 mr-1 inline size-4" />
{/if}
<span>{conversation.title}</span>
</p>
<div class="pr-2">

View file

@ -0,0 +1,52 @@
<script lang="ts">
import type { Props } from '.';
import { cn } from '$lib/utils/utils';
let { class: className, ...rest }: Props = $props();
</script>
<svg
class={cn('size-4 fill-current', className)}
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
{...rest}
>
<path
d="M2.96132 21.09V4.55M2.96132 4.55V2.64M2.96132 4.55V5.78C2.96132 7.39 4.17132 8.89 6.16132 9.72L10.7413 11.64C12.7213 12.47 13.9413 13.96 13.9413 15.58V19.42"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M17.4713 17.59L14.0613 21.25L10.4013 17.84"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M20.8382 6.76662C20.8382 5.5946 20.3727 4.47058 19.5439 3.64183C18.7152 2.81309 17.5911 2.3475 16.4191 2.3475C15.1837 2.35215 13.9979 2.83421 13.1097 3.69288L12 4.80257"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M12 2.3475V4.80257H14.4551"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M12 6.76662C12 7.93865 12.4656 9.06267 13.2943 9.89141C14.1231 10.7202 15.2471 11.1857 16.4191 11.1857C17.6545 11.1811 18.8403 10.699 19.7285 9.84037L20.8382 8.73068"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M18.3832 8.73067H20.8382V11.1857"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>

View file

@ -0,0 +1,28 @@
<script lang="ts">
import type { Props } from '.';
import { cn } from '$lib/utils/utils';
let { class: className, ...rest }: Props = $props();
</script>
<svg
class={cn('size-4 fill-current', className)}
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
{...rest}
>
<path
d="M6.02 21.09V4.55M6.02 4.55V2.64M6.02 4.55V5.78C6.02 7.39 7.23 8.89 9.22 9.72L13.8 11.64C15.78 12.47 17 13.96 17 15.58V19.42"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
<path
d="M20.53 17.59L17.12 21.25L13.46 17.84"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>

View file

@ -2,6 +2,8 @@ import type { HTMLAttributes } from 'svelte/elements';
import GitHub from './github.svelte';
import TypeScript from './typescript.svelte';
import Svelte from './svelte.svelte';
import Branch from './branch.svelte';
import BranchAndRegen from './branch-and-regen.svelte';
export interface Props extends HTMLAttributes<SVGElement> {
class?: string;
@ -9,4 +11,4 @@ export interface Props extends HTMLAttributes<SVGElement> {
height?: number;
}
export { GitHub, TypeScript, Svelte };
export { GitHub, TypeScript, Svelte, Branch, BranchAndRegen };

View file

@ -93,6 +93,7 @@
this={href ? 'a' : 'button'}
{...rest}
data-slot="button"
data-loading={loading}
type={href ? undefined : type}
href={href && !disabled ? href : undefined}
disabled={href ? undefined : disabled || loading}

View file

@ -18,8 +18,9 @@ import { parseMessageForRules } from '$lib/utils/rules.js';
// Set to true to enable debug logging
const ENABLE_LOGGING = true;
const reqBodySchema = z.object({
message: z.string(),
const reqBodySchema = z
.object({
message: z.string().optional(),
model_id: z.string(),
session_token: z.string(),
@ -34,7 +35,17 @@ const reqBodySchema = z.object({
})
)
.optional(),
});
})
.refine(
(data) => {
if (data.conversation_id === undefined && data.message === undefined) return false;
return true;
},
{
message: 'You must provide a message when creating a new conversation',
}
);
export type GenerateMessageRequestBody = z.infer<typeof reqBodySchema>;
@ -688,6 +699,11 @@ export const POST: RequestHandler = async ({ request }) => {
let conversationId = args.conversation_id;
if (!conversationId) {
// technically zod should catch this but just in case
if (args.message === undefined) {
return error(400, 'You must provide a message when creating a new conversation');
}
const convMessageResult = await ResultAsync.fromPromise(
client.mutation(api.conversations.createAndAddMessage, {
content: args.message,
@ -722,6 +738,8 @@ export const POST: RequestHandler = async ({ request }) => {
);
} else {
log('Using existing conversation', startTime);
if (args.message) {
const userMessageResult = await ResultAsync.fromPromise(
client.mutation(api.messages.create, {
conversation_id: conversationId as Id<'conversations'>,
@ -742,6 +760,7 @@ export const POST: RequestHandler = async ({ request }) => {
log('User message created', startTime);
}
}
// Set generating status to true before starting background generation
const setGeneratingResult = await ResultAsync.fromPromise(

View file

@ -61,7 +61,7 @@
<title>{conversation.data?.title} | thom.chat</title>
</svelte:head>
<div class="flex h-full flex-1 flex-col py-4">
<div class="flex h-full flex-1 flex-col py-4 pt-6">
{#if !conversation.data && !conversation.isLoading}
<div class="flex flex-1 flex-col items-center justify-center gap-4 pt-[25svh]">
<div>

View file

@ -1,7 +1,7 @@
<script lang="ts">
import { cn } from '$lib/utils/utils';
import { tv } from 'tailwind-variants';
import type { Doc } from '$lib/backend/convex/_generated/dataModel';
import type { Doc, Id } from '$lib/backend/convex/_generated/dataModel';
import { CopyButton } from '$lib/components/ui/copy-button';
import '../../../markdown.css';
import MarkdownRenderer from './markdown-renderer.svelte';
@ -9,6 +9,15 @@
import { sanitizeHtml } from '$lib/utils/markdown-it';
import { on } from 'svelte/events';
import { isHtmlElement } from '$lib/utils/is';
import { Button } from '$lib/components/ui/button';
import Tooltip from '$lib/components/ui/tooltip.svelte';
import { useConvexClient } from 'convex-svelte';
import { api } from '$lib/backend/convex/_generated/api';
import { session } from '$lib/state/session.svelte';
import { ResultAsync } from 'neverthrow';
import { goto } from '$app/navigation';
import { callGenerateMessage } from '../../api/generate-message/call';
import * as Icons from '$lib/components/icons';
const style = tv({
base: 'prose rounded-xl p-2 max-w-full',
@ -24,6 +33,8 @@
message: Doc<'messages'>;
};
const client = useConvexClient();
let { message }: Props = $props();
let imageModal = $state<{ open: boolean; imageUrl: string; fileName: string }>({
@ -39,6 +50,57 @@
fileName,
};
}
async function createBranchedConversation() {
const res = await ResultAsync.fromPromise(
client.mutation(api.conversations.createBranched, {
conversation_id: message.conversation_id as Id<'conversations'>,
from_message_id: message._id,
session_token: session.current?.session.token ?? '',
}),
(e) => e
);
if (res.isErr()) {
console.error(res.error);
return;
}
await goto(`/chat/${res.value}`);
}
async function branchAndGenerate() {
const res = await ResultAsync.fromPromise(
client.mutation(api.conversations.createBranched, {
conversation_id: message.conversation_id as Id<'conversations'>,
from_message_id: message._id,
session_token: session.current?.session.token ?? '',
}),
(e) => e
);
if (res.isErr()) {
console.error(res.error);
return;
}
const cid = res.value;
const generateRes = await callGenerateMessage({
session_token: session.current?.session.token ?? '',
conversation_id: cid,
model_id: message.model_id!,
images: message.images,
web_search_enabled: message.web_search_enabled,
});
if (generateRes.isErr()) {
// TODO: add error toast
return;
}
await goto(`/chat/${cid}`);
}
</script>
{#if message.role !== 'system' && !(message.role === 'assistant' && message.content.length === 0 && !message.error)}
@ -101,15 +163,55 @@
</div>
<div
class={cn(
'flex place-items-center gap-2 opacity-0 transition-opacity group-hover:opacity-100',
'flex place-items-center gap-2 md:opacity-0 transition-opacity group-hover:opacity-100',
{
'justify-end': message.role === 'user',
}
)}
>
{#if message.content.length > 0}
<CopyButton class="size-7" text={message.content} />
<Tooltip>
{#snippet trigger(tooltip)}
<Button
size="icon"
variant="ghost"
class={cn('group order-2 size-7', { 'order-1': message.role === 'user' })}
onClickPromise={createBranchedConversation}
{...tooltip.trigger}
>
<Icons.Branch class="group-data-[loading=true]:opacity-0" />
</Button>
{/snippet}
Branch off this message
</Tooltip>
{#if message.role === 'user'}
<Tooltip>
{#snippet trigger(tooltip)}
<Button
size="icon"
variant="ghost"
class={cn('group order-0 size-7')}
onClickPromise={branchAndGenerate}
{...tooltip.trigger}
>
<Icons.BranchAndRegen class="group-data-[loading=true]:opacity-0" />
</Button>
{/snippet}
Branch and regenerate
</Tooltip>
{/if}
{#if message.content.length > 0}
<Tooltip>
{#snippet trigger(tooltip)}
<CopyButton
class={cn('order-1 size-7', { 'order-2': message.role === 'user' })}
text={message.content}
{...tooltip.trigger}
/>
{/snippet}
Copy
</Tooltip>
{/if}
{#if message.role === 'assistant'}
{#if message.model_id !== undefined}
<span class="text-muted-foreground text-xs">{message.model_id}</span>
{/if}
@ -122,6 +224,7 @@
${message.cost_usd.toFixed(6)}
</span>
{/if}
{/if}
</div>
</div>