WIP Branching (#29)
This commit is contained in:
parent
13b3449d82
commit
c969d44dd4
10 changed files with 317 additions and 48 deletions
|
|
@ -147,6 +147,65 @@ export const createAndAddMessage = mutation({
|
|||
},
|
||||
});
|
||||
|
||||
export const createBranched = mutation({
|
||||
args: {
|
||||
conversation_id: v.id('conversations'),
|
||||
from_message_id: v.id('messages'),
|
||||
session_token: v.string(),
|
||||
},
|
||||
handler: async (ctx, args): Promise<Id<'conversations'>> => {
|
||||
const session = await ctx.runQuery(api.betterAuth.publicGetSession, {
|
||||
session_token: args.session_token,
|
||||
});
|
||||
|
||||
if (!session) throw new Error('Unauthorized');
|
||||
|
||||
const existingConversation = await ctx.db.get(args.conversation_id);
|
||||
|
||||
console.log(existingConversation);
|
||||
|
||||
if (!existingConversation) throw new Error('Conversation not found');
|
||||
if (existingConversation.user_id !== session.userId && !existingConversation.public)
|
||||
throw new Error('Unauthorized');
|
||||
|
||||
const messages = await ctx.db
|
||||
.query('messages')
|
||||
.withIndex('by_conversation', (q) => q.eq('conversation_id', args.conversation_id))
|
||||
.collect();
|
||||
|
||||
const messageIndex = messages.findIndex((m) => m._id === args.from_message_id);
|
||||
|
||||
const newMessages = messages.slice(0, messageIndex + 1);
|
||||
|
||||
const newConversationId = await ctx.db.insert('conversations', {
|
||||
title: existingConversation.title,
|
||||
branched_from: existingConversation._id,
|
||||
user_id: session.userId,
|
||||
updated_at: Date.now(),
|
||||
generating: false,
|
||||
public: false,
|
||||
cost_usd: newMessages.reduce((acc, m) => acc + (m.cost_usd ?? 0), 0),
|
||||
});
|
||||
|
||||
console.log(newConversationId);
|
||||
|
||||
await Promise.all(
|
||||
newMessages.map((m) => {
|
||||
const newMessage = {
|
||||
...m,
|
||||
_id: undefined,
|
||||
_creationTime: undefined,
|
||||
conversation_id: newConversationId,
|
||||
};
|
||||
|
||||
return ctx.db.insert('messages', newMessage);
|
||||
})
|
||||
);
|
||||
|
||||
return newConversationId;
|
||||
},
|
||||
});
|
||||
|
||||
export const updateTitle = mutation({
|
||||
args: {
|
||||
conversation_id: v.id('conversations'),
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ export default defineSchema({
|
|||
generating: v.optional(v.boolean()),
|
||||
cost_usd: v.optional(v.number()),
|
||||
public: v.optional(v.boolean()),
|
||||
branched_from: v.optional(v.id('conversations')),
|
||||
}).index('by_user', ['user_id']),
|
||||
messages: defineTable({
|
||||
conversation_id: v.string(),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
import XIcon from '~icons/lucide/x';
|
||||
import { Button } from './ui/button';
|
||||
import { callModal } from './ui/modal/global-modal.svelte';
|
||||
import SplitIcon from '~icons/lucide/split';
|
||||
|
||||
let { searchModalOpen = $bindable(false) }: { searchModalOpen: boolean } = $props();
|
||||
|
||||
|
|
@ -198,6 +199,9 @@
|
|||
)}
|
||||
>
|
||||
<p class="truncate rounded-lg py-2 pr-4 pl-3 whitespace-nowrap">
|
||||
{#if conversation.branched_from}
|
||||
<SplitIcon class="text-muted-foreground/50 mr-1 inline size-4" />
|
||||
{/if}
|
||||
<span>{conversation.title}</span>
|
||||
</p>
|
||||
<div class="pr-2">
|
||||
|
|
|
|||
52
src/lib/components/icons/branch-and-regen.svelte
Normal file
52
src/lib/components/icons/branch-and-regen.svelte
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
<script lang="ts">
|
||||
import type { Props } from '.';
|
||||
import { cn } from '$lib/utils/utils';
|
||||
|
||||
let { class: className, ...rest }: Props = $props();
|
||||
</script>
|
||||
|
||||
<svg
|
||||
class={cn('size-4 fill-current', className)}
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...rest}
|
||||
>
|
||||
<path
|
||||
d="M2.96132 21.09V4.55M2.96132 4.55V2.64M2.96132 4.55V5.78C2.96132 7.39 4.17132 8.89 6.16132 9.72L10.7413 11.64C12.7213 12.47 13.9413 13.96 13.9413 15.58V19.42"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M17.4713 17.59L14.0613 21.25L10.4013 17.84"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M20.8382 6.76662C20.8382 5.5946 20.3727 4.47058 19.5439 3.64183C18.7152 2.81309 17.5911 2.3475 16.4191 2.3475C15.1837 2.35215 13.9979 2.83421 13.1097 3.69288L12 4.80257"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M12 2.3475V4.80257H14.4551"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M12 6.76662C12 7.93865 12.4656 9.06267 13.2943 9.89141C14.1231 10.7202 15.2471 11.1857 16.4191 11.1857C17.6545 11.1811 18.8403 10.699 19.7285 9.84037L20.8382 8.73068"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M18.3832 8.73067H20.8382V11.1857"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
28
src/lib/components/icons/branch.svelte
Normal file
28
src/lib/components/icons/branch.svelte
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
<script lang="ts">
|
||||
import type { Props } from '.';
|
||||
import { cn } from '$lib/utils/utils';
|
||||
|
||||
let { class: className, ...rest }: Props = $props();
|
||||
</script>
|
||||
|
||||
<svg
|
||||
class={cn('size-4 fill-current', className)}
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...rest}
|
||||
>
|
||||
<path
|
||||
d="M6.02 21.09V4.55M6.02 4.55V2.64M6.02 4.55V5.78C6.02 7.39 7.23 8.89 9.22 9.72L13.8 11.64C15.78 12.47 17 13.96 17 15.58V19.42"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M20.53 17.59L17.12 21.25L13.46 17.84"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
|
|
@ -2,6 +2,8 @@ import type { HTMLAttributes } from 'svelte/elements';
|
|||
import GitHub from './github.svelte';
|
||||
import TypeScript from './typescript.svelte';
|
||||
import Svelte from './svelte.svelte';
|
||||
import Branch from './branch.svelte';
|
||||
import BranchAndRegen from './branch-and-regen.svelte';
|
||||
|
||||
export interface Props extends HTMLAttributes<SVGElement> {
|
||||
class?: string;
|
||||
|
|
@ -9,4 +11,4 @@ export interface Props extends HTMLAttributes<SVGElement> {
|
|||
height?: number;
|
||||
}
|
||||
|
||||
export { GitHub, TypeScript, Svelte };
|
||||
export { GitHub, TypeScript, Svelte, Branch, BranchAndRegen };
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@
|
|||
this={href ? 'a' : 'button'}
|
||||
{...rest}
|
||||
data-slot="button"
|
||||
data-loading={loading}
|
||||
type={href ? undefined : type}
|
||||
href={href && !disabled ? href : undefined}
|
||||
disabled={href ? undefined : disabled || loading}
|
||||
|
|
|
|||
|
|
@ -18,8 +18,9 @@ import { parseMessageForRules } from '$lib/utils/rules.js';
|
|||
// Set to true to enable debug logging
|
||||
const ENABLE_LOGGING = true;
|
||||
|
||||
const reqBodySchema = z.object({
|
||||
message: z.string(),
|
||||
const reqBodySchema = z
|
||||
.object({
|
||||
message: z.string().optional(),
|
||||
model_id: z.string(),
|
||||
|
||||
session_token: z.string(),
|
||||
|
|
@ -34,7 +35,17 @@ const reqBodySchema = z.object({
|
|||
})
|
||||
)
|
||||
.optional(),
|
||||
});
|
||||
})
|
||||
.refine(
|
||||
(data) => {
|
||||
if (data.conversation_id === undefined && data.message === undefined) return false;
|
||||
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: 'You must provide a message when creating a new conversation',
|
||||
}
|
||||
);
|
||||
|
||||
export type GenerateMessageRequestBody = z.infer<typeof reqBodySchema>;
|
||||
|
||||
|
|
@ -688,6 +699,11 @@ export const POST: RequestHandler = async ({ request }) => {
|
|||
|
||||
let conversationId = args.conversation_id;
|
||||
if (!conversationId) {
|
||||
// technically zod should catch this but just in case
|
||||
if (args.message === undefined) {
|
||||
return error(400, 'You must provide a message when creating a new conversation');
|
||||
}
|
||||
|
||||
const convMessageResult = await ResultAsync.fromPromise(
|
||||
client.mutation(api.conversations.createAndAddMessage, {
|
||||
content: args.message,
|
||||
|
|
@ -722,6 +738,8 @@ export const POST: RequestHandler = async ({ request }) => {
|
|||
);
|
||||
} else {
|
||||
log('Using existing conversation', startTime);
|
||||
|
||||
if (args.message) {
|
||||
const userMessageResult = await ResultAsync.fromPromise(
|
||||
client.mutation(api.messages.create, {
|
||||
conversation_id: conversationId as Id<'conversations'>,
|
||||
|
|
@ -742,6 +760,7 @@ export const POST: RequestHandler = async ({ request }) => {
|
|||
|
||||
log('User message created', startTime);
|
||||
}
|
||||
}
|
||||
|
||||
// Set generating status to true before starting background generation
|
||||
const setGeneratingResult = await ResultAsync.fromPromise(
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@
|
|||
<title>{conversation.data?.title} | thom.chat</title>
|
||||
</svelte:head>
|
||||
|
||||
<div class="flex h-full flex-1 flex-col py-4">
|
||||
<div class="flex h-full flex-1 flex-col py-4 pt-6">
|
||||
{#if !conversation.data && !conversation.isLoading}
|
||||
<div class="flex flex-1 flex-col items-center justify-center gap-4 pt-[25svh]">
|
||||
<div>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<script lang="ts">
|
||||
import { cn } from '$lib/utils/utils';
|
||||
import { tv } from 'tailwind-variants';
|
||||
import type { Doc } from '$lib/backend/convex/_generated/dataModel';
|
||||
import type { Doc, Id } from '$lib/backend/convex/_generated/dataModel';
|
||||
import { CopyButton } from '$lib/components/ui/copy-button';
|
||||
import '../../../markdown.css';
|
||||
import MarkdownRenderer from './markdown-renderer.svelte';
|
||||
|
|
@ -9,6 +9,15 @@
|
|||
import { sanitizeHtml } from '$lib/utils/markdown-it';
|
||||
import { on } from 'svelte/events';
|
||||
import { isHtmlElement } from '$lib/utils/is';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import Tooltip from '$lib/components/ui/tooltip.svelte';
|
||||
import { useConvexClient } from 'convex-svelte';
|
||||
import { api } from '$lib/backend/convex/_generated/api';
|
||||
import { session } from '$lib/state/session.svelte';
|
||||
import { ResultAsync } from 'neverthrow';
|
||||
import { goto } from '$app/navigation';
|
||||
import { callGenerateMessage } from '../../api/generate-message/call';
|
||||
import * as Icons from '$lib/components/icons';
|
||||
|
||||
const style = tv({
|
||||
base: 'prose rounded-xl p-2 max-w-full',
|
||||
|
|
@ -24,6 +33,8 @@
|
|||
message: Doc<'messages'>;
|
||||
};
|
||||
|
||||
const client = useConvexClient();
|
||||
|
||||
let { message }: Props = $props();
|
||||
|
||||
let imageModal = $state<{ open: boolean; imageUrl: string; fileName: string }>({
|
||||
|
|
@ -39,6 +50,57 @@
|
|||
fileName,
|
||||
};
|
||||
}
|
||||
|
||||
async function createBranchedConversation() {
|
||||
const res = await ResultAsync.fromPromise(
|
||||
client.mutation(api.conversations.createBranched, {
|
||||
conversation_id: message.conversation_id as Id<'conversations'>,
|
||||
from_message_id: message._id,
|
||||
session_token: session.current?.session.token ?? '',
|
||||
}),
|
||||
(e) => e
|
||||
);
|
||||
|
||||
if (res.isErr()) {
|
||||
console.error(res.error);
|
||||
return;
|
||||
}
|
||||
|
||||
await goto(`/chat/${res.value}`);
|
||||
}
|
||||
|
||||
async function branchAndGenerate() {
|
||||
const res = await ResultAsync.fromPromise(
|
||||
client.mutation(api.conversations.createBranched, {
|
||||
conversation_id: message.conversation_id as Id<'conversations'>,
|
||||
from_message_id: message._id,
|
||||
session_token: session.current?.session.token ?? '',
|
||||
}),
|
||||
(e) => e
|
||||
);
|
||||
|
||||
if (res.isErr()) {
|
||||
console.error(res.error);
|
||||
return;
|
||||
}
|
||||
|
||||
const cid = res.value;
|
||||
|
||||
const generateRes = await callGenerateMessage({
|
||||
session_token: session.current?.session.token ?? '',
|
||||
conversation_id: cid,
|
||||
model_id: message.model_id!,
|
||||
images: message.images,
|
||||
web_search_enabled: message.web_search_enabled,
|
||||
});
|
||||
|
||||
if (generateRes.isErr()) {
|
||||
// TODO: add error toast
|
||||
return;
|
||||
}
|
||||
|
||||
await goto(`/chat/${cid}`);
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if message.role !== 'system' && !(message.role === 'assistant' && message.content.length === 0 && !message.error)}
|
||||
|
|
@ -101,15 +163,55 @@
|
|||
</div>
|
||||
<div
|
||||
class={cn(
|
||||
'flex place-items-center gap-2 opacity-0 transition-opacity group-hover:opacity-100',
|
||||
'flex place-items-center gap-2 md:opacity-0 transition-opacity group-hover:opacity-100',
|
||||
{
|
||||
'justify-end': message.role === 'user',
|
||||
}
|
||||
)}
|
||||
>
|
||||
{#if message.content.length > 0}
|
||||
<CopyButton class="size-7" text={message.content} />
|
||||
<Tooltip>
|
||||
{#snippet trigger(tooltip)}
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class={cn('group order-2 size-7', { 'order-1': message.role === 'user' })}
|
||||
onClickPromise={createBranchedConversation}
|
||||
{...tooltip.trigger}
|
||||
>
|
||||
<Icons.Branch class="group-data-[loading=true]:opacity-0" />
|
||||
</Button>
|
||||
{/snippet}
|
||||
Branch off this message
|
||||
</Tooltip>
|
||||
{#if message.role === 'user'}
|
||||
<Tooltip>
|
||||
{#snippet trigger(tooltip)}
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class={cn('group order-0 size-7')}
|
||||
onClickPromise={branchAndGenerate}
|
||||
{...tooltip.trigger}
|
||||
>
|
||||
<Icons.BranchAndRegen class="group-data-[loading=true]:opacity-0" />
|
||||
</Button>
|
||||
{/snippet}
|
||||
Branch and regenerate
|
||||
</Tooltip>
|
||||
{/if}
|
||||
{#if message.content.length > 0}
|
||||
<Tooltip>
|
||||
{#snippet trigger(tooltip)}
|
||||
<CopyButton
|
||||
class={cn('order-1 size-7', { 'order-2': message.role === 'user' })}
|
||||
text={message.content}
|
||||
{...tooltip.trigger}
|
||||
/>
|
||||
{/snippet}
|
||||
Copy
|
||||
</Tooltip>
|
||||
{/if}
|
||||
{#if message.role === 'assistant'}
|
||||
{#if message.model_id !== undefined}
|
||||
<span class="text-muted-foreground text-xs">{message.model_id}</span>
|
||||
{/if}
|
||||
|
|
@ -122,6 +224,7 @@
|
|||
${message.cost_usd.toFixed(6)}
|
||||
</span>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue