Implement derive macro to convert structs to GBNF rules.

This is the initial implementation of a derive macro implementation
that converts structs into GBNF grammars. This simplies the code by
allowing us to get rid of all the hardcoded GBNF strings, prevents
errors from manually editing or copying said strings, and other
benefits.

The main purpose of this implementation is to lay the foundation for
generating hyper-specific GBNF rules that will allow us to limit LLM
output to specific UUIDs. LLM can't generate a weird response for an
exit or entity ID if it is only allowed to generate a specific list of
UUIDs in its response.
This commit is contained in:
projectmoon 2024-02-01 12:41:08 +01:00
parent c6f10f7a61
commit 1e80ae508e
46 changed files with 4147 additions and 338 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
/target
**/target
surreal.db/
todo.org
config.toml

662
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +1,7 @@
[package]
name = "ai-game"
version = "0.1.0"
edition = "2021"
[workspace]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
anyhow = "1.0.75"
futures = "0.3"
eventsource-client = "0.11.0"
progenitor = { git = "https://github.com/oxidecomputer/progenitor" }
progenitor-client = { git = "https://github.com/oxidecomputer/progenitor" }
reqwest = { version = "0.11", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
async-trait = "0.1.74"
reedline = "0.27.1"
async-recursion = "1.0.5"
thiserror = "1.0.53"
strum = {version = "0.25", features = [ "derive" ] }
uuid = {version = "1.6.1", features = [ "std", "v7", "fast-rng" ] }
polodb_core = "4.4.0"
arangors = "0.5.4"
itertools = "0.12.0"
crossterm = "0.27.0"
textwrap = "0.16.0"
config = "0.13.4"
tabled = "0.15.0"
[build-dependencies]
prettyplease = "0.1.25"
progenitor = { git = "https://github.com/oxidecomputer/progenitor" }
progenitor-client = { git = "https://github.com/oxidecomputer/progenitor" }
serde_json = "1.0"
syn = "1.0"
members = [
"game",
"gbnf",
"gbnf_derive"
]

BIN
game.db

Binary file not shown.

2836
game/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

39
game/Cargo.toml Normal file
View File

@ -0,0 +1,39 @@
[package]
name = "ai-game"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
anyhow = "1.0.75"
futures = "0.3"
eventsource-client = "0.11.0"
progenitor = { git = "https://github.com/oxidecomputer/progenitor" }
progenitor-client = { git = "https://github.com/oxidecomputer/progenitor" }
reqwest = { version = "0.11", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
async-trait = "0.1.74"
reedline = "0.27.1"
async-recursion = "1.0.5"
thiserror = "1.0.53"
strum = {version = "0.25", features = [ "derive" ] }
uuid = {version = "1.6.1", features = [ "std", "v7", "fast-rng" ] }
polodb_core = "4.4.0"
arangors = "0.5.4"
itertools = "0.12.0"
crossterm = "0.27.0"
textwrap = "0.16.0"
config = "0.13.4"
tabled = "0.15.0"
gbnf = { path = "../gbnf" }
gbnf_derive = { path = "../gbnf_derive" }
[build-dependencies]
prettyplease = "0.1.25"
progenitor = { git = "https://github.com/oxidecomputer/progenitor" }
progenitor-client = { git = "https://github.com/oxidecomputer/progenitor" }
serde_json = "1.0"
syn = "1.0"

View File

@ -0,0 +1,22 @@
// use crate::models::commands::ChangeScene;
// use super::Gbnf;
// // TODO put all events in one place and change root based on event
const CHANGE_SCENE_BNF: &'static str = r#"
root ::= ChangeScene
ChangeScene ::= "{" ws "\"scenekey\":" ws string "}"
ChangeScenelist ::= "[]" | "[" ws ChangeScene ("," ws ChangeScene)* "]"
string ::= "\"" ([^"]*) "\""
boolean ::= "true" | "false"
ws ::= [ \t\n]*
number ::= [0-9]+ "."? [0-9]*
stringlist ::= "[" ws "]" | "[" ws string ("," ws string)* ws "]"
numberlist ::= "[" ws "]" | "[" ws string ("," ws number)* ws "]"
"#;
// impl Gbnf for ChangeScene {
// fn to_gbnf() -> String {
// CHANGE_SCENE_BNF.to_string()
// }
// }

0
game/src/ai/gbnf/mod.rs Normal file
View File

View File

@ -1,4 +1,5 @@
pub(self) mod coherence;
pub mod gbnf;
pub mod convo;
pub mod generator;
pub mod prompts;

View File

@ -0,0 +1,11 @@
use super::tables::exit_table;
use crate::models::world::scenes::Scene;
pub(super) const CHANGE_SCENE: &'static str = r#"
The player is moving to a new scene. Pick the correct scene key from the exits table, based on the place the player wants to go.
"#;
pub(super) fn change_scene(scene: &Scene) -> String {
// currently have exits table in beginning of prompt.
CHANGE_SCENE.replacen("{EXIT_TABLE}", &exit_table(&scene.exits).to_string(), 1)
}

View File

@ -1,18 +1,4 @@
use crate::ai::convo::AiPrompt;
pub const COMMAND_BNF: &str = r#"
root ::= Commands
Command ::= "{" ws "\"verb\":" ws string "," ws "\"target\":" ws string "," ws "\"location\":" ws string "," ws "\"using\":" ws string "}"
Commandlist ::= "[]" | "[" ws Command ("," ws Command)* "]"
Commands ::= "{" ws "\"commands\":" ws Commandlist "," ws "\"count\":" ws number "}"
Commandslist ::= "[]" | "[" ws Commands ("," ws Commands)* "]"
string ::= "\"" ([^"]*) "\""
boolean ::= "true" | "false"
ws ::= [ \t\n]*
number ::= [0-9]+ "."? [0-9]*
stringlist ::= "[" ws "]" | "[" ws string ("," ws string)* ws "]"
numberlist ::= "[" ws "]" | "[" ws string ("," ws number)* ws "]"
"#;
use crate::{ai::convo::AiPrompt, models::commands::ParsedCommands};
pub const INTRO_PROMPT: &'static str = r#"
[INST]
@ -104,8 +90,8 @@ Text: `{}`
[/INST]";
pub fn intro_prompt(cmd: &str) -> AiPrompt {
let mut prompt = INTRO_PROMPT.replace("{}", cmd);
AiPrompt::new_with_grammar(&prompt, COMMAND_BNF)
let prompt = INTRO_PROMPT.replace("{}", cmd);
AiPrompt::new_with_grammar(&prompt, &ParsedCommands::to_grammar())
}
pub fn continuation_prompt(cmd: &str) -> AiPrompt {
@ -116,11 +102,11 @@ pub fn continuation_prompt(cmd: &str) -> AiPrompt {
prompt.push_str("[/INST]");
AiPrompt::new_with_grammar(&prompt, COMMAND_BNF)
AiPrompt::new_with_grammar(&prompt, &ParsedCommands::to_grammar())
}
pub fn coherence_prompt() -> AiPrompt {
AiPrompt::new_with_grammar(COHERENCE_PROMPT, COMMAND_BNF)
AiPrompt::new_with_grammar(COHERENCE_PROMPT, &ParsedCommands::to_grammar())
}
pub fn find_verbs_prompt(cmd: &str) -> AiPrompt {

View File

@ -0,0 +1,96 @@
use crate::models::commands::{
CommandEvent, CommandEventType, EventConversionFailure, ParsedCommand, RawCommandExecution,
};
use crate::models::world::items::Item;
use crate::models::world::people::Person;
use crate::models::world::scenes::{Exit, Prop, Scene, Stage};
use crate::models::Insertable;
use itertools::Itertools;
use tabled::settings::Style;
use tabled::{Table, Tabled};
const UNKNOWN: &'static str = "unknown";
const PERSON: &'static str = "person";
const ITEM: &'static str = "item";
const PROP: &'static str = "prop";
const NO_KEY: &'static str = "n/a";
#[derive(Tabled)]
pub struct EntityTableRow<'a> {
name: &'a str,
#[tabled(rename = "type")]
entity_type: &'a str,
key: &'a str,
}
impl<'a> From<&'a Person> for EntityTableRow<'a> {
fn from(value: &'a Person) -> Self {
EntityTableRow {
name: &value.name,
key: value.key().unwrap_or(UNKNOWN),
entity_type: PERSON,
}
}
}
impl<'a> From<&'a Item> for EntityTableRow<'a> {
fn from(value: &'a Item) -> Self {
EntityTableRow {
name: &value.name,
key: value.key().unwrap_or(UNKNOWN),
entity_type: ITEM,
}
}
}
impl<'a> From<&'a Prop> for EntityTableRow<'a> {
fn from(value: &'a Prop) -> Self {
EntityTableRow {
name: &value.name,
entity_type: PROP,
key: NO_KEY,
}
}
}
#[derive(Tabled)]
pub struct ExitTableRow<'a> {
pub name: &'a str,
pub direction: &'a str,
pub scene_key: &'a str,
pub region: &'a str,
}
impl<'a> From<&'a Exit> for ExitTableRow<'a> {
fn from(value: &'a Exit) -> Self {
ExitTableRow {
name: &value.name,
direction: &value.direction,
scene_key: &value.scene_key,
region: &value.region,
}
}
}
pub(super) fn entity_table(stage: &Stage) -> Table {
let people = stage.people.iter().map_into::<EntityTableRow>();
let items = stage.items.iter().map_into::<EntityTableRow>();
let props = stage.scene.props.iter().map_into::<EntityTableRow>();
let entities = people.chain(items).chain(props);
let mut entities_table = Table::new(entities);
entities_table.with(Style::markdown());
entities_table
}
pub(super) fn exit_table<'a, I>(exits: I) -> Table
where
I: IntoIterator<Item = &'a Exit>,
{
let exits = exits.into_iter();
let mut table = Table::new(exits.map_into::<ExitTableRow>());
table.with(Style::markdown());
table
}

View File

@ -1,10 +1,10 @@
use ai::logic::AiLogic;
use anyhow::Result;
use config::Config;
use game_loop::GameLoop;
use ai::logic::AiLogic;
use models::world::scenes::{root_scene_id, Stage};
use state::GameState;
use std::{io::stdout, rc::Rc, time::Duration, str::FromStr};
use std::{io::stdout, rc::Rc, str::FromStr, time::Duration};
use arangors::Connection;
@ -68,7 +68,6 @@ fn load_config() -> Result<GameConfig> {
.build()
.unwrap();
let kobold_endpoint = settings
.get::<Option<String>>("connection.kobold_endpoint")?
.unwrap_or("http://127.0.0.1:5001/api".to_string());

View File

@ -3,6 +3,8 @@ use std::fmt::Display;
use serde::{Deserialize, Serialize};
use strum::{EnumString, EnumVariantNames};
use thiserror::Error;
use gbnf::prelude::*;
use gbnf_derive::Gbnf;
/// Stored in the database to bypass AI 'parsing' when possible.
#[derive(Debug, Serialize, Deserialize, Clone)]
@ -12,7 +14,7 @@ pub struct CachedParsedCommand {
pub commands: ParsedCommands,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, Gbnf)]
pub struct ParsedCommands {
#[serde(default)]
pub original: String, // The original text entered by the player, set by code.
@ -30,7 +32,7 @@ impl ParsedCommands {
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, Gbnf)]
pub struct ParsedCommand {
pub verb: String,
pub target: String,

12
gbnf/Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "gbnf"
version = "0.1.0"
edition = "2021"
[dependencies]
auto_impl = "1.1.2"
syn = { version = "2.0", features = [ "derive", "full", "parsing", "printing", "visit", "visit-mut", "clone-impls", "proc-macro" ] }
quote = "1.0.35"
itertools = "0.12.0"
serde = "1.0.196"
serde_derive = "1.0.196"

386
gbnf/src/lib.rs Normal file
View File

@ -0,0 +1,386 @@
extern crate proc_macro;
use itertools::Itertools;
use serde::de::DeserializeOwned;
pub mod prelude {
pub use crate::gbnf_field;
pub use crate::gbnf_field_type;
pub use crate::AsGbnf;
pub use crate::AsGrammar;
pub use crate::GbnfComplex;
pub use crate::GbnfField;
pub use crate::GbnfFieldType;
pub use crate::GbnfPrimitive;
pub use crate::GbnfRule;
pub use crate::GbnfToken;
}
// TODOs for this implementation:
// 1. Move primitive definitions (string, bool, etc) to the bottom of generated grammar.
// 2. Implement support for limited values.
// 3. Generate static strings for the gramma rules where possible.
// 4. Properly support optional types (right now they map to non-optional values).
// Converts GBNF defintions (through the types below) into the grammar
// rules.
pub trait AsGrammar {
fn rules(&self) -> Vec<GbnfRule>;
fn token(&self) -> String;
}
/// Trait for regular types to implement to convert themselves to a
/// GBNF value.
pub trait AsGbnf {
fn to_gbnf() -> GbnfFieldType;
}
macro_rules! define_field_type {
($type:ty, $gbnf_type:expr) => {
impl AsGbnf for $type {
fn to_gbnf() -> GbnfFieldType {
$gbnf_type
}
}
};
}
macro_rules! define_array_blanket_impl {
($len:expr) => {
impl<T> AsGbnf for [T; $len]
where
T: AsGbnf + DeserializeOwned,
{
fn to_gbnf() -> GbnfFieldType {
use GbnfFieldType::*;
match <T as AsGbnf>::to_gbnf() {
Primitive(primitive_type) => PrimitiveList(primitive_type),
OptionalPrimitive(primitive_type) => PrimitiveList(primitive_type),
Complex(complex_type) => ComplexList(complex_type),
OptionalComplex(complex_type) => ComplexList(complex_type),
Limited(_) => panic!("limited values are not yet supported"),
ComplexList(_) | PrimitiveList(_) => panic!("nested lists not supported"),
}
}
}
};
}
#[macro_export]
macro_rules! gbnf_field_type {
($type:ty) => {
<$type as AsGbnf>::to_gbnf()
};
}
#[macro_export]
macro_rules! gbnf_field {
($field_name:literal, $field_type:ty) => {
GbnfField {
field_name: $field_name.to_string(),
field_type: gbnf_field_type!($field_type),
}
};
}
// Implemented field type mappings for common rust types.
define_field_type!(i16, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(u16, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(i32, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(u32, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(i64, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(u64, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(f32, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(f64, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(usize, GbnfFieldType::Primitive(GbnfPrimitive::Number));
define_field_type!(bool, GbnfFieldType::Primitive(GbnfPrimitive::Boolean));
define_field_type!(String, GbnfFieldType::Primitive(GbnfPrimitive::String));
define_field_type!(char, GbnfFieldType::Primitive(GbnfPrimitive::String));
// Macro-based blanket impls for arrays
define_array_blanket_impl!(1);
define_array_blanket_impl!(3);
define_array_blanket_impl!(4);
define_array_blanket_impl!(5);
define_array_blanket_impl!(6);
define_array_blanket_impl!(7);
define_array_blanket_impl!(8);
define_array_blanket_impl!(9);
define_array_blanket_impl!(10);
define_array_blanket_impl!(11);
define_array_blanket_impl!(12);
define_array_blanket_impl!(13);
define_array_blanket_impl!(14);
define_array_blanket_impl!(15);
define_array_blanket_impl!(16);
// Blanket implementations to cover more types
impl<T> AsGbnf for Vec<T>
where
T: AsGbnf,
{
fn to_gbnf() -> GbnfFieldType {
use GbnfFieldType::*;
match <T as AsGbnf>::to_gbnf() {
Primitive(primitive_type) => PrimitiveList(primitive_type),
OptionalPrimitive(primitive_type) => PrimitiveList(primitive_type),
Complex(complex_type) => ComplexList(complex_type),
OptionalComplex(complex_type) => ComplexList(complex_type),
Limited(_) => panic!("limited values not yet supported"),
ComplexList(_) | PrimitiveList(_) => panic!("nested lists not supported"),
}
}
}
impl<T> AsGbnf for Option<T>
where
T: AsGbnf,
{
fn to_gbnf() -> GbnfFieldType {
use GbnfFieldType::*;
match <T as AsGbnf>::to_gbnf() {
Primitive(primitive_type) => OptionalPrimitive(primitive_type),
Complex(complex_type) => OptionalComplex(complex_type),
OptionalPrimitive(_) | OptionalComplex(_) => panic!("nested options are not allowed"),
Limited(_) => panic!("limited values not yet supported"),
_ => panic!("optional type cannot be a list"),
}
}
}
// Actual GBNF rule itself. Holds rule text for dedup.
#[derive(Debug, Clone, Eq, Hash, PartialEq)]
pub struct GbnfRule {
name: String,
text: String,
}
impl GbnfRule {
pub fn new(token: String, rule_text: String) -> GbnfRule {
GbnfRule {
name: token,
text: rule_text,
}
}
pub fn single(token: String, rule_text: String) -> Vec<GbnfRule> {
vec![GbnfRule::new(token, rule_text)]
}
}
/// Tokens in the GBNF rule.
pub enum GbnfToken {
Space,
}
impl GbnfToken {
pub(self) const SPACE: &'static str = r#"[ \t\n]*"#;
}
impl AsGrammar for GbnfToken {
fn rules(&self) -> Vec<GbnfRule> {
match self {
Self::Space => GbnfRule::single(self.token(), Self::SPACE.to_string()),
}
}
fn token(&self) -> String {
match self {
Self::Space => "ws".to_string(),
}
}
}
/// Represents a primitive value in the GBNF, the simplest possible
/// value a type can hold.
#[derive(Debug)]
pub enum GbnfPrimitive {
String,
Boolean,
Number,
}
impl GbnfPrimitive {
pub(self) const STRING: &'static str = r#""\"" ([^"]*) "\"""#;
pub(self) const BOOLEAN: &'static str = r#""true" | "false""#;
pub(self) const NUMBER: &'static str = r#"[0-9]+ "."? [0-9]*"#;
}
impl AsGrammar for GbnfPrimitive {
/// Output the raw GBNF rule of this primitive.
fn rules(&self) -> Vec<GbnfRule> {
let rule_text = match self {
Self::Boolean => Self::BOOLEAN,
Self::Number => Self::NUMBER,
Self::String => Self::STRING,
};
GbnfRule::single(self.token(), rule_text.to_string())
}
/// Output the token name of the GBNF rule (to refer to in other
/// rules).
fn token(&self) -> String {
String::from(match self {
Self::Boolean => "boolean",
Self::Number => "number",
Self::String => "string",
})
}
}
/// Categorize all types of fields that the generated grammar can
/// handle.
#[derive(Debug)]
pub enum GbnfFieldType {
/// A single property on the type, e.g. myField: i32
Primitive(GbnfPrimitive),
/// Can be a value or null.
OptionalPrimitive(GbnfPrimitive),
/// A list/vec of primitive types.
PrimitiveList(GbnfPrimitive),
/// A complex type, with its own properties.
Complex(GbnfComplex),
/// Can be a value or null.
OptionalComplex(GbnfComplex),
/// A list/vec of complex types.
ComplexList(GbnfComplex),
/// A single property field, but with limited values allowed,
/// constrained by the primitive type.
Limited(GbnfPrimitive),
}
impl GbnfFieldType {
pub fn as_complex(self) -> GbnfComplex {
match self {
GbnfFieldType::Complex(complex) => complex,
_ => panic!("Not a GBNF complex type"),
}
}
}
/// Connect a property name and a field type to generate a GBNF rule.
#[derive(Debug)]
pub struct GbnfField {
pub field_name: String,
pub field_type: GbnfFieldType,
}
impl GbnfField {
fn list_rule(field_type: &(impl AsGrammar + ?Sized)) -> String {
r#""[]" | "[" {SPACE} {TYPE_NAME} ("," {SPACE} {TYPE_NAME})* "]""#
.replace("{LIST_NAME}", "")
.replace("{SPACE}", &GbnfToken::Space.token())
.replace("{TYPE_NAME}", &field_type.token())
}
fn list_rules<T: AsGrammar>(&self, f: &T) -> Vec<GbnfRule> {
// Create two rules: one for the list and on for its actual type.
let list_rule = GbnfRule::new(self.token(), Self::list_rule(f));
let mut rules = vec![list_rule];
rules.append(&mut f.rules());
rules
}
}
impl AsGrammar for GbnfField {
fn token(&self) -> String {
match &self.field_type {
GbnfFieldType::Primitive(f) => f.token(),
GbnfFieldType::OptionalPrimitive(f) => f.token(),
GbnfFieldType::PrimitiveList(f) => format!("{}List", f.token()),
GbnfFieldType::Complex(f) => f.token(),
GbnfFieldType::OptionalComplex(f) => f.token(),
GbnfFieldType::ComplexList(f) => format!("{}List", f.token()),
GbnfFieldType::Limited(f) => f.token(),
_ => "".to_string(),
}
}
// TODO need to implement optional rules, which probably involves
// wrapping the primitive rule in parens, and then ORing to null.
fn rules(&self) -> Vec<GbnfRule> {
match &self.field_type {
GbnfFieldType::Complex(f) => f.rules(),
GbnfFieldType::OptionalComplex(f) => f.rules(),
GbnfFieldType::ComplexList(f) => self.list_rules(f),
GbnfFieldType::Primitive(f) => f.rules(),
GbnfFieldType::OptionalPrimitive(f) => f.rules(),
GbnfFieldType::PrimitiveList(f) => self.list_rules(f),
GbnfFieldType::Limited(f) => f.rules(),
}
}
}
/// The complex type is a direct mapping from a supported Rust struct,
/// and also used to generate the root of a GBNF grammar.
#[derive(Debug)]
pub struct GbnfComplex {
pub name: String,
pub fields: Vec<GbnfField>,
}
impl GbnfComplex {
pub fn to_grammar(&self) -> String {
let mut rules = vec![GbnfRule::new("root".to_string(), self.name.clone())];
rules.append(&mut self.rules());
for field in &self.fields {
rules.append(&mut field.rules());
}
rules
.into_iter()
.unique()
.map(|rule| format!("{} ::= {}", rule.name, rule.text))
.join("\n")
}
}
impl AsGrammar for GbnfComplex {
fn rules(&self) -> Vec<GbnfRule> {
// This will output the full set of rules for the complex type.
// Deduplication handled later.
let mut rule = String::new();
rule.push_str(r#""{" "#);
let field_rules_text = self
.fields
.iter()
.map(|field| {
let mut text = String::new();
text.push_str(&GbnfToken::Space.token());
text.push_str(" ");
text.push_str(&format!(
r#""\"{}\":" {} {}"#,
field.field_name,
GbnfToken::Space.token(),
field.token(),
));
text
})
.join(r#" "," "#);
rule.push_str(&field_rules_text);
rule.push_str(r#" "}""#);
let mut rules = GbnfRule::single(self.token(), rule);
rules.append(&mut GbnfToken::Space.rules());
rules
}
fn token(&self) -> String {
self.name.clone()
}
}

14
gbnf_derive/Cargo.toml Normal file
View File

@ -0,0 +1,14 @@
[package]
name = "gbnf_derive"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[dependencies]
auto_impl = "1.1.2"
syn = { version = "2.0", features = [ "derive", "full", "parsing", "printing", "visit", "visit-mut", "clone-impls", "proc-macro" ] }
quote = "1.0.35"
itertools = "0.12.0"
gbnf = { path = "../gbnf" }

110
gbnf_derive/src/lib.rs Normal file
View File

@ -0,0 +1,110 @@
use proc_macro::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{braced, parse_macro_input};
use syn::{DeriveInput, Field, Ident, LitStr, Token};
#[derive(Debug)]
struct GbnfStructDef {
name: Ident,
fields: Punctuated<Field, Token![,]>,
}
impl Parse for GbnfStructDef {
fn parse(input: ParseStream) -> syn::Result<Self> {
// let _ = Discard tokens we don't care about.
let _: Option<Token![pub]> = input.parse()?;
let _: Option<Token![struct]> = input.parse()?;
let content;
let name: Ident = input.parse()?;
let _ = braced!(content in input);
Ok(GbnfStructDef {
name,
fields: content.parse_terminated(Field::parse_named, Token![,])?,
})
}
}
fn generate_gbnf(input: TokenStream, create_struct: bool) -> TokenStream {
// To define complex types, we take a struct into the macro, and
// then output a bunch of calls to gbnf_field (wrapped in gbnf
// complex).
// We could also generate the entire complex type now during macro
// run, and then shove the resulting GBNF rule into the type as a
// static string.
if let Ok(expr_struct) = syn::parse::<GbnfStructDef>(input.clone()) {
let struct_name_str = LitStr::new(&expr_struct.name.to_string(), Span::call_site().into());
let struct_name = expr_struct.name;
let fields = expr_struct.fields.iter();
let gbnfs: Vec<_> = expr_struct
.fields
.iter()
.map(|field| {
let field_type = &field.ty;
let field_ident = field
.ident
.clone()
.map(|i| i.to_string())
.map(|field_name| LitStr::new(&field_name, Span::call_site().into()))
.expect("no ident");
quote! { gbnf_field!(#field_ident, #field_type) }
})
.collect();
let struct_frag = if create_struct {
quote! {
pub struct #struct_name {
#(#fields),*
}
}
} else {
quote! {}
};
let code = quote! {
#struct_frag
impl #struct_name {
pub fn to_grammar() -> String {
Self::to_gbnf().as_complex().to_grammar()
}
}
impl AsGbnf for #struct_name {
fn to_gbnf() -> gbnf::GbnfFieldType {
GbnfFieldType::Complex(
GbnfComplex {
name: String::from(#struct_name_str),
fields: vec![#(#gbnfs),*]
}
)
}
}
};
code.into()
} else {
panic!("Can only generate GBNF from structs (pub or private)");
}
}
/// Create a GBNF complex type as a Rust struct.
#[proc_macro]
pub fn gbnf_complex(input: TokenStream) -> TokenStream {
generate_gbnf(input, true)
}
/// Add the ability to convert a Rust type into a GBNF grammar.
#[proc_macro_derive(Gbnf)]
pub fn gbnf(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
generate_gbnf(input.to_token_stream().into(), false)
}

217
src/ai/gbnf/mod.rs Normal file
View File

@ -0,0 +1,217 @@
extern crate proc_macro;
use auto_impl::auto_impl;
use itertools::Itertools;
mod events;
// Actual GBNF rule itself. Holds rule text for dedup.
#[derive(Debug, Clone, Eq, Hash, PartialEq)]
pub struct GbnfRule {
name: String,
text: String,
}
impl GbnfRule {
pub fn new(token: String, rule_text: String) -> GbnfRule {
GbnfRule {
name: token,
text: rule_text,
}
}
pub fn single(token: String, rule_text: String) -> Vec<GbnfRule> {
vec![GbnfRule::new(token, rule_text)]
}
}
//token() returns the gbnf identifier for the rule.
//rule() returns the rule itself.
#[auto_impl(&, Box)]
pub trait TokenAndRule {
fn rules(&self) -> Vec<GbnfRule>;
fn token(&self) -> String;
}
pub enum GbnfToken {
Space,
}
impl GbnfToken {
pub(self) const SPACE: &'static str = r#"[ \t\n]*"#;
}
impl TokenAndRule for GbnfToken {
fn rules(&self) -> Vec<GbnfRule> {
match self {
Self::Space => GbnfRule::single(self.token(), Self::SPACE.to_string()),
}
}
fn token(&self) -> String {
match self {
Self::Space => "ws".to_string(),
}
}
}
#[derive(Debug)]
pub enum GbnfPrimitive {
String,
Boolean,
Number,
}
impl GbnfPrimitive {
pub(self) const STRING: &'static str = r#""\"" ([^"]*) "\""#;
pub(self) const BOOLEAN: &'static str = r#""true" | "false""#;
pub(self) const NUMBER: &'static str = r#"[0-9]+ "."? [0-9]*"#;
}
impl TokenAndRule for GbnfPrimitive {
/// Output the raw GBNF rule of this primitive.
fn rules(&self) -> Vec<GbnfRule> {
let rule_text = match self {
Self::Boolean => Self::BOOLEAN,
Self::Number => Self::NUMBER,
Self::String => Self::STRING,
};
GbnfRule::single(self.token(), rule_text.to_string())
}
/// Output the token name of the GBNF rule (to refer to in other
/// rules).
fn token(&self) -> String {
String::from(match self {
Self::Boolean => "boolean",
Self::Number => "number",
Self::String => "string",
})
}
}
#[derive(Debug)]
pub enum FieldType {
/// A single property on the type, e.g. myField: i32
Primitive(GbnfPrimitive),
/// A complex property, with its own properties.
Complex(GbnfType),
/// A list/vec of primitive types.
PrimitiveList(GbnfPrimitive),
/// A list/vec of complex types.
ComplexList(GbnfType),
/// A single property field, but with limited values allowed,
/// constrained by the primitive type.
Limited(GbnfPrimitive),
}
#[derive(Debug)]
pub struct GbnfField {
pub field_name: String,
pub field_type: FieldType,
}
#[derive(Debug)]
pub struct GbnfType {
pub name: String,
pub fields: Vec<GbnfField>,
}
impl GbnfField {
fn list_rule(field_type: &(impl TokenAndRule + ?Sized)) -> String {
r#""[]" | "[" {SPACE} {TYPE_NAME} ("," {SPACE} {TYPE_NAME})* "]""#
.replace("{LIST_NAME}", "")
.replace("{SPACE}", &GbnfToken::Space.token())
.replace("{TYPE_NAME}", &field_type.token())
}
fn list_rules<T: TokenAndRule>(&self, f: &T) -> Vec<GbnfRule> {
// Create two rules: one for the list and on for its actual type.
let list_rule = GbnfRule::new(self.token(), Self::list_rule(f));
let mut rules = vec![list_rule];
rules.append(&mut f.rules());
rules
}
}
impl TokenAndRule for GbnfField {
fn token(&self) -> String {
match &self.field_type {
FieldType::Primitive(f) => f.token(),
FieldType::PrimitiveList(f) => format!("{}_List", f.token()),
FieldType::Complex(f) => f.token(),
FieldType::ComplexList(f) => format!("{}_List", f.token()),
FieldType::Limited(f) => f.token(),
_ => "".to_string(),
}
}
fn rules(&self) -> Vec<GbnfRule> {
match &self.field_type {
FieldType::ComplexList(f) => self.list_rules(f),
FieldType::Complex(f) => f.rules(),
FieldType::PrimitiveList(f) => self.list_rules(f),
FieldType::Primitive(f) => f.rules(),
FieldType::Limited(f) => f.rules(),
}
}
}
impl TokenAndRule for GbnfType {
fn rules(&self) -> Vec<GbnfRule> {
// This will output the full set of rules for the complex type.
// Deduplication handled later.
let mut rule = String::new();
rule.push_str(r#""{ "#);
let field_rules_text = self
.fields
.iter()
.map(|field| {
let mut text = String::new();
text.push_str(&GbnfToken::Space.token());
text.push_str(" ");
text.push_str(&format!(
r#""\"{}\":" {} {}"#,
field.field_name,
GbnfToken::Space.token(),
field.token(),
));
text
})
.join(r#" "," "#);
rule.push_str(&field_rules_text);
rule.push_str(r#" "}""#);
GbnfRule::single(self.token(), rule)
}
fn token(&self) -> String {
self.name.clone()
}
}
pub fn create_gbnf(gbnf_type: GbnfType) -> String {
let mut rules = vec![GbnfRule::new("root".to_string(), gbnf_type.name.clone())];
rules.append(&mut gbnf_type.rules());
for field in gbnf_type.fields {
rules.append(&mut field.rules());
}
rules
.into_iter()
.unique()
.map(|rule| format!("{} ::= {}", rule.name, rule.text))
.join("\n")
}