use crate::diagnostics::{Diagnostic, Level, Span};
use crate::engine::Engine;
use crate::id::Id;
use crate::types::Type;
use crate::value::{UValue, Value};
use crate::TypedSignature;
use comemo::TrackedMut;
use ecow::EcoVec;
use rustre_parser::ast::expr_visitor::ExpressionWalker;
use rustre_parser::ast::{
AstNode, CallByPosExpressionNode, EffectiveNodeNode, IdRefNode, Ident, NodeNode, StaticArgNode,
StaticArgsNode, StaticParamNode, WithExpressionNode,
};
use std::cmp::Ordering;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct StaticParams {
inner: EcoVec<(Option<Ident>, StaticParamType)>,
}
impl StaticParams {
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum StaticParamType {
Type,
Const {
typ: Option<Type>,
},
Node {
is_function: bool,
signature: Option<TypedSignature>,
},
}
#[memoize]
fn static_param_of_node(engine: TrackedMut<Engine>, node: StaticParamNode) -> StaticParamType {
if node.r#type().is_some() {
StaticParamType::Type
} else if node.r#const().is_some() {
let typ = node.type_node().map(|type_node| {
crate::types::type_of_ast_type(engine, None, type_node)
});
StaticParamType::Const { typ }
} else if node.is_node() || node.is_function() || node.is_unsafe() {
StaticParamType::Node {
is_function: node.is_function(),
signature: None, }
} else {
todo!("unimplemented")
}
}
#[memoize]
pub fn static_params_of_node(mut engine: TrackedMut<Engine>, node: NodeNode) -> StaticParams {
if let Some(static_params) = node.static_params_node() {
StaticParams {
inner: static_params
.all_static_param_node()
.map(|param| {
let id = param.id_node().and_then(|n| n.ident());
(
id,
static_param_of_node(TrackedMut::reborrow_mut(&mut engine), param),
)
})
.collect(),
}
} else {
StaticParams::default()
}
}
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct StaticArgs {
inner: EcoVec<(Box<Id>, StaticArg)>,
}
impl StaticArgs {
pub fn values(&self) -> impl Iterator<Item = &StaticArg> {
self.inner.iter().map(|(_, value)| value)
}
pub fn get(&self, id: &Id) -> Option<&StaticArg> {
self.inner
.iter()
.find(|(par_id, _)| par_id.as_ref() == id)
.map(|(_, arg)| arg)
}
pub fn resolve_type(&self, id: &Id) -> Option<&Type> {
self.get(id).and_then(StaticArg::as_type)
}
pub fn resolve_const(&self, id: &Id) -> Option<&Value> {
self.get(id).and_then(StaticArg::as_const)
}
pub fn resolve_node(&self, id: &Id) -> Option<&NodeInstance> {
self.get(id).and_then(StaticArg::as_node)
}
}
impl PartialOrd for StaticArgs {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(Ord::cmp(self, other))
}
}
impl Ord for StaticArgs {
fn cmp(&self, other: &Self) -> Ordering {
self.values().cmp(other.values())
}
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum StaticArg {
Type(Type),
Const(Value),
Node(NodeInstance),
}
macro_rules! as_impl {
($self:ident as $variant:path) => {
if let $variant(x) = $self {
Some(x)
} else {
None
}
};
}
impl StaticArg {
pub fn as_type(&self) -> Option<&Type> {
as_impl!(self as Self::Type)
}
pub fn as_const(&self) -> Option<&Value> {
as_impl!(self as Self::Const)
}
pub fn as_node(&self) -> Option<&NodeInstance> {
as_impl!(self as Self::Node)
}
}
impl From<Type> for StaticArg {
fn from(value: Type) -> Self {
Self::Type(value)
}
}
impl From<Value> for StaticArg {
fn from(value: Value) -> Self {
Self::Const(value)
}
}
impl From<NodeInstance> for StaticArg {
fn from(value: NodeInstance) -> Self {
Self::Node(value)
}
}
#[derive(Clone, Eq, Hash, PartialEq)]
pub struct NodeInstance {
pub node: NodeNode,
pub static_args: StaticArgs,
}
impl NodeInstance {
pub fn new_non_parametric(node: NodeNode) -> Self {
Self {
node,
static_args: StaticArgs::default(),
}
}
pub fn new_parametric(node: NodeNode, static_args: StaticArgs) -> Self {
Self { node, static_args }
}
pub fn is_parametric(&self) -> bool {
!self.static_args.inner.is_empty()
}
fn node_name(&self) -> Option<Ident> {
self.node.id_node().map(|id| id.ident().unwrap())
}
}
impl Debug for NodeInstance {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
if self.node.is_function() {
"function "
} else {
"node "
}
)?;
if let Some(name) = self.node_name() {
write!(f, "{}", <&Id>::from(&name))?;
} else {
write!(f, "[unnamed]")?;
}
if self.is_parametric() {
write!(f, "<<")?;
for (idx, (_, value)) in self.static_args.inner.iter().enumerate() {
if idx > 0 {
write!(f, ", ")?;
}
write!(f, "{value:?}")?;
}
write!(f, ">>")?;
}
Ok(())
}
}
impl PartialOrd for NodeInstance {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(Ord::cmp(self, other))
}
}
impl Ord for NodeInstance {
fn cmp(&self, other: &Self) -> Ordering {
Ord::cmp(&!self.node.is_function(), &!other.node.is_function()).then_with(|| {
let self_name = self.node_name();
let self_name = self_name.as_ref().map(<&Id>::from);
let other_name = other.node_name();
let other_name = other_name.as_ref().map(<&Id>::from);
Ord::cmp(&self_name, &other_name)
.then_with(|| Ord::cmp(&self.static_args, &other.static_args))
})
}
}
#[memoize]
fn static_arg_of_static_arg_node(
mut engine: TrackedMut<Engine>,
caller: Option<NodeInstance>,
expected: StaticParamType,
arg: StaticArgNode,
) -> Option<StaticArg> {
if let Some(n) = arg.type_node() {
assert!(matches!(expected, StaticParamType::Type));
let ty = crate::types::type_of_ast_type(engine, caller, n);
Some(ty.into())
} else if let Some(n) = arg.expression_node() {
assert!(matches!(expected, StaticParamType::Const { .. }));
crate::eval::eval_const_node(TrackedMut::reborrow_mut(&mut engine), n, caller)
.map(StaticArg::from)
} else if let Some(n) = arg.effective_node_node() {
assert!(matches!(expected, StaticParamType::Node { .. }));
instance_of_effective_node(TrackedMut::reborrow_mut(&mut engine), caller, n)
.map(StaticArg::from)
} else if let Some(n) = arg.id_ref_node() {
match expected {
StaticParamType::Type => todo!("proper guidance to add `type ` in front"),
StaticParamType::Const { .. } => {
crate::eval::eval_id_ref(TrackedMut::reborrow_mut(&mut engine), n, caller)
.map(StaticArg::from)
}
StaticParamType::Node { .. } => {
instance_of_id_ref(TrackedMut::reborrow_mut(&mut engine), caller, n)
.map(StaticArg::from)
}
}
} else {
unimplemented!()
}
}
#[memoize]
pub fn static_args_of_effective_node(
mut engine: TrackedMut<Engine>,
caller: Option<NodeInstance>,
callee: NodeNode,
static_args_node: Option<StaticArgsNode>,
) -> Option<StaticArgs> {
let static_params = static_params_of_node(TrackedMut::reborrow_mut(&mut engine), callee);
match static_args_node {
None => Some(StaticArgs::default()),
Some(args) => {
let all_args = args.all_static_arg_node().collect::<Vec<_>>();
if all_args.is_empty() {
let span = Span::of_node(args.syntax());
Diagnostic::build(Level::Error, "no static arguments inside << >>")
.with_attachment(span, "hint: remove this")
.emit(&mut engine);
}
if all_args.len() != static_params.len() {
let span = Span::of_node(args.syntax());
let expected = static_params.len();
Diagnostic::build(Level::Error, "invalid static argument count")
.with_attachment(span, format!("expected {expected} static arguments"))
.emit(&mut engine);
}
let resolved_args = static_params
.inner
.iter()
.zip(all_args)
.map(|((id, expected), arg)| {
let id = id.as_ref().map(Box::<Id>::from).unwrap_or_default();
static_arg_of_static_arg_node(
TrackedMut::reborrow_mut(&mut engine),
caller.clone(),
expected.clone(),
arg,
)
.map(|arg| (id, arg))
})
.collect::<Option<EcoVec<(Box<Id>, StaticArg)>>>()?;
Some(StaticArgs {
inner: resolved_args,
})
}
}
}
#[memoize]
pub fn instance_of_call_expression(
mut engine: TrackedMut<Engine>,
caller: Option<NodeInstance>,
expr: CallByPosExpressionNode,
) -> Option<NodeInstance> {
let callee_name = expr.node_ref()?;
let callee = crate::name_resolution::find_node(
TrackedMut::reborrow_mut(&mut engine),
&caller.as_ref().map(|i| i.static_args.clone()),
callee_name,
)?;
let static_args =
static_args_of_effective_node(engine, caller, callee.clone(), expr.static_args_node())?;
Some(NodeInstance::new_parametric(callee, static_args))
}
#[memoize]
pub fn instance_of_effective_node(
mut engine: TrackedMut<Engine>,
caller: Option<NodeInstance>,
effective_node: EffectiveNodeNode,
) -> Option<NodeInstance> {
let alias_node = crate::name_resolution::find_node(
TrackedMut::reborrow_mut(&mut engine),
&None,
effective_node.id_ref_node()?,
)?;
let static_args = static_args_of_effective_node(
TrackedMut::reborrow_mut(&mut engine),
caller,
alias_node.clone(),
effective_node.static_args_node(),
)?;
Some(NodeInstance::new_parametric(alias_node, static_args))
}
#[memoize]
pub fn instance_of_id_ref(
mut engine: TrackedMut<Engine>,
_scope: Option<NodeInstance>,
id_ref: IdRefNode,
) -> Option<NodeInstance> {
let alias_node =
crate::name_resolution::find_node(TrackedMut::reborrow_mut(&mut engine), &None, id_ref)?;
Some(NodeInstance::new_non_parametric(alias_node))
}
struct CallSiteWalker<'a, 'world> {
engine: TrackedMut<'a, Engine<'world>>,
within: &'a NodeInstance,
instances: &'a mut EcoVec<NodeInstance>,
}
impl ExpressionWalker for CallSiteWalker<'_, '_> {
fn walk_with(&mut self, e: WithExpressionNode) -> bool {
let Some(cond) = e.cond() else {
return false;
};
let cond = crate::eval::eval_const_node(
TrackedMut::reborrow_mut(&mut self.engine),
cond,
Some(self.within.clone()),
);
let cond = cond.as_ref().map(Value::unpack);
if cond == Some(UValue::Bool(true)) {
self.walk_expr_opt(e.with_body());
} else if cond == Some(UValue::Bool(false)) {
self.walk_expr_opt(e.else_body());
}
false
}
fn walk_call_by_pos(&mut self, e: CallByPosExpressionNode) {
let callee = instance_of_call_expression(
TrackedMut::reborrow_mut(&mut self.engine),
Some(self.within.clone()),
e,
);
if let Some(callee) = callee {
self.instances.push(callee);
}
}
}
#[instrument(level = "TRACE", skip(engine), ret)]
fn node_call_sites(engine: TrackedMut<Engine>, instance: NodeInstance) -> EcoVec<NodeInstance> {
if instance.node.equal().is_some() {
return if let Some(alias) = instance.node.alias() {
instance_of_effective_node(engine, Some(instance), alias)
.into_iter()
.collect()
} else {
EcoVec::new()
};
}
let mut instances = EcoVec::new();
let mut walker = CallSiteWalker {
engine,
within: &instance,
instances: &mut instances,
};
let Some(body) = instance.node.body_node() else {
return EcoVec::new();
};
for equation in body.all_equals_equation_node() {
walker.walk_expr_opt(equation.expression_node());
}
for assert in body.all_assert_equation_node() {
walker.walk_expr_opt(assert.expression_node());
}
instances
}
#[memoize]
pub fn all_instances(mut engine: TrackedMut<Engine>) -> EcoVec<NodeInstance> {
let sources = crate::all_sources(TrackedMut::reborrow_mut(&mut engine));
let mut non_parametric_nodes = Vec::new();
let mut parametric_nodes = Vec::new();
for source in sources {
for node in source.root().all_node_node() {
if node.static_params_node().is_some() {
parametric_nodes.push(node);
} else {
non_parametric_nodes.push(node);
}
}
}
let mut instances = non_parametric_nodes
.iter()
.cloned()
.map(NodeInstance::new_non_parametric)
.collect::<EcoVec<_>>();
let mut visited = instances.iter().cloned().collect::<HashSet<_>>();
let mut idx = 0;
while let Some(instance) = instances.get(idx) {
let indirect_instances =
node_call_sites(TrackedMut::reborrow_mut(&mut engine), instance.clone());
for indirect_instance in indirect_instances {
if visited.insert(indirect_instance.clone()) {
instances.push(indirect_instance);
}
}
idx += 1;
}
info!("{} node instances visited", visited.len());
instances
}