abort messages

This commit is contained in:
Thomas G. Lopes 2025-06-18 00:29:00 +01:00
parent e8b5f94456
commit d4e6228869
5 changed files with 234 additions and 18 deletions

View file

@ -0,0 +1,88 @@
import { api } from '$lib/backend/convex/_generated/api';
import type { Id } from '$lib/backend/convex/_generated/dataModel';
import { error, json, type RequestHandler } from '@sveltejs/kit';
import { ConvexHttpClient } from 'convex/browser';
import { ResultAsync } from 'neverthrow';
import { z } from 'zod/v4';
import { getSessionCookie } from 'better-auth/cookies';
import { PUBLIC_CONVEX_URL } from '$env/static/public';
// Import the global cache from generate-message
import { generationAbortControllers } from '../generate-message/cache.js';
const client = new ConvexHttpClient(PUBLIC_CONVEX_URL);
const reqBodySchema = z.object({
conversation_id: z.string(),
session_token: z.string(),
});
export type CancelGenerationRequestBody = z.infer<typeof reqBodySchema>;
export type CancelGenerationResponse = {
ok: true;
cancelled: boolean;
};
function response(res: CancelGenerationResponse) {
return json(res);
}
export const POST: RequestHandler = async ({ request }) => {
const bodyResult = await ResultAsync.fromPromise(
request.json(),
() => 'Failed to parse request body'
);
if (bodyResult.isErr()) {
return error(400, 'Failed to parse request body');
}
const parsed = reqBodySchema.safeParse(bodyResult.value);
if (!parsed.success) {
return error(400, parsed.error);
}
const args = parsed.data;
const cookie = getSessionCookie(request.headers);
const sessionToken = cookie?.split('.')[0] ?? null;
if (!sessionToken || sessionToken !== args.session_token) {
return error(401, 'Unauthorized');
}
// Verify the user owns this conversation
const conversationResult = await ResultAsync.fromPromise(
client.query(api.conversations.getById, {
conversation_id: args.conversation_id as Id<'conversations'>,
session_token: sessionToken,
}),
(e) => `Failed to get conversation: ${e}`
);
if (conversationResult.isErr()) {
return error(403, 'Conversation not found or unauthorized');
}
// Try to cancel the generation
const abortController = generationAbortControllers.get(args.conversation_id);
let cancelled = false;
if (abortController) {
abortController.abort();
generationAbortControllers.delete(args.conversation_id);
cancelled = true;
// Update conversation generating status to false
await ResultAsync.fromPromise(
client.mutation(api.conversations.updateGenerating, {
conversation_id: args.conversation_id as Id<'conversations'>,
generating: false,
session_token: sessionToken,
}),
(e) => `Failed to update generating status: ${e}`
);
}
return response({ ok: true, cancelled });
};

View file

@ -0,0 +1,17 @@
import { ResultAsync } from 'neverthrow';
import type { CancelGenerationRequestBody, CancelGenerationResponse } from './+server';
export async function callCancelGeneration(args: CancelGenerationRequestBody) {
const res = ResultAsync.fromPromise(
fetch('/api/cancel-generation', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(args),
}),
(e) => e
).map((r) => r.json() as Promise<CancelGenerationResponse>);
return res;
}

View file

@ -10,6 +10,7 @@ import { waitUntil } from '@vercel/functions';
import { z } from 'zod/v4';
import type { ChatCompletionSystemMessageParam } from 'openai/resources';
import { getSessionCookie } from 'better-auth/cookies';
import { generationAbortControllers } from './cache.js';
// Set to true to enable debug logging
const ENABLE_LOGGING = true;
@ -162,6 +163,7 @@ async function generateAIResponse({
modelResultPromise,
keyResultPromise,
rulesResultPromise,
abortSignal,
}: {
conversationId: string;
sessionToken: string;
@ -169,9 +171,15 @@ async function generateAIResponse({
keyResultPromise: ResultAsync<string | null, string>;
modelResultPromise: ResultAsync<Doc<'user_enabled_models'> | null, string>;
rulesResultPromise: ResultAsync<Doc<'user_rules'>[], string>;
abortSignal?: AbortSignal;
}) {
log('Starting AI response generation in background', startTime);
if (abortSignal?.aborted) {
log('AI response generation aborted before starting', startTime);
return;
}
const [modelResult, keyResult, messagesQueryResult, rulesResult] = await Promise.all([
modelResultPromise,
keyResultPromise,
@ -272,12 +280,19 @@ ${attachedRules.map((r) => `- ${r.name}: ${r.rule}`).join('\n')}`,
};
});
if (abortSignal?.aborted) {
log('AI response generation aborted before OpenAI call', startTime);
return;
}
const streamResult = await ResultAsync.fromPromise(
openai.chat.completions.create({
model: model.model_id,
messages: [...formattedMessages, systemMessage],
temperature: 0.7,
stream: true,
}, {
signal: abortSignal,
}),
(e) => `OpenAI API call failed: ${e}`
);
@ -317,6 +332,11 @@ ${attachedRules.map((r) => `- ${r.name}: ${r.rule}`).join('\n')}`,
try {
for await (const chunk of stream) {
if (abortSignal?.aborted) {
log('AI response generation aborted during streaming', startTime);
break;
}
chunkCount++;
content += chunk.choices[0]?.delta?.content || '';
if (!content) continue;
@ -420,6 +440,10 @@ ${attachedRules.map((r) => `- ${r.name}: ${r.rule}`).join('\n')}`,
log('Background: Cost usd updated', startTime);
} catch (error) {
log(`Background stream processing error: ${error}`, startTime);
} finally {
// Clean up the cached AbortController
generationAbortControllers.delete(conversationId);
log('Background: Cleaned up abort controller', startTime);
}
}
@ -537,6 +561,25 @@ export const POST: RequestHandler = async ({ request }) => {
log('User message created', startTime);
}
// Set generating status to true before starting background generation
const setGeneratingResult = await ResultAsync.fromPromise(
client.mutation(api.conversations.updateGenerating, {
conversation_id: conversationId as Id<'conversations'>,
generating: true,
session_token: sessionToken,
}),
(e) => `Failed to set generating status: ${e}`
);
if (setGeneratingResult.isErr()) {
log(`Failed to set generating status: ${setGeneratingResult.error}`, startTime);
return error(500, 'Failed to set generating status');
}
// Create and cache AbortController for this generation
const abortController = new AbortController();
generationAbortControllers.set(conversationId, abortController);
// Start AI response generation in background - don't await
waitUntil(
generateAIResponse({
@ -546,8 +589,22 @@ export const POST: RequestHandler = async ({ request }) => {
modelResultPromise,
keyResultPromise,
rulesResultPromise,
}).catch((error) => {
abortSignal: abortController.signal,
}).catch(async (error) => {
log(`Background AI response generation error: ${error}`, startTime);
// Reset generating status on error
try {
await client.mutation(api.conversations.updateGenerating, {
conversation_id: conversationId as Id<'conversations'>,
generating: false,
session_token: sessionToken,
});
} catch (e) {
log(`Failed to reset generating status after error: ${e}`, startTime);
}
}).finally(() => {
// Clean up the cached AbortController
generationAbortControllers.delete(conversationId);
})
);

View file

@ -0,0 +1,2 @@
// Global cache for AbortControllers keyed by conversation ID
export const generationAbortControllers = new Map<string, AbortController>();

View file

@ -28,6 +28,7 @@
import { Avatar } from 'melt/components';
import { Debounced, ElementSize, IsMounted, ScrollState } from 'runed';
import SendIcon from '~icons/lucide/arrow-up';
import StopIcon from '~icons/lucide/square';
import ChevronDownIcon from '~icons/lucide/chevron-down';
import ImageIcon from '~icons/lucide/image';
import LoaderCircleIcon from '~icons/lucide/loader-circle';
@ -38,6 +39,7 @@
import UploadIcon from '~icons/lucide/upload';
import XIcon from '~icons/lucide/x';
import { callGenerateMessage } from '../api/generate-message/call.js';
import { callCancelGeneration } from '../api/cancel-generation/call.js';
import ModelPicker from './model-picker.svelte';
const client = useConvexClient();
@ -46,7 +48,42 @@
let form = $state<HTMLFormElement>();
let textarea = $state<HTMLTextAreaElement>();
let abortController = $state<AbortController | null>(null);
const currentConversationQuery = useCachedQuery(api.conversations.getById, () => ({
conversation_id: page.params.id as Id<'conversations'>,
session_token: session.current?.session.token ?? '',
}));
const isGenerating = $derived(Boolean(currentConversationQuery.data?.generating));
async function stopGeneration() {
if (!page.params.id || !session.current?.session.token) return;
try {
const result = await callCancelGeneration({
conversation_id: page.params.id,
session_token: session.current.session.token,
});
if (result.isErr()) {
console.error('Failed to cancel generation:', result.error);
} else {
console.log('Generation cancelled:', result.value.cancelled);
}
} catch (error) {
console.error('Error cancelling generation:', error);
}
// Clear local abort controller if it exists
if (abortController) {
abortController = null;
}
}
async function handleSubmit() {
if (isGenerating) return;
const formData = new FormData(form);
const message = formData.get('message');
@ -58,19 +95,26 @@
const imagesCopy = [...selectedImages];
selectedImages = [];
const res = await callGenerateMessage({
message: messageCopy,
session_token: session.current?.session.token,
conversation_id: page.params.id ?? undefined,
model_id: settings.modelId,
images: imagesCopy.length > 0 ? imagesCopy : undefined,
});
if (res.isErr()) return; // TODO: Handle error
try {
const res = await callGenerateMessage({
message: messageCopy,
session_token: session.current?.session.token,
conversation_id: page.params.id ?? undefined,
model_id: settings.modelId,
images: imagesCopy.length > 0 ? imagesCopy : undefined,
});
const cid = res.value.conversation_id;
if (res.isErr()) {
return; // TODO: Handle error
}
if (page.params.id !== cid) {
goto(`/chat/${cid}`);
const cid = res.value.conversation_id;
if (page.params.id !== cid) {
goto(`/chat/${cid}`);
}
} catch (error) {
console.error('Error generating message:', error);
}
}
@ -670,9 +714,11 @@
<textarea
{...pick(popover.trigger, ['id', 'style', 'onfocusout', 'onfocus'])}
bind:this={textarea}
disabled={!openRouterKeyQuery.data}
disabled={!openRouterKeyQuery.data || isGenerating}
class="text-foreground placeholder:text-muted-foreground/60 max-h-64 min-h-[60px] w-full resize-none !overflow-y-auto bg-transparent text-base leading-6 outline-none disabled:cursor-not-allowed disabled:opacity-50"
placeholder="Type your message here... Tag rules with @"
placeholder={isGenerating
? 'Generating response...'
: 'Type your message here... Tag rules with @'}
name="message"
onkeydown={(e) => {
if (e.key === 'Enter' && !e.shiftKey && !popover.open) {
@ -715,14 +761,20 @@
<Tooltip placement="top">
{#snippet trigger(tooltip)}
<button
type="submit"
class="border-reflect button-reflect hover:bg-primary/90 active:bg-primary text-primary-foreground relative h-9 w-9 rounded-lg p-2 font-semibold shadow transition"
type={isGenerating ? 'button' : 'submit'}
onclick={isGenerating ? stopGeneration : undefined}
disabled={isGenerating ? false : !message.trim()}
class="border-reflect button-reflect hover:bg-primary/90 active:bg-primary text-primary-foreground relative h-9 w-9 rounded-lg p-2 font-semibold shadow transition disabled:cursor-not-allowed disabled:opacity-50"
{...tooltip.trigger}
>
<SendIcon class="!size-5" />
{#if isGenerating}
<StopIcon class="!size-5" />
{:else}
<SendIcon class="!size-5" />
{/if}
</button>
{/snippet}
Send message
{isGenerating ? 'Stop generation' : 'Send message'}
</Tooltip>
</div>
<div class="flex flex-col gap-2 pr-2 sm:flex-row sm:items-center">