diff --git a/game/src/ai/prompts/parsing_prompts.rs b/game/src/ai/prompts/parsing_prompts.rs index 3769efb..79d8a01 100644 --- a/game/src/ai/prompts/parsing_prompts.rs +++ b/game/src/ai/prompts/parsing_prompts.rs @@ -1,4 +1,7 @@ -use crate::{ai::convo::AiPrompt, models::commands::ParsedCommands}; +use crate::{ + ai::convo::AiPrompt, + models::commands::{ParsedCommands, VerbsResponse}, +}; pub const INTRO_PROMPT: &'static str = r#" [INST] @@ -69,18 +72,6 @@ Check the generated commands for coherence according to these instructions. Your [/INST] "#; -pub const FIND_VERBS_BNF: &str = r#" -root ::= Verbs -Verbs ::= "{" ws "\"verbs\":" ws stringlist "}" -Verbslist ::= "[]" | "[" ws Verbs ("," ws Verbs)* "]" -string ::= "\"" ([^"]*) "\"" -boolean ::= "true" | "false" -ws ::= [ \t\n]* -number ::= [0-9]+ "."? [0-9]* -stringlist ::= "[" ws "]" | "[" ws string ("," ws string)* ws "]" -numberlist ::= "[" ws "]" | "[" ws string ("," ws number)* ws "]" -"#; - pub const FIND_VERBS_PROMPT: &'static str = " [INST] @@ -91,7 +82,7 @@ Text: `{}` pub fn intro_prompt(cmd: &str) -> AiPrompt { let prompt = INTRO_PROMPT.replace("{}", cmd); - AiPrompt::new_with_grammar(&prompt, &ParsedCommands::to_grammar()) + AiPrompt::new_with_grammar(&prompt, ParsedCommands::to_grammar()) } pub fn continuation_prompt(cmd: &str) -> AiPrompt { @@ -102,14 +93,14 @@ pub fn continuation_prompt(cmd: &str) -> AiPrompt { prompt.push_str("[/INST]"); - AiPrompt::new_with_grammar(&prompt, &ParsedCommands::to_grammar()) + AiPrompt::new_with_grammar(&prompt, ParsedCommands::to_grammar()) } pub fn coherence_prompt() -> AiPrompt { - AiPrompt::new_with_grammar(COHERENCE_PROMPT, &ParsedCommands::to_grammar()) + AiPrompt::new_with_grammar(COHERENCE_PROMPT, ParsedCommands::to_grammar()) } pub fn find_verbs_prompt(cmd: &str) -> AiPrompt { let prompt = FIND_VERBS_PROMPT.replace("{}", cmd); - AiPrompt::new_with_grammar(&prompt, FIND_VERBS_BNF) + AiPrompt::new_with_grammar(&prompt, VerbsResponse::to_grammar()) } diff --git a/game/src/models/commands.rs b/game/src/models/commands.rs index ec165a0..e3012cb 100644 --- a/game/src/models/commands.rs +++ b/game/src/models/commands.rs @@ -40,7 +40,7 @@ pub struct ParsedCommand { pub using: String, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Gbnf)] pub struct VerbsResponse { pub verbs: Vec, } diff --git a/gbnf_derive/src/lib.rs b/gbnf_derive/src/lib.rs index 7d206ba..9d284aa 100644 --- a/gbnf_derive/src/lib.rs +++ b/gbnf_derive/src/lib.rs @@ -68,8 +68,10 @@ fn generate_gbnf(input: TokenStream, create_struct: bool) -> TokenStream { #struct_frag impl #struct_name { - pub fn to_grammar() -> String { - Self::to_gbnf().as_complex().to_grammar() + pub fn to_grammar() -> &'static str { + use std::sync::OnceLock; + static GRAMMAR: OnceLock = OnceLock::new(); + GRAMMAR.get_or_init(|| Self::to_gbnf().as_complex().to_grammar()) } }