From 3f348837711e37c39dbff72c5ab1ced5ba0b07f1 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 7 Dec 2025 08:57:36 +0000 Subject: [PATCH 1/2] refactor: Remove custom parser, rely solely on protoc for descriptor generation This major refactoring removes the custom nom-based parser from protokit and simplifies the code generation pipeline to work directly with protoc's FileDescriptorSet output. Key changes: - Remove protokit_proto crate from the workspace - Remove parser and protoc feature flags from protokit_build - Create new DescriptorPool as a thin overlay on FileDescriptorSet - Rewrite code generation to work directly with protobuf descriptors - Simplify grpc.rs to use the new pool-based approach - Remove unused deps.rs and tabular.rs files - Update conformance and gendesc tools to use the simplified API The new architecture: 1. protoc generates FileDescriptorSet with fully resolved type names 2. DescriptorPool indexes types by FQN and detects circular references 3. CodeGenerator works directly with protobuf descriptor types This eliminates the complexity of the custom parser and AST translation layer while leveraging protoc's robust name resolution. --- Cargo.toml | 2 - protokit_build/Cargo.toml | 12 +- protokit_build/src/deps.rs | 7 - protokit_build/src/filegen/grpc.rs | 592 +++++++++--------- protokit_build/src/filegen/mod.rs | 864 +++++++++++++++----------- protokit_build/src/filegen/tabular.rs | 102 --- protokit_build/src/lib.rs | 514 ++++++++++----- tools/conformance/Cargo.toml | 2 +- tools/gendesc/Cargo.toml | 2 - 9 files changed, 1182 insertions(+), 915 deletions(-) delete mode 100644 protokit_build/src/deps.rs delete mode 100644 protokit_build/src/filegen/tabular.rs diff --git a/Cargo.toml b/Cargo.toml index 8adf109..15dc02b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ members = [ "tools/conformance", "protokit", "protokit_build", - "protokit_proto", "protokit_grpc", "protokit_desc", "protokit_textformat", @@ -53,7 +52,6 @@ textformat = { path = "protokit_textformat", package = "protokit_textformat", ve desc = { path = "protokit_desc", package = "protokit_desc", version = "0.2.0" } grpc = { path = "protokit_grpc", package = "protokit_grpc", version = "0.2.0" } derive = { path = "protokit_derive", package = "protokit_derive", version = "0.2.0" } -proto = { path = "protokit_proto", package = "protokit_proto", version = "0.2.0", default-features = false } protokit = { path = "protokit", version = "0.2.0" } lex_core = { package = "lexical-core", version = "0.8.5" } diff --git a/protokit_build/Cargo.toml b/protokit_build/Cargo.toml index 30759f3..cd998d1 100644 --- a/protokit_build/Cargo.toml +++ b/protokit_build/Cargo.toml @@ -9,22 +9,14 @@ authors = [""] description = "Usable protocol buffers" [features] -default = ["protoc"] -protoc = ["descriptors"] -parser = ["proto"] - -descriptors = ["desc/descriptors"] +default = [] [dependencies.binformat] workspace = true [dependencies.desc] workspace = true - -[dependencies.proto] -workspace = true -default-features = false -optional = true +features = ["descriptors"] [dependencies] anyhow = { workspace = true } diff --git a/protokit_build/src/deps.rs b/protokit_build/src/deps.rs deleted file mode 100644 index d4dd7d1..0000000 --- a/protokit_build/src/deps.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub use std::fs::{create_dir_all, File}; -pub use std::io::Write; -pub use std::path::{Path, PathBuf}; - -pub use anyhow::bail; -pub use desc::arcstr::ArcStr; -pub use desc::*; diff --git a/protokit_build/src/filegen/grpc.rs b/protokit_build/src/filegen/grpc.rs index 4267633..12f696b 100644 --- a/protokit_build/src/filegen/grpc.rs +++ b/protokit_build/src/filegen/grpc.rs @@ -1,320 +1,350 @@ use convert_case::{Case, Casing}; -use desc::{FileDef, RpcDef, ServiceDef}; use quote::__private::TokenStream; use quote::{format_ident, quote}; -use crate::filegen::{rustify_name, CodeGenerator}; - -impl CodeGenerator<'_> { - pub fn generate_server(&self, file: &FileDef, svc: &ServiceDef) -> TokenStream { - let svc_qualified_raw_name = format!("{}.{}", file.package, svc.name); - - let svc_name = format_ident!("{}", rustify_name(svc.name.as_str())); - let mod_name = format_ident!("{}_server", rustify_name(svc.name.as_str())); - let server_name = format_ident!("{}Server", rustify_name(svc.name.as_str())); - - let mut trait_items = vec![]; - let mut arms = vec![]; - let mut defs = vec![]; - - for (_, rpc) in &svc.rpc { - let rpc: &RpcDef = rpc; - let rpc_struct = format_ident!("{}Svc", rpc.name.as_str()); - let method_name = format_ident!("{}", rpc.name.as_str().to_case(Case::Snake)); - let stream_name = format_ident!("{}Stream", rpc.name.as_str()); - let path = format!("/{}.{}/{}", file.package, svc.name, rpc.name); - - let raw_req_type = self - .base_type(&rpc.req_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); - - let raw_res_type = self - .base_type(&rpc.res_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); - - let mut rpc_kind_method = quote! { unary }; - - let mut req_type = raw_req_type.clone(); - let res_type; - let response_type; - let mut stream_def = quote! {}; - - let svc_type = match (&rpc.req_stream, &rpc.res_stream) { - (false, false) => { - req_type = quote! { super::#raw_req_type }; - res_type = quote! { super::#raw_res_type }; - response_type = quote! { Self::Response }; - quote! { UnaryService } - } - (true, false) => { - req_type = quote! { tonic::Streaming }; - res_type = quote! { super::#raw_res_type }; - response_type = quote! { Self::Response }; - rpc_kind_method = quote! { client_streaming }; - quote! { ClientStreamingService } - } - (false, true) => { - req_type = quote! { super::#req_type }; - res_type = quote! { Self::#stream_name }; - stream_def = quote! { - type ResponseStream = S::#stream_name; - }; - trait_items.push(quote! { - type #stream_name: Stream> + Send + 'static; - }); - response_type = quote! { Self::ResponseStream }; - rpc_kind_method = quote! { server_streaming }; - quote! { ServerStreamingService } - } - (true, true) => { - req_type = quote! { tonic::Streaming }; - res_type = quote! { Self::#stream_name }; - stream_def = quote! { - type ResponseStream = S::#stream_name; - }; - trait_items.push(quote! { - type #stream_name: Stream> + Send + 'static; - }); - response_type = quote! { Self::ResponseStream }; - rpc_kind_method = quote! { streaming }; - quote! { StreamingService } - } - }; - - trait_items.push(quote! { - async fn #method_name(&self, req: tonic::Request<#req_type>) -> Result, tonic::Status>; - }); - - defs.push(quote! { - struct #rpc_struct(Arc); - impl tonic::server::#svc_type for #rpc_struct { - type Response = super::#raw_res_type; - #stream_def - type Future = BoxFuture< - tonic::Response<#response_type>, - tonic::Status, - >; - - fn call(&mut self, request: tonic::Request<#req_type>) -> Self::Future { - let inner = self.0.clone(); - Box::pin(async move { inner.#method_name(request).await }) - } - } - }); - arms.push(quote! { - #path => { - let inner = self.0.clone(); - let fut = async move { - let method = #rpc_struct(inner); - let codec = protokit::grpc::TonicCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec); - let res = grpc.#rpc_kind_method(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - }); - } - quote! { - mod #mod_name { - use super::*; - use protokit::grpc::*; - #[protokit::grpc::async_trait] - pub trait #svc_name : Send + Sync + 'static { - #(#trait_items)* - } - #[derive(Debug)] - pub struct #server_name (pub Arc); - impl Clone for #server_name { - fn clone(&self) -> Self { - Self(self.0.clone()) - } - } - impl From for #server_name { - fn from(v: S) -> Self { - Self(::std::sync::Arc::new(v)) - } - } - impl From<::std::sync::Arc> for #server_name { - fn from(v: ::std::sync::Arc) -> Self { - Self(v) - } - } - - #(#defs)* +use crate::{DescriptorPool, ServiceDescriptorProto}; +use super::{rustify_name, rust_type_name, Options}; - impl Service> for #server_name - where - S: #svc_name, - B: Body + Send + 'static, - B::Error: Into> + Send + 'static, +fn base_type_from_type_name(pool: &DescriptorPool, opts: &Options, type_name: &str) -> TokenStream { + let rust_name = rust_type_name(pool, type_name); + let ident = format_ident!("{}", rust_name); + let gen = opts.generics.struct_use_generics(); + quote! { #ident #gen } +} - { - type Response = http::Response; - type Error = core::convert::Infallible; - type Future = BoxFuture; +pub fn generate_server(pool: &DescriptorPool, opts: &Options, svc: &ServiceDescriptorProto) -> TokenStream { + let svc_name_str = svc.name.as_deref().unwrap_or(""); - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: http::Request) -> Self::Future { - match req.uri().path() { - #(#arms)* - _ => Box::pin(async move { - Ok( - http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(empty_body()) - .unwrap(), - ) - }) - } - } - } - impl tonic::server::NamedService for #server_name { - const NAME: &'static str = #svc_qualified_raw_name; - } + // We need to find the file this service is in to get the package + let mut package = ""; + for file in &pool.descriptor_set.file { + for file_svc in &file.service { + if file_svc.name.as_deref() == Some(svc_name_str) { + package = file.package.as_deref().unwrap_or(""); + break; } - pub use #mod_name::*; } } - pub fn generate_client(&self, file: &FileDef, svc: &ServiceDef) -> TokenStream { - let _svc_qualified_raw_name = format!("{}.{}", file.package, svc.name); + let svc_qualified_raw_name = format!("{}.{}", package, svc_name_str); - let _svc_name = format_ident!("{}", rustify_name(svc.name.as_str())); - let mod_name = format_ident!("{}_client", rustify_name(svc.name.as_str())); - let client_name = format_ident!("{}Client", rustify_name(svc.name.as_str())); + let svc_name = format_ident!("{}", rustify_name(svc_name_str)); + let mod_name = format_ident!("{}_server", rustify_name(svc_name_str)); + let server_name = format_ident!("{}Server", rustify_name(svc_name_str)); - let mut methods = vec![]; - for (_, rpc) in &svc.rpc { - let rpc: &RpcDef = rpc; - // let rpc_struct = format_ident!("{}Svc", rpc.name.as_str()); - let method_name = format_ident!("{}", rpc.name.as_str().to_case(Case::Snake)); - let _stream_name = format_ident!("{}Stream", rpc.name.as_str()); - let path = format!("/{}.{}/{}", file.package, svc.name, rpc.name); + let mut trait_items = vec![]; + let mut arms = vec![]; + let mut defs = vec![]; - let raw_req_type = self - .base_type(&rpc.req_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); + for rpc in &svc.method { + let rpc_name = rpc.name.as_deref().unwrap_or(""); + let rpc_struct = format_ident!("{}Svc", rpc_name); + let method_name = format_ident!("{}", rpc_name.to_case(Case::Snake)); + let stream_name = format_ident!("{}Stream", rpc_name); + let path = format!("/{}.{}/{}", package, svc_name_str, rpc_name); - let raw_res_type = self - .base_type(&rpc.res_typ) - // .with_context(|| format!("{msg_name}.{field_raw_name} in {:?}", file.name)) - .expect("Resolving name"); + let req_type_name = rpc.input_type.as_deref().unwrap_or(""); + let res_type_name = rpc.output_type.as_deref().unwrap_or(""); - let _req_type = raw_req_type.clone(); - let _res_type = raw_res_type.clone(); + let raw_req_type = base_type_from_type_name(pool, opts, req_type_name); + let raw_res_type = base_type_from_type_name(pool, opts, res_type_name); - let mut rpc_kind_method = quote! { unary }; + let req_stream = rpc.client_streaming.unwrap_or(false); + let res_stream = rpc.server_streaming.unwrap_or(false); - let res_type = if rpc.res_stream { - quote!( tonic::Streaming ) - } else { - quote! { super::#raw_res_type } - }; + let mut rpc_kind_method = quote! { unary }; + let mut stream_def = quote! {}; - let (req_type, req_convert) = if rpc.req_stream { + let (req_type, res_type, response_type, svc_type) = match (req_stream, res_stream) { + (false, false) => { + ( + quote! { super::#raw_req_type }, + quote! { super::#raw_res_type }, + quote! { Self::Response }, + quote! { UnaryService }, + ) + } + (true, false) => { + rpc_kind_method = quote! { client_streaming }; ( - quote! { impl tonic::IntoStreamingRequest }, - quote! { into_streaming_request }, + quote! { tonic::Streaming }, + quote! { super::#raw_res_type }, + quote! { Self::Response }, + quote! { ClientStreamingService }, ) - } else { + } + (false, true) => { + stream_def = quote! { + type ResponseStream = S::#stream_name; + }; + trait_items.push(quote! { + type #stream_name: Stream> + Send + 'static; + }); + rpc_kind_method = quote! { server_streaming }; ( - quote! { impl tonic::IntoRequest }, - quote! { into_request }, + quote! { super::#raw_req_type }, + quote! { Self::#stream_name }, + quote! { Self::ResponseStream }, + quote! { ServerStreamingService }, ) - }; + } + (true, true) => { + stream_def = quote! { + type ResponseStream = S::#stream_name; + }; + trait_items.push(quote! { + type #stream_name: Stream> + Send + 'static; + }); + rpc_kind_method = quote! { streaming }; + ( + quote! { tonic::Streaming }, + quote! { Self::#stream_name }, + quote! { Self::ResponseStream }, + quote! { StreamingService }, + ) + } + }; + + trait_items.push(quote! { + async fn #method_name(&self, req: tonic::Request<#req_type>) -> Result, tonic::Status>; + }); - match (&rpc.req_stream, &rpc.res_stream) { - (false, false) => {} - (true, false) => { - rpc_kind_method = quote! { client_streaming }; + defs.push(quote! { + struct #rpc_struct(Arc); + impl tonic::server::#svc_type for #rpc_struct { + type Response = super::#raw_res_type; + #stream_def + type Future = BoxFuture< + tonic::Response<#response_type>, + tonic::Status, + >; + + fn call(&mut self, request: tonic::Request<#req_type>) -> Self::Future { + let inner = self.0.clone(); + Box::pin(async move { inner.#method_name(request).await }) } - (false, true) => { - rpc_kind_method = quote! { server_streaming }; + } + }); + arms.push(quote! { + #path => { + let inner = self.0.clone(); + let fut = async move { + let method = #rpc_struct(inner); + let codec = protokit::grpc::TonicCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec); + let res = grpc.#rpc_kind_method(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + }); + } + + quote! { + mod #mod_name { + use super::*; + use protokit::grpc::*; + #[protokit::grpc::async_trait] + pub trait #svc_name : Send + Sync + 'static { + #(#trait_items)* + } + #[derive(Debug)] + pub struct #server_name (pub Arc); + impl Clone for #server_name { + fn clone(&self) -> Self { + Self(self.0.clone()) } - (true, true) => { - rpc_kind_method = quote! { streaming }; + } + impl From for #server_name { + fn from(v: S) -> Self { + Self(::std::sync::Arc::new(v)) } - }; - methods.push(quote! { - pub async fn #method_name( - &mut self, - request: #req_type, - ) -> Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - Status::new(Code::Unknown, format!("Service was not ready: {}", e.into())) - })?; - let codec = protokit - ::grpc::TonicCodec::default(); - let path = http::uri::PathAndQuery::from_static(#path); - self.inner.#rpc_kind_method(request.#req_convert(), path, codec).await + } + impl From<::std::sync::Arc> for #server_name { + fn from(v: ::std::sync::Arc) -> Self { + Self(v) } - }) - } - quote! { - mod #mod_name { - use super::*; - use protokit::grpc::*; - #[derive(Debug, Clone)] - pub struct #client_name { - inner: tonic::client::Grpc, + } + + #(#defs)* + + impl Service> for #server_name + where + S: #svc_name, + B: Body + Send + 'static, + B::Error: Into> + Send + 'static, + + { + type Response = http::Response; + type Error = core::convert::Infallible; + type Future = BoxFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - impl #client_name { - pub async fn connect(dst: D) -> Result - where - D: core::convert::TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) + + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + #(#arms)* + _ => Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) } } - impl #client_name + } + impl tonic::server::NamedService for #server_name { + const NAME: &'static str = #svc_qualified_raw_name; + } + } + pub use #mod_name::*; + } +} + +pub fn generate_client(pool: &DescriptorPool, opts: &Options, svc: &ServiceDescriptorProto) -> TokenStream { + let svc_name_str = svc.name.as_deref().unwrap_or(""); + + // We need to find the file this service is in to get the package + let mut package = ""; + for file in &pool.descriptor_set.file { + for file_svc in &file.service { + if file_svc.name.as_deref() == Some(svc_name_str) { + package = file.package.as_deref().unwrap_or(""); + break; + } + } + } + + let _svc_qualified_raw_name = format!("{}.{}", package, svc_name_str); + + let _svc_name = format_ident!("{}", rustify_name(svc_name_str)); + let mod_name = format_ident!("{}_client", rustify_name(svc_name_str)); + let client_name = format_ident!("{}Client", rustify_name(svc_name_str)); + + let mut methods = vec![]; + for rpc in &svc.method { + let rpc_name = rpc.name.as_deref().unwrap_or(""); + let method_name = format_ident!("{}", rpc_name.to_case(Case::Snake)); + let _stream_name = format_ident!("{}Stream", rpc_name); + let path = format!("/{}.{}/{}", package, svc_name_str, rpc_name); + + let req_type_name = rpc.input_type.as_deref().unwrap_or(""); + let res_type_name = rpc.output_type.as_deref().unwrap_or(""); + + let raw_req_type = base_type_from_type_name(pool, opts, req_type_name); + let raw_res_type = base_type_from_type_name(pool, opts, res_type_name); + + let req_stream = rpc.client_streaming.unwrap_or(false); + let res_stream = rpc.server_streaming.unwrap_or(false); + + let mut rpc_kind_method = quote! { unary }; + + let res_type = if res_stream { + quote!( tonic::Streaming ) + } else { + quote! { super::#raw_res_type } + }; + + let (req_type, req_convert) = if req_stream { + ( + quote! { impl tonic::IntoStreamingRequest }, + quote! { into_streaming_request }, + ) + } else { + ( + quote! { impl tonic::IntoRequest }, + quote! { into_request }, + ) + }; + + match (req_stream, res_stream) { + (false, false) => {} + (true, false) => { + rpc_kind_method = quote! { client_streaming }; + } + (false, true) => { + rpc_kind_method = quote! { server_streaming }; + } + (true, true) => { + rpc_kind_method = quote! { streaming }; + } + }; + methods.push(quote! { + pub async fn #method_name( + &mut self, + request: #req_type, + ) -> Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + Status::new(Code::Unknown, format!("Service was not ready: {}", e.into())) + })?; + let codec = protokit + ::grpc::TonicCodec::default(); + let path = http::uri::PathAndQuery::from_static(#path); + self.inner.#rpc_kind_method(request.#req_convert(), path, codec).await + } + }) + } + quote! { + mod #mod_name { + use super::*; + use protokit::grpc::*; + #[derive(Debug, Clone)] + pub struct #client_name { + inner: tonic::client::Grpc, + } + impl #client_name { + pub async fn connect(dst: D) -> Result where - S: tonic::client::GrpcService, - S::Error: Into, - S::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, + D: core::convert::TryInto, + D::Error: Into, { - pub fn new(inner: S) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_interceptor( - inner: S, - interceptor: F, - ) -> #client_name> - where - F: tonic::service::Interceptor, - S::ResponseBody: Default, - S: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl #client_name + where + S: tonic::client::GrpcService, + S::Error: Into, + S::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: S) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_interceptor( + inner: S, + interceptor: F, + ) -> #client_name> + where + F: tonic::service::Interceptor, + S::ResponseBody: Default, + S: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, >, - , - >>::Error: Into + Send + Sync, - { - #client_name::new(InterceptedService::new(inner, interceptor)) - } - #(#methods)* + >, + , + >>::Error: Into + Send + Sync, + { + #client_name::new(InterceptedService::new(inner, interceptor)) } - + #(#methods)* } - pub use #mod_name::*; + } + pub use #mod_name::*; } } diff --git a/protokit_build/src/filegen/mod.rs b/protokit_build/src/filegen/mod.rs index e6b793c..d3980ca 100644 --- a/protokit_build/src/filegen/mod.rs +++ b/protokit_build/src/filegen/mod.rs @@ -1,14 +1,18 @@ -use core::ops::Deref; -use core::str::FromStr; -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; +use std::fs::{create_dir_all, File}; +use std::io::Write; +use std::path::Path; use anyhow::{Context, Result}; use convert_case::{Case, Casing}; -use desc::Syntax::{Proto2, Proto3}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; -use crate::deps::*; +use crate::{ + DescriptorPool, DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, + FieldDescriptorProtoLabel, FieldDescriptorProtoType, FileDescriptorProto, + OneofDescriptorProto, ServiceDescriptorProto, +}; pub mod grpc; @@ -19,9 +23,10 @@ pub struct Generics { } impl Generics { - fn liftetime_arg(&self) -> Option { + fn lifetime_arg(&self) -> Option { self.buf_arg.clone() } + fn struct_def_generics(&self) -> TokenStream { match (&self.buf_arg, &self.alloc_arg) { (Some(b), Some(a)) => quote! { < #b, #a : std::alloc::Allocator + Debug> }, @@ -43,7 +48,6 @@ impl Generics { #[derive(Debug)] pub struct Options { - pub replacement: BTreeMap, pub import_root: Option, pub string_type: TokenStream, @@ -54,22 +58,12 @@ pub struct Options { pub generics: Generics, pub protoattrs: Vec, - pub force_box: HashSet, - pub track_unknowns: bool, } -impl Options { - pub fn replace_import(&mut self, from: &str, to: &str) -> &mut Self { - self.replacement.insert(from.to_string(), to.to_string()); - self - } -} - impl Default for Options { fn default() -> Self { Self { - replacement: Default::default(), import_root: None, string_type: quote! { String }, bytes_type: quote! { Vec }, @@ -77,33 +71,27 @@ impl Default for Options { unknown_type: quote! { ::protokit::binformat::UnknownFieldsOwned }, generics: Generics::default(), protoattrs: vec![], - force_box: Default::default(), track_unknowns: false, } } } const STRICT: &[&str] = &[ - "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", "false", "fn", - "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return", "self", "Self", - "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where", "while", + "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", + "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", + "unsafe", "use", "where", "while", ]; const RESERVED: &[&str] = &[ - "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual", - "yield", + "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try", "typeof", + "unsized", "virtual", "yield", ]; pub const TYPES: &[&str] = &["Option", "Result"]; pub fn rustify_name(n: impl AsRef) -> String { let n = n.as_ref().replace('.', "_"); - // let pos = n.find('.'); - // let n = if let Some(pos) = pos { - // &n[pos + 1..] - // } else { - // n - // }; for s in STRICT.iter().chain(RESERVED) { if *s == n { return format!("r#{n}"); @@ -114,210 +102,396 @@ pub fn rustify_name(n: impl AsRef) -> String { return format!("Proto{n}"); } } - n.replace('.', "") } -pub fn builtin_type_marker(typ: BuiltinType) -> &'static str { +/// Extract the simple name from a fully qualified type name +/// e.g. ".google.protobuf.Any" -> "Any" +fn simple_name(fqn: &str) -> &str { + fqn.rsplit('.').next().unwrap_or(fqn) +} + +/// Get the Rust type name from a fully qualified protobuf type name +fn rust_type_name(pool: &DescriptorPool, type_name: &str) -> String { + if let Some(loc) = pool.lookup(type_name) { + // Use the path to construct nested type names + rustify_name(loc.path.join("_")) + } else { + // Fallback to simple name + rustify_name(simple_name(type_name)) + } +} + +fn builtin_type_marker(typ: FieldDescriptorProtoType) -> &'static str { match typ { - BuiltinType::Int32 => "varint", - BuiltinType::Int64 => "varint", - BuiltinType::Uint32 => "varint", - BuiltinType::Uint64 => "varint", - BuiltinType::Sint32 => "sigint", - BuiltinType::Sint64 => "sigint", - BuiltinType::Bool => "bool", - BuiltinType::Fixed64 => "fixed64", - BuiltinType::Sfixed64 => "fixed64", - BuiltinType::Fixed32 => "fixed32", - BuiltinType::Sfixed32 => "fixed32", - BuiltinType::Double => "fixed64", - BuiltinType::Float => "fixed32", - BuiltinType::String_ => "string", - BuiltinType::Bytes_ => "bytes", + FieldDescriptorProtoType::TYPE_INT32 => "varint", + FieldDescriptorProtoType::TYPE_INT64 => "varint", + FieldDescriptorProtoType::TYPE_UINT32 => "varint", + FieldDescriptorProtoType::TYPE_UINT64 => "varint", + FieldDescriptorProtoType::TYPE_SINT32 => "sigint", + FieldDescriptorProtoType::TYPE_SINT64 => "sigint", + FieldDescriptorProtoType::TYPE_BOOL => "bool", + FieldDescriptorProtoType::TYPE_FIXED64 => "fixed64", + FieldDescriptorProtoType::TYPE_SFIXED64 => "fixed64", + FieldDescriptorProtoType::TYPE_FIXED32 => "fixed32", + FieldDescriptorProtoType::TYPE_SFIXED32 => "fixed32", + FieldDescriptorProtoType::TYPE_DOUBLE => "fixed64", + FieldDescriptorProtoType::TYPE_FLOAT => "fixed32", + FieldDescriptorProtoType::TYPE_STRING => "string", + FieldDescriptorProtoType::TYPE_BYTES => "bytes", + FieldDescriptorProtoType::TYPE_MESSAGE => "nested", + FieldDescriptorProtoType::TYPE_GROUP => "group", + FieldDescriptorProtoType::TYPE_ENUM => "protoenum", + _ => panic!("Unknown field type"), } } +fn builtin_rust_type(opts: &Options, typ: FieldDescriptorProtoType) -> TokenStream { + match typ { + FieldDescriptorProtoType::TYPE_INT32 => quote! { i32 }, + FieldDescriptorProtoType::TYPE_INT64 => quote! { i64 }, + FieldDescriptorProtoType::TYPE_UINT32 => quote! { u32 }, + FieldDescriptorProtoType::TYPE_UINT64 => quote! { u64 }, + FieldDescriptorProtoType::TYPE_SINT32 => quote! { i32 }, + FieldDescriptorProtoType::TYPE_SINT64 => quote! { i64 }, + FieldDescriptorProtoType::TYPE_BOOL => quote! { bool }, + FieldDescriptorProtoType::TYPE_FIXED64 => quote! { u64 }, + FieldDescriptorProtoType::TYPE_SFIXED64 => quote! { i64 }, + FieldDescriptorProtoType::TYPE_FIXED32 => quote! { u32 }, + FieldDescriptorProtoType::TYPE_SFIXED32 => quote! { i32 }, + FieldDescriptorProtoType::TYPE_DOUBLE => quote! { f64 }, + FieldDescriptorProtoType::TYPE_FLOAT => quote! { f32 }, + FieldDescriptorProtoType::TYPE_STRING => opts.string_type.clone(), + FieldDescriptorProtoType::TYPE_BYTES => opts.bytes_type.clone(), + _ => panic!("Not a builtin type"), + } +} + +fn is_scalar_type(typ: FieldDescriptorProtoType) -> bool { + !matches!( + typ, + FieldDescriptorProtoType::TYPE_STRING + | FieldDescriptorProtoType::TYPE_BYTES + | FieldDescriptorProtoType::TYPE_MESSAGE + | FieldDescriptorProtoType::TYPE_GROUP + ) +} + pub struct CodeGenerator<'a> { - context: &'a FileSetDef, + pool: &'a DescriptorPool, + file: &'a FileDescriptorProto, options: &'a Options, types: Vec, output: Vec, } -pub fn resolve_name(set: &FileSetDef, id: GlobalDefId) -> Result { - if let Some((msg, _)) = set.message_by_id(id) { - return Ok(rustify_name(msg.name.as_str())); - } else if let Some((en, _)) = set.enum_by_id(id) { - return Ok(rustify_name(en.name.as_str())); - } else { - bail!( - "Could not resolve {} {} {:b} files:{}: {:#?}", - id, - id >> 32, - id & 0xFFFFFFFF, - set.files.len(), - set.files.keys() - ); +impl<'a> CodeGenerator<'a> { + pub fn new(pool: &'a DescriptorPool, file_idx: usize, options: &'a Options) -> Self { + Self { + pool, + file: &pool.descriptor_set.file[file_idx], + options, + types: vec![], + output: vec![], + } } -} -impl CodeGenerator<'_> { - pub fn resolve_name(&self, id: GlobalDefId) -> Result { - resolve_name(self.context, id) + fn is_proto3(&self) -> bool { + self.file.syntax.as_deref() == Some("proto3") } - pub fn builtin_rusttype(&self, typ: BuiltinType) -> TokenStream { - TokenStream::from_str(match typ { - BuiltinType::Int32 => "i32", - BuiltinType::Int64 => "i64", - BuiltinType::Uint32 => "u32", - BuiltinType::Uint64 => "u64", - BuiltinType::Sint32 => "i32", - BuiltinType::Sint64 => "i64", - BuiltinType::Bool => "bool", - BuiltinType::Fixed64 => "u64", - BuiltinType::Sfixed64 => "i64", - BuiltinType::Fixed32 => "u32", - BuiltinType::Sfixed32 => "i32", - BuiltinType::Double => "f64", - BuiltinType::Float => "f32", - BuiltinType::String_ => return self.options.string_type.clone(), - BuiltinType::Bytes_ => return self.options.bytes_type.clone(), - }) - .unwrap() + + fn package(&self) -> &str { + self.file.package.as_deref().unwrap_or("") } - pub fn type_marker(typ: &DataType) -> TokenStream { - TokenStream::from_str(match typ { - DataType::Unresolved(_, _) => panic!(), - DataType::Builtin(b) => builtin_type_marker(*b), - DataType::Message(_) => "nested", - DataType::Group(_) => "group", - DataType::Enum(_) => "protoenum", - - DataType::Map(k) => { - eprintln!("{:?} to {:?}", typ, k); - return TokenStream::from_str(&format!( - "map({}, {})", - builtin_type_marker(k.0), - Self::type_marker(&k.1) - )) - .unwrap(); - } - }) - .unwrap() + fn base_type(&self, field: &FieldDescriptorProto) -> Result { + let typ = field.r#type.as_ref().unwrap(); + + if let Some(type_name) = &field.type_name { + // Message, Enum, or Group type + let rust_name = rust_type_name(self.pool, type_name); + let ident = format_ident!("{}", rust_name); + let gen = self.options.generics.struct_use_generics(); + Ok(quote! { #ident #gen }) + } else { + // Builtin type + Ok(builtin_rust_type(self.options, *typ)) + } } - pub fn base_type(&self, typ: &DataType) -> Result { - Ok(match typ { - DataType::Unresolved(path, _) => { - panic!("Name {path} was not resolved to actual type") - } - DataType::Builtin(bt) => return Ok(self.builtin_rusttype(*bt)), - DataType::Message(id) | DataType::Group(id) => { - // let borrow = self.borrow(); - let gen = self.options.generics.struct_use_generics(); - let ident = format_ident!("{}", self.resolve_name(*id)?); + fn is_map_field(&self, field: &FieldDescriptorProto) -> Option<(&DescriptorProto, &FieldDescriptorProto, &FieldDescriptorProto)> { + if field.label != Some(FieldDescriptorProtoLabel::LABEL_REPEATED) { + return None; + } - quote! {#ident #gen} - } - DataType::Enum(id) => TokenStream::from_str(&self.resolve_name(*id)?).unwrap(), - DataType::Map(m) => { - let kt = self.base_type(&DataType::Builtin(m.deref().0))?; - let vt = self.base_type(&m.deref().1)?; - let mt = &self.options.map_type; - return Ok(quote! { #mt<#kt,#vt> }); - } - }) + let type_name = field.type_name.as_ref()?; + let msg = self.pool.get_message(type_name)?; + + // Check if it's a map entry + if msg.options.as_ref().and_then(|o| o.map_entry) != Some(true) { + return None; + } + + // Get key and value fields + let key_field = msg.field.iter().find(|f| f.number == Some(1))?; + let value_field = msg.field.iter().find(|f| f.number == Some(2))?; + + Some((msg, key_field, value_field)) } - pub fn field_type(&self, field: &FieldDef) -> Result { - let base = self.base_type(&field.typ)?; - let is_msg = matches!(field.typ, DataType::Message(..) | DataType::Group(_)); + fn field_type(&self, field: &FieldDescriptorProto) -> Result { + let typ = field.r#type.as_ref().unwrap(); + let is_msg = matches!( + typ, + &FieldDescriptorProtoType::TYPE_MESSAGE | &FieldDescriptorProtoType::TYPE_GROUP + ); + let is_repeated = field.label == Some(FieldDescriptorProtoLabel::LABEL_REPEATED); + + // Check if this is a map field + if let Some((_, key_field, value_field)) = self.is_map_field(field) { + let kt = self.base_type(key_field)?; + let vt = self.base_type(value_field)?; + let mt = &self.options.map_type; + return Ok(quote! { #mt<#kt, #vt> }); + } - let force_box = match field.typ { - DataType::Message(m) | DataType::Group(m) => self.options.force_box.contains(&m), - _ => false, - }; + let base = self.base_type(field)?; + + let force_box = field + .type_name + .as_ref() + .map(|t| self.pool.is_boxed(t)) + .unwrap_or(false); + + let is_optional = field.label == Some(FieldDescriptorProtoLabel::LABEL_OPTIONAL); + let is_required = field.label == Some(FieldDescriptorProtoLabel::LABEL_REQUIRED); + let is_proto3_optional = field.proto3_optional == Some(true); + + match (is_repeated, is_optional || is_proto3_optional, is_required, is_msg, force_box) { + // Repeated field -> Vec + (true, _, _, _, _) => Ok(quote!(Vec<#base>)), + + // Required message with boxing + (false, _, true, true, true) => Ok(quote!(Option>)), + // Required message without boxing + (false, _, true, true, false) => Ok(quote!(#base)), + // Required non-message + (false, _, true, false, _) => Ok(quote!(#base)), + + // Optional message with boxing + (false, true, _, true, true) => Ok(quote!(Option>)), + // Optional message without boxing + (false, true, _, true, false) => Ok(quote!(Option<#base>)), + // Optional non-message + (false, true, _, false, _) => Ok(quote!(Option<#base>)), + + // Singular (proto3 implicit presence) message with boxing + (false, false, false, true, true) => Ok(quote!(Option>)), + // Singular message without boxing + (false, false, false, true, false) => Ok(quote!(Option<#base>)), + // Singular non-message (proto3) + (false, false, false, false, _) => Ok(base), + } + } - match (field.frequency, is_msg, force_box) { - (Frequency::Singular | Frequency::Required, false, _) => Ok(base), + fn type_marker(&self, field: &FieldDescriptorProto) -> TokenStream { + let typ = field.r#type.as_ref().unwrap(); - (Frequency::Required, true, true) => Ok(quote!(Option>)), - (Frequency::Singular, true, true) => Ok(quote!(Option>)), + // Check for map fields + if let Some((_, key_field, value_field)) = self.is_map_field(field) { + let key_typ = key_field.r#type.as_ref().unwrap(); + let key_marker = builtin_type_marker(*key_typ); - (Frequency::Required, true, false) => Ok(quote!(#base)), - (Frequency::Singular, true, false) => Ok(quote!(Option<#base>)), + let val_typ = value_field.r#type.as_ref().unwrap(); + let val_marker = builtin_type_marker(*val_typ); - (Frequency::Optional, false, _) => Ok(quote!(Option<#base>)), + let marker_str = format!("map({}, {})", key_marker, val_marker); + return core::str::FromStr::from_str(&marker_str).unwrap(); + } - (Frequency::Optional, true, true) => Ok(quote!(Option>)), - (Frequency::Optional, true, false) => Ok(quote!(Option<#base>)), + let marker = builtin_type_marker(*typ); + core::str::FromStr::from_str(marker).unwrap() + } - (Frequency::Repeated, _, _) => Ok(quote!(Vec<#base>)), + fn field_frequency(&self, field: &FieldDescriptorProto) -> TokenStream { + let typ = field.r#type.as_ref().unwrap(); + let is_msg = matches!( + typ, + &FieldDescriptorProtoType::TYPE_MESSAGE | &FieldDescriptorProtoType::TYPE_GROUP + ); + let is_scalar = is_scalar_type(*typ); + let is_enum = *typ == FieldDescriptorProtoType::TYPE_ENUM; + let force_box = field + .type_name + .as_ref() + .map(|t| self.pool.is_boxed(t)) + .unwrap_or(false); + + // Check if packed + let is_packed = field + .options + .as_ref() + .and_then(|o| o.packed) + .unwrap_or(false); + + let is_repeated = field.label == Some(FieldDescriptorProtoLabel::LABEL_REPEATED); + let is_optional = field.label == Some(FieldDescriptorProtoLabel::LABEL_OPTIONAL); + let is_required = field.label == Some(FieldDescriptorProtoLabel::LABEL_REQUIRED); + let is_proto3_optional = field.proto3_optional == Some(true); + + // Handle map fields specially - they're singular even though they're repeated messages + if self.is_map_field(field).is_some() { + return quote! { singular }; } + + let freq = match (is_repeated, is_optional, is_required, is_msg, force_box) { + (true, _, _, _, _) if is_packed => "packed", + (true, _, _, _, _) if self.is_proto3() && (is_scalar || is_enum) => "packed", + (true, _, _, _, _) => "repeated", + (false, true, _, _, _) => "optional", + (false, _, _, true, _) if is_proto3_optional => "optional", + (false, false, false, true, _) => "optional", + (false, false, false, false, _) if is_proto3_optional => "optional", + (false, false, false, false, _) => "singular", + (false, _, true, _, true) => "optional", + (false, _, true, _, _) => "required", + }; + + core::str::FromStr::from_str(freq).unwrap() } - // pub fn protoattrs(&self, msg: &MessageDef) -> TokenStream { - // if !self.options.protoattrs.is_empty() { - // let attrs = &self.options.protoattrs; - // quote! { #[proto(#(#attrs,)*)] } - // } else { - // quote! {} - // } - // } - - pub fn file(&mut self, f: &FileDef) -> Result<()> { - for (_, en) in f.enums.iter() { - self.r#enum(f, en)?; + pub fn generate(&mut self) -> Result<()> { + // Generate enums first (they may be referenced by messages) + for enum_desc in &self.file.enum_type { + self.generate_enum(enum_desc, &[])?; } - for (i, (name, msg)) in f.messages.iter().enumerate() { - self.message(f, i as u32, name, msg)? + // Generate messages + for msg_desc in &self.file.message_type { + self.generate_message(msg_desc, &[])?; } - for (_name, svc) in f.services.iter() { - self.output.push(self.generate_server(f, svc)); - self.output.push(self.generate_client(f, svc)); + // Generate services + for svc_desc in &self.file.service { + self.generate_service(svc_desc)?; } Ok(()) } - pub fn message(&mut self, file: &FileDef, msg_idx: u32, msg_name: &ArcStr, msg: &MessageDef) -> Result<()> { - if msg.is_virtual_map { + fn generate_enum(&mut self, desc: &EnumDescriptorProto, parent_path: &[&str]) -> Result<()> { + let name = desc.name.as_deref().unwrap_or(""); + let full_name = if parent_path.is_empty() { + name.to_string() + } else { + format!("{}_{}", parent_path.join("_"), name) + }; + + let ident = format_ident!("{}", rustify_name(&full_name)); + + let open = if self.is_proto3() { + quote! { open } + } else { + quote! { closed } + }; + + let variants = desc.value.iter().map(|v| { + let var_name = v.name.as_deref().unwrap_or(""); + let var_ident = format_ident!("{}", var_name); + let num = v.number.unwrap_or(0); + quote! { + #[var(#num, #var_name)] + pub const #var_ident: #ident = #ident(#num); + } + }); + + self.output.push(quote! { + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + pub struct #ident(pub i32); + #[protoenum(#open)] + impl #ident { + #(#variants)* + } + }); + + Ok(()) + } + + fn generate_message(&mut self, desc: &DescriptorProto, parent_path: &[&str]) -> Result<()> { + let name = desc.name.as_deref().unwrap_or(""); + + // Skip map entry types - they're synthetic + if desc.options.as_ref().and_then(|o| o.map_entry) == Some(true) { return Ok(()); } - let mut extfields = vec![]; - for (def, ext) in &file.extenders { - if (def & LOCAL_ONLY_ID) == msg_idx { - for ext in ext { - let (ext, extfile) = self.context.ext_by_id(*ext).unwrap(); - extfields.extend(ext.fields.by_number.values().map(|f| (f, extfile.package.as_str()))) - } - } + let mut path = parent_path.to_vec(); + path.push(name); + + // Generate nested enums + for enum_desc in &desc.enum_type { + self.generate_enum(enum_desc, &path)?; } - let ident = format_ident!("{}", rustify_name(msg_name)); - // let borrow = self.borrow(); + // Generate nested messages + for nested in &desc.nested_type { + self.generate_message(nested, &path)?; + } + + let full_name = path.join("_"); + let ident = format_ident!("{}", rustify_name(&full_name)); let generics = self.options.generics.struct_def_generics(); - let name = msg.name.as_str(); - let pkg = file.package.as_str(); + let pkg = self.package(); + let attrs = quote! { #[proto(name = #name, package = #pkg)] }; - let extfields = extfields.into_iter().map(|(f, pkg)| self.field(file, f, Some(pkg))); - - let fields = msg.fields.by_number.values().map(|f| self.field(file, f, None)); + // Collect fields that are part of oneofs + let mut oneof_field_indices: HashSet = HashSet::new(); + for oneof in &desc.oneof_decl { + let oneof_name = oneof.name.as_deref().unwrap_or(""); + for field in &desc.field { + if let Some(idx) = field.oneof_index { + if desc.oneof_decl.get(idx as usize).and_then(|o| o.name.as_deref()) == Some(oneof_name) { + // Don't add proto3 optional synthetic oneofs to the set + if field.proto3_optional != Some(true) { + oneof_field_indices.insert(field.number.unwrap_or(0)); + } + } + } + } + } - let fields = fields.chain(extfields).collect::>>()?; + // Generate regular fields (excluding oneof members) + let fields = desc + .field + .iter() + .filter(|f| { + // Include if not part of a oneof, or if it's a proto3 optional synthetic oneof + if f.oneof_index.is_some() { + // proto3 optional creates a synthetic oneof - include it as a regular field + f.proto3_optional == Some(true) + || !oneof_field_indices.contains(&f.number.unwrap_or(0)) + } else { + true + } + }) + .map(|f| self.generate_field(f)) + .collect::>>()?; - let oneofs = msg - .oneofs + // Generate oneofs (excluding proto3 optional synthetic oneofs) + let oneofs = desc + .oneof_decl .iter() - .map(|(_, def)| self.oneof(msg_name, def)) + .enumerate() + .filter(|(idx, _oneof)| { + // Skip synthetic oneofs for proto3 optional + !desc.field.iter().any(|f| { + f.oneof_index == Some(*idx as i32) && f.proto3_optional == Some(true) + }) + }) + .map(|(idx, oneof)| self.generate_oneof(&full_name, idx, oneof, &desc.field)) .collect::>>()?; let last = if self.options.track_unknowns { @@ -330,7 +504,7 @@ impl CodeGenerator<'_> { None }; - self.types.push(quote! {#ident}); + self.types.push(quote! { #ident }); self.output.push(quote! { #[derive(Debug, Default, Clone, PartialEq, Proto)] #attrs @@ -344,38 +518,66 @@ impl CodeGenerator<'_> { Ok(()) } - pub fn oneof(&mut self, msg_name: &str, def: &OneOfDef) -> Result { - let field_ident = format_ident!("{}", rustify_name(&def.name)); + fn generate_field(&self, field: &FieldDescriptorProto) -> Result { + let name = field.name.as_deref().unwrap_or(""); + let fname = format_ident!("{}", rustify_name(name)); + let num = field.number.unwrap_or(0) as u32; + let typ = self.field_type(field)?; + let kind = self.type_marker(field); + let freq = self.field_frequency(field); + + Ok(quote! { + #[field(#num, #name, #kind, #freq)] + pub #fname: #typ + }) + } + + fn generate_oneof( + &mut self, + msg_name: &str, + oneof_idx: usize, + oneof: &OneofDescriptorProto, + all_fields: &[FieldDescriptorProto], + ) -> Result { + let oneof_name = oneof.name.as_deref().unwrap_or(""); + let field_ident = format_ident!("{}", rustify_name(oneof_name)); let oneof_type = format_ident!( "{}OneOf{}", rustify_name(msg_name), - def.name.as_str().to_case(Case::Pascal) + oneof_name.to_case(Case::Pascal) ); - // let borrow = self.borrow(); + let generics = self.options.generics.struct_def_generics(); let use_generics = self.options.generics.struct_use_generics(); let borrow_or_static = self .options .generics - .liftetime_arg() + .lifetime_arg() .unwrap_or_else(|| quote! { 'static }); let mut nums = vec![]; let mut names = vec![]; let mut vars = vec![]; - let mut default = None; - for (_, var) in def.fields.by_number.iter() { - let var_name = format_ident!("{}", var.name.as_str().to_case(Case::Pascal)); - let typ = self.base_type(&var.typ)?; - let num = var.num as u32; - let name = var.name.as_str(); - let kind = Self::type_marker(&var.typ); + for field in all_fields { + if field.oneof_index != Some(oneof_idx as i32) { + continue; + } + // Skip proto3 optional synthetic oneofs + if field.proto3_optional == Some(true) { + continue; + } + + let field_name = field.name.as_deref().unwrap_or(""); + let var_name = format_ident!("{}", field_name.to_case(Case::Pascal)); + let typ = self.base_type(field)?; + let num = field.number.unwrap_or(0) as u32; + let kind = self.type_marker(field); vars.push(quote! { - #[field(#num, #name, #kind, raw)] + #[field(#num, #field_name, #kind, raw)] #var_name(#typ), }); @@ -386,16 +588,15 @@ impl CodeGenerator<'_> { Self::#var_name(Default::default()) } } - }) + }); } nums.push(num); - names.push(name); + names.push(field_name); } self.output.push(quote! { #[derive(Debug, Clone, PartialEq, Proto)] - // #attrs pub enum #oneof_type #generics { #(#vars)* __Unused(::core::marker::PhantomData<& #borrow_or_static ()>), @@ -409,174 +610,110 @@ impl CodeGenerator<'_> { }) } - pub fn r#enum(&mut self, file: &FileDef, def: &EnumDef) -> Result<()> { - let ident = format_ident!("{}", rustify_name(def.name.as_str())); - let open = if file.syntax == Proto3 { - quote! { open } - } else { - quote! { closed } - }; - let variants = def.variants.by_name.iter().map(|(_, def)| { - let name = def.name.as_str(); - let var_ident = format_ident!("{}", def.name.as_str()); - let num = def.num; - quote! { - #[var(#num, #name)] - pub const #var_ident: #ident = #ident(#num); - } - }); - self.output.push(quote! { - #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] - pub struct #ident(pub i32); - #[protoenum(#open)] - impl #ident { - #(#variants)* - } - }); - + fn generate_service(&mut self, svc: &ServiceDescriptorProto) -> Result<()> { + self.output.push(self.generate_server(svc)); + self.output.push(self.generate_client(svc)); Ok(()) } - pub fn field(&self, file: &FileDef, def: &FieldDef, extpkg: Option<&str>) -> Result { - let typ = self.field_type(def)?; - let fname = format_ident!("{}", rustify_name(def.name.as_str())); - let name = match extpkg { - Some(pkg) => format!("[{}.{}]", pkg, def.name), - None => def.name.to_string(), - }; - - if let DataType::Enum(id) = def.typ { - if let Some((en, efile)) = self.context.enum_by_id(id) { - if file.syntax == Proto3 && efile.syntax == Proto2 { - panic!("Can't use proto2 enum ({}) in proto3 file", en.name); - } - } - } - - let num = def.num as u32; - - let force_box = match def.typ { - DataType::Message(m) | DataType::Group(m) => self.options.force_box.contains(&m), - _ => false, - }; - - let kind = Self::type_marker(&def.typ); - let freq = TokenStream::from_str(match def.frequency { - Frequency::Singular if def.typ.is_message() => "optional", - Frequency::Singular => "singular", - Frequency::Optional => "optional", - Frequency::Repeated if def.is_packed() => "packed", - #[cfg(feature = "descriptors")] - Frequency::Repeated - if file.syntax == Proto3 && def.typ.is_scalar() && def.options.packed != Some(false) => - { - "packed" - } - #[cfg(not(feature = "descriptors"))] - Frequency::Repeated if file.syntax == Proto3 && def.typ.is_scalar() => "packed", - Frequency::Repeated => "repeated", - Frequency::Required if force_box => "optional", - Frequency::Required => "required", - }) - .unwrap(); + fn generate_server(&self, svc: &ServiceDescriptorProto) -> TokenStream { + grpc::generate_server(self.pool, self.options, svc) + } - Ok(quote! { - #[field(#num, #name, #kind, #freq)] - pub #fname: #typ - }) + fn generate_client(&self, svc: &ServiceDescriptorProto) -> TokenStream { + grpc::generate_client(self.pool, self.options, svc) } } -pub fn generate_file(ctx: &FileSetDef, opts: &Options, name: PathBuf, file: &FileDef) -> Result<()> { - let mut generator = CodeGenerator { - context: ctx, - options: opts, - types: vec![], - output: vec![], - }; - - generator.file(file)?; - +pub fn generate_file( + pool: &DescriptorPool, + opts: &Options, + path: impl AsRef, + file_idx: usize, +) -> Result<()> { + let file = &pool.descriptor_set.file[file_idx]; + let mut generator = CodeGenerator::new(pool, file_idx, opts); + generator.generate()?; + + // Generate extension registrations let mut extensions = vec![]; - for ext_def in file.extensions.values() { - let extendee = &ext_def.in_message; - for field in ext_def.fields.by_number.values() { - let name = &field.name; - let number = field.num as u32; - let typ = match field.typ { - DataType::Builtin(t) => t as u32, - DataType::Enum(_) => 0, // ENUM - DataType::Message(_) => 11, // MESSAGE - DataType::Group(_) => 10, // GROUP - _ => continue, - }; - let repeated = field.frequency == Frequency::Repeated; - let extendee = extendee.as_str(); - let name = name.as_str(); + for ext in &file.extension { + let extendee = ext.extendee.as_deref().unwrap_or(""); + let name = ext.name.as_deref().unwrap_or(""); + let number = ext.number.unwrap_or(0) as u32; + let typ = ext.r#type.map(|t| t.0 as u32).unwrap_or(0); + let repeated = ext.label == Some(FieldDescriptorProtoLabel::LABEL_REPEATED); + + let type_name_str = ext.type_name.as_deref().map(|t| rust_type_name(pool, t)).unwrap_or_default(); + let type_name = type_name_str.as_str(); + + extensions.push(quote! { + registry.register_extension(#extendee, #number, #name, #typ, #repeated, #type_name); + }); + } - let type_name_str = match field.typ { - DataType::Message(id) | DataType::Group(id) | DataType::Enum(id) => resolve_name(ctx, id).unwrap(), - _ => "".to_string(), - }; - let type_name = type_name_str.as_str(); + // Generate imports + let imports = file.dependency.iter().map(|dep_name| { + // Find the file in the descriptor set + let dep_file = pool.descriptor_set.file.iter().find(|f| f.name.as_deref() == Some(dep_name)); - extensions.push(quote! { - registry.register_extension(#extendee, #number, #name, #typ, #repeated, #type_name); - }); - } - } + if let Some(dep_file) = dep_file { + let dep_package = dep_file.package.as_deref().unwrap_or(""); + let our_package = file.package.as_deref().unwrap_or(""); - let root = opts.import_root.clone(); - let imports = file.imports.iter().map(|imp| imp.file_idx).map(|file_idx| { - let (_, other): (_, &FileDef) = ctx.files.get_index(file_idx).unwrap(); - let _our_name = name.file_name().unwrap().to_str().unwrap(); + let their_name = if dep_name.contains('/') { + &dep_name[dep_name.rfind('/').unwrap() + 1..] + } else { + dep_name.as_str() + }; + let their_name = if their_name.contains('.') { + &their_name[..their_name.rfind('.').unwrap()] + } else { + their_name + }; - // if let Some(rep) = ctx.replacement.get(other.name.as_str()) { - // let rep = TokenStream::from_str(rep).unwrap(); - // return quote! { use #rep::*; }; - // } + // Calculate relative path + let mut our_module = our_package; + let mut that_module = dep_package; - let their_name = if other.name.contains('/') { - &other.name.as_str()[other.name.rfind('/').unwrap() + 1..] - } else { - other.name.as_str() - }; - let their_name = if their_name.contains('.') { - &their_name[..their_name.rfind('.').unwrap()] - } else { - their_name - }; - let mut our_module = file.package.as_str(); - let mut that_module = other.package.as_str(); + while !our_module.is_empty() && !that_module.is_empty() { + if our_module.chars().next() == that_module.chars().next() { + our_module = &our_module[1..]; + that_module = &that_module[1..]; + } else { + break; + } + } - while !our_module.is_empty() && !that_module.is_empty() && our_module[..1] == that_module[..1] { - our_module = &our_module[1..]; - that_module = &that_module[1..]; - } - let mut path = String::new(); - path.push_str("super::"); + let mut path = String::new(); + path.push_str("super::"); - if !our_module.is_empty() { - for _s in our_module.strip_prefix('.').unwrap_or(our_module).split('.') { - path.push_str("super::"); + if !our_module.is_empty() { + for _s in our_module.strip_prefix('.').unwrap_or(our_module).split('.') { + path.push_str("super::"); + } } - } - if !that_module.is_empty() { - for s in that_module.split('.') { - path.push_str(&rustify_name(s)); - path.push_str("::") + if !that_module.is_empty() { + for s in that_module.split('.') { + path.push_str(&rustify_name(s)); + path.push_str("::"); + } } - } - path.push_str(&rustify_name(their_name)); + path.push_str(&rustify_name(their_name)); - let import = TokenStream::from_str(&path).unwrap(); - quote! { - use #import::*; + let import: TokenStream = core::str::FromStr::from_str(&path).unwrap(); + quote! { + use #import::*; + } + } else { + quote! {} } }); + let root = opts.import_root.clone(); let output = generator.output; let types = generator.types; + let maproot = if let Some(ref root) = root { quote! { use #root as protokit; } } else { @@ -600,36 +737,32 @@ pub fn generate_file(ctx: &FileSetDef, opts: &Options, name: PathBuf, file: &Fil #(#imports)* #(#output)* }; - let output = syn::parse2(output.clone()).with_context(|| output.to_string()).unwrap(); + + let output = syn::parse2(output.clone()) + .with_context(|| output.to_string()) + .unwrap(); let output = prettyplease::unparse(&output); - println!("Creating file: {name:?}"); - create_dir_all(name.parent().unwrap()).unwrap(); + let path = path.as_ref(); + create_dir_all(path.parent().unwrap())?; let mut f = File::options() .write(true) .create(true) .truncate(true) - .open(name) - .unwrap(); + .open(path)?; f.write_all(output.as_bytes())?; f.flush()?; + Ok(()) } -// #[cfg(feature = "descriptors")] -// pub fn generate_descriptor(ctx: &TranslateCtx, name: impl AsRef) { -// let mut output = vec![]; -// ctx.def.to_descriptor().encode(&mut output).unwrap(); -// -// let mut f = make_file(name); -// -// f.write_all(&output).unwrap(); -// f.flush().unwrap(); -// } - -pub fn generate_mod<'s>(path: impl AsRef, opts: &Options, files: impl Iterator) -> Result<()> { +pub fn generate_mod<'s>( + path: impl AsRef, + opts: &Options, + files: impl Iterator, +) -> Result<()> { let root = opts.import_root.clone(); let files: Vec<_> = files .map(|v| { @@ -655,9 +788,14 @@ pub fn generate_mod<'s>(path: impl AsRef, opts: &Options, files: impl Iter } }; - create_dir_all(path.as_ref().parent().unwrap())?; + create_dir_all(path.as_ref())?; - let mut f = make_file(path.as_ref().join("mod.rs"))?; + let mod_path = path.as_ref().join("mod.rs"); + let mut f = File::options() + .write(true) + .create(true) + .truncate(true) + .open(&mod_path)?; let output = syn::parse2(output)?; let output = prettyplease::unparse(&output); @@ -666,11 +804,3 @@ pub fn generate_mod<'s>(path: impl AsRef, opts: &Options, files: impl Iter Ok(()) } - -pub fn make_file(path: impl AsRef) -> Result { - let path = path.as_ref(); - - let f = File::options().write(true).create(true).truncate(true).open(path)?; - - Ok(f) -} diff --git a/protokit_build/src/filegen/tabular.rs b/protokit_build/src/filegen/tabular.rs deleted file mode 100644 index 2562aaa..0000000 --- a/protokit_build/src/filegen/tabular.rs +++ /dev/null @@ -1,102 +0,0 @@ -use core::str::FromStr; -use anyhow::bail; -use quote::__private::TokenStream; -use quote::quote; -use protokit_desc::{DataType, FieldDef, Frequency}; -use protokit_proto::translate::TranslateCtx; -use crate::filegen::Options; - -pub fn tabular_parser(ctx: &TranslateCtx, opts: &Options, f: &FieldDef, i: usize) -> Result { - use crate::BuiltinType::*; - let str = TokenStream::from_str(&type_to_str(ctx, opts, &f.typ)?).unwrap(); - Ok(match &f.typ { - DataType::Builtin(String_ | Bytes_) if f.is_repeated() => quote!( Bytes::, #i> ), - DataType::Builtin(String_ | Bytes_) => quote!( Bytes::<#str, #i> ), - - DataType::Builtin(bt) if bt.is_varint() && bt.is_zigzag() && f.is_repeated_packed() => { - quote! { PackedSInt::<#str, #i> } - } - DataType::Builtin(bt) if bt.is_varint() && bt.is_zigzag() => { - quote! { SInt::<#str, #i> } - } - DataType::Builtin(bt) if bt.is_varint() && f.is_repeated_packed() => { - quote! { PackedVInt::<#str, #i> } - } - DataType::Builtin(bt) if bt.is_varint() => { - quote! { VInt::<#str, #i> } - } - DataType::Builtin(_bt) if f.is_repeated_packed() => { - quote! { PackedFixed::<#str, #i> } - } - DataType::Builtin(_bt) => { - quote! { Fixed::<#str, #i> } - } - DataType::Message(_m) if f.is_repeated() => { - quote! { Repeated::<#str, #i> } - } - DataType::Message(_m) => { - quote! { Nested::<#str, #i> } - } - DataType::Enum(_m) if f.is_repeated() => { - quote! { Enum::<#str, #i> } - } - DataType::Enum(_m) => { - quote! { Enum::<#str, #i> } - } - DataType::Map(m) => { - let kf = FieldDef { - name: Default::default(), - frequency: Frequency::Normal, - typ: DataType::Builtin(m.0.clone()), - num: 0, - #[cfg(feature = "descriptors")] - options: Default::default(), - }; - let kp = tabular_parser(ctx, opts, &kf, 1)?; - let _vf = FieldDef { - name: Default::default(), - frequency: Frequency::Normal, - typ: m.1.clone(), - num: 0, - #[cfg(feature = "descriptors")] - options: Default::default(), - }; - let vp = tabular_parser(ctx, opts, &kf, 1)?; - quote! { Mapped:: } - } - other => bail!("Unknown: {other:?}"), - }) -} - -// let tabular_fields = tabular_fields.values(); -// let tabular_format = if cfg!(feature = "tabular") { -// quote! { -// impl binformat::tabular::TableDecodable for #msg_name { -// fn info(&self) -> binformat::tabular::MessageInfo { -// const _INFO: &[binformat::tabular::FieldInfo] = &[ -// #(#tabular_fields,)* -// ]; -// binformat::tabular::MessageInfo::sorted(_INFO, #qualified_name) -// } -// } -// } -// } else { -// quote! {} -// }; - -pub fn gnerate_tabular_field() { - let taglen = u64::required_space(normal_tag as u64); - // TODO: Use tag len here, that's the intended usecase - let tabular = tabular_parser(ctx, opts, &field, taglen).unwrap(); - let encoded_tag = protokit_binformat::tabular::tag(normal_tag); - let field_num = field_idx as u32; - out.tabular_fields.insert(encoded_tag,quote! { - binformat::tabular::FieldInfo { - tag: #encoded_tag, - offset: binformat::tabular::offset_of!(#msg_name, #name) as isize, - parser: ::decode, - number: #field_num, - } - }); - -} \ No newline at end of file diff --git a/protokit_build/src/lib.rs b/protokit_build/src/lib.rs index c816eb6..4b64d55 100644 --- a/protokit_build/src/lib.rs +++ b/protokit_build/src/lib.rs @@ -1,69 +1,375 @@ -use core::str::FromStr; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use std::usize; +use std::path::{Path, PathBuf}; pub use anyhow::Result; +use anyhow::bail; use petgraph::graph::{DefaultIx, NodeIndex}; use petgraph::{Direction, Graph}; -use quote::__private::TokenStream; use quote::quote; -use crate::deps::*; - -mod deps; mod filegen; -#[cfg(all(not(feature = "protoc"), not(feature = "parser")))] -compile_error!("Either enable 'protoc' (to use system protoc) or 'parser' (to use builtin parser) feature"); - -#[cfg(all(feature = "protoc", feature = "parser"))] -compile_error!("Either disable 'protoc' or 'parser' feature"); +pub use desc::generated::google::protobuf::descriptor::*; + +/// A descriptor pool that provides lookup capabilities over a FileDescriptorSet. +/// This is a thin overlay that indexes the descriptors by their fully qualified names. +#[derive(Debug, Default)] +pub struct DescriptorPool { + /// The underlying FileDescriptorSet from protoc + pub descriptor_set: FileDescriptorSet, + /// Index from fully qualified type name to (file_index, type_path) + /// e.g. ".google.protobuf.Any" -> (file_idx, ["Any"]) + pub types: HashMap, + /// Set of types that need to be boxed due to circular references + pub boxed_types: HashSet, +} -const REMAPS: &[(&str, &str)] = &[ - // ("google/protobuf/any.proto", "root::types::any"), - // ("google/protobuf/empty.proto", "root::types::empty"), - // ("google/protobuf/timestamp.proto", "root::types::timestamp"), - // ("google/protobuf/field_mask.proto", "root::types::field_mask"), - // ("google/protobuf/descriptor.proto", "root::types::descriptor"), -]; +/// Location of a type within the descriptor set +#[derive(Debug, Clone)] +pub struct TypeLocation { + pub file_idx: usize, + pub path: Vec, + pub kind: TypeKind, +} -#[cfg(feature = "parser")] -#[derive(Default, Debug)] -pub struct ParserContext { - ctx: proto::translate::TranslateCtx, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypeKind { + Message, + Enum, + Service, } -#[cfg(feature = "parser")] -impl ParserContext { - pub fn include(&mut self, p: impl Into) { - self.ctx.include(p) +impl DescriptorPool { + pub fn new(descriptor_set: FileDescriptorSet) -> Self { + let mut pool = Self { + descriptor_set, + types: HashMap::new(), + boxed_types: HashSet::new(), + }; + pool.build_index(); + pool.detect_cycles(); + pool } - pub fn compile(&mut self, p: impl Into) -> Result<()> { - self.ctx.compile_file(p.into())?; - Ok(()) + + fn build_index(&mut self) { + // Collect indexing work first to avoid borrow issues + struct IndexWork { + file_idx: usize, + path: Vec, + fqn: String, + kind: TypeKind, + } + + fn collect_message_indices( + work: &mut Vec, + file_idx: usize, + package: &str, + parent_path: &[String], + msg: &DescriptorProto, + ) { + let name = msg.name.as_deref().unwrap_or(""); + let mut path = parent_path.to_vec(); + path.push(name.to_string()); + + let fqn = if parent_path.is_empty() { + format!(".{}.{}", package, name) + } else { + format!(".{}.{}.{}", package, parent_path.join("."), name) + }; + + work.push(IndexWork { + file_idx, + path: path.clone(), + fqn: fqn.clone(), + kind: TypeKind::Message, + }); + + // Index nested enums + for enum_desc in &msg.enum_type { + let enum_name = enum_desc.name.as_deref().unwrap_or(""); + let enum_fqn = format!("{}.{}", fqn, enum_name); + let mut enum_path = path.clone(); + enum_path.push(enum_name.to_string()); + work.push(IndexWork { + file_idx, + path: enum_path, + fqn: enum_fqn, + kind: TypeKind::Enum, + }); + } + + // Index nested messages + for nested_msg in &msg.nested_type { + collect_message_indices(work, file_idx, package, &path, nested_msg); + } + } + + let mut work = Vec::new(); + + for (file_idx, file) in self.descriptor_set.file.iter().enumerate() { + let package = file.package.as_deref().unwrap_or(""); + + // Index top-level enums + for enum_desc in &file.enum_type { + let name = enum_desc.name.as_deref().unwrap_or(""); + let fqn = format!(".{}.{}", package, name); + work.push(IndexWork { + file_idx, + path: vec![name.to_string()], + fqn, + kind: TypeKind::Enum, + }); + } + + // Index top-level messages (and nested types) + for msg_desc in &file.message_type { + collect_message_indices(&mut work, file_idx, package, &[], msg_desc); + } + + // Index services + for svc_desc in &file.service { + let name = svc_desc.name.as_deref().unwrap_or(""); + let fqn = format!(".{}.{}", package, name); + work.push(IndexWork { + file_idx, + path: vec![name.to_string()], + fqn, + kind: TypeKind::Service, + }); + } + } + + // Now apply all the indices + for item in work { + self.types.insert( + item.fqn, + TypeLocation { + file_idx: item.file_idx, + path: item.path, + kind: item.kind, + }, + ); + } + } + + fn detect_cycles(&mut self) { + // Build a dependency graph + let mut graph: Graph = Graph::new(); + let mut node_indices: HashMap = HashMap::new(); + let mut field_counts: HashMap = HashMap::new(); + + // Add nodes for all messages + for (fqn, loc) in &self.types { + if loc.kind == TypeKind::Message { + let idx = graph.add_node(fqn.clone()); + node_indices.insert(fqn.clone(), idx); + } + } + + // Add edges based on field references + for file in &self.descriptor_set.file { + let package = file.package.as_deref().unwrap_or(""); + for msg in &file.message_type { + self.add_message_edges(&mut graph, &node_indices, &mut field_counts, package, &[], msg); + } + } + + // Find SCCs using Tarjan's algorithm + let mut cycles: Vec>> = petgraph::algo::tarjan_scc(&graph) + .into_iter() + .filter(|scc| scc.len() > 1 || { + // Check for self-loops + scc.len() == 1 && graph.find_edge(scc[0], scc[0]).is_some() + }) + .map(|v| HashSet::from_iter(v.into_iter())) + .collect(); + + // Find nodes to box to break cycles (prefer nodes with more incoming edges) + let mut to_box = HashSet::new(); + loop { + let mut counts: HashMap = HashMap::new(); + for cycle in &cycles { + for node in cycle.iter() { + *counts.entry(*node).or_default() += 1; + } + } + + if let Some(max) = counts + .iter() + .filter_map(|a| field_counts.get(a.0).map(|fcount| (a.0, a.1, fcount))) + .max_by(|a, b| { + let ac = graph.edges_directed(*a.0, Direction::Incoming).count(); + let bc = graph.edges_directed(*b.0, Direction::Incoming).count(); + (ac * 100 + a.1 * 10 + a.2).cmp(&(bc * 100 + b.1 * 10 + b.2)) + }) + { + to_box.insert(*max.0); + cycles.retain_mut(|cycle| !cycle.contains(&max.0)); + } else { + break; + } + } + + // Convert node indices back to type names + for idx in to_box { + if let Some(name) = graph.node_weight(idx) { + self.boxed_types.insert(name.clone()); + } + } + } + + fn add_message_edges( + &self, + graph: &mut Graph, + node_indices: &HashMap, + field_counts: &mut HashMap, + package: &str, + parent_path: &[String], + msg: &DescriptorProto, + ) { + let name = msg.name.as_deref().unwrap_or(""); + let src_fqn = if parent_path.is_empty() { + format!(".{}.{}", package, name) + } else { + format!(".{}.{}.{}", package, parent_path.join("."), name) + }; + + let Some(&src_idx) = node_indices.get(&src_fqn) else { + return; + }; + + // Add edges for fields + for field in &msg.field { + if let Some(type_name) = &field.type_name { + if let Some(&dst_idx) = node_indices.get(type_name) { + graph.add_edge(src_idx, dst_idx, ()); + *field_counts.entry(src_idx).or_default() += 1; + } + } + } + + // Process nested messages + let mut path = parent_path.to_vec(); + path.push(name.to_string()); + for nested_msg in &msg.nested_type { + self.add_message_edges(graph, node_indices, field_counts, package, &path, nested_msg); + } + } + + pub fn file(&self, idx: usize) -> &FileDescriptorProto { + &self.descriptor_set.file[idx] + } + + pub fn is_boxed(&self, type_name: &str) -> bool { + self.boxed_types.contains(type_name) + } + + pub fn lookup(&self, type_name: &str) -> Option<&TypeLocation> { + self.types.get(type_name) + } + + /// Get a message descriptor by its fully qualified name + pub fn get_message(&self, fqn: &str) -> Option<&DescriptorProto> { + let loc = self.types.get(fqn)?; + if loc.kind != TypeKind::Message { + return None; + } + let file = &self.descriptor_set.file[loc.file_idx]; + self.find_message_in_file(file, &loc.path) } - pub fn finish(self) -> Result { - Ok(self.ctx.def) + + fn find_message_in_file<'a>( + &self, + file: &'a FileDescriptorProto, + path: &[String], + ) -> Option<&'a DescriptorProto> { + if path.is_empty() { + return None; + } + + let mut current: Option<&DescriptorProto> = None; + let first = &path[0]; + + for msg in &file.message_type { + if msg.name.as_deref() == Some(first) { + current = Some(msg); + break; + } + } + + for name in &path[1..] { + let msg = current?; + current = None; + for nested in &msg.nested_type { + if nested.name.as_deref() == Some(name) { + current = Some(nested); + break; + } + } + } + + current + } + + /// Get an enum descriptor by its fully qualified name + pub fn get_enum(&self, fqn: &str) -> Option<&EnumDescriptorProto> { + let loc = self.types.get(fqn)?; + if loc.kind != TypeKind::Enum { + return None; + } + let file = &self.descriptor_set.file[loc.file_idx]; + self.find_enum_in_file(file, &loc.path) + } + + fn find_enum_in_file<'a>( + &self, + file: &'a FileDescriptorProto, + path: &[String], + ) -> Option<&'a EnumDescriptorProto> { + if path.is_empty() { + return None; + } + + if path.len() == 1 { + // Top-level enum + for e in &file.enum_type { + if e.name.as_deref() == Some(&path[0]) { + return Some(e); + } + } + return None; + } + + // Nested enum - find parent message first + let msg = self.find_message_in_file(file, &path[..path.len() - 1])?; + let enum_name = &path[path.len() - 1]; + for e in &msg.enum_type { + if e.name.as_deref() == Some(enum_name) { + return Some(e); + } + } + None } } -#[cfg(feature = "protoc")] +/// Build context for generating Rust code from protobuf files #[derive(Default, Debug)] pub struct ProtocContext { pub includes: Vec, pub proto_paths: Vec, } -#[cfg(feature = "protoc")] impl ProtocContext { pub fn include(&mut self, p: impl Into) { self.includes.push(p.into()); } + pub fn compile(&mut self, p: impl Into) -> Result<()> { self.proto_paths.push(p.into()); Ok(()) } - pub fn finish(self) -> Result { + + pub fn finish(self) -> Result { let mut cmd = std::process::Command::new("protoc"); cmd.arg("--experimental_allow_proto3_optional"); @@ -76,115 +382,45 @@ impl ProtocContext { cmd.arg(format!("{}", p.display())); } - cmd.arg(format!("-o{}/descriptor.bin", std::env::var("OUT_DIR").unwrap())); + let out_dir = std::env::var("OUT_DIR").unwrap(); + cmd.arg(format!("-o{}/descriptor.bin", out_dir)); + let out = cmd.output().expect("PROTOC invocation failed"); if !out.status.success() { bail!("Protoc error: {}", String::from_utf8_lossy(&out.stderr)) } - let data = std::fs::read(Path::new(&std::env::var("OUT_DIR").unwrap()).join("descriptor.bin")).unwrap(); - let desc = binformat::decode::(data.as_slice())?; + let data = std::fs::read(Path::new(&out_dir).join("descriptor.bin"))?; + let desc = binformat::decode::(data.as_slice())?; - Ok(FileSetDef::from_descriptor(desc)) + Ok(DescriptorPool::new(desc)) } } #[must_use] #[derive(Default, Debug)] pub struct Build { - #[cfg(feature = "parser")] - pub ctx: ParserContext, - #[cfg(feature = "protoc")] pub ctx: ProtocContext, pub options: filegen::Options, pub out_dir: Option, } -fn generate(opts: &mut filegen::Options, set: &desc::FileSetDef, out_dir: PathBuf) -> Result<()> { - create_dir_all(&out_dir).unwrap(); - - let mut graph = Graph::new(); - let mut fields = HashMap::::new(); - for (fidx, file) in set.files.values().enumerate() { - for (midx, msg) in file.messages.values().enumerate() { - eprintln!("msg: {:?}", msg.name); - let src = local_to_global(fidx, LOCAL_DEFID_MSG | (midx as u32)); - // let src = resolve_name(set, src).unwrap(); - let src = graph.add_node(src); - for field in msg.fields.by_number.values() { - match &field.typ { - DataType::Unresolved(_, _) => { - panic!("Should be resolved"); - } - DataType::Message(msg) => { - // let msg = resolve_name(set, *msg).unwrap(); - let dst = graph.add_node(*msg); - graph.add_edge(src, dst, ()); - *fields.entry(src).or_default() += 1; - } - DataType::Group(msg) => { - // let msg = resolve_name(set, *msg).unwrap(); - let dst = graph.add_node(*msg); - graph.add_edge(src, dst, ()); - *fields.entry(src).or_default() += 1; - } - _ => {} - } - } - } - } - - let mut cycles: Vec>> = petgraph::algo::tarjan_scc(&graph) - .into_iter() - .map(|v| HashSet::from_iter(v.into_iter())) - .collect(); +fn generate(opts: &filegen::Options, pool: &DescriptorPool, out_dir: PathBuf) -> Result<()> { + std::fs::create_dir_all(&out_dir)?; - let mut to_remove = HashSet::new(); - - loop { - let mut counts: HashMap = HashMap::new(); - for cycle in &cycles { - for node in cycle.iter() { - *counts.entry(*node).or_default() += 1; - } - } - if let Some(max) = counts - .iter() - .filter_map(|a| fields.get(a.0).map(|fcount| (a.0, a.1, fcount))) - .max_by(|a, b| { - let ac = graph.edges_directed(*a.0, Direction::Incoming).count(); - let bc = graph.edges_directed(*b.0, Direction::Incoming).count(); - (ac * 100 + a.1 * 10 + a.2).cmp(&(bc * 100 + b.1 * 10 + b.2)) - }) - { - to_remove.insert(*max.0); - cycles.retain_mut(|cycle| !cycle.contains(&max.0)) - } else { - break; - } - } - - let nodes = to_remove - .into_iter() - .map(|item| graph.node_weight(item).cloned().unwrap()) - .collect::>(); - - // panic!("TO REMOVE: {:?}", nodes); - - opts.force_box = nodes; - - // TODO: Use package name + file name let mut generated_names = vec![]; - for (_, file) in set.files.values().enumerate() { - // if self.ctx.replacement.contains_key(file.name.as_str()) { - // continue; - // } - let path = Path::new(file.name.as_str()); - let file_name = - file.package.replace('.', "/") + "/" + path.with_extension("rs").file_name().unwrap().to_str().unwrap(); - let out_name = out_dir.join(&file_name); - generated_names.push(file_name.clone()); - filegen::generate_file(set, opts, out_name, file).unwrap(); + for (file_idx, file) in pool.descriptor_set.file.iter().enumerate() { + let file_name = file.name.as_deref().unwrap_or("unknown.proto"); + let package = file.package.as_deref().unwrap_or(""); + + let path = Path::new(file_name); + let out_name = package.replace('.', "/") + + "/" + + path.with_extension("rs").file_name().unwrap().to_str().unwrap(); + let out_path = out_dir.join(&out_name); + + generated_names.push(out_name.clone()); + filegen::generate_file(pool, opts, out_path, file_idx)?; } let dirs: Vec> = generated_names.iter().map(|v| v.split('/').collect()).collect(); @@ -202,30 +438,17 @@ fn generate(opts: &mut filegen::Options, set: &desc::FileSetDef, out_dir: PathBu } for (k, v) in &subdirs { - eprintln!("Creating module in: {:?}", out_dir.join(k)); filegen::generate_mod(out_dir.join(k), opts, v.iter().copied())?; } - // #[cfg(feature = "descriptors")] - // filegen::generate_descriptor(&self.ctx, out_dir.join("descriptor.bin")); Ok(()) } impl Build { pub fn new() -> Self { - let mut this = Self::without_replacements(); - for (from, to) in REMAPS { - this.options.replace_import(from, to); - } - this + Self::default() } - pub fn without_replacements() -> Self { - Self { - ctx: Default::default(), - ..Default::default() - } - } pub fn include(mut self, p: impl Into) -> Self { self.ctx.include(p); self @@ -249,18 +472,22 @@ impl Build { self.options.track_unknowns = t; self } + pub fn root(mut self, s: &str) -> Self { - self.options.import_root = Some(TokenStream::from_str(s).unwrap()); + self.options.import_root = Some(core::str::FromStr::from_str(s).unwrap()); self } + pub fn string_type(mut self, typ: &str) -> Self { - self.options.string_type = TokenStream::from_str(typ).unwrap(); + self.options.string_type = core::str::FromStr::from_str(typ).unwrap(); self } + pub fn bytes_type(mut self, typ: &str) -> Self { - self.options.bytes_type = TokenStream::from_str(typ).unwrap(); + self.options.bytes_type = core::str::FromStr::from_str(typ).unwrap(); self } + pub fn out_dir(mut self, p: impl Into) -> Self { self.out_dir = Some(p.into()); self @@ -272,10 +499,11 @@ impl Build { Ok(self) } - pub fn generate(mut self) -> anyhow::Result<()> { + pub fn generate(self) -> anyhow::Result<()> { let out_dir = self .out_dir .unwrap_or_else(|| PathBuf::from(std::env::var("OUT_DIR").unwrap())); - generate(&mut self.options, &self.ctx.finish()?, out_dir) + let pool = self.ctx.finish()?; + generate(&self.options, &pool, out_dir) } } diff --git a/tools/conformance/Cargo.toml b/tools/conformance/Cargo.toml index e2f8a97..6b0243d 100644 --- a/tools/conformance/Cargo.toml +++ b/tools/conformance/Cargo.toml @@ -19,4 +19,4 @@ byteorder = "1.4.3" protokit = { path = "../../protokit", features = ["textformat"] } [build-dependencies] -build = { workspace = true, features = ["protoc"] } +build = { workspace = true } diff --git a/tools/gendesc/Cargo.toml b/tools/gendesc/Cargo.toml index 8d7990a..d7a9f3c 100644 --- a/tools/gendesc/Cargo.toml +++ b/tools/gendesc/Cargo.toml @@ -11,5 +11,3 @@ publish = false [dependencies.protokit_build] path = "../../protokit_build" -default-features = false -features = ["protoc"] From 689b27358f21bd68ba0e045924ea426e7e7a4c3a Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 7 Dec 2025 09:07:24 +0000 Subject: [PATCH 2/2] fix: Fix textformat parsing for repeated fields and remove repro tests Fixed issues in merge_repeated: - Allow implicit whitespace separators in lists (e.g., [1 2 3]) - Allow semicolons as separators in lists - Check if comma is followed by field name before continuing to parse values This fixes comma-separated field syntax like "vals: 1, vals: 2" Removed repro test files that were only used for debugging: - repro_ext.rs - repro_list.rs - repro_p2.rs - repro_sep.rs --- protokit_textformat/src/lib.rs | 36 ++++--- protokit_textformat/src/repro_ext.rs | 46 -------- protokit_textformat/src/repro_list.rs | 140 ------------------------ protokit_textformat/src/repro_p2.rs | 149 -------------------------- protokit_textformat/src/repro_sep.rs | 103 ------------------ 5 files changed, 19 insertions(+), 455 deletions(-) delete mode 100644 protokit_textformat/src/repro_ext.rs delete mode 100644 protokit_textformat/src/repro_list.rs delete mode 100644 protokit_textformat/src/repro_p2.rs delete mode 100644 protokit_textformat/src/repro_sep.rs diff --git a/protokit_textformat/src/lib.rs b/protokit_textformat/src/lib.rs index d04ba9f..f15f8f0 100644 --- a/protokit_textformat/src/lib.rs +++ b/protokit_textformat/src/lib.rs @@ -10,14 +10,6 @@ use thiserror::Error; mod escape; mod lex; pub mod reflect; -#[cfg(test)] -mod repro_ext; -#[cfg(test)] -mod repro_list; -#[cfg(test)] -mod repro_p2; -#[cfg(test)] -mod repro_sep; pub mod stream; use escape::unescape; @@ -623,19 +615,31 @@ pub fn merge_repeated<'buf, T: TextField<'buf> + Default>( out.last_mut().unwrap().merge_value(stream)?; match stream.cur { // End of the list - RBracket | EndOfFile if is_list => { - // In this case we must advance one past the rbracket + RBracket if is_list => { stream.advance(); return Ok(()); } - // Comma/Semi as elem separator + EndOfFile => { + if is_list { + stream.advance(); + } + return Ok(()); + } + // Comma as elem separator Comma => { + // Check if after comma is a new field (not a value) + if stream.lookahead_is_field() { + // Don't consume comma, let caller handle field separator + return Ok(()); + } stream.advance(); continue; } Semi => { if is_list { - return crate::unexpected(Comma, stream.cur, stream.buf()); + // Semicolon works as separator inside list too + stream.advance(); + continue; } else { // For non-list (top-level repeated), check if next is field or value if stream.lookahead_is_field() { @@ -646,15 +650,13 @@ pub fn merge_repeated<'buf, T: TextField<'buf> + Default>( } } } - // This was the last entry in this field, return + // For non-list, any other token means end of this repeated field _ if !is_list => { return Ok(()); } - // Implicit separator + // Implicit separator in list (whitespace between values) _ => { - if is_list { - return crate::unexpected(Comma, stream.cur, stream.buf()); - } + // In a list, allow implicit separators (whitespace) between values continue; } } diff --git a/protokit_textformat/src/repro_ext.rs b/protokit_textformat/src/repro_ext.rs deleted file mode 100644 index fd64360..0000000 --- a/protokit_textformat/src/repro_ext.rs +++ /dev/null @@ -1,46 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::reflect::Registry; - use crate::stream::InputStream; - use crate::Token; - - #[test] - fn test_extension_key_parsing() { - let reg = Registry::default(); - let txt = "[some.ext.field]: 123"; - let mut stream = InputStream::new(txt, ®); - - // InputStream::new doesn't advance automatically to first token? - // Let's check constructor. It calls Lexer::new, sets cur=StartOfFile. - // next_token() or advance() needed? - // merge_field loop usually calls advance. - - // Manually drive pars_key - stream.advance(); // Cur should be LBracket - assert_eq!(stream.token(), Token::LBracket); - - // But parse_key expects to start AT the key tokens? - // parse_key implementation: - // match self.cur { LBracket => ... } - // So yes, we need to be at LBracket. - - let key = stream.parse_key().expect("Should parse extension key"); - assert_eq!(key, "some.ext.field"); - - // Should be at Colon now - assert_eq!(stream.token(), Token::Colon); - } - - #[test] - fn test_simple_key_parsing() { - let reg = Registry::default(); - let txt = "simple_field: 123"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); - assert_eq!(stream.token(), Token::Ident); - - let key = stream.parse_key().expect("Should parse simple key"); - assert_eq!(key, "simple_field"); - assert_eq!(stream.token(), Token::Colon); - } -} diff --git a/protokit_textformat/src/repro_list.rs b/protokit_textformat/src/repro_list.rs deleted file mode 100644 index 8226b8d..0000000 --- a/protokit_textformat/src/repro_list.rs +++ /dev/null @@ -1,140 +0,0 @@ -use crate::decode; - -#[derive(Debug, Default, PartialEq)] -struct TestStringList { - s: Vec, -} - -impl binformat::ProtoName for TestStringList { - fn qualified_name(&self) -> &'static str { - "TestStringList" - } -} - -impl<'buf> binformat::BinProto<'buf> for TestStringList { - fn merge_field(&mut self, _tag: u32, _stream: &mut binformat::InputStream<'buf>) -> binformat::Result<()> { - Ok(()) - } - fn size(&self, _stack: &mut binformat::SizeStack) -> usize { - 0 - } - fn encode(&self, _stream: &mut binformat::OutputStream) {} -} - -impl<'buf> crate::TextProto<'buf> for TestStringList { - fn decode(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - while stream.token() != crate::Token::EndOfFile { - self.merge_field(stream)?; - } - Ok(()) - } - - fn merge_field(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - if stream.token() == crate::Token::Ident && stream.buf() == "s" { - stream.advance(); - if stream.try_consume(crate::Token::Colon) { - // ok - } - crate::merge_repeated(&mut self.s, stream)?; - } else { - // unknown field or end - if stream.token() != crate::Token::EndOfFile { - stream.advance(); - } - } - Ok(()) - } - fn encode(&self, _stream: &mut crate::OutputStream) {} -} - -#[test] -fn test_string_list_space_separated() { - let input = "s: [\"a\" \"b\"]"; - let reg = crate::reflect::Registry::default(); - let msg: TestStringList = decode(input, ®).unwrap(); - assert_eq!(msg.s, vec!["ab".to_string()]); -} - -#[test] -fn test_string_concatenation() { - let input = "s: [\"a\" \"b\", \"c\"]"; - let reg = crate::reflect::Registry::default(); - // Expect: ["ab", "c"] - let msg: TestStringList = decode(input, ®).unwrap(); - assert_eq!(msg.s, vec!["ab".to_string(), "c".to_string()]); -} - -#[derive(Debug, Default, PartialEq)] -struct TestIntList { - i: Vec, -} - -impl binformat::ProtoName for TestIntList { - fn qualified_name(&self) -> &'static str { - "TestIntList" - } -} - -impl<'buf> binformat::BinProto<'buf> for TestIntList { - fn merge_field(&mut self, _tag: u32, _stream: &mut binformat::InputStream<'buf>) -> binformat::Result<()> { - Ok(()) - } - fn size(&self, _stack: &mut binformat::SizeStack) -> usize { - 0 - } - fn encode(&self, _stream: &mut binformat::OutputStream) {} -} - -impl<'buf> crate::TextProto<'buf> for TestIntList { - fn decode(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - while stream.token() != crate::Token::EndOfFile { - self.merge_field(stream)?; - } - Ok(()) - } - fn merge_field(&mut self, stream: &mut crate::InputStream<'buf>) -> crate::Result<()> { - if stream.token() == crate::Token::Ident && stream.buf() == "i" { - stream.advance(); - if stream.try_consume(crate::Token::Colon) {} - crate::merge_repeated(&mut self.i, stream)?; - } else { - if stream.token() != crate::Token::EndOfFile { - stream.advance(); - } - } - Ok(()) - } - fn encode(&self, _stream: &mut crate::OutputStream) {} -} - -#[test] -fn test_int_list_space_separated() { - let input = "i: [1 2]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2]); -} - -#[test] -fn test_int_list_comma() { - let input = "i: [1, 2]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2]); -} - -#[test] -fn test_int_list_comment() { - let input = "i: [1 /* comment */ 2]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2]); -} - -#[test] -fn test_int_list_mixed() { - let input = "i: [1, 2; 3 4]"; - let reg = crate::reflect::Registry::default(); - let msg: TestIntList = decode(input, ®).unwrap(); - assert_eq!(msg.i, vec![1, 2, 3, 4]); -} diff --git a/protokit_textformat/src/repro_p2.rs b/protokit_textformat/src/repro_p2.rs deleted file mode 100644 index 0531f5c..0000000 --- a/protokit_textformat/src/repro_p2.rs +++ /dev/null @@ -1,149 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::reflect::Registry; - use crate::*; - use std::fmt::Display; - use std::str::FromStr; - - // Simulate a derived Enum - #[derive(Debug, PartialEq, Clone, Copy)] - struct MyEnum(i32); - - impl Default for MyEnum { - fn default() -> Self { - Self(0) - } - } - - impl Display for MyEnum { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } - - impl FromStr for MyEnum { - type Err = Error; - fn from_str(_s: &str) -> Result { - Ok(Self(0)) // Dummy - } - } - - impl From for MyEnum { - fn from(v: u32) -> Self { - Self(v as i32) - } - } - impl From for u32 { - fn from(v: MyEnum) -> Self { - v.0 as u32 - } - } - - impl Enum for MyEnum {} - - impl<'buf> TextField<'buf> for MyEnum { - fn merge_value(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - // Logic copied from protokit_derive - match stream.field() { - "FOO" => { - *self = Self(0); - stream.advance(); - } - "BAR" => { - *self = Self(1); - stream.advance(); - } - name => { - if let Ok(v) = name.parse::() { - *self = Self(v); - stream.advance(); - } else { - return crate::unknown(name); - } - } - } - Ok(()) - } - - fn emit_value(&self, stream: &mut OutputStream) { - match self.0 { - 0 => stream.ident("FOO"), - 1 => stream.ident("BAR"), - v => stream.disp(&v), - } - } - } - - #[test] - fn test_enum_number_parsing() { - let reg = Registry::default(); - let txt = "1"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); // Prime first token - let mut e = MyEnum::default(); - e.merge_value(&mut stream).expect("Should parse '1'"); - assert_eq!(e.0, 1); - } - - #[test] - fn test_float_zero_parsing() { - let reg = Registry::default(); - let txt = "0"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); - let v = stream.f64().expect("Should parse '0' as float"); - assert_eq!(v, 0.0); - } - - #[test] - fn test_float_neg_zero_parsing() { - let reg = Registry::default(); - let txt = "-0"; - let mut stream = InputStream::new(txt, ®); - stream.advance(); - let v = stream.f64().expect("Should parse '-0' as float"); - assert_eq!(v, 0.0); - assert!(v.is_sign_negative()); - } - - // Simulate Message with enum field to test full merge_field flow - #[derive(Default)] - struct MyMsg { - e: MyEnum, - } - - impl<'buf> TextProto<'buf> for MyMsg { - fn merge_field(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - match stream.field() { - "e" => { - // merge_single - self.e.merge_text(stream) - } - _ => crate::skip_unknown(stream), - } - } - fn encode(&self, _stream: &mut OutputStream) {} - fn decode(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - stream.message_fields(self) - } - } - impl binformat::ProtoName for MyMsg { - fn qualified_name(&self) -> &'static str { - "MyMsg" - } - } - - #[test] - fn test_msg_enum_parsing() { - let reg = Registry::default(); - let txt = "e: 1"; - let mut stream = InputStream::new(txt, ®); - // message_fields advances internally? - // No, message_fields expects stream to be at start of message? - // InputStream::new sets root=true. - // stream.message_fields -> if root=true, advances. - let mut msg = MyMsg::default(); - stream.message_fields(&mut msg).expect("Should parse msg"); - assert_eq!(msg.e.0, 1); - } -} diff --git a/protokit_textformat/src/repro_sep.rs b/protokit_textformat/src/repro_sep.rs deleted file mode 100644 index 1c40a77..0000000 --- a/protokit_textformat/src/repro_sep.rs +++ /dev/null @@ -1,103 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::reflect::Registry; - use crate::stream::InputStream; - use crate::*; - - #[derive(Default)] - struct MockRepeated { - vals: Vec, - bools: Vec, - } - - impl<'buf> TextProto<'buf> for MockRepeated { - fn merge_field(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - match stream.parse_key()?.as_ref() { - "vals" => merge_repeated(&mut self.vals, stream), - "bools" => merge_repeated(&mut self.bools, stream), - "1" => { - // Simulate integer key 1 mapping to vals - merge_repeated(&mut self.vals, stream) - } - _ => skip_unknown(stream), - } - } - fn encode(&self, _stream: &mut OutputStream) {} - fn decode(&mut self, stream: &mut InputStream<'buf>) -> Result<()> { - stream.message_fields(self) - } - } - - impl binformat::ProtoName for MockRepeated { - fn qualified_name(&self) -> &'static str { - "MockRepeated" - } - } - - #[test] - fn test_comma_separated_fields() { - let reg = Registry::default(); - // "field: val, field: val" - let txt = "vals: 1, vals: 2"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream - .message_fields(&mut msg) - .expect("Should parse comma separated fields"); - assert_eq!(msg.vals, vec![1, 2]); - } - - #[test] - fn test_comma_separated_values() { - let reg = Registry::default(); - // "field: 1, 2" - let txt = "vals: 3, 4"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream - .message_fields(&mut msg) - .expect("Should parse comma separated values"); - assert_eq!(msg.vals, vec![3, 4]); - } - - #[test] - fn test_semicolon_separated_fields() { - let reg = Registry::default(); - // "field: val; field: val" - let txt = "vals: 5; vals: 6"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream - .message_fields(&mut msg) - .expect("Should parse semicolon separated fields"); - assert_eq!(msg.vals, vec![5, 6]); - } - - #[test] - fn test_integer_key() { - let reg = Registry::default(); - let txt = "1: 10"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream.message_fields(&mut msg).expect("Should parse integer key"); - assert_eq!(msg.vals, vec![10]); - } - - #[test] - fn test_bool_separators() { - let reg = Registry::default(); - let txt = "bools: true, bools: false"; - let mut stream = InputStream::new(txt, ®); - let mut msg = MockRepeated::default(); - stream.message_fields(&mut msg).expect("Should parse bool fields"); - assert_eq!(msg.bools, vec![true, false]); - - let txt2 = "bools: true, false"; - let mut stream2 = InputStream::new(txt2, ®); - let mut msg2 = MockRepeated::default(); - stream2 - .message_fields(&mut msg2) - .expect("Should parse comma bool values"); - assert_eq!(msg2.bools, vec![true, false]); - } -}