diff --git a/src/analyze.rs b/src/analyze.rs index a1ebd34..fff56b4 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -104,14 +104,14 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> { } #[derive(Debug, Clone)] -struct DeferredDefTy { - cache: Rc>>, +struct DeferredDefTy<'tcx> { + cache: Rc, rty::RefinedType>>>, } #[derive(Debug, Clone)] -enum DefTy { +enum DefTy<'tcx> { Concrete(rty::RefinedType), - Deferred(DeferredDefTy), + Deferred(DeferredDefTy<'tcx>), } #[derive(Clone)] @@ -123,7 +123,7 @@ pub struct Analyzer<'tcx> { /// currently contains only local-def templates, /// but will be extended to contain externally known def's refinement types /// (at least for every defs referenced by local def bodies) - defs: HashMap, + defs: HashMap>, /// Resulting CHC system. system: Rc>, @@ -241,15 +241,17 @@ impl<'tcx> Analyzer<'tcx> { pub fn def_ty_with_args( &mut self, def_id: DefId, - rty_args: rty::TypeArgs, + generic_args: mir_ty::GenericArgsRef<'tcx>, ) -> Option { let deferred_ty = match self.defs.get(&def_id)? { DefTy::Concrete(rty) => { + let type_builder = TypeBuilder::new(self.tcx, def_id); + let mut def_ty = rty.clone(); def_ty.instantiate_ty_params( - rty_args - .clone() - .into_iter() + generic_args + .types() + .map(|ty| type_builder.build(ty)) .map(rty::RefinedType::unrefined) .collect(), ); @@ -259,21 +261,17 @@ impl<'tcx> Analyzer<'tcx> { }; let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self - if let Some(rty) = deferred_ty_cache.borrow().get(&rty_args) { + if let Some(rty) = deferred_ty_cache.borrow().get(&generic_args) { return Some(rty.clone()); } - let type_builder = TypeBuilder::new(self.tcx, def_id).with_param_mapper({ - let rty_args = rty_args.clone(); - move |ty: rty::ParamType| rty_args[ty.idx].clone() - }); let mut analyzer = self.local_def_analyzer(def_id.as_local()?); - analyzer.type_builder(type_builder); + analyzer.generic_args(generic_args); let expected = analyzer.expected_ty(); deferred_ty_cache .borrow_mut() - .insert(rty_args, expected.clone()); + .insert(generic_args, expected.clone()); analyzer.run(&expected); Some(expected) @@ -340,4 +338,30 @@ impl<'tcx> Analyzer<'tcx> { self.tcx.dcx().err(format!("verification error: {:?}", err)); } } + + /// Computes the signature of the local function. + /// + /// This is a drop-in replacement of `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`, + /// but extracts parameter and return types directly from the given `body` to obtain a signature that + /// reflects potential type instantiations happened after `optimized_mir`. + pub fn local_fn_sig_with_body( + &self, + local_def_id: LocalDefId, + body: &mir::Body<'tcx>, + ) -> mir_ty::FnSig<'tcx> { + let ty = self.tcx.type_of(local_def_id).instantiate_identity(); + let sig = if let mir_ty::TyKind::Closure(_, substs) = ty.kind() { + substs.as_closure().sig().skip_binder() + } else { + ty.fn_sig(self.tcx).skip_binder() + }; + + self.tcx.mk_fn_sig( + body.args_iter().map(|arg| body.local_decls[arg].ty), + body.return_ty(), + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + } } diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 9f6ee49..81f7705 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -263,15 +263,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { _ty, ) => { let func_ty = match operand.const_fn_def() { - Some((def_id, args)) => { - let rty_args: IndexVec<_, _> = - args.types().map(|ty| self.type_builder.build(ty)).collect(); - self.ctx - .def_ty_with_args(def_id, rty_args) - .expect("unknown def") - .ty - .clone() - } + Some((def_id, args)) => self + .ctx + .def_ty_with_args(def_id, args) + .expect("unknown def") + .ty + .clone(), _ => unimplemented!(), }; PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null()) @@ -471,14 +468,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into()); rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into() } - Some((def_id, args)) => { - let rty_args = args.types().map(|ty| self.type_builder.build(ty)).collect(); - self.ctx - .def_ty_with_args(def_id, rty_args) - .expect("unknown def") - .ty - .vacuous() - } + Some((def_id, args)) => self + .ctx + .def_ty_with_args(def_id, args) + .expect("unknown def") + .ty + .vacuous(), _ => self.operand_type(func.clone()).ty, }; let expected_args: IndexVec<_, _> = args @@ -988,11 +983,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } - pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { - self.type_builder = type_builder; - self - } - pub fn run(&mut self, expected: &BasicBlockType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index e15ca49..d6e343d 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -76,8 +76,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // check polymorphic function def by replacing type params with some opaque type // (and this is no-op if the function is mono) - let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id()) - .with_param_mapper(|_| rty::Type::int()); let mut expected = expected.clone(); let subst = rty::TypeParamSubst::new( expected @@ -87,13 +85,62 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .collect(), ); expected.subst_ty_params(&subst); + let generic_args = self.placeholder_generic_args(*local_def_id); self.ctx .local_def_analyzer(*local_def_id) - .type_builder(type_builder) + .generic_args(generic_args) .run(&expected); } } + fn placeholder_generic_args(&self, local_def_id: LocalDefId) -> mir_ty::GenericArgsRef<'tcx> { + let mut constrained_params = HashSet::new(); + let predicates = self.tcx.predicates_of(local_def_id); + let sized_trait = self.tcx.lang_items().sized_trait().unwrap(); + for (clause, _) in predicates.predicates { + let mir_ty::ClauseKind::Trait(pred) = clause.kind().skip_binder() else { + continue; + }; + if pred.def_id() == sized_trait { + continue; + }; + for arg in pred.trait_ref.args.iter().flat_map(|ty| ty.walk()) { + let Some(ty) = arg.as_type() else { + continue; + }; + let mir_ty::TyKind::Param(param_ty) = ty.kind() else { + continue; + }; + constrained_params.insert(param_ty.index); + } + } + + let mut args: Vec> = Vec::new(); + + let generics = self.tcx.generics_of(local_def_id); + for idx in 0..generics.count() { + let param = generics.param_at(idx, self.tcx); + let arg = match param.kind { + mir_ty::GenericParamDefKind::Type { .. } => { + if constrained_params.contains(¶m.index) { + panic!( + "unable to check generic function with constrained type parameter: {}", + self.tcx.def_path_str(local_def_id) + ); + } + self.tcx.types.i32.into() + } + mir_ty::GenericParamDefKind::Const { .. } => { + unimplemented!() + } + mir_ty::GenericParamDefKind::Lifetime { .. } => self.tcx.lifetimes.re_erased.into(), + }; + args.push(arg); + } + + self.tcx.mk_args(&args) + } + fn assert_callable_entry(&mut self) { if let Some((def_id, _)) = self.tcx.entry_fn(()) { // we want to assert entry function is safe to execute without any assumption diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 58e4b8b..50a4397 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -169,8 +169,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } pub fn expected_ty(&mut self) -> rty::RefinedType { - let sig = self.tcx.fn_sig(self.local_def_id); - let sig = sig.instantiate_identity().skip_binder(); + let sig = self + .ctx + .local_fn_sig_with_body(self.local_def_id, &self.body); let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self @@ -532,7 +533,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .basic_block_analyzer(self.local_def_id, bb) .body(self.body.clone()) .drop_points(drop_points) - .type_builder(self.type_builder.clone()) .run(&rty); } } @@ -649,8 +649,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } } - pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { - self.type_builder = type_builder; + pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self { + self.body = + mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args); self } diff --git a/src/chc.rs b/src/chc.rs index 6f5db32..8a3309f 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -389,7 +389,7 @@ impl Function { } /// A logical term. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub enum Term { Null, Var(V), @@ -984,7 +984,7 @@ impl Pred { } /// An atom is a predicate applied to a list of terms. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct Atom { pub pred: Pred, pub args: Vec>, @@ -1077,7 +1077,7 @@ impl Atom { /// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known /// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to /// enforce this restriction statically. Also see the definition of [`Body`]. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub enum Formula { Atom(Atom), Not(Box>), @@ -1296,7 +1296,7 @@ impl Formula { } /// The body part of a clause, consisting of atoms and a formula. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct Body { pub atoms: Vec>, /// NOTE: This doesn't contain predicate variables. Also see [`Formula`]. diff --git a/src/refine/template.rs b/src/refine/template.rs index c859419..446c450 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -59,19 +59,6 @@ where } } -trait ParamTypeMapper { - fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type; -} - -impl ParamTypeMapper for F -where - F: Fn(rty::ParamType) -> rty::Type, -{ - fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type { - self(ty) - } -} - /// Translates [`mir_ty::Ty`] to [`rty::Type`]. /// /// This struct implements a translation from Rust MIR types to Thrust types. @@ -87,9 +74,6 @@ pub struct TypeBuilder<'tcx> { /// mapped when we translate a [`mir_ty::ParamTy`] to [`rty::ParamType`]. /// See [`rty::TypeParamIdx`] for more details. param_idx_mapping: HashMap, - /// Optionally also want to further map rty::ParamType to other rty::Type before generating - /// templates. This is no-op by default. - param_type_mapper: std::rc::Rc, } impl<'tcx> TypeBuilder<'tcx> { @@ -109,25 +93,15 @@ impl<'tcx> TypeBuilder<'tcx> { Self { tcx, param_idx_mapping, - param_type_mapper: std::rc::Rc::new(|ty: rty::ParamType| ty.into()), } } - pub fn with_param_mapper(mut self, mapper: F) -> Self - where - F: Fn(rty::ParamType) -> rty::Type + 'static, - { - self.param_type_mapper = std::rc::Rc::new(mapper); - self - } - fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::Type { let index = *self .param_idx_mapping .get(&ty.index) .expect("unknown type param idx"); - let param_ty = rty::ParamType::new(index); - self.param_type_mapper.map_param_ty(param_ty) + rty::ParamType::new(index).into() } // TODO: consolidate two impls @@ -400,17 +374,6 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty = Some(rty); self } - - pub fn would_contain_template(&self) -> bool { - if self.param_tys.is_empty() { - return self.ret_rty.is_none(); - } - - let last_param_idx = rty::FunctionParamIdx::from(self.param_tys.len() - 1); - let param_annotated = - self.param_refinement.is_some() || self.param_rtys.contains_key(&last_param_idx); - self.ret_rty.is_none() || !param_annotated - } } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> diff --git a/src/rty.rs b/src/rty.rs index ce6ef5e..64b8dfb 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -55,7 +55,7 @@ mod subtyping; pub use subtyping::{relate_sub_closed_type, ClauseScope, Subtyping}; mod params; -pub use params::{RefinedTypeArgs, TypeArgs, TypeParamIdx, TypeParamSubst}; +pub use params::{RefinedTypeArgs, TypeParamIdx, TypeParamSubst}; rustc_index::newtype_index! { /// An index representing function parameter. @@ -88,7 +88,7 @@ where /// In Thrust, function types are closed. Because of that, function types, thus its parameters and /// return type only refer to the parameters of the function itself using [`FunctionParamIdx`] and /// do not accept other type of variables from the environment. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct FunctionType { pub params: IndexVec>, pub ret: Box>, @@ -156,7 +156,7 @@ impl FunctionType { } /// The kind of a reference, which is either mutable or immutable. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RefKind { Mut, Immut, @@ -181,7 +181,7 @@ where } /// The kind of a pointer, which is either a reference or an owned pointer. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PointerKind { Ref(RefKind), Own, @@ -221,7 +221,7 @@ impl PointerKind { } /// A pointer type. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct PointerType { pub kind: PointerKind, pub elem: Box>, @@ -334,7 +334,7 @@ impl PointerType { /// Note that the current implementation uses tuples to represent structs. See /// implementation in `crate::refine::template` module for details. /// It is our TODO to improve the struct representation. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct TupleType { pub elems: Vec>, } @@ -458,7 +458,7 @@ impl EnumDatatypeDef { /// An enum type. /// /// An enum type includes its type arguments and the argument types can refer to outer variables `T`. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct EnumType { pub symbol: chc::DatatypeSymbol, pub args: IndexVec>, @@ -560,7 +560,7 @@ impl EnumType { } /// A type parameter. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct ParamType { pub idx: TypeParamIdx, } @@ -589,7 +589,7 @@ impl ParamType { } /// An underlying type of a refinement type. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub enum Type { Int, Bool, @@ -995,7 +995,7 @@ impl ShiftExistential for RefinedTypeVar { /// A formula, potentially equipped with an existential quantifier. /// /// Note: This is not to be confused with [`crate::chc::Formula`] in the [`crate::chc`] module, which is a different notion. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct Formula { pub existentials: IndexVec, pub body: chc::Body, @@ -1236,7 +1236,7 @@ impl Instantiator { } /// A refinement type. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct RefinedType { pub ty: Type, pub refinement: Refinement,