ai-game/gbnf_derive/src/lib.rs

285 lines
8.9 KiB
Rust

use darling::ast::Data;
use darling::{FromDeriveInput, FromField};
use proc_macro::{Span, TokenStream};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{parse_macro_input, Attribute};
use syn::{DeriveInput, Ident, LitStr};
#[derive(Clone, Debug, FromField)]
#[darling(forward_attrs(gbnf_limit_primitive, gbnf_limit_complex))]
struct GbnfFieldDef {
ident: Option<syn::Ident>,
ty: syn::Type,
vis: syn::Visibility,
attrs: Vec<syn::Attribute>,
}
enum LimitType {
LimitedPrimitive,
LimitedComplex,
NotLimited,
}
impl GbnfFieldDef {
fn limit_type(&self) -> LimitType {
let limit_priv = self
.attrs
.iter()
.find(|attr| attr.path().get_ident().unwrap().to_string() == "gbnf_limit_primitive");
let limit_complex = self
.attrs
.iter()
.find(|attr| attr.path().get_ident().unwrap().to_string() == "gbnf_limit_complex");
match (limit_priv, limit_complex) {
(Some(_), None) => LimitType::LimitedPrimitive,
(None, Some(_)) => LimitType::LimitedComplex,
(Some(_), Some(_)) => panic!("cannot be both a primitive and complex limit"),
(None, None) => LimitType::NotLimited,
}
}
}
impl ToTokens for GbnfFieldDef {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let ident = &self.ident;
let ty = &self.ty; // TODO figure out how to get ident out
let vis = &self.vis;
let wrapper_type = match self.limit_type() {
LimitType::LimitedPrimitive => quote! { GbnfLimitedPrimitive<#ty> },
LimitType::LimitedComplex => quote! { GbnfLimitedComplex<#ty> },
_ => panic!("somehow attempting to make a limited field without a helper attr"),
};
let output = quote! {
#vis #ident: #wrapper_type
};
output.to_tokens(tokens);
}
}
#[derive(Debug, FromDeriveInput)]
#[darling(supports(struct_named))]
struct GbnfStructDef {
ident: syn::Ident,
data: darling::ast::Data<(), GbnfFieldDef>,
}
impl ToTokens for GbnfStructDef {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let ident = &self.ident;
let fields = match self.data {
Data::Struct(ref struct_def) => &struct_def.fields,
_ => panic!("Can only use GbnfLimit on structs with owned data"),
};
let output = quote! {
pub struct #ident {
#(#fields),*
}
};
tokens.extend(output);
}
}
fn is_limited_field(attr: &Attribute) -> bool {
let ident = attr.path().get_ident().unwrap().to_string();
ident == "gbnf_limit_primitive" || ident == "gbnf_limit_complex"
}
/// Find fields in the struct with a #[gbnf_limit] attribute.
fn find_limited_fields(fields: &[GbnfFieldDef]) -> impl Iterator<Item = &GbnfFieldDef> + '_ {
fields.iter().filter(|field| {
field
.attrs
.iter()
.find(|attr| is_limited_field(attr))
.is_some()
})
}
/// Find fields in the struct without a #[gbnf_limit] attribute.
fn find_non_limited_fields(fields: &[GbnfFieldDef]) -> impl Iterator<Item = &GbnfFieldDef> + '_ {
fields.iter().filter(|field| {
field
.attrs
.iter()
.find(|attr| is_limited_field(attr))
.is_none()
})
}
fn generate_to_grammar_impl(original_struct_name: &Ident, fields: &[GbnfFieldDef]) -> TokenStream2 {
let limit_struct_name = format_ident!("{}GbnfLimit", original_struct_name);
// Convert limit struct field names to string literals
let limit_struct_fields: Vec<_> = find_limited_fields(fields)
.map(|field| quote! { #field })
.collect();
// Convert provided values of the limit struct into corresponding
// GbnfLimit instances. Bunch of tuples fed into HashMap::from.
let from_assignments = find_limited_fields(fields).map(|field| {
let key = LitStr::new(
&field.ident.as_ref().expect("no ident!").to_string(),
Span::call_site().into(),
)
.to_token_stream();
let ident = &field.ident;
let value = quote! { self.#ident.to_gbnf_limit() };
quote! { (#key, #value) }
});
let as_gbnf_complex_impl = quote! {
impl AsGbnfComplex for #original_struct_name {
fn to_gbnf_complex() -> GbnfComplex {
Self::to_gbnf().as_complex()
}
}
};
if limit_struct_fields.len() > 0 {
quote! {
pub struct #limit_struct_name {
#(#limit_struct_fields),*
}
impl GbnfLimitStructMarker for #limit_struct_name {}
impl AsGbnfLimit for #limit_struct_name {
fn to_gbnf_limit(self) -> GbnfLimit {
GbnfLimit::Complex(
HashMap::from([
#(#from_assignments),*
])
)
}
}
impl GbnfLimitType for #original_struct_name {
type Type = #limit_struct_name;
}
impl GbnfLimitTypeContainer<#original_struct_name> for #original_struct_name {
type ContainerType = #limit_struct_name;
}
impl #original_struct_name {
pub fn to_grammar_with_limit(limit: #limit_struct_name) -> String {
let gbnf_limit = limit.to_gbnf_limit();
Self::to_gbnf().as_complex().to_grammar(Some(gbnf_limit))
}
}
#as_gbnf_complex_impl
}
} else {
quote! {
impl #original_struct_name {
pub fn to_grammar() -> &'static str {
use std::sync::OnceLock;
static GRAMMAR: OnceLock<String> = OnceLock::new();
GRAMMAR.get_or_init(|| Self::to_gbnf().as_complex().to_grammar(None))
}
}
#as_gbnf_complex_impl
}
}
}
/// Generate the GBNF rules and the limit struct (if applicable).
fn generate_gbnf(input: &DeriveInput) -> TokenStream {
if let Ok(gbnf_struct) = GbnfStructDef::from_derive_input(input) {
let struct_name = gbnf_struct.ident;
let struct_name_str = LitStr::new(&struct_name.to_string(), Span::call_site().into());
let fields = match gbnf_struct.data {
Data::Struct(struct_def) => struct_def.fields,
_ => panic!("Can only use GbnfLimit on structs with owned data"),
};
// Gbnf rule generation stuff
let limited_gbnfs = map_limited_gbnfs(&fields);
let non_limited_gbnfs = map_non_limited_gbnfs(&fields);
let gbnfs = limited_gbnfs.chain(non_limited_gbnfs);
let as_gbnf_impl = quote! {
impl AsGbnf for #struct_name {
fn to_gbnf() -> gbnf::GbnfFieldType {
GbnfFieldType::Complex(
GbnfComplex {
name: String::from(#struct_name_str),
fields: vec![#(#gbnfs),*]
}
)
}
}
};
let to_grammar_impl = generate_to_grammar_impl(&struct_name, &fields);
let final_output = quote! {
use gbnf::prelude::*;
#as_gbnf_impl
#to_grammar_impl
};
final_output.into()
} else {
panic!("Can only use GbnfLimit on structs with owned data");
}
}
/// Turn the fields without the limit attr into GBNF rule definitions.
fn map_non_limited_gbnfs(fields: &[GbnfFieldDef]) -> impl Iterator<Item = TokenStream2> + '_ {
find_non_limited_fields(fields).map(|field| {
let field_type = &field.ty;
let field_ident = field
.ident
.as_ref()
.map(|i| i.to_string())
.map(|field_name| LitStr::new(&field_name, Span::call_site().into()))
.expect("no ident");
quote! { gbnf_field!(#field_ident, #field_type) }
})
}
/// Turn the fields with #[gbnf_limit] attr into GBNF rule definitions.
fn map_limited_gbnfs(fields: &[GbnfFieldDef]) -> impl Iterator<Item = TokenStream2> + '_ {
find_limited_fields(fields).map(|field| {
let field_type = &field.ty;
let field_ident = field
.ident
.as_ref()
.map(|i| i.to_string())
.map(|field_name| LitStr::new(&field_name, Span::call_site().into()))
.expect("no ident");
quote! {
GbnfField {
field_name: #field_ident.to_string(),
field_type: gbnf_field_type!(#field_type),
limited: true,
}
}
})
}
/// Convert a Rust type into a GBNF grammar.
#[proc_macro_derive(Gbnf, attributes(gbnf_limit_primitive, gbnf_limit_complex))]
pub fn gbnf(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
generate_gbnf(&input)
}