diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/Phi3PromptTemplate.test.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/Phi3PromptTemplate.test.ts new file mode 100644 index 000000000..664631ee5 --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/Phi3PromptTemplate.test.ts @@ -0,0 +1,72 @@ +import { chat, instruction, text } from "./Phi3PromptTemplate"; + +describe("text prompt", () => { + it("should format prompt", () => { + const prompt = text().format("prompt"); + + expect(prompt).toMatchSnapshot(); + }); +}); + +describe("instruction prompt", () => { + it("should format prompt with instruction", () => { + const prompt = instruction().format({ + instruction: "instruction", + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with system and instruction", () => { + const prompt = instruction().format({ + system: "system", + instruction: "instruction", + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with instruction and response prefix", () => { + const prompt = instruction().format({ + instruction: "instruction", + responsePrefix: "response prefix", + }); + + expect(prompt).toMatchSnapshot(); + }); +}); + +describe("chat prompt", () => { + it("should format prompt with user message", () => { + const prompt = chat().format({ + messages: [{ role: "user", content: "user message" }], + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with user-assistant-user messages", () => { + const prompt = chat().format({ + messages: [ + { role: "user", content: "1st user message" }, + { role: "assistant", content: "assistant message" }, + { role: "user", content: "2nd user message" }, + ], + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with system message and user-assistant-user messages", () => { + const prompt = chat().format({ + system: "you are a chatbot", + messages: [ + { role: "user", content: "1st user message" }, + { role: "assistant", content: "assistant message" }, + { role: "user", content: "2nd user message" }, + ], + }); + + expect(prompt).toMatchSnapshot(); + }); +}); diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/Phi3PromptTemplate.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/Phi3PromptTemplate.ts new file mode 100644 index 000000000..cd8f25f4a --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/Phi3PromptTemplate.ts @@ -0,0 +1,116 @@ +import { TextGenerationPromptTemplate } from "../TextGenerationPromptTemplate"; +import { ChatPrompt } from "./ChatPrompt"; +import { validateContentIsString } from "./ContentPart"; +import { InstructionPrompt } from "./InstructionPrompt"; +import { InvalidPromptError } from "./InvalidPromptError"; + +const END_SEGMENT = "<|end|>"; + +function segmentStart(role: "system" | "user" | "assistant") { + return `<|${role}|>\n`; +} + +function segment( + role: "system" | "user" | "assistant", + text: string | undefined +) { + return text == null ? "" : `${segmentStart(role)}${text}${END_SEGMENT}\n`; +} + +/** + * Formats a text prompt using the Phi3 format. + */ +export function text(): TextGenerationPromptTemplate { + return { + stopSequences: [END_SEGMENT], + format(prompt) { + // prompt and then prefix start of assistant response: + return segment("user", prompt) + segmentStart("assistant"); + }, + }; +} + +/** + * Formats an instruction prompt using the Phi3 format. + * + * Phi3 prompt template: + * ``` + * <|system|> + * ${ system prompt }<|end|> + * <|user|> + * ${ instruction }<|end|> + * <|assistant|> + * ${response prefix} + * ``` + */ +export function instruction(): TextGenerationPromptTemplate< + InstructionPrompt, + string +> { + return { + stopSequences: [END_SEGMENT], + format(prompt) { + const instruction = validateContentIsString(prompt.instruction, prompt); + + return ( + segment("system", prompt.system) + + segment("user", instruction) + + segmentStart("assistant") + + (prompt.responsePrefix ?? "") + ); + }, + }; +} + +/** + * Formats a chat prompt using the Phi3 format. + * + * Phi3 prompt template: + * ``` + * <|system|> + * You are a helpful assistant that answers questions about the world.<|end|> + * <|user|> + * What is the capital of France?<|end|> + * <|assistant|> + * Paris<|end|> + * ``` + */ +export function chat(): TextGenerationPromptTemplate { + return { + format(prompt) { + let text = prompt.system != null ? segment("system", prompt.system) : ""; + + for (const { role, content } of prompt.messages) { + switch (role) { + case "user": { + text += segment("user", validateContentIsString(content, prompt)); + break; + } + case "assistant": { + text += segment( + "assistant", + validateContentIsString(content, prompt) + ); + break; + } + case "tool": { + throw new InvalidPromptError( + "Tool messages are not supported.", + prompt + ); + } + default: { + const _exhaustiveCheck: never = role; + throw new Error(`Unsupported role: ${_exhaustiveCheck}`); + } + } + } + + // prefix start of assistant response: + text += segmentStart("assistant"); + + return text; + }, + stopSequences: [END_SEGMENT], + }; +} diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/Phi3PromptTemplate.test.ts.snap b/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/Phi3PromptTemplate.test.ts.snap new file mode 100644 index 000000000..bee05099a --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/Phi3PromptTemplate.test.ts.snap @@ -0,0 +1,62 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`chat prompt > should format prompt with system message and user-assistant-user messages 1`] = ` +"<|system|> +you are a chatbot<|end|> +<|user|> +1st user message<|end|> +<|assistant|> +assistant message<|end|> +<|user|> +2nd user message<|end|> +<|assistant|> +" +`; + +exports[`chat prompt > should format prompt with user message 1`] = ` +"<|user|> +user message<|end|> +<|assistant|> +" +`; + +exports[`chat prompt > should format prompt with user-assistant-user messages 1`] = ` +"<|user|> +1st user message<|end|> +<|assistant|> +assistant message<|end|> +<|user|> +2nd user message<|end|> +<|assistant|> +" +`; + +exports[`instruction prompt > should format prompt with instruction 1`] = ` +"<|user|> +instruction<|end|> +<|assistant|> +" +`; + +exports[`instruction prompt > should format prompt with instruction and response prefix 1`] = ` +"<|user|> +instruction<|end|> +<|assistant|> +response prefix" +`; + +exports[`instruction prompt > should format prompt with system and instruction 1`] = ` +"<|system|> +system<|end|> +<|user|> +instruction<|end|> +<|assistant|> +" +`; + +exports[`text prompt > should format prompt 1`] = ` +"<|user|> +prompt<|end|> +<|assistant|> +" +`; diff --git a/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts b/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts index 05e14c9d1..021d3f66b 100644 --- a/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts +++ b/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts @@ -9,6 +9,7 @@ import * as synthiaPrompt from "../../model-function/generate-text/prompt-templa import * as textPrompt from "../../model-function/generate-text/prompt-template/TextPromptTemplate"; import * as vicunaPrompt from "../../model-function/generate-text/prompt-template/VicunaPromptTemplate"; import * as LlamaCppBakLLaVA1Prompt from "./LlamaCppBakLLaVA1PromptTemplate"; +import * as phi3Prompt from "../../model-function/generate-text/prompt-template/Phi3PromptTemplate"; import { LlamaCppCompletionPrompt } from "./LlamaCppCompletionModel"; export function asLlamaCppPromptTemplate( @@ -79,3 +80,4 @@ export const Alpaca = asLlamaCppTextPromptTemplateProvider(alpacaPrompt); export const Synthia = asLlamaCppTextPromptTemplateProvider(synthiaPrompt); export const Vicuna = asLlamaCppTextPromptTemplateProvider(vicunaPrompt); export const BakLLaVA1 = LlamaCppBakLLaVA1Prompt; +export const Phi3 = asLlamaCppTextPromptTemplateProvider(phi3Prompt); diff --git a/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts b/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts index 738d679ef..1954a29c8 100644 --- a/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts +++ b/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts @@ -8,6 +8,7 @@ import { TextGenerationPromptTemplateProvider } from "../../model-function/gener import * as synthiaPrompt from "../../model-function/generate-text/prompt-template/SynthiaPromptTemplate"; import * as textPrompt from "../../model-function/generate-text/prompt-template/TextPromptTemplate"; import * as vicunaPrompt from "../../model-function/generate-text/prompt-template/VicunaPromptTemplate"; +import * as phi3Prompt from "../../model-function/generate-text/prompt-template/Phi3PromptTemplate"; import { OllamaCompletionPrompt } from "./OllamaCompletionModel"; export function asOllamaCompletionPromptTemplate( @@ -83,3 +84,4 @@ export const Synthia = asOllamaCompletionTextPromptTemplateProvider(synthiaPrompt); export const Vicuna = asOllamaCompletionTextPromptTemplateProvider(vicunaPrompt); +export const Phi3 = asOllamaCompletionTextPromptTemplateProvider(phi3Prompt);