Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> {
}

#[derive(Debug, Clone)]
struct DeferredDefTy {
cache: Rc<RefCell<HashMap<rty::TypeArgs, rty::RefinedType>>>,
struct DeferredDefTy<'tcx> {
cache: Rc<RefCell<HashMap<mir_ty::GenericArgsRef<'tcx>, rty::RefinedType>>>,
}

#[derive(Debug, Clone)]
enum DefTy {
enum DefTy<'tcx> {
Concrete(rty::RefinedType),
Deferred(DeferredDefTy),
Deferred(DeferredDefTy<'tcx>),
}

#[derive(Clone)]
Expand All @@ -123,7 +123,7 @@ pub struct Analyzer<'tcx> {
/// currently contains only local-def templates,
/// but will be extended to contain externally known def's refinement types
/// (at least for every defs referenced by local def bodies)
defs: HashMap<DefId, DefTy>,
defs: HashMap<DefId, DefTy<'tcx>>,

/// Resulting CHC system.
system: Rc<RefCell<chc::System>>,
Expand Down Expand Up @@ -241,15 +241,17 @@ impl<'tcx> Analyzer<'tcx> {
pub fn def_ty_with_args(
&mut self,
def_id: DefId,
rty_args: rty::TypeArgs,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Option<rty::RefinedType> {
let deferred_ty = match self.defs.get(&def_id)? {
DefTy::Concrete(rty) => {
let type_builder = TypeBuilder::new(self.tcx, def_id);

let mut def_ty = rty.clone();
def_ty.instantiate_ty_params(
rty_args
.clone()
.into_iter()
generic_args
.types()
.map(|ty| type_builder.build(ty))
.map(rty::RefinedType::unrefined)
.collect(),
);
Expand All @@ -259,21 +261,17 @@ impl<'tcx> Analyzer<'tcx> {
};

let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self
if let Some(rty) = deferred_ty_cache.borrow().get(&rty_args) {
if let Some(rty) = deferred_ty_cache.borrow().get(&generic_args) {
return Some(rty.clone());
}

let type_builder = TypeBuilder::new(self.tcx, def_id).with_param_mapper({
let rty_args = rty_args.clone();
move |ty: rty::ParamType| rty_args[ty.idx].clone()
});
let mut analyzer = self.local_def_analyzer(def_id.as_local()?);
analyzer.type_builder(type_builder);
analyzer.generic_args(generic_args);

let expected = analyzer.expected_ty();
deferred_ty_cache
.borrow_mut()
.insert(rty_args, expected.clone());
.insert(generic_args, expected.clone());

analyzer.run(&expected);
Some(expected)
Expand Down Expand Up @@ -340,4 +338,30 @@ impl<'tcx> Analyzer<'tcx> {
self.tcx.dcx().err(format!("verification error: {:?}", err));
}
}

/// Computes the signature of the local function.
///
/// This is a drop-in replacement of `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
/// but extracts parameter and return types directly from the given `body` to obtain a signature that
/// reflects potential type instantiations happened after `optimized_mir`.
pub fn local_fn_sig_with_body(
&self,
local_def_id: LocalDefId,
body: &mir::Body<'tcx>,
) -> mir_ty::FnSig<'tcx> {
let ty = self.tcx.type_of(local_def_id).instantiate_identity();
let sig = if let mir_ty::TyKind::Closure(_, substs) = ty.kind() {
substs.as_closure().sig().skip_binder()
} else {
ty.fn_sig(self.tcx).skip_binder()
};

self.tcx.mk_fn_sig(
body.args_iter().map(|arg| body.local_decls[arg].ty),
body.return_ty(),
sig.c_variadic,
sig.unsafety,
sig.abi,
)
}
}
34 changes: 12 additions & 22 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
_ty,
) => {
let func_ty = match operand.const_fn_def() {
Some((def_id, args)) => {
let rty_args: IndexVec<_, _> =
args.types().map(|ty| self.type_builder.build(ty)).collect();
self.ctx
.def_ty_with_args(def_id, rty_args)
.expect("unknown def")
.ty
.clone()
}
Some((def_id, args)) => self
.ctx
.def_ty_with_args(def_id, args)
.expect("unknown def")
.ty
.clone(),
_ => unimplemented!(),
};
PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null())
Expand Down Expand Up @@ -471,14 +468,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into());
rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into()
}
Some((def_id, args)) => {
let rty_args = args.types().map(|ty| self.type_builder.build(ty)).collect();
self.ctx
.def_ty_with_args(def_id, rty_args)
.expect("unknown def")
.ty
.vacuous()
}
Some((def_id, args)) => self
.ctx
.def_ty_with_args(def_id, args)
.expect("unknown def")
.ty
.vacuous(),
_ => self.operand_type(func.clone()).ty,
};
let expected_args: IndexVec<_, _> = args
Expand Down Expand Up @@ -988,11 +983,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self
}

pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self {
self.type_builder = type_builder;
self
}

pub fn run(&mut self, expected: &BasicBlockType) {
let span = tracing::info_span!("bb", bb = ?self.basic_block);
let _guard = span.enter();
Expand Down
53 changes: 50 additions & 3 deletions src/analyze/crate_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {

// check polymorphic function def by replacing type params with some opaque type
// (and this is no-op if the function is mono)
let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id())
.with_param_mapper(|_| rty::Type::int());
let mut expected = expected.clone();
let subst = rty::TypeParamSubst::new(
expected
Expand All @@ -87,13 +85,62 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
.collect(),
);
expected.subst_ty_params(&subst);
let generic_args = self.placeholder_generic_args(*local_def_id);
self.ctx
.local_def_analyzer(*local_def_id)
.type_builder(type_builder)
.generic_args(generic_args)
.run(&expected);
}
}

fn placeholder_generic_args(&self, local_def_id: LocalDefId) -> mir_ty::GenericArgsRef<'tcx> {
let mut constrained_params = HashSet::new();
let predicates = self.tcx.predicates_of(local_def_id);
let sized_trait = self.tcx.lang_items().sized_trait().unwrap();
for (clause, _) in predicates.predicates {
let mir_ty::ClauseKind::Trait(pred) = clause.kind().skip_binder() else {
continue;
};
if pred.def_id() == sized_trait {
continue;
};
for arg in pred.trait_ref.args.iter().flat_map(|ty| ty.walk()) {
let Some(ty) = arg.as_type() else {
continue;
};
let mir_ty::TyKind::Param(param_ty) = ty.kind() else {
continue;
};
constrained_params.insert(param_ty.index);
}
}

let mut args: Vec<mir_ty::GenericArg<'tcx>> = Vec::new();

let generics = self.tcx.generics_of(local_def_id);
for idx in 0..generics.count() {
let param = generics.param_at(idx, self.tcx);
let arg = match param.kind {
mir_ty::GenericParamDefKind::Type { .. } => {
if constrained_params.contains(&param.index) {
panic!(
"unable to check generic function with constrained type parameter: {}",
self.tcx.def_path_str(local_def_id)
);
}
self.tcx.types.i32.into()
}
mir_ty::GenericParamDefKind::Const { .. } => {
unimplemented!()
}
mir_ty::GenericParamDefKind::Lifetime { .. } => self.tcx.lifetimes.re_erased.into(),
};
args.push(arg);
}

self.tcx.mk_args(&args)
}

fn assert_callable_entry(&mut self) {
if let Some((def_id, _)) = self.tcx.entry_fn(()) {
// we want to assert entry function is safe to execute without any assumption
Expand Down
11 changes: 6 additions & 5 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}

pub fn expected_ty(&mut self) -> rty::RefinedType {
let sig = self.tcx.fn_sig(self.local_def_id);
let sig = sig.instantiate_identity().skip_binder();
let sig = self
.ctx
.local_fn_sig_with_body(self.local_def_id, &self.body);

let mut param_resolver = analyze::annot::ParamResolver::default();
for (input_ident, input_ty) in self
Expand Down Expand Up @@ -532,7 +533,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
.basic_block_analyzer(self.local_def_id, bb)
.body(self.body.clone())
.drop_points(drop_points)
.type_builder(self.type_builder.clone())
.run(&rty);
}
}
Expand Down Expand Up @@ -649,8 +649,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
}

pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self {
self.type_builder = type_builder;
pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self {
self.body =
mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args);
self
}

Expand Down
8 changes: 4 additions & 4 deletions src/chc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl Function {
}

/// A logical term.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone)]
pub enum Term<V = TermVarIdx> {
Null,
Var(V),
Expand Down Expand Up @@ -984,7 +984,7 @@ impl Pred {
}

/// An atom is a predicate applied to a list of terms.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone)]
pub struct Atom<V = TermVarIdx> {
pub pred: Pred,
pub args: Vec<Term<V>>,
Expand Down Expand Up @@ -1077,7 +1077,7 @@ impl<V> Atom<V> {
/// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known
/// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to
/// enforce this restriction statically. Also see the definition of [`Body`].
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone)]
pub enum Formula<V = TermVarIdx> {
Atom(Atom<V>),
Not(Box<Formula<V>>),
Expand Down Expand Up @@ -1296,7 +1296,7 @@ impl<V> Formula<V> {
}

/// The body part of a clause, consisting of atoms and a formula.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone)]
pub struct Body<V = TermVarIdx> {
pub atoms: Vec<Atom<V>>,
/// NOTE: This doesn't contain predicate variables. Also see [`Formula`].
Expand Down
39 changes: 1 addition & 38 deletions src/refine/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,6 @@ where
}
}

trait ParamTypeMapper {
fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type<rty::Closed>;
}

impl<F> ParamTypeMapper for F
where
F: Fn(rty::ParamType) -> rty::Type<rty::Closed>,
{
fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type<rty::Closed> {
self(ty)
}
}

/// Translates [`mir_ty::Ty`] to [`rty::Type`].
///
/// This struct implements a translation from Rust MIR types to Thrust types.
Expand All @@ -87,9 +74,6 @@ pub struct TypeBuilder<'tcx> {
/// mapped when we translate a [`mir_ty::ParamTy`] to [`rty::ParamType`].
/// See [`rty::TypeParamIdx`] for more details.
param_idx_mapping: HashMap<u32, rty::TypeParamIdx>,
/// Optionally also want to further map rty::ParamType to other rty::Type before generating
/// templates. This is no-op by default.
param_type_mapper: std::rc::Rc<dyn ParamTypeMapper>,
}

impl<'tcx> TypeBuilder<'tcx> {
Expand All @@ -109,25 +93,15 @@ impl<'tcx> TypeBuilder<'tcx> {
Self {
tcx,
param_idx_mapping,
param_type_mapper: std::rc::Rc::new(|ty: rty::ParamType| ty.into()),
}
}

pub fn with_param_mapper<F>(mut self, mapper: F) -> Self
where
F: Fn(rty::ParamType) -> rty::Type<rty::Closed> + 'static,
{
self.param_type_mapper = std::rc::Rc::new(mapper);
self
}

fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::Type<rty::Closed> {
let index = *self
.param_idx_mapping
.get(&ty.index)
.expect("unknown type param idx");
let param_ty = rty::ParamType::new(index);
self.param_type_mapper.map_param_ty(param_ty)
rty::ParamType::new(index).into()
}

// TODO: consolidate two impls
Expand Down Expand Up @@ -400,17 +374,6 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
self.ret_rty = Some(rty);
self
}

pub fn would_contain_template(&self) -> bool {
if self.param_tys.is_empty() {
return self.ret_rty.is_none();
}

let last_param_idx = rty::FunctionParamIdx::from(self.param_tys.len() - 1);
let param_annotated =
self.param_refinement.is_some() || self.param_rtys.contains_key(&last_param_idx);
self.ret_rty.is_none() || !param_annotated
}
}

impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R>
Expand Down
Loading