diff --git a/Cargo.toml b/Cargo.toml index 8adf109..15dc02b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ members = [ "tools/conformance", "protokit", "protokit_build", - "protokit_proto", "protokit_grpc", "protokit_desc", "protokit_textformat", @@ -53,7 +52,6 @@ textformat = { path = "protokit_textformat", package = "protokit_textformat", ve desc = { path = "protokit_desc", package = "protokit_desc", version = "0.2.0" } grpc = { path = "protokit_grpc", package = "protokit_grpc", version = "0.2.0" } derive = { path = "protokit_derive", package = "protokit_derive", version = "0.2.0" } -proto = { path = "protokit_proto", package = "protokit_proto", version = "0.2.0", default-features = false } protokit = { path = "protokit", version = "0.2.0" } lex_core = { package = "lexical-core", version = "0.8.5" } diff --git a/protokit_build/Cargo.toml b/protokit_build/Cargo.toml index 30759f3..cd998d1 100644 --- a/protokit_build/Cargo.toml +++ b/protokit_build/Cargo.toml @@ -9,22 +9,14 @@ authors = [""] description = "Usable protocol buffers" [features] -default = ["protoc"] -protoc = ["descriptors"] -parser = ["proto"] - -descriptors = ["desc/descriptors"] +default = [] [dependencies.binformat] workspace = true [dependencies.desc] workspace = true - -[dependencies.proto] -workspace = true -default-features = false -optional = true +features = ["descriptors"] [dependencies] anyhow = { workspace = true } diff --git a/protokit_build/src/deps.rs b/protokit_build/src/deps.rs deleted file mode 100644 index d4dd7d1..0000000 --- a/protokit_build/src/deps.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub use std::fs::{create_dir_all, File}; -pub use std::io::Write; -pub use std::path::{Path, PathBuf}; - -pub use anyhow::bail; -pub use desc::arcstr::ArcStr; -pub use desc::*; diff --git a/protokit_build/src/filegen/grpc.rs b/protokit_build/src/filegen/grpc.rs index 4267633..12f696b 100644 --- a/protokit_build/src/filegen/grpc.rs +++ b/protokit_build/src/filegen/grpc.rs @@ -1,320 +1,350 @@ use convert_case::{Case, Casing}; -use desc::{FileDef, RpcDef, ServiceDef}; use quote::__private::TokenStream; use quote::{format_ident, quote}; -use crate::filegen::{rustify_name, CodeGenerator}; - -impl CodeGenerator<'_> { - pub fn generate_server(&self, file: &FileDef, svc: &ServiceDef) -> TokenStream { - let svc_qualified_raw_name = format!("{}.{}", file.package, svc.name); - - let svc_name = format_ident!("{}", rustify_name(svc.name.as_str())); - let mod_name = format_ident!("{}_server", rustify_name(svc.name.as_str())); - let server_name = format_ident!("{}Server", rustify_name(svc.name.as_str())); - - let mut trait_items = vec![]; - let mut arms = vec![]; - let mut defs = vec![]; - - for (_, rpc) in &svc.rpc { - let rpc: &RpcDef = rpc; - let rpc_struct = format_ident!("{}Svc", rpc.name.as_str()); - let method_name = format_ident!("{}", rpc.name.as_str().to_case(Case::Snake)); - let stream_name = format_ident!("{}Stream", rpc.name.as_str()); - let path = format!("/{}.{}/{}", file.package, svc.name, rpc.name); - - let raw_req_type = self - .base_type(&rpc.req_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); - - let raw_res_type = self - .base_type(&rpc.res_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); - - let mut rpc_kind_method = quote! { unary }; - - let mut req_type = raw_req_type.clone(); - let res_type; - let response_type; - let mut stream_def = quote! {}; - - let svc_type = match (&rpc.req_stream, &rpc.res_stream) { - (false, false) => { - req_type = quote! { super::#raw_req_type }; - res_type = quote! { super::#raw_res_type }; - response_type = quote! { Self::Response }; - quote! { UnaryService } - } - (true, false) => { - req_type = quote! { tonic::Streaming }; - res_type = quote! { super::#raw_res_type }; - response_type = quote! { Self::Response }; - rpc_kind_method = quote! { client_streaming }; - quote! { ClientStreamingService } - } - (false, true) => { - req_type = quote! { super::#req_type }; - res_type = quote! { Self::#stream_name }; - stream_def = quote! { - type ResponseStream = S::#stream_name; - }; - trait_items.push(quote! { - type #stream_name: Stream> + Send + 'static; - }); - response_type = quote! { Self::ResponseStream }; - rpc_kind_method = quote! { server_streaming }; - quote! { ServerStreamingService } - } - (true, true) => { - req_type = quote! { tonic::Streaming }; - res_type = quote! { Self::#stream_name }; - stream_def = quote! { - type ResponseStream = S::#stream_name; - }; - trait_items.push(quote! { - type #stream_name: Stream> + Send + 'static; - }); - response_type = quote! { Self::ResponseStream }; - rpc_kind_method = quote! { streaming }; - quote! { StreamingService } - } - }; - - trait_items.push(quote! { - async fn #method_name(&self, req: tonic::Request<#req_type>) -> Result, tonic::Status>; - }); - - defs.push(quote! { - struct #rpc_struct(Arc); - impl tonic::server::#svc_type for #rpc_struct { - type Response = super::#raw_res_type; - #stream_def - type Future = BoxFuture< - tonic::Response<#response_type>, - tonic::Status, - >; - - fn call(&mut self, request: tonic::Request<#req_type>) -> Self::Future { - let inner = self.0.clone(); - Box::pin(async move { inner.#method_name(request).await }) - } - } - }); - arms.push(quote! { - #path => { - let inner = self.0.clone(); - let fut = async move { - let method = #rpc_struct(inner); - let codec = protokit::grpc::TonicCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec); - let res = grpc.#rpc_kind_method(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - }); - } - quote! { - mod #mod_name { - use super::*; - use protokit::grpc::*; - #[protokit::grpc::async_trait] - pub trait #svc_name : Send + Sync + 'static { - #(#trait_items)* - } - #[derive(Debug)] - pub struct #server_name (pub Arc); - impl Clone for #server_name { - fn clone(&self) -> Self { - Self(self.0.clone()) - } - } - impl From for #server_name { - fn from(v: S) -> Self { - Self(::std::sync::Arc::new(v)) - } - } - impl From<::std::sync::Arc> for #server_name { - fn from(v: ::std::sync::Arc) -> Self { - Self(v) - } - } - - #(#defs)* +use crate::{DescriptorPool, ServiceDescriptorProto}; +use super::{rustify_name, rust_type_name, Options}; - impl Service> for #server_name - where - S: #svc_name, - B: Body + Send + 'static, - B::Error: Into> + Send + 'static, +fn base_type_from_type_name(pool: &DescriptorPool, opts: &Options, type_name: &str) -> TokenStream { + let rust_name = rust_type_name(pool, type_name); + let ident = format_ident!("{}", rust_name); + let gen = opts.generics.struct_use_generics(); + quote! { #ident #gen } +} - { - type Response = http::Response; - type Error = core::convert::Infallible; - type Future = BoxFuture; +pub fn generate_server(pool: &DescriptorPool, opts: &Options, svc: &ServiceDescriptorProto) -> TokenStream { + let svc_name_str = svc.name.as_deref().unwrap_or(""); - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: http::Request) -> Self::Future { - match req.uri().path() { - #(#arms)* - _ => Box::pin(async move { - Ok( - http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(empty_body()) - .unwrap(), - ) - }) - } - } - } - impl tonic::server::NamedService for #server_name { - const NAME: &'static str = #svc_qualified_raw_name; - } + // We need to find the file this service is in to get the package + let mut package = ""; + for file in &pool.descriptor_set.file { + for file_svc in &file.service { + if file_svc.name.as_deref() == Some(svc_name_str) { + package = file.package.as_deref().unwrap_or(""); + break; } - pub use #mod_name::*; } } - pub fn generate_client(&self, file: &FileDef, svc: &ServiceDef) -> TokenStream { - let _svc_qualified_raw_name = format!("{}.{}", file.package, svc.name); + let svc_qualified_raw_name = format!("{}.{}", package, svc_name_str); - let _svc_name = format_ident!("{}", rustify_name(svc.name.as_str())); - let mod_name = format_ident!("{}_client", rustify_name(svc.name.as_str())); - let client_name = format_ident!("{}Client", rustify_name(svc.name.as_str())); + let svc_name = format_ident!("{}", rustify_name(svc_name_str)); + let mod_name = format_ident!("{}_server", rustify_name(svc_name_str)); + let server_name = format_ident!("{}Server", rustify_name(svc_name_str)); - let mut methods = vec![]; - for (_, rpc) in &svc.rpc { - let rpc: &RpcDef = rpc; - // let rpc_struct = format_ident!("{}Svc", rpc.name.as_str()); - let method_name = format_ident!("{}", rpc.name.as_str().to_case(Case::Snake)); - let _stream_name = format_ident!("{}Stream", rpc.name.as_str()); - let path = format!("/{}.{}/{}", file.package, svc.name, rpc.name); + let mut trait_items = vec![]; + let mut arms = vec![]; + let mut defs = vec![]; - let raw_req_type = self - .base_type(&rpc.req_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); + for rpc in &svc.method { + let rpc_name = rpc.name.as_deref().unwrap_or(""); + let rpc_struct = format_ident!("{}Svc", rpc_name); + let method_name = format_ident!("{}", rpc_name.to_case(Case::Snake)); + let stream_name = format_ident!("{}Stream", rpc_name); + let path = format!("/{}.{}/{}", package, svc_name_str, rpc_name); - let raw_res_type = self - .base_type(&rpc.res_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); + let req_type_name = rpc.input_type.as_deref().unwrap_or(""); + let res_type_name = rpc.output_type.as_deref().unwrap_or(""); - let _req_type = raw_req_type.clone(); - let _res_type = raw_res_type.clone(); + let raw_req_type = base_type_from_type_name(pool, opts, req_type_name); + let raw_res_type = base_type_from_type_name(pool, opts, res_type_name); - let mut rpc_kind_method = quote! { unary }; + let req_stream = rpc.client_streaming.unwrap_or(false); + let res_stream = rpc.server_streaming.unwrap_or(false); - let res_type = if rpc.res_stream { - quote!( tonic::Streaming ) - } else { - quote! { super::#raw_res_type } - }; + let mut rpc_kind_method = quote! { unary }; + let mut stream_def = quote! {}; - let (req_type, req_convert) = if rpc.req_stream { + let (req_type, res_type, response_type, svc_type) = match (req_stream, res_stream) { + (false, false) => { + ( + quote! { super::#raw_req_type }, + quote! { super::#raw_res_type }, + quote! { Self::Response }, + quote! { UnaryService }, + ) + } + (true, false) => { + rpc_kind_method = quote! { client_streaming }; ( - quote! { impl tonic::IntoStreamingRequest }, - quote! { into_streaming_request }, + quote! { tonic::Streaming }, + quote! { super::#raw_res_type }, + quote! { Self::Response }, + quote! { ClientStreamingService }, ) - } else { + } + (false, true) => { + stream_def = quote! { + type ResponseStream = S::#stream_name; + }; + trait_items.push(quote! { + type #stream_name: Stream> + Send + 'static; + }); + rpc_kind_method = quote! { server_streaming }; ( - quote! { impl tonic::IntoRequest }, - quote! { into_request }, + quote! { super::#raw_req_type }, + quote! { Self::#stream_name }, + quote! { Self::ResponseStream }, + quote! { ServerStreamingService }, ) - }; + } + (true, true) => { + stream_def = quote! { + type ResponseStream = S::#stream_name; + }; + trait_items.push(quote! { + type #stream_name: Stream> + Send + 'static; + }); + rpc_kind_method = quote! { streaming }; + ( + quote! { tonic::Streaming }, + quote! { Self::#stream_name }, + quote! { Self::ResponseStream }, + quote! { StreamingService }, + ) + } + }; + + trait_items.push(quote! { + async fn #method_name(&self, req: tonic::Request<#req_type>) -> Result, tonic::Status>; + }); - match (&rpc.req_stream, &rpc.res_stream) { - (false, false) => {} - (true, false) => { - rpc_kind_method = quote! { client_streaming }; + defs.push(quote! { + struct #rpc_struct(Arc); + impl tonic::server::#svc_type for #rpc_struct { + type Response = super::#raw_res_type; + #stream_def + type Future = BoxFuture< + tonic::Response<#response_type>, + tonic::Status, + >; + + fn call(&mut self, request: tonic::Request<#req_type>) -> Self::Future { + let inner = self.0.clone(); + Box::pin(async move { inner.#method_name(request).await }) } - (false, true) => { - rpc_kind_method = quote! { server_streaming }; + } + }); + arms.push(quote! { + #path => { + let inner = self.0.clone(); + let fut = async move { + let method = #rpc_struct(inner); + let codec = protokit::grpc::TonicCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec); + let res = grpc.#rpc_kind_method(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + }); + } + + quote! { + mod #mod_name { + use super::*; + use protokit::grpc::*; + #[protokit::grpc::async_trait] + pub trait #svc_name : Send + Sync + 'static { + #(#trait_items)* + } + #[derive(Debug)] + pub struct #server_name (pub Arc); + impl Clone for #server_name { + fn clone(&self) -> Self { + Self(self.0.clone()) } - (true, true) => { - rpc_kind_method = quote! { streaming }; + } + impl From for #server_name { + fn from(v: S) -> Self { + Self(::std::sync::Arc::new(v)) } - }; - methods.push(quote! { - pub async fn #method_name( - &mut self, - request: #req_type, - ) -> Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - Status::new(Code::Unknown, format!("Service was not ready: {}", e.into())) - })?; - let codec = protokit - ::grpc::TonicCodec::default(); - let path = http::uri::PathAndQuery::from_static(#path); - self.inner.#rpc_kind_method(request.#req_convert(), path, codec).await + } + impl From<::std::sync::Arc> for #server_name { + fn from(v: ::std::sync::Arc) -> Self { + Self(v) } - }) - } - quote! { - mod #mod_name { - use super::*; - use protokit::grpc::*; - #[derive(Debug, Clone)] - pub struct #client_name { - inner: tonic::client::Grpc, + } + + #(#defs)* + + impl Service> for #server_name + where + S: #svc_name, + B: Body + Send + 'static, + B::Error: Into> + Send + 'static, + + { + type Response = http::Response; + type Error = core::convert::Infallible; + type Future = BoxFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - impl #client_name { - pub async fn connect(dst: D) -> Result - where - D: core::convert::TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) + + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + #(#arms)* + _ => Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) } } - impl #client_name + } + impl tonic::server::NamedService for #server_name { + const NAME: &'static str = #svc_qualified_raw_name; + } + } + pub use #mod_name::*; + } +} + +pub fn generate_client(pool: &DescriptorPool, opts: &Options, svc: &ServiceDescriptorProto) -> TokenStream { + let svc_name_str = svc.name.as_deref().unwrap_or(""); + + // We need to find the file this service is in to get the package + let mut package = ""; + for file in &pool.descriptor_set.file { + for file_svc in &file.service { + if file_svc.name.as_deref() == Some(svc_name_str) { + package = file.package.as_deref().unwrap_or(""); + break; + } + } + } + + let _svc_qualified_raw_name = format!("{}.{}", package, svc_name_str); + + let _svc_name = format_ident!("{}", rustify_name(svc_name_str)); + let mod_name = format_ident!("{}_client", rustify_name(svc_name_str)); + let client_name = format_ident!("{}Client", rustify_name(svc_name_str)); + + let mut methods = vec![]; + for rpc in &svc.method { + let rpc_name = rpc.name.as_deref().unwrap_or(""); + let method_name = format_ident!("{}", rpc_name.to_case(Case::Snake)); + let _stream_name = format_ident!("{}Stream", rpc_name); + let path = format!("/{}.{}/{}", package, svc_name_str, rpc_name); + + let req_type_name = rpc.input_type.as_deref().unwrap_or(""); + let res_type_name = rpc.output_type.as_deref().unwrap_or(""); + + let raw_req_type = base_type_from_type_name(pool, opts, req_type_name); + let raw_res_type = base_type_from_type_name(pool, opts, res_type_name); + + let req_stream = rpc.client_streaming.unwrap_or(false); + let res_stream = rpc.server_streaming.unwrap_or(false); + + let mut rpc_kind_method = quote! { unary }; + + let res_type = if res_stream { + quote!( tonic::Streaming ) + } else { + quote! { super::#raw_res_type } + }; + + let (req_type, req_convert) = if req_stream { + ( + quote! { impl tonic::IntoStreamingRequest }, + quote! { into_streaming_request }, + ) + } else { + ( + quote! { impl tonic::IntoRequest }, + quote! { into_request }, + ) + }; + + match (req_stream, res_stream) { + (false, false) => {} + (true, false) => { + rpc_kind_method = quote! { client_streaming }; + } + (false, true) => { + rpc_kind_method = quote! { server_streaming }; + } + (true, true) => { + rpc_kind_method = quote! { streaming }; + } + }; + methods.push(quote! { + pub async fn #method_name( + &mut self, + request: #req_type, + ) -> Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + Status::new(Code::Unknown, format!("Service was not ready: {}", e.into())) + })?; + let codec = protokit + ::grpc::TonicCodec::default(); + let path = http::uri::PathAndQuery::from_static(#path); + self.inner.#rpc_kind_method(request.#req_convert(), path, codec).await + } + }) + } + quote! { + mod #mod_name { + use super::*; + use protokit::grpc::*; + #[derive(Debug, Clone)] + pub struct #client_name { + inner: tonic::client::Grpc, + } + impl #client_name { + pub async fn connect(dst: D) -> Result where - S: tonic::client::GrpcService, - S::Error: Into, - S::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, + D: core::convert::TryInto, + D::Error: Into, { - pub fn new(inner: S) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_interceptor( - inner: S, - interceptor: F, - ) -> #client_name> - where - F: tonic::service::Interceptor, - S::ResponseBody: Default, - S: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl #client_name + where + S: tonic::client::GrpcService, + S::Error: Into, + S::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: S) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_interceptor( + inner: S, + interceptor: F, + ) -> #client_name> + where + F: tonic::service::Interceptor, + S::ResponseBody: Default, + S: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, >, - , - >>::Error: Into + Send + Sync, - { - #client_name::new(InterceptedService::new(inner, interceptor)) - } - #(#methods)* + >, + , + >>::Error: Into + Send + Sync, + { + #client_name::new(InterceptedService::new(inner, interceptor)) } - + #(#methods)* } - pub use #mod_name::*; + } + pub use #mod_name::*; } } diff --git a/protokit_build/src/filegen/mod.rs b/protokit_build/src/filegen/mod.rs index e6b793c..d3980ca 100644 --- a/protokit_build/src/filegen/mod.rs +++ b/protokit_build/src/filegen/mod.rs @@ -1,14 +1,18 @@ -use core::ops::Deref; -use core::str::FromStr; -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; +use std::fs::{create_dir_all, File}; +use std::io::Write; +use std::path::Path; use anyhow::{Context, Result}; use convert_case::{Case, Casing}; -use desc::Syntax::{Proto2, Proto3}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; -use crate::deps::*; +use crate::{ + DescriptorPool, DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, + FieldDescriptorProtoLabel, FieldDescriptorProtoType, FileDescriptorProto, + OneofDescriptorProto, ServiceDescriptorProto, +}; pub mod grpc; @@ -19,9 +23,10 @@ pub struct Generics { } impl Generics { - fn liftetime_arg(&self) -> Option { + fn lifetime_arg(&self) -> Option { self.buf_arg.clone() } + fn struct_def_generics(&self) -> TokenStream { match (&self.buf_arg, &self.alloc_arg) { (Some(b), Some(a)) => quote! { < #b, #a : std::alloc::Allocator + Debug> }, @@ -43,7 +48,6 @@ impl Generics { #[derive(Debug)] pub struct Options { - pub replacement: BTreeMap, pub import_root: Option, pub string_type: TokenStream, @@ -54,22 +58,12 @@ pub struct Options { pub generics: Generics, pub protoattrs: Vec, - pub force_box: HashSet, - pub track_unknowns: bool, } -impl Options { - pub fn replace_import(&mut self, from: &str, to: &str) -> &mut Self { - self.replacement.insert(from.to_string(), to.to_string()); - self - } -} - impl Default for Options { fn default() -> Self { Self { - replacement: Default::default(), import_root: None, string_type: quote! { String }, bytes_type: quote! { Vec }, @@ -77,33 +71,27 @@ impl Default for Options { unknown_type: quote! { ::protokit::binformat::UnknownFieldsOwned }, generics: Generics::default(), protoattrs: vec![], - force_box: Default::default(), track_unknowns: false, } } } const STRICT: &[&str] = &[ - "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", "false", "fn", - "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return", "self", "Self", - "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where", "while", + "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", + "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", + "unsafe", "use", "where", "while", ]; const RESERVED: &[&str] = &[ - "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", - "yield", + "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try", "typeof", + "unsized", "virtual", "yield", ]; pub const TYPES: &[&str] = &["Option", "Result"]; pub fn rustify_name(n: impl AsRef) -> String { let n = n.as_ref().replace('.', "_"); - // let pos = n.find('.'); - // let n = if let Some(pos) = pos { - // &n[pos + 1..] - // } else { - // n - // }; for s in STRICT.iter().chain(RESERVED) { if *s == n { return format!("r#{n}"); @@ -114,210 +102,396 @@ pub fn rustify_name(n: impl AsRef) -> String { return format!("Proto{n}"); } } - n.replace('.', "") } -pub fn builtin_type_marker(typ: BuiltinType) -> &'static str { +/// Extract the simple name from a fully qualified type name +/// e.g. ".google.protobuf.Any" -> "Any" +fn simple_name(fqn: &str) -> &str { + fqn.rsplit('.').next().unwrap_or(fqn) +} + +/// Get the Rust type name from a fully qualified protobuf type name +fn rust_type_name(pool: &DescriptorPool, type_name: &str) -> String { + if let Some(loc) = pool.lookup(type_name) { + // Use the path to construct nested type names + rustify_name(loc.path.join("_")) + } else { + // Fallback to simple name + rustify_name(simple_name(type_name)) + } +} + +fn builtin_type_marker(typ: FieldDescriptorProtoType) -> &'static str { match typ { - BuiltinType::Int32 => "varint", - BuiltinType::Int64 => "varint", - BuiltinType::Uint32 => "varint", - BuiltinType::Uint64 => "varint", - BuiltinType::Sint32 => "sigint", - BuiltinType::Sint64 => "sigint", - BuiltinType::Bool => "bool", - BuiltinType::Fixed64 => "fixed64", - BuiltinType::Sfixed64 => "fixed64", - BuiltinType::Fixed32 => "fixed32", - BuiltinType::Sfixed32 => "fixed32", - BuiltinType::Double => "fixed64", - BuiltinType::Float => "fixed32", - BuiltinType::String_ => "string", - BuiltinType::Bytes_ => "bytes", + FieldDescriptorProtoType::TYPE_INT32 => "varint", + FieldDescriptorProtoType::TYPE_INT64 => "varint", + FieldDescriptorProtoType::TYPE_UINT32 => "varint", + FieldDescriptorProtoType::TYPE_UINT64 => "varint", + FieldDescriptorProtoType::TYPE_SINT32 => "sigint", + FieldDescriptorProtoType::TYPE_SINT64 => "sigint", + FieldDescriptorProtoType::TYPE_BOOL => "bool", + FieldDescriptorProtoType::TYPE_FIXED64 => "fixed64", + FieldDescriptorProtoType::TYPE_SFIXED64 => "fixed64", + FieldDescriptorProtoType::TYPE_FIXED32 => "fixed32", + FieldDescriptorProtoType::TYPE_SFIXED32 => "fixed32", + FieldDescriptorProtoType::TYPE_DOUBLE => "fixed64", + FieldDescriptorProtoType::TYPE_FLOAT => "fixed32", + FieldDescriptorProtoType::TYPE_STRING => "string", + FieldDescriptorProtoType::TYPE_BYTES => "bytes", + FieldDescriptorProtoType::TYPE_MESSAGE => "nested", + FieldDescriptorProtoType::TYPE_GROUP => "group", + FieldDescriptorProtoType::TYPE_ENUM => "protoenum", + _ => panic!("Unknown field type"), } } +fn builtin_rust_type(opts: &Options, typ: FieldDescriptorProtoType) -> TokenStream { + match typ { + FieldDescriptorProtoType::TYPE_INT32 => quote! { i32 }, + FieldDescriptorProtoType::TYPE_INT64 => quote! { i64 }, + FieldDescriptorProtoType::TYPE_UINT32 => quote! { u32 }, + FieldDescriptorProtoType::TYPE_UINT64 => quote! { u64 }, + FieldDescriptorProtoType::TYPE_SINT32 => quote! { i32 }, + FieldDescriptorProtoType::TYPE_SINT64 => quote! { i64 }, + FieldDescriptorProtoType::TYPE_BOOL => quote! { bool }, + FieldDescriptorProtoType::TYPE_FIXED64 => quote! { u64 }, + FieldDescriptorProtoType::TYPE_SFIXED64 => quote! { i64 }, + FieldDescriptorProtoType::TYPE_FIXED32 => quote! { u32 }, + FieldDescriptorProtoType::TYPE_SFIXED32 => quote! { i32 }, + FieldDescriptorProtoType::TYPE_DOUBLE => quote! { f64 }, + FieldDescriptorProtoType::TYPE_FLOAT => quote! { f32 }, + FieldDescriptorProtoType::TYPE_STRING => opts.string_type.clone(), + FieldDescriptorProtoType::TYPE_BYTES => opts.bytes_type.clone(), + _ => panic!("Not a builtin type"), + } +} + +fn is_scalar_type(typ: FieldDescriptorProtoType) -> bool { + !matches!( + typ, + FieldDescriptorProtoType::TYPE_STRING + | FieldDescriptorProtoType::TYPE_BYTES + | FieldDescriptorProtoType::TYPE_MESSAGE + | FieldDescriptorProtoType::TYPE_GROUP + ) +} + pub struct CodeGenerator<'a> { - context: &'a FileSetDef, + pool: &'a DescriptorPool, + file: &'a FileDescriptorProto, options: &'a Options, types: Vec, output: Vec, } -pub fn resolve_name(set: &FileSetDef, id: GlobalDefId) -> Result { - if let Some((msg, _)) = set.message_by_id(id) { - return Ok(rustify_name(msg.name.as_str())); - } else if let Some((en, _)) = set.enum_by_id(id) { - return Ok(rustify_name(en.name.as_str())); - } else { - bail!( - "Could not resolve {} {} {:b} files:{}: {:#?}", - id, - id >> 32, - id & 0xFFFFFFFF, - set.files.len(), - set.files.keys() - ); +impl<'a> CodeGenerator<'a> { + pub fn new(pool: &'a DescriptorPool, file_idx: usize, options: &'a Options) -> Self { + Self { + pool, + file: &pool.descriptor_set.file[file_idx], + options, + types: vec![], + output: vec![], + } } -} -impl CodeGenerator<'_> { - pub fn resolve_name(&self, id: GlobalDefId) -> Result { - resolve_name(self.context, id) + fn is_proto3(&self) -> bool { + self.file.syntax.as_deref() == Some("proto3") } - pub fn builtin_rusttype(&self, typ: BuiltinType) -> TokenStream { - TokenStream::from_str(match typ { - BuiltinType::Int32 => "i32", - BuiltinType::Int64 => "i64", - BuiltinType::Uint32 => "u32", - BuiltinType::Uint64 => "u64", - BuiltinType::Sint32 => "i32", - BuiltinType::Sint64 => "i64", - BuiltinType::Bool => "bool", - BuiltinType::Fixed64 => "u64", - BuiltinType::Sfixed64 => "i64", - BuiltinType::Fixed32 => "u32", - BuiltinType::Sfixed32 => "i32", - BuiltinType::Double => "f64", - BuiltinType::Float => "f32", - BuiltinType::String_ => return self.options.string_type.clone(), - BuiltinType::Bytes_ => return self.options.bytes_type.clone(), - }) - .unwrap() + + fn package(&self) -> &str { + self.file.package.as_deref().unwrap_or("") } - pub fn type_marker(typ: &DataType) -> TokenStream { - TokenStream::from_str(match typ { - DataType::Unresolved(_, _) => panic!(), - DataType::Builtin(b) => builtin_type_marker(*b), - DataType::Message(_) => "nested", - DataType::Group(_) => "group", - DataType::Enum(_) => "protoenum", - - DataType::Map(k) => { - eprintln!("{:?} to {:?}", typ, k); - return TokenStream::from_str(&format!( - "map({}, {})", - builtin_type_marker(k.0), - Self::type_marker(&k.1) - )) - .unwrap(); - } - }) - .unwrap() + fn base_type(&self, field: &FieldDescriptorProto) -> Result { + let typ = field.r#type.as_ref().unwrap(); + + if let Some(type_name) = &field.type_name { + // Message, Enum, or Group type + let rust_name = rust_type_name(self.pool, type_name); + let ident = format_ident!("{}", rust_name); + let gen = self.options.generics.struct_use_generics(); + Ok(quote! { #ident #gen }) + } else { + // Builtin type + Ok(builtin_rust_type(self.options, *typ)) + } } - pub fn base_type(&self, typ: &DataType) -> Result { - Ok(match typ { - DataType::Unresolved(path, _) => { - panic!("Name {path} was not resolved to actual type") - } - DataType::Builtin(bt) => return Ok(self.builtin_rusttype(*bt)), - DataType::Message(id) | DataType::Group(id) => { - // let borrow = self.borrow(); - let gen = self.options.generics.struct_use_generics(); - let ident = format_ident!("{}", self.resolve_name(*id)?); + fn is_map_field(&self, field: &FieldDescriptorProto) -> Option<(&DescriptorProto, &FieldDescriptorProto, &FieldDescriptorProto)> { + if field.label != Some(FieldDescriptorProtoLabel::LABEL_REPEATED) { + return None; + } - quote! {#ident #gen} - } - DataType::Enum(id) => TokenStream::from_str(&self.resolve_name(*id)?).unwrap(), - DataType::Map(m) => { - let kt = self.base_type(&DataType::Builtin(m.deref().0))?; - let vt = self.base_type(&m.deref().1)?; - let mt = &self.options.map_type; - return Ok(quote! { #mt<#kt,#vt> }); - } - }) + let type_name = field.type_name.as_ref()?; + let msg = self.pool.get_message(type_name)?; + + // Check if it's a map entry + if msg.options.as_ref().and_then(|o| o.map_entry) != Some(true) { + return None; + } + + // Get key and value fields + let key_field = msg.field.iter().find(|f| f.number == Some(1))?; + let value_field = msg.field.iter().find(|f| f.number == Some(2))?; + + Some((msg, key_field, value_field)) } - pub fn field_type(&self, field: &FieldDef) -> Result { - let base = self.base_type(&field.typ)?; - let is_msg = matches!(field.typ, DataType::Message(..) | DataType::Group(_)); + fn field_type(&self, field: &FieldDescriptorProto) -> Result { + let typ = field.r#type.as_ref().unwrap(); + let is_msg = matches!( + typ, + &FieldDescriptorProtoType::TYPE_MESSAGE | &FieldDescriptorProtoType::TYPE_GROUP + ); + let is_repeated = field.label == Some(FieldDescriptorProtoLabel::LABEL_REPEATED); + + // Check if this is a map field + if let Some((_, key_field, value_field)) = self.is_map_field(field) { + let kt = self.base_type(key_field)?; + let vt = self.base_type(value_field)?; + let mt = &self.options.map_type; + return Ok(quote! { #mt<#kt, #vt> }); + } - let force_box = match field.typ { - DataType::Message(m) | DataType::Group(m) => self.options.force_box.contains(&m), - _ => false, - }; + let base = self.base_type(field)?; + + let force_box = field + .type_name + .as_ref() + .map(|t| self.pool.is_boxed(t)) + .unwrap_or(false); + + let is_optional = field.label == Some(FieldDescriptorProtoLabel::LABEL_OPTIONAL); + let is_required = field.label == Some(FieldDescriptorProtoLabel::LABEL_REQUIRED); + let is_proto3_optional = field.proto3_optional == Some(true); + + match (is_repeated, is_optional || is_proto3_optional, is_required, is_msg, force_box) { + // Repeated field -> Vec + (true, _, _, _, _) => Ok(quote!(Vec<#base>)), + + // Required message with boxing + (false, _, true, true, true) => Ok(quote!(Option>)), + // Required message without boxing + (false, _, true, true, false) => Ok(quote!(#base)), + // Required non-message + (false, _, true, false, _) => Ok(quote!(#base)), + + // Optional message with boxing + (false, true, _, true, true) => Ok(quote!(Option>)), + // Optional message without boxing + (false, true, _, true, false) => Ok(quote!(Option<#base>)), + // Optional non-message + (false, true, _, false, _) => Ok(quote!(Option<#base>)), + + // Singular (proto3 implicit presence) message with boxing + (false, false, false, true, true) => Ok(quote!(Option>)), + // Singular message without boxing + (false, false, false, true, false) => Ok(quote!(Option<#base>)), + // Singular non-message (proto3) + (false, false, false, false, _) => Ok(base), + } + } - match (field.frequency, is_msg, force_box) { - (Frequency::Singular | Frequency::Required, false, _) => Ok(base), + fn type_marker(&self, field: &FieldDescriptorProto) -> TokenStream { + let typ = field.r#type.as_ref().unwrap(); - (Frequency::Required, true, true) => Ok(quote!(Option>)), - (Frequency::Singular, true, true) => Ok(quote!(Option>)), + // Check for map fields + if let Some((_, key_field, value_field)) = self.is_map_field(field) { + let key_typ = key_field.r#type.as_ref().unwrap(); + let key_marker = builtin_type_marker(*key_typ); - (Frequency::Required, true, false) => Ok(quote!(#base)), - (Frequency::Singular, true, false) => Ok(quote!(Option<#base>)), + let val_typ = value_field.r#type.as_ref().unwrap(); + let val_marker = builtin_type_marker(*val_typ); - (Frequency::Optional, false, _) => Ok(quote!(Option<#base>)), + let marker_str = format!("map({}, {})", key_marker, val_marker); + return core::str::FromStr::from_str(&marker_str).unwrap(); + } - (Frequency::Optional, true, true) => Ok(quote!(Option>)), - (Frequency::Optional, true, false) => Ok(quote!(Option<#base>)), + let marker = builtin_type_marker(*typ); + core::str::FromStr::from_str(marker).unwrap() + } - (Frequency::Repeated, _, _) => Ok(quote!(Vec<#base>)), + fn field_frequency(&self, field: &FieldDescriptorProto) -> TokenStream { + let typ = field.r#type.as_ref().unwrap(); + let is_msg = matches!( + typ, + &FieldDescriptorProtoType::TYPE_MESSAGE | &FieldDescriptorProtoType::TYPE_GROUP + ); + let is_scalar = is_scalar_type(*typ); + let is_enum = *typ == FieldDescriptorProtoType::TYPE_ENUM; + let force_box = field + .type_name + .as_ref() + .map(|t| self.pool.is_boxed(t)) + .unwrap_or(false); + + // Check if packed + let is_packed = field + .options + .as_ref() + .and_then(|o| o.packed) + .unwrap_or(false); + + let is_repeated = field.label == Some(FieldDescriptorProtoLabel::LABEL_REPEATED); + let is_optional = field.label == Some(FieldDescriptorProtoLabel::LABEL_OPTIONAL); + let is_required = field.label == Some(FieldDescriptorProtoLabel::LABEL_REQUIRED); + let is_proto3_optional = field.proto3_optional == Some(true); + + // Handle map fields specially - they're singular even though they're repeated messages + if self.is_map_field(field).is_some() { + return quote! { singular }; } + + let freq = match (is_repeated, is_optional, is_required, is_msg, force_box) { + (true, _, _, _, _) if is_packed => "packed", + (true, _, _, _, _) if self.is_proto3() && (is_scalar || is_enum) => "packed", + (true, _, _, _, _) => "repeated", + (false, true, _, _, _) => "optional", + (false, _, _, true, _) if is_proto3_optional => "optional", + (false, false, false, true, _) => "optional", + (false, false, false, false, _) if is_proto3_optional => "optional", + (false, false, false, false, _) => "singular", + (false, _, true, _, true) => "optional", + (false, _, true, _, _) => "required", + }; + + core::str::FromStr::from_str(freq).unwrap() } - // pub fn protoattrs(&self, msg: &MessageDef) -> TokenStream { - // if !self.options.protoattrs.is_empty() { - // let attrs = &self.options.protoattrs; - // quote! { #[proto(#(#attrs,)*)] } - // } else { - // quote! {} - // } - // } - - pub fn file(&mut self, f: &FileDef) -> Result<()> { - for (_, en) in f.enums.iter() { - self.r#enum(f, en)?; + pub fn generate(&mut self) -> Result<()> { + // Generate enums first (they may be referenced by messages) + for enum_desc in &self.file.enum_type { + self.generate_enum(enum_desc, &[])?; } - for (i, (name, msg)) in f.messages.iter().enumerate() { - self.message(f, i as u32, name, msg)? + // Generate messages + for msg_desc in &self.file.message_type { + self.generate_message(msg_desc, &[])?; } - for (_name, svc) in f.services.iter() { - self.output.push(self.generate_server(f, svc)); - self.output.push(self.generate_client(f, svc)); + // Generate services + for svc_desc in &self.file.service { + self.generate_service(svc_desc)?; } Ok(()) } - pub fn message(&mut self, file: &FileDef, msg_idx: u32, msg_name: &ArcStr, msg: &MessageDef) -> Result<()> { - if msg.is_virtual_map { + fn generate_enum(&mut self, desc: &EnumDescriptorProto, parent_path: &[&str]) -> Result<()> { + let name = desc.name.as_deref().unwrap_or(""); + let full_name = if parent_path.is_empty() { + name.to_string() + } else { + format!("{}_{}", parent_path.join("_"), name) + }; + + let ident = format_ident!("{}", rustify_name(&full_name)); + + let open = if self.is_proto3() { + quote! { open } + } else { + quote! { closed } + }; + + let variants = desc.value.iter().map(|v| { + let var_name = v.name.as_deref().unwrap_or(""); + let var_ident = format_ident!("{}", var_name); + let num = v.number.unwrap_or(0); + quote! { + #[var(#num, #var_name)] + pub const #var_ident: #ident = #ident(#num); + } + }); + + self.output.push(quote! { + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + pub struct #ident(pub i32); + #[protoenum(#open)] + impl #ident { + #(#variants)* + } + }); + + Ok(()) + } + + fn generate_message(&mut self, desc: &DescriptorProto, parent_path: &[&str]) -> Result<()> { + let name = desc.name.as_deref().unwrap_or(""); + + // Skip map entry types - they're synthetic + if desc.options.as_ref().and_then(|o| o.map_entry) == Some(true) { return Ok(()); } - let mut extfields = vec![]; - for (def, ext) in &file.extenders { - if (def & LOCAL_ONLY_ID) == msg_idx { - for ext in ext { - let (ext, extfile) = self.context.ext_by_id(*ext).unwrap(); - extfields.extend(ext.fields.by_number.values().map(|f| (f, extfile.package.as_str()))) - } - } + let mut path = parent_path.to_vec(); + path.push(name); + + // Generate nested enums + for enum_desc in &desc.enum_type { + self.generate_enum(enum_desc, &path)?; } - let ident = format_ident!("{}", rustify_name(msg_name)); - // let borrow = self.borrow(); + // Generate nested messages + for nested in &desc.nested_type { + self.generate_message(nested, &path)?; + } + + let full_name = path.join("_"); + let ident = format_ident!("{}", rustify_name(&full_name)); let generics = self.options.generics.struct_def_generics(); - let name = msg.name.as_str(); - let pkg = file.package.as_str(); + let pkg = self.package(); + let attrs = quote! { #[proto(name = #name, package = #pkg)] }; - let extfields = extfields.into_iter().map(|(f, pkg)| self.field(file, f, Some(pkg))); - - let fields = msg.fields.by_number.values().map(|f| self.field(file, f, None)); + // Collect fields that are part of oneofs + let mut oneof_field_indices: HashSet = HashSet::new(); + for oneof in &desc.oneof_decl { + let oneof_name = oneof.name.as_deref().unwrap_or(""); + for field in &desc.field { + if let Some(idx) = field.oneof_index { + if desc.oneof_decl.get(idx as usize).and_then(|o| o.name.as_deref()) == Some(oneof_name) { + // Don't add proto3 optional synthetic oneofs to the set + if field.proto3_optional != Some(true) { + oneof_field_indices.insert(field.number.unwrap_or(0)); + } + } + } + } + } - let fields = fields.chain(extfields).collect::>>()?; + // Generate regular fields (excluding oneof members) + let fields = desc + .field + .iter() + .filter(|f| { + // Include if not part of a oneof, or if it's a proto3 optional synthetic oneof + if f.oneof_index.is_some() { + // proto3 optional creates a synthetic oneof - include it as a regular field + f.proto3_optional == Some(true) + || !oneof_field_indices.contains(&f.number.unwrap_or(0)) + } else { + true + } + }) + .map(|f| self.generate_field(f)) + .collect::>>()?; - let oneofs = msg - .oneofs + // Generate oneofs (excluding proto3 optional synthetic oneofs) + let oneofs = desc + .oneof_decl .iter() - .map(|(_, def)| self.oneof(msg_name, def)) + .enumerate() + .filter(|(idx, _oneof)| { + // Skip synthetic oneofs for proto3 optional + !desc.field.iter().any(|f| { + f.oneof_index == Some(*idx as i32) && f.proto3_optional == Some(true) + }) + }) + .map(|(idx, oneof)| self.generate_oneof(&full_name, idx, oneof, &desc.field)) .collect::>>()?; let last = if self.options.track_unknowns { @@ -330,7 +504,7 @@ impl CodeGenerator<'_> { None }; - self.types.push(quote! {#ident}); + self.types.push(quote! { #ident }); self.output.push(quote! { #[derive(Debug, Default, Clone, PartialEq, Proto)] #attrs @@ -344,38 +518,66 @@ impl CodeGenerator<'_> { Ok(()) } - pub fn oneof(&mut self, msg_name: &str, def: &OneOfDef) -> Result { - let field_ident = format_ident!("{}", rustify_name(&def.name)); + fn generate_field(&self, field: &FieldDescriptorProto) -> Result { + let name = field.name.as_deref().unwrap_or(""); + let fname = format_ident!("{}", rustify_name(name)); + let num = field.number.unwrap_or(0) as u32; + let typ = self.field_type(field)?; + let kind = self.type_marker(field); + let freq = self.field_frequency(field); + + Ok(quote! { + #[field(#num, #name, #kind, #freq)] + pub #fname: #typ + }) + } + + fn generate_oneof( + &mut self, + msg_name: &str, + oneof_idx: usize, + oneof: &OneofDescriptorProto, + all_fields: &[FieldDescriptorProto], + ) -> Result { + let oneof_name = oneof.name.as_deref().unwrap_or(""); + let field_ident = format_ident!("{}", rustify_name(oneof_name)); let oneof_type = format_ident!( "{}OneOf{}", rustify_name(msg_name), - def.name.as_str().to_case(Case::Pascal) + oneof_name.to_case(Case::Pascal) ); - // let borrow = self.borrow(); + let generics = self.options.generics.struct_def_generics(); let use_generics = self.options.generics.struct_use_generics(); let borrow_or_static = self .options .generics - .liftetime_arg() + .lifetime_arg() .unwrap_or_else(|| quote! { 'static }); let mut nums = vec![]; let mut names = vec![]; let mut vars = vec![]; - let mut default = None; - for (_, var) in def.fields.by_number.iter() { - let var_name = format_ident!("{}", var.name.as_str().to_case(Case::Pascal)); - let typ = self.base_type(&var.typ)?; - let num = var.num as u32; - let name = var.name.as_str(); - let kind = Self::type_marker(&var.typ); + for field in all_fields { + if field.oneof_index != Some(oneof_idx as i32) { + continue; + } + // Skip proto3 optional synthetic oneofs + if field.proto3_optional == Some(true) { + continue; + } + + let field_name = field.name.as_deref().unwrap_or(""); + let var_name = format_ident!("{}", field_name.to_case(Case::Pascal)); + let typ = self.base_type(field)?; + let num = field.number.unwrap_or(0) as u32; + let kind = self.type_marker(field); vars.push(quote! { - #[field(#num, #name, #kind, raw)] + #[field(#num, #field_name, #kind, raw)] #var_name(#typ), }); @@ -386,16 +588,15 @@ impl CodeGenerator<'_> { Self::#var_name(Default::default()) } } - }) + }); } nums.push(num); - names.push(name); + names.push(field_name); } self.output.push(quote! { #[derive(Debug, Clone, PartialEq, Proto)] - // #attrs pub enum #oneof_type #generics { #(#vars)* __Unused(::core::marker::PhantomData<& #borrow_or_static ()>), @@ -409,174 +610,110 @@ impl CodeGenerator<'_> { }) } - pub fn r#enum(&mut self, file: &FileDef, def: &EnumDef) -> Result<()> { - let ident = format_ident!("{}", rustify_name(def.name.as_str())); - let open = if file.syntax == Proto3 { - quote! { open } - } else { - quote! { closed } - }; - let variants = def.variants.by_name.iter().map(|(_, def)| { - let name = def.name.as_str(); - let var_ident = format_ident!("{}", def.name.as_str()); - let num = def.num; - quote! { - #[var(#num, #name)] - pub const #var_ident: #ident = #ident(#num); - } - }); - self.output.push(quote! { - #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] - pub struct #ident(pub i32); - #[protoenum(#open)] - impl #ident { - #(#variants)* - } - }); - + fn generate_service(&mut self, svc: &ServiceDescriptorProto) -> Result<()> { + self.output.push(self.generate_server(svc)); + self.output.push(self.generate_client(svc)); Ok(()) } - pub fn field(&self, file: &FileDef, def: &FieldDef, extpkg: Option<&str>) -> Result { - let typ = self.field_type(def)?; - let fname = format_ident!("{}", rustify_name(def.name.as_str())); - let name = match extpkg { - Some(pkg) => format!("[{}.{}]", pkg, def.name), - None => def.name.to_string(), - }; - - if let DataType::Enum(id) = def.typ { - if let Some((en, efile)) = self.context.enum_by_id(id) { - if file.syntax == Proto3 && efile.syntax == Proto2 { - panic!("Can't use proto2 enum ({}) in proto3 file", en.name); - } - } - } - - let num = def.num as u32; - - let force_box = match def.typ { - DataType::Message(m) | DataType::Group(m) => self.options.force_box.contains(&m), - _ => false, - }; - - let kind = Self::type_marker(&def.typ); - let freq = TokenStream::from_str(match def.frequency { - Frequency::Singular if def.typ.is_message() => "optional", - Frequency::Singular => "singular", - Frequency::Optional => "optional", - Frequency::Repeated if def.is_packed() => "packed", - #[cfg(feature = "descriptors")] - Frequency::Repeated - if file.syntax == Proto3 && def.typ.is_scalar() && def.options.packed != Some(false) => - { - "packed" - } - #[cfg(not(feature = "descriptors"))] - Frequency::Repeated if file.syntax == Proto3 && def.typ.is_scalar() => "packed", - Frequency::Repeated => "repeated", - Frequency::Required if force_box => "optional", - Frequency::Required => "required", - }) - .unwrap(); + fn generate_server(&self, svc: &ServiceDescriptorProto) -> TokenStream { + grpc::generate_server(self.pool, self.options, svc) + } - Ok(quote! { - #[field(#num, #name, #kind, #freq)] - pub #fname: #typ - }) + fn generate_client(&self, svc: &ServiceDescriptorProto) -> TokenStream { + grpc::generate_client(self.pool, self.options, svc) } } -pub fn generate_file(ctx: &FileSetDef, opts: &Options, name: PathBuf, file: &FileDef) -> Result<()> { - let mut generator = CodeGenerator { - context: ctx, - options: opts, - types: vec![], - output: vec![], - }; - - generator.file(file)?; - +pub fn generate_file( + pool: &DescriptorPool, + opts: &Options, + path: impl AsRef, + file_idx: usize, +) -> Result<()> { + let file = &pool.descriptor_set.file[file_idx]; + let mut generator = CodeGenerator::new(pool, file_idx, opts); + generator.generate()?; + + // Generate extension registrations let mut extensions = vec![]; - for ext_def in file.extensions.values() { - let extendee = &ext_def.in_message; - for field in ext_def.fields.by_number.values() { - let name = &field.name; - let number = field.num as u32; - let typ = match field.typ { - DataType::Builtin(t) => t as u32, - DataType::Enum(_) => 0, // ENUM - DataType::Message(_) => 11, // MESSAGE - DataType::Group(_) => 10, // GROUP - _ => continue, - }; - let repeated = field.frequency == Frequency::Repeated; - let extendee = extendee.as_str(); - let name = name.as_str(); + for ext in &file.extension { + let extendee = ext.extendee.as_deref().unwrap_or(""); + let name = ext.name.as_deref().unwrap_or(""); + let number = ext.number.unwrap_or(0) as u32; + let typ = ext.r#type.map(|t| t.0 as u32).unwrap_or(0); + let repeated = ext.label == Some(FieldDescriptorProtoLabel::LABEL_REPEATED); + + let type_name_str = ext.type_name.as_deref().map(|t| rust_type_name(pool, t)).unwrap_or_default(); + let type_name = type_name_str.as_str(); + + extensions.push(quote! { + registry.register_extension(#extendee, #number, #name, #typ, #repeated, #type_name); + }); + } - let type_name_str = match field.typ { - DataType::Message(id) | DataType::Group(id) | DataType::Enum(id) => resolve_name(ctx, id).unwrap(), - _ => "".to_string(), - }; - let type_name = type_name_str.as_str(); + // Generate imports + let imports = file.dependency.iter().map(|dep_name| { + // Find the file in the descriptor set + let dep_file = pool.descriptor_set.file.iter().find(|f| f.name.as_deref() == Some(dep_name)); - extensions.push(quote! { - registry.register_extension(#extendee, #number, #name, #typ, #repeated, #type_name); - }); - } - } + if let Some(dep_file) = dep_file { + let dep_package = dep_file.package.as_deref().unwrap_or(""); + let our_package = file.package.as_deref().unwrap_or(""); - let root = opts.import_root.clone(); - let imports = file.imports.iter().map(|imp| imp.file_idx).map(|file_idx| { - let (_, other): (_, &FileDef) = ctx.files.get_index(file_idx).unwrap(); - let _our_name = name.file_name().unwrap().to_str().unwrap(); + let their_name = if dep_name.contains('/') { + &dep_name[dep_name.rfind('/').unwrap() + 1..] + } else { + dep_name.as_str() + }; + let their_name = if their_name.contains('.') { + &their_name[..their_name.rfind('.').unwrap()] + } else { + their_name + }; - // if let Some(rep) = ctx.replacement.get(other.name.as_str()) { - // let rep = TokenStream::from_str(rep).unwrap(); - // return quote! { use #rep::*; }; - // } + // Calculate relative path + let mut our_module = our_package; + let mut that_module = dep_package; - let their_name = if other.name.contains('/') { - &other.name.as_str()[other.name.rfind('/').unwrap() + 1..] - } else { - other.name.as_str() - }; - let their_name = if their_name.contains('.') { - &their_name[..their_name.rfind('.').unwrap()] - } else { - their_name - }; - let mut our_module = file.package.as_str(); - let mut that_module = other.package.as_str(); + while !our_module.is_empty() && !that_module.is_empty() { + if our_module.chars().next() == that_module.chars().next() { + our_module = &our_module[1..]; + that_module = &that_module[1..]; + } else { + break; + } + } - while !our_module.is_empty() && !that_module.is_empty() && our_module[..1] == that_module[..1] { - our_module = &our_module[1..]; - that_module = &that_module[1..]; - } - let mut path = String::new(); - path.push_str("super::"); + let mut path = String::new(); + path.push_str("super::"); - if !our_module.is_empty() { - for _s in our_module.strip_prefix('.').unwrap_or(our_module).split('.') { - path.push_str("super::"); + if !our_module.is_empty() { + for _s in our_module.strip_prefix('.').unwrap_or(our_module).split('.') { + path.push_str("super::"); + } } - } - if !that_module.is_empty() { - for s in that_module.split('.') { - path.push_str(&rustify_name(s)); - path.push_str("::") + if !that_module.is_empty() { + for s in that_module.split('.') { + path.push_str(&rustify_name(s)); + path.push_str("::"); + } } - } - path.push_str(&rustify_name(their_name)); + path.push_str(&rustify_name(their_name)); - let import = TokenStream::from_str(&path).unwrap(); - quote! { - use #import::*; + let import: TokenStream = core::str::FromStr::from_str(&path).unwrap(); + quote! { + use #import::*; + } + } else { + quote! {} } }); + let root = opts.import_root.clone(); let output = generator.output; let types = generator.types; + let maproot = if let Some(ref root) = root { quote! { use #root as protokit; } } else { @@ -600,36 +737,32 @@ pub fn generate_file(ctx: &FileSetDef, opts: &Options, name: PathBuf, file: &Fil #(#imports)* #(#output)* }; - let output = syn::parse2(output.clone()).with_context(|| output.to_string()).unwrap(); + + let output = syn::parse2(output.clone()) + .with_context(|| output.to_string()) + .unwrap(); let output = prettyplease::unparse(&output); - println!("Creating file: {name:?}"); - create_dir_all(name.parent().unwrap()).unwrap(); + let path = path.as_ref(); + create_dir_all(path.parent().unwrap())?; let mut f = File::options() .write(true) .create(true) .truncate(true) - .open(name) - .unwrap(); + .open(path)?; f.write_all(output.as_bytes())?; f.flush()?; + Ok(()) } -// #[cfg(feature = "descriptors")] -// pub fn generate_descriptor(ctx: &TranslateCtx, name: impl AsRef) { -// let mut output = vec![]; -// ctx.def.to_descriptor().encode(&mut output).unwrap(); -// -// let mut f = make_file(name); -// -// f.write_all(&output).unwrap(); -// f.flush().unwrap(); -// } - -pub fn generate_mod<'s>(path: impl AsRef, opts: &Options, files: impl Iterator) -> Result<()> { +pub fn generate_mod<'s>( + path: impl AsRef, + opts: &Options, + files: impl Iterator, +) -> Result<()> { let root = opts.import_root.clone(); let files: Vec<_> = files .map(|v| { @@ -655,9 +788,14 @@ pub fn generate_mod<'s>(path: impl AsRef, opts: &Options, files: impl Iter } }; - create_dir_all(path.as_ref().parent().unwrap())?; + create_dir_all(path.as_ref())?; - let mut f = make_file(path.as_ref().join("mod.rs"))?; + let mod_path = path.as_ref().join("mod.rs"); + let mut f = File::options() + .write(true) + .create(true) + .truncate(true) + .open(&mod_path)?; let output = syn::parse2(output)?; let output = prettyplease::unparse(&output); @@ -666,11 +804,3 @@ pub fn generate_mod<'s>(path: impl AsRef, opts: &Options, files: impl Iter Ok(()) } - -pub fn make_file(path: impl AsRef) -> Result { - let path = path.as_ref(); - - let f = File::options().write(true).create(true).truncate(true).open(path)?; - - Ok(f) -} diff --git a/protokit_build/src/filegen/tabular.rs b/protokit_build/src/filegen/tabular.rs deleted file mode 100644 index 2562aaa..0000000 --- a/protokit_build/src/filegen/tabular.rs +++ /dev/null @@ -1,102 +0,0 @@ -use core::str::FromStr; -use anyhow::bail; -use quote::__private::TokenStream; -use quote::quote; -use protokit_desc::{DataType, FieldDef, Frequency}; -use protokit_proto::translate::TranslateCtx; -use crate::filegen::Options; - -pub fn tabular_parser(ctx: &TranslateCtx, opts: &Options, f: &FieldDef, i: usize) -> Result { - use crate::BuiltinType::*; - let str = TokenStream::from_str(&type_to_str(ctx, opts, &f.typ)?).unwrap(); - Ok(match &f.typ { - DataType::Builtin(String_ | Bytes_) if f.is_repeated() => quote!( Bytes::, #i> ), - DataType::Builtin(String_ | Bytes_) => quote!( Bytes::<#str, #i> ), - - DataType::Builtin(bt) if bt.is_varint() && bt.is_zigzag() && f.is_repeated_packed() => { - quote! { PackedSInt::<#str, #i> } - } - DataType::Builtin(bt) if bt.is_varint() && bt.is_zigzag() => { - quote! { SInt::<#str, #i> } - } - DataType::Builtin(bt) if bt.is_varint() && f.is_repeated_packed() => { - quote! { PackedVInt::<#str, #i> } - } - DataType::Builtin(bt) if bt.is_varint() => { - quote! { VInt::<#str, #i> } - } - DataType::Builtin(_bt) if f.is_repeated_packed() => { - quote! { PackedFixed::<#str, #i> } - } - DataType::Builtin(_bt) => { - quote! { Fixed::<#str, #i> } - } - DataType::Message(_m) if f.is_repeated() => { - quote! { Repeated::<#str, #i> } - } - DataType::Message(_m) => { - quote! { Nested::<#str, #i> } - } - DataType::Enum(_m) if f.is_repeated() => { - quote! { Enum::<#str, #i> } - } - DataType::Enum(_m) => { - quote! { Enum::<#str, #i> } - } - DataType::Map(m) => { - let kf = FieldDef { - name: Default::default(), - frequency: Frequency::Normal, - typ: DataType::Builtin(m.0.clone()), - num: 0, - #[cfg(feature = "descriptors")] - options: Default::default(), - }; - let kp = tabular_parser(ctx, opts, &kf, 1)?; - let _vf = FieldDef { - name: Default::default(), - frequency: Frequency::Normal, - typ: m.1.clone(), - num: 0, - #[cfg(feature = "descriptors")] - options: Default::default(), - }; - let vp = tabular_parser(ctx, opts, &kf, 1)?; - quote! { Mapped:: } - } - other => bail!("Unknown: {other:?}"), - }) -} - -// let tabular_fields = tabular_fields.values(); -// let tabular_format = if cfg!(feature = "tabular") { -// quote! { -// impl binformat::tabular::TableDecodable for #msg_name { -// fn info(&self) -> binformat::tabular::MessageInfo { -// const _INFO: &[binformat::tabular::FieldInfo] = &[ -// #(#tabular_fields,)* -// ]; -// binformat::tabular::MessageInfo::sorted(_INFO, #qualified_name) -// } -// } -// } -// } else { -// quote! {} -// }; - -pub fn gnerate_tabular_field() { - let taglen = u64::required_space(normal_tag as u64); - // TODO: Use tag len here, that's the intended usecase - let tabular = tabular_parser(ctx, opts, &field, taglen).unwrap(); - let encoded_tag = protokit_binformat::tabular::tag(normal_tag); - let field_num = field_idx as u32; - out.tabular_fields.insert(encoded_tag,quote! { - binformat::tabular::FieldInfo { - tag: #encoded_tag, - offset: binformat::tabular::offset_of!(#msg_name, #name) as isize, - parser: ::decode, - number: #field_num, - } - }); - -} \ No newline at end of file diff --git a/protokit_build/src/lib.rs b/protokit_build/src/lib.rs index c816eb6..4b64d55 100644 --- a/protokit_build/src/lib.rs +++ b/protokit_build/src/lib.rs @@ -1,69 +1,375 @@ -use core::str::FromStr; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use std::usize; +use std::path::{Path, PathBuf}; pub use anyhow::Result; +use anyhow::bail; use petgraph::graph::{DefaultIx, NodeIndex}; use petgraph::{Direction, Graph}; -use quote::__private::TokenStream; use quote::quote; -use crate::deps::*; - -mod deps; mod filegen; -#[cfg(all(not(feature = "protoc"), not(feature = "parser")))] -compile_error!("Either enable 'protoc' (to use system protoc) or 'parser' (to use builtin parser) feature"); - -#[cfg(all(feature = "protoc", feature = "parser"))] -compile_error!("Either disable 'protoc' or 'parser' feature"); +pub use desc::generated::google::protobuf::descriptor::*; + +/// A descriptor pool that provides lookup capabilities over a FileDescriptorSet. +/// This is a thin overlay that indexes the descriptors by their fully qualified names. +#[derive(Debug, Default)] +pub struct DescriptorPool { + /// The underlying FileDescriptorSet from protoc + pub descriptor_set: FileDescriptorSet, + /// Index from fully qualified type name to (file_index, type_path) + /// e.g. ".google.protobuf.Any" -> (file_idx, ["Any"]) + pub types: HashMap, + /// Set of types that need to be boxed due to circular references + pub boxed_types: HashSet, +} -const REMAPS: &[(&str, &str)] = &[ - // ("google/protobuf/any.proto", "root::types::any"), - // ("google/protobuf/empty.proto", "root::types::empty"), - // ("google/protobuf/timestamp.proto", "root::types::timestamp"), - // ("google/protobuf/field_mask.proto", "root::types::field_mask"), - // ("google/protobuf/descriptor.proto", "root::types::descriptor"), -]; +/// Location of a type within the descriptor set +#[derive(Debug, Clone)] +pub struct TypeLocation { + pub file_idx: usize, + pub path: Vec, + pub kind: TypeKind, +} -#[cfg(feature = "parser")] -#[derive(Default, Debug)] -pub struct ParserContext { - ctx: proto::translate::TranslateCtx, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypeKind { + Message, + Enum, + Service, } -#[cfg(feature = "parser")] -impl ParserContext { - pub fn include(&mut self, p: impl Into) { - self.ctx.include(p) +impl DescriptorPool { + pub fn new(descriptor_set: FileDescriptorSet) -> Self { + let mut pool = Self { + descriptor_set, + types: HashMap::new(), + boxed_types: HashSet::new(), + }; + pool.build_index(); + pool.detect_cycles(); + pool } - pub fn compile(&mut self, p: impl Into) -> Result<()> { - self.ctx.compile_file(p.into())?; - Ok(()) + + fn build_index(&mut self) { + // Collect indexing work first to avoid borrow issues + struct IndexWork { + file_idx: usize, + path: Vec, + fqn: String, + kind: TypeKind, + } + + fn collect_message_indices( + work: &mut Vec, + file_idx: usize, + package: &str, + parent_path: &[String], + msg: &DescriptorProto, + ) { + let name = msg.name.as_deref().unwrap_or(""); + let mut path = parent_path.to_vec(); + path.push(name.to_string()); + + let fqn = if parent_path.is_empty() { + format!(".{}.{}", package, name) + } else { + format!(".{}.{}.{}", package, parent_path.join("."), name) + }; + + work.push(IndexWork { + file_idx, + path: path.clone(), + fqn: fqn.clone(), + kind: TypeKind::Message, + }); + + // Index nested enums + for enum_desc in &msg.enum_type { + let enum_name = enum_desc.name.as_deref().unwrap_or(""); + let enum_fqn = format!("{}.{}", fqn, enum_name); + let mut enum_path = path.clone(); + enum_path.push(enum_name.to_string()); + work.push(IndexWork { + file_idx, + path: enum_path, + fqn: enum_fqn, + kind: TypeKind::Enum, + }); + } + + // Index nested messages + for nested_msg in &msg.nested_type { + collect_message_indices(work, file_idx, package, &path, nested_msg); + } + } + + let mut work = Vec::new(); + + for (file_idx, file) in self.descriptor_set.file.iter().enumerate() { + let package = file.package.as_deref().unwrap_or(""); + + // Index top-level enums + for enum_desc in &file.enum_type { + let name = enum_desc.name.as_deref().unwrap_or(""); + let fqn = format!(".{}.{}", package, name); + work.push(IndexWork { + file_idx, + path: vec![name.to_string()], + fqn, + kind: TypeKind::Enum, + }); + } + + // Index top-level messages (and nested types) + for msg_desc in &file.message_type { + collect_message_indices(&mut work, file_idx, package, &[], msg_desc); + } + + // Index services + for svc_desc in &file.service { + let name = svc_desc.name.as_deref().unwrap_or(""); + let fqn = format!(".{}.{}", package, name); + work.push(IndexWork { + file_idx, + path: vec![name.to_string()], + fqn, + kind: TypeKind::Service, + }); + } + } + + // Now apply all the indices + for item in work { + self.types.insert( + item.fqn, + TypeLocation { + file_idx: item.file_idx, + path: item.path, + kind: item.kind, + }, + ); + } + } + + fn detect_cycles(&mut self) { + // Build a dependency graph + let mut graph: Graph = Graph::new(); + let mut node_indices: HashMap = HashMap::new(); + let mut field_counts: HashMap = HashMap::new(); + + // Add nodes for all messages + for (fqn, loc) in &self.types { + if loc.kind == TypeKind::Message { + let idx = graph.add_node(fqn.clone()); + node_indices.insert(fqn.clone(), idx); + } + } + + // Add edges based on field references + for file in &self.descriptor_set.file { + let package = file.package.as_deref().unwrap_or(""); + for msg in &file.message_type { + self.add_message_edges(&mut graph, &node_indices, &mut field_counts, package, &[], msg); + } + } + + // Find SCCs using Tarjan's algorithm + let mut cycles: Vec>> = petgraph::algo::tarjan_scc(&graph) + .into_iter() + .filter(|scc| scc.len() > 1 || { + // Check for self-loops + scc.len() == 1 && graph.find_edge(scc[0], scc[0]).is_some() + }) + .map(|v| HashSet::from_iter(v.into_iter())) + .collect(); + + // Find nodes to box to break cycles (prefer nodes with more incoming edges) + let mut to_box = HashSet::new(); + loop { + let mut counts: HashMap = HashMap::new(); + for cycle in &cycles { + for node in cycle.iter() { + *counts.entry(*node).or_default() += 1; + } + } + + if let Some(max) = counts + .iter() + .filter_map(|a| field_counts.get(a.0).map(|fcount| (a.0, a.1, fcount))) + .max_by(|a, b| { + let ac = graph.edges_directed(*a.0, Direction::Incoming).count(); + let bc = graph.edges_directed(*b.0, Direction::Incoming).count(); + (ac * 100 + a.1 * 10 + a.2).cmp(&(bc * 100 + b.1 * 10 + b.2)) + }) + { + to_box.insert(*max.0); + cycles.retain_mut(|cycle| !cycle.contains(&max.0)); + } else { + break; + } + } + + // Convert node indices back to type names + for idx in to_box { + if let Some(name) = graph.node_weight(idx) { + self.boxed_types.insert(name.clone()); + } + } + } + + fn add_message_edges( + &self, + graph: &mut Graph, + node_indices: &HashMap, + field_counts: &mut HashMap, + package: &str, + parent_path: &[String], + msg: &DescriptorProto, + ) { + let name = msg.name.as_deref().unwrap_or(""); + let src_fqn = if parent_path.is_empty() { + format!(".{}.{}", package, name) + } else { + format!(".{}.{}.{}", package, parent_path.join("."), name) + }; + + let Some(&src_idx) = node_indices.get(&src_fqn) else { + return; + }; + + // Add edges for fields + for field in &msg.field { + if let Some(type_name) = &field.type_name { + if let Some(&dst_idx) = node_indices.get(type_name) { + graph.add_edge(src_idx, dst_idx, ()); + *field_counts.entry(src_idx).or_default() += 1; + } + } + } + + // Process nested messages + let mut path = parent_path.to_vec(); + path.push(name.to_string()); + for nested_msg in &msg.nested_type { + self.add_message_edges(graph, node_indices, field_counts, package, &path, nested_msg); + } + } + + pub fn file(&self, idx: usize) -> &FileDescriptorProto { + &self.descriptor_set.file[idx] + } + + pub fn is_boxed(&self, type_name: &str) -> bool { + self.boxed_types.contains(type_name) + } + + pub fn lookup(&self, type_name: &str) -> Option<&TypeLocation> { + self.types.get(type_name) + } + + /// Get a message descriptor by its fully qualified name + pub fn get_message(&self, fqn: &str) -> Option<&DescriptorProto> { + let loc = self.types.get(fqn)?; + if loc.kind != TypeKind::Message { + return None; + } + let file = &self.descriptor_set.file[loc.file_idx]; + self.find_message_in_file(file, &loc.path) } - pub fn finish(self) -> Result { - Ok(self.ctx.def) + + fn find_message_in_file<'a>( + &self, + file: &'a FileDescriptorProto, + path: &[String], + ) -> Option<&'a DescriptorProto> { + if path.is_empty() { + return None; + } + + let mut current: Option<&DescriptorProto> = None; + let first = &path[0]; + + for msg in &file.message_type { + if msg.name.as_deref() == Some(first) { + current = Some(msg); + break; + } + } + + for name in &path[1..] { + let msg = current?; + current = None; + for nested in &msg.nested_type { + if nested.name.as_deref() == Some(name) { + current = Some(nested); + break; + } + } + } + + current + } + + /// Get an enum descriptor by its fully qualified name + pub fn get_enum(&self, fqn: &str) -> Option<&EnumDescriptorProto> { + let loc = self.types.get(fqn)?; + if loc.kind != TypeKind::Enum { + return None; + } + let file = &self.descriptor_set.file[loc.file_idx]; + self.find_enum_in_file(file, &loc.path) + } + + fn find_enum_in_file<'a>( + &self, + file: &'a FileDescriptorProto, + path: &[String], + ) -> Option<&'a EnumDescriptorProto> { + if path.is_empty() { + return None; + } + + if path.len() == 1 { + // Top-level enum + for e in &file.enum_type { + if e.name.as_deref() == Some(&path[0]) { + return Some(e); + } + } + return None; + } + + // Nested enum - find parent message first + let msg = self.find_message_in_file(file, &path[..path.len() - 1])?; + let enum_name = &path[path.len() - 1]; + for e in &msg.enum_type { + if e.name.as_deref() == Some(enum_name) { + return Some(e); + } + } + None } } -#[cfg(feature = "protoc")] +/// Build context for generating Rust code from protobuf files #[derive(Default, Debug)] pub struct ProtocContext { pub includes: Vec, pub proto_paths: Vec, } -#[cfg(feature = "protoc")] impl ProtocContext { pub fn include(&mut self, p: impl Into) { self.includes.push(p.into()); } + pub fn compile(&mut self, p: impl Into) -> Result<()> { self.proto_paths.push(p.into()); Ok(()) } - pub fn finish(self) -> Result { + + pub fn finish(self) -> Result { let mut cmd = std::process::Command::new("protoc"); cmd.arg("--experimental_allow_proto3_optional"); @@ -76,115 +382,45 @@ impl ProtocContext { cmd.arg(format!("{}", p.display())); } - cmd.arg(format!("-o{}/descriptor.bin", std::env::var("OUT_DIR").unwrap())); + let out_dir = std::env::var("OUT_DIR").unwrap(); + cmd.arg(format!("-o{}/descriptor.bin", out_dir)); + let out = cmd.output().expect("PROTOC invocation failed"); if !out.status.success() { bail!("Protoc error: {}", String::from_utf8_lossy(&out.stderr)) } - let data = std::fs::read(Path::new(&std::env::var("OUT_DIR").unwrap()).join("descriptor.bin")).unwrap(); - let desc = binformat::decode::(data.as_slice())?; + let data = std::fs::read(Path::new(&out_dir).join("descriptor.bin"))?; + let desc = binformat::decode::(data.as_slice())?; - Ok(FileSetDef::from_descriptor(desc)) + Ok(DescriptorPool::new(desc)) } } #[must_use] #[derive(Default, Debug)] pub struct Build { - #[cfg(feature = "parser")] - pub ctx: ParserContext, - #[cfg(feature = "protoc")] pub ctx: ProtocContext, pub options: filegen::Options, pub out_dir: Option, } -fn generate(opts: &mut filegen::Options, set: &desc::FileSetDef, out_dir: PathBuf) -> Result<()> { - create_dir_all(&out_dir).unwrap(); - - let mut graph = Graph::new(); - let mut fields = HashMap::::new(); - for (fidx, file) in set.files.values().enumerate() { - for (midx, msg) in file.messages.values().enumerate() { - eprintln!("msg: {:?}", msg.name); - let src = local_to_global(fidx, LOCAL_DEFID_MSG | (midx as u32)); - // let src = resolve_name(set, src).unwrap(); - let src = graph.add_node(src); - for field in msg.fields.by_number.values() { - match &field.typ { - DataType::Unresolved(_, _) => { - panic!("Should be resolved"); - } - DataType::Message(msg) => { - // let msg = resolve_name(set, *msg).unwrap(); - let dst = graph.add_node(*msg); - graph.add_edge(src, dst, ()); - *fields.entry(src).or_default() += 1; - } - DataType::Group(msg) => { - // let msg = resolve_name(set, *msg).unwrap(); - let dst = graph.add_node(*msg); - graph.add_edge(src, dst, ()); - *fields.entry(src).or_default() += 1; - } - _ => {} - } - } - } - } - - let mut cycles: Vec>> = petgraph::algo::tarjan_scc(&graph) - .into_iter() - .map(|v| HashSet::from_iter(v.into_iter())) - .collect(); +fn generate(opts: &filegen::Options, pool: &DescriptorPool, out_dir: PathBuf) -> Result<()> { + std::fs::create_dir_all(&out_dir)?; - let mut to_remove = HashSet::new(); - - loop { - let mut counts: HashMap = HashMap::new(); - for cycle in &cycles { - for node in cycle.iter() { - *counts.entry(*node).or_default() += 1; - } - } - if let Some(max) = counts - .iter() - .filter_map(|a| fields.get(a.0).map(|fcount| (a.0, a.1, fcount))) - .max_by(|a, b| { - let ac = graph.edges_directed(*a.0, Direction::Incoming).count(); - let bc = graph.edges_directed(*b.0, Direction::Incoming).count(); - (ac * 100 + a.1 * 10 + a.2).cmp(&(bc * 100 + b.1 * 10 + b.2)) - }) - { - to_remove.insert(*max.0); - cycles.retain_mut(|cycle| !cycle.contains(&max.0)) - } else { - break; - } - } - - let nodes = to_remove - .into_iter() - .map(|item| graph.node_weight(item).cloned().unwrap()) - .collect::>(); - - // panic!("TO REMOVE: {:?}", nodes); - - opts.force_box = nodes; - - // TODO: Use package name + file name let mut generated_names = vec![]; - for (_, file) in set.files.values().enumerate() { - // if self.ctx.replacement.contains_key(file.name.as_str()) { - // continue; - // } - let path = Path::new(file.name.as_str()); - let file_name = - file.package.replace('.', "/") + "/" + path.with_extension("rs").file_name().unwrap().to_str().unwrap(); - let out_name = out_dir.join(&file_name); - generated_names.push(file_name.clone()); - filegen::generate_file(set, opts, out_name, file).unwrap(); + for (file_idx, file) in pool.descriptor_set.file.iter().enumerate() { + let file_name = file.name.as_deref().unwrap_or("unknown.proto"); + let package = file.package.as_deref().unwrap_or(""); + + let path = Path::new(file_name); + let out_name = package.replace('.', "/") + + "/" + + path.with_extension("rs").file_name().unwrap().to_str().unwrap(); + let out_path = out_dir.join(&out_name); + + generated_names.push(out_name.clone()); + filegen::generate_file(pool, opts, out_path, file_idx)?; } let dirs: Vec> = generated_names.iter().map(|v| v.split('/').collect()).collect(); @@ -202,30 +438,17 @@ fn generate(opts: &mut filegen::Options, set: &desc::FileSetDef, out_dir: PathBu } for (k, v) in &subdirs { - eprintln!("Creating module in: {:?}", out_dir.join(k)); filegen::generate_mod(out_dir.join(k), opts, v.iter().copied())?; } - // #[cfg(feature = "descriptors")] - // filegen::generate_descriptor(&self.ctx, out_dir.join("descriptor.bin")); Ok(()) } impl Build { pub fn new() -> Self { - let mut this = Self::without_replacements(); - for (from, to) in REMAPS { - this.options.replace_import(from, to); - } - this + Self::default() } - pub fn without_replacements() -> Self { - Self { - ctx: Default::default(), - ..Default::default() - } - } pub fn include(mut self, p: impl Into) -> Self { self.ctx.include(p); self @@ -249,18 +472,22 @@ impl Build { self.options.track_unknowns = t; self } + pub fn root(mut self, s: &str) -> Self { - self.options.import_root = Some(TokenStream::from_str(s).unwrap()); + self.options.import_root = Some(core::str::FromStr::from_str(s).unwrap()); self } + pub fn string_type(mut self, typ: &str) -> Self { - self.options.string_type = TokenStream::from_str(typ).unwrap(); + self.options.string_type = core::str::FromStr::from_str(typ).unwrap(); self } + pub fn bytes_type(mut self, typ: &str) -> Self { - self.options.bytes_type = TokenStream::from_str(typ).unwrap(); + self.options.bytes_type = core::str::FromStr::from_str(typ).unwrap(); self } + pub fn out_dir(mut self, p: impl Into) -> Self { self.out_dir = Some(p.into()); self @@ -272,10 +499,11 @@ impl Build { Ok(self) } - pub fn generate(mut self) -> anyhow::Result<()> { + pub fn generate(self) -> anyhow::Result<()> { let out_dir = self .out_dir .unwrap_or_else(|| PathBuf::from(std::env::var("OUT_DIR").unwrap())); - generate(&mut self.options, &self.ctx.finish()?, out_dir) + let pool = self.ctx.finish()?; + generate(&self.options, &pool, out_dir) } } diff --git a/protokit_textformat/src/lib.rs b/protokit_textformat/src/lib.rs index d04ba9f..f15f8f0 100644 --- a/protokit_textformat/src/lib.rs +++ b/protokit_textformat/src/lib.rs @@ -10,14 +10,6 @@ use thiserror::Error; mod escape; mod lex; pub mod reflect; -#[cfg(test)] -mod repro_ext; -#[cfg(test)] -mod repro_list; -#[cfg(test)] -mod repro_p2; -#[cfg(test)] -mod repro_sep; pub mod stream; use escape::unescape; @@ -623,19 +615,31 @@ pub fn merge_repeated<'buf, T: TextField<'buf> + Default>( out.last_mut().unwrap().merge_value(stream)?; match stream.cur { // End of the list - RBracket | EndOfFile if is_list => { - // In this case we must advance one past the rbracket + RBracket if is_list => { stream.advance(); return Ok(()); } - // Comma/Semi as elem separator + EndOfFile => { + if is_list { + stream.advance(); + } + return Ok(()); + } + // Comma as elem separator Comma => { + // Check if after comma is a new field (not a value) + if stream.lookahead_is_field() { + // Don't consume comma, let caller handle field separator + return Ok(()); + } stream.advance(); continue; } Semi => { if is_list { - return crate::unexpected(Comma, stream.cur, stream.buf()); + // Semicolon works as separator inside list too + stream.advance(); + continue; } else { // For non-list (top-level repeated), check if next is field or value if stream.lookahead_is_field() { @@ -646,15 +650,13 @@ pub fn merge_repeated<'buf, T: TextField<'buf> + Default>( } } } - // This was the last entry in this field, return + // For non-list, any other token means end of this repeated field _ if !is_list => { return Ok(()); } - // Implicit separator + // Implicit separator in list (whitespace between values) _ => { - if is_list { - return crate::unexpected(Comma, stream.cur, stream.buf()); - } + // In a list, allow implicit separators (whitespace) between values continue; } } diff --git a/protokit_textformat/src/repro_ext.rs b/protokit_textformat/src/repro_ext.rs deleted file mode 100644 index fd64360..0000000 --- a/protokit_textformat/src/repro_ext.rs +++ /dev/null @@ -1,46 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::reflect::Registry; - use crate::stream::InputStream; - use crate::Token; - - #[test] - fn test_extension_key_parsing() { - let reg = Registry::default(); - let txt = "[some.ext.field]: 123"; - let mut stream = InputStream::new(txt, ®); - - // InputStream::new doesn't advance automatically to first token? - // Let's check constructor. It calls Lexer::new, sets cur=StartOfFile. - // next_token() or advance() needed? - // merge_field loop usually calls advance. - - // Manually drive pars_key - stream.advance(); // Cur should be LBracket - assert_eq!(stream.token(), Token::LBracket); - - // But parse_key expects to start AT the key tokens? - // parse_key implementation: - // match self.cur { LBracket => ... } - // So yes, we need to be at LBracket. - - let key = stream.parse_key().expect("Should parse extension key"); - assert_eq!(key, "some.ext.field"); - - // Should be at Colon now - assert_eq!(stream.token(), Token::Colon); - } - - #[test] - fn test_simple_key_parsing() { - let reg = Registry::default(); - let txt = "simple_field: 123"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); - assert_eq!(stream.token(), Token::Ident); - - let key = stream.parse_key().expect("Should parse simple key"); - assert_eq!(key, "simple_field"); - assert_eq!(stream.token(), Token::Colon); - } -} diff --git a/protokit_textformat/src/repro_list.rs b/protokit_textformat/src/repro_list.rs deleted file mode 100644 index 8226b8d..0000000 --- a/protokit_textformat/src/repro_list.rs +++ /dev/null @@ -1,140 +0,0 @@ -use crate::decode; - -#[derive(Debug, Default, PartialEq)] -struct TestStringList { - s: Vec, -} - -impl binformat::ProtoName for TestStringList { - fn qualified_name(&self) -> &'static str { - "TestStringList" - } -} - -impl<'buf> binformat::BinProto<'buf> for TestStringList { - fn merge_field(&mut self, _tag: u32, _stream: &mut binformat::InputStream<'buf>) -> binformat::Result<()> { - Ok(()) - } - fn size(&self, _stack: &mut binformat::SizeStack) -> usize { - 0 - } - fn encode(&self, _stream: &mut binformat::OutputStream) {} -} - -impl<'buf> crate::TextProto<'buf> for TestStringList { - fn decode(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - while stream.token() != crate::Token::EndOfFile { - self.merge_field(stream)?; - } - Ok(()) - } - - fn merge_field(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - if stream.token() == crate::Token::Ident && stream.buf() == "s" { - stream.advance(); - if stream.try_consume(crate::Token::Colon) { - // ok - } - crate::merge_repeated(&mut self.s, stream)?; - } else { - // unknown field or end - if stream.token() != crate::Token::EndOfFile { - stream.advance(); - } - } - Ok(()) - } - fn encode(&self, _stream: &mut crate::OutputStream) {} -} - -#[test] -fn test_string_list_space_separated() { - let input = "s: [\"a\" \"b\"]"; - let reg = crate::reflect::Registry::default(); - let msg: TestStringList = decode(input, ®).unwrap(); - assert_eq!(msg.s, vec!["ab".to_string()]); -} - -#[test] -fn test_string_concatenation() { - let input = "s: [\"a\" \"b\", \"c\"]"; - let reg = crate::reflect::Registry::default(); - // Expect: ["ab", "c"] - let msg: TestStringList = decode(input, ®).unwrap(); - assert_eq!(msg.s, vec!["ab".to_string(), "c".to_string()]); -} - -#[derive(Debug, Default, PartialEq)] -struct TestIntList { - i: Vec, -} - -impl binformat::ProtoName for TestIntList { - fn qualified_name(&self) -> &'static str { - "TestIntList" - } -} - -impl<'buf> binformat::BinProto<'buf> for TestIntList { - fn merge_field(&mut self, _tag: u32, _stream: &mut binformat::InputStream<'buf>) -> binformat::Result<()> { - Ok(()) - } - fn size(&self, _stack: &mut binformat::SizeStack) -> usize { - 0 - } - fn encode(&self, _stream: &mut binformat::OutputStream) {} -} - -impl<'buf> crate::TextProto<'buf> for TestIntList { - fn decode(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - while stream.token() != crate::Token::EndOfFile { - self.merge_field(stream)?; - } - Ok(()) - } - fn merge_field(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - if stream.token() == crate::Token::Ident && stream.buf() == "i" { - stream.advance(); - if stream.try_consume(crate::Token::Colon) {} - crate::merge_repeated(&mut self.i, stream)?; - } else { - if stream.token() != crate::Token::EndOfFile { - stream.advance(); - } - } - Ok(()) - } - fn encode(&self, _stream: &mut crate::OutputStream) {} -} - -#[test] -fn test_int_list_space_separated() { - let input = "i: [1 2]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2]); -} - -#[test] -fn test_int_list_comma() { - let input = "i: [1, 2]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2]); -} - -#[test] -fn test_int_list_comment() { - let input = "i: [1 /* comment */ 2]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2]); -} - -#[test] -fn test_int_list_mixed() { - let input = "i: [1, 2; 3 4]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2, 3, 4]); -} diff --git a/protokit_textformat/src/repro_p2.rs b/protokit_textformat/src/repro_p2.rs deleted file mode 100644 index 0531f5c..0000000 --- a/protokit_textformat/src/repro_p2.rs +++ /dev/null @@ -1,149 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::reflect::Registry; - use crate::*; - use std::fmt::Display; - use std::str::FromStr; - - // Simulate a derived Enum - #[derive(Debug, PartialEq, Clone, Copy)] - struct MyEnum(i32); - - impl Default for MyEnum { - fn default() -> Self { - Self(0) - } - } - - impl Display for MyEnum { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } - - impl FromStr for MyEnum { - type Err = Error; - fn from_str(_s: &str) -> Result { - Ok(Self(0)) // Dummy - } - } - - impl From for MyEnum { - fn from(v: u32) -> Self { - Self(v as i32) - } - } - impl From for u32 { - fn from(v: MyEnum) -> Self { - v.0 as u32 - } - } - - impl Enum for MyEnum {} - - impl<'buf> TextField<'buf> for MyEnum { - fn merge_value(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - // Logic copied from protokit_derive - match stream.field() { - "FOO" => { - *self = Self(0); - stream.advance(); - } - "BAR" => { - *self = Self(1); - stream.advance(); - } - name => { - if let Ok(v) = name.parse::() { - *self = Self(v); - stream.advance(); - } else { - return crate::unknown(name); - } - } - } - Ok(()) - } - - fn emit_value(&self, stream: &mut OutputStream) { - match self.0 { - 0 => stream.ident("FOO"), - 1 => stream.ident("BAR"), - v => stream.disp(&v), - } - } - } - - #[test] - fn test_enum_number_parsing() { - let reg = Registry::default(); - let txt = "1"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); // Prime first token - let mut e = MyEnum::default(); - e.merge_value(&mut stream).expect("Should parse '1'"); - assert_eq!(e.0, 1); - } - - #[test] - fn test_float_zero_parsing() { - let reg = Registry::default(); - let txt = "0"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); - let v = stream.f64().expect("Should parse '0' as float"); - assert_eq!(v, 0.0); - } - - #[test] - fn test_float_neg_zero_parsing() { - let reg = Registry::default(); - let txt = "-0"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); - let v = stream.f64().expect("Should parse '-0' as float"); - assert_eq!(v, 0.0); - assert!(v.is_sign_negative()); - } - - // Simulate Message with enum field to test full merge_field flow - #[derive(Default)] - struct MyMsg { - e: MyEnum, - } - - impl<'buf> TextProto<'buf> for MyMsg { - fn merge_field(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - match stream.field() { - "e" => { - // merge_single - self.e.merge_text(stream) - } - _ => crate::skip_unknown(stream), - } - } - fn encode(&self, _stream: &mut OutputStream) {} - fn decode(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - stream.message_fields(self) - } - } - impl binformat::ProtoName for MyMsg { - fn qualified_name(&self) -> &'static str { - "MyMsg" - } - } - - #[test] - fn test_msg_enum_parsing() { - let reg = Registry::default(); - let txt = "e: 1"; - let mut stream = InputStream::new(txt, ®); - // message_fields advances internally? - // No, message_fields expects stream to be at start of message? - // InputStream::new sets root=true. - // stream.message_fields -> if root=true, advances. - let mut msg = MyMsg::default(); - stream.message_fields(&mut msg).expect("Should parse msg"); - assert_eq!(msg.e.0, 1); - } -} diff --git a/protokit_textformat/src/repro_sep.rs b/protokit_textformat/src/repro_sep.rs deleted file mode 100644 index 1c40a77..0000000 --- a/protokit_textformat/src/repro_sep.rs +++ /dev/null @@ -1,103 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::reflect::Registry; - use crate::stream::InputStream; - use crate::*; - - #[derive(Default)] - struct MockRepeated { - vals: Vec, - bools: Vec, - } - - impl<'buf> TextProto<'buf> for MockRepeated { - fn merge_field(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - match stream.parse_key()?.as_ref() { - "vals" => merge_repeated(&mut self.vals, stream), - "bools" => merge_repeated(&mut self.bools, stream), - "1" => { - // Simulate integer key 1 mapping to vals - merge_repeated(&mut self.vals, stream) - } - _ => skip_unknown(stream), - } - } - fn encode(&self, _stream: &mut OutputStream) {} - fn decode(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - stream.message_fields(self) - } - } - - impl binformat::ProtoName for MockRepeated { - fn qualified_name(&self) -> &'static str { - "MockRepeated" - } - } - - #[test] - fn test_comma_separated_fields() { - let reg = Registry::default(); - // "field: val, field: val" - let txt = "vals: 1, vals: 2"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream - .message_fields(&mut msg) - .expect("Should parse comma separated fields"); - assert_eq!(msg.vals, vec![1, 2]); - } - - #[test] - fn test_comma_separated_values() { - let reg = Registry::default(); - // "field: 1, 2" - let txt = "vals: 3, 4"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream - .message_fields(&mut msg) - .expect("Should parse comma separated values"); - assert_eq!(msg.vals, vec![3, 4]); - } - - #[test] - fn test_semicolon_separated_fields() { - let reg = Registry::default(); - // "field: val; field: val" - let txt = "vals: 5; vals: 6"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream - .message_fields(&mut msg) - .expect("Should parse semicolon separated fields"); - assert_eq!(msg.vals, vec![5, 6]); - } - - #[test] - fn test_integer_key() { - let reg = Registry::default(); - let txt = "1: 10"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream.message_fields(&mut msg).expect("Should parse integer key"); - assert_eq!(msg.vals, vec![10]); - } - - #[test] - fn test_bool_separators() { - let reg = Registry::default(); - let txt = "bools: true, bools: false"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream.message_fields(&mut msg).expect("Should parse bool fields"); - assert_eq!(msg.bools, vec![true, false]); - - let txt2 = "bools: true, false"; - let mut stream2 = InputStream::new(txt2, ®); - let mut msg2 = MockRepeated::default(); - stream2 - .message_fields(&mut msg2) - .expect("Should parse comma bool values"); - assert_eq!(msg2.bools, vec![true, false]); - } -} diff --git a/tools/conformance/Cargo.toml b/tools/conformance/Cargo.toml index e2f8a97..6b0243d 100644 --- a/tools/conformance/Cargo.toml +++ b/tools/conformance/Cargo.toml @@ -19,4 +19,4 @@ byteorder = "1.4.3" protokit = { path = "../../protokit", features = ["textformat"] } [build-dependencies] -build = { workspace = true, features = ["protoc"] } +build = { workspace = true } diff --git a/tools/gendesc/Cargo.toml b/tools/gendesc/Cargo.toml index 8d7990a..d7a9f3c 100644 --- a/tools/gendesc/Cargo.toml +++ b/tools/gendesc/Cargo.toml @@ -11,5 +11,3 @@ publish = false [dependencies.protokit_build] path = "../../protokit_build" -default-features = false -features = ["protoc"]