aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoel Kronqvist <joel.kronqvist@iki.fi>2025-08-12 20:00:56 +0300
committerJoel Kronqvist <joel.kronqvist@iki.fi>2025-08-12 20:00:56 +0300
commita8a4c5b567ea6a58809dc8232ea5f1d3c93879b9 (patch)
treed9ed2fc0970900bb4b8ff790158ae3ff3c0213d9
parentdb736d795b759edd913d96195747f0881c4e950f (diff)
downloadmyslip-a8a4c5b567ea6a58809dc8232ea5f1d3c93879b9.tar.gz
myslip-a8a4c5b567ea6a58809dc8232ea5f1d3c93879b9.zip
feat: type checking for case expressions
-rw-r--r--src/sexp/util.rs8
-rw-r--r--src/type/case.rs132
-rw-r--r--src/type/check.rs56
-rw-r--r--src/type/display.rs7
-rw-r--r--src/type/mod.rs8
5 files changed, 203 insertions, 8 deletions
diff --git a/src/sexp/util.rs b/src/sexp/util.rs
index a28aaad..3de45d4 100644
--- a/src/sexp/util.rs
+++ b/src/sexp/util.rs
@@ -86,4 +86,12 @@ impl SExp {
_ => None
}
}
+
+ pub fn check_case(self) -> Option<(SExp, Vec<SExp>)> {
+ match &(self.parts())[..] {
+ [casekw, scrutinee, cases @ ..] if casekw.clone() == Atom(Case) =>
+ Some((scrutinee.clone(), cases.to_vec())),
+ _ => None,
+ }
+ }
}
diff --git a/src/type/case.rs b/src/type/case.rs
index 2403acd..70597b6 100644
--- a/src/type/case.rs
+++ b/src/type/case.rs
@@ -1,6 +1,8 @@
use crate::r#type::{Type, TypeError, Type::*, TypeError::*, PatFail::*, util::*};
use crate::sexp::{SExp, SExp::*, SLeaf::*, util::*};
+use std::collections::{HashMap, HashSet};
+use std::iter;
impl SExp {
@@ -95,8 +97,132 @@ impl SExp {
/// ])
/// );
/// ```
- ///
- pub fn matches_type(&self, ty: &Type) -> Result<Vec<(String, Type)>, TypeError> {
- todo!()
+ ///
+ /// Vector matching should work without the rest pattern too:
+ /// ```rust
+ /// use myslip::sexp::{SExp::*, SLeaf::*, util::*};
+ /// use myslip::r#type::{Type::*, util::*, TypeError::*, PatFail::*};
+ /// assert_eq!(
+ /// scons(var("h"), scons(2, Nil)).matches_type(&vecof(Integer)),
+ /// Ok(vec![("h".to_string(), Integer)])
+ /// );
+ /// ```
+ /// TODO: Nil / empty list
+ pub fn matches_type(
+ &self,
+ ty: &Type
+ ) -> Result<Vec<(String, Type)>, TypeError> {
+
+ let mut checks = HashSet::new();
+
+ let res = self.clone()
+ .matches_type_ctx(ty.clone(), HashMap::new())?;
+
+ for (k, _) in res.clone() {
+ if !checks.insert(k.clone()) {
+ return Err(InvalidPattern(RepeatedVariable(k, self.clone())));
+ }
+ }
+
+ Ok(res)
+
+ }
+
+ fn matches_type_ctx(
+ self,
+ ty: Type,
+ ctx: HashMap<String, Type>
+ ) -> Result<Vec<(String, Type)>, TypeError> {
+
+ match (self, ty) {
+
+ (a, b) if a.infer_list_type(ctx.clone()) == Ok(b.clone()) =>
+ Ok(ctx.into_iter().collect()),
+
+ (a, b) if a.infer_type(ctx.clone()) == Ok(b.clone()) =>
+ Ok(ctx.into_iter().collect()),
+
+ (Atom(Var(name)), t) =>
+ Ok(ctx.into_iter()
+ .chain(iter::once((name, t)))
+ .collect()),
+
+ (exp, VecOf(ty)) => {
+ let mut res: Vec<(String, Type)> =
+ ctx.clone().into_iter().collect();
+ let mut exps = exp.clone().parts();
+ // TODO: Nil or empty exp
+ let restpat = exps.remove(exps.len() - 1);
+ for exp in exps {
+ for et in exp.matches_type_ctx(*ty.clone(), ctx.clone())? {
+ res.push(et);
+ }
+ }
+ match restpat {
+ Atom(RestPat(name)) => {
+ res.push((name, vecof(ty)));
+ Ok(res)
+ },
+ t => {
+ for et in t.matches_type_ctx(*ty.clone(), ctx.clone())? {
+ res.push(et);
+ }
+ Ok(res)
+ }
+ }
+ },
+
+ (SCons(e1, e2), List(typelist)) => {
+ let explist = scons(e1.clone(), e2.clone()).parts();
+ let mut res: Vec<(String, Type)> =
+ ctx.clone().into_iter().collect();
+ if explist.len() == typelist.len() {
+ for (exp, ty) in explist.into_iter().zip(typelist) {
+ for (e, t) in exp.matches_type_ctx(ty, ctx.clone())? {
+ res.push((e, t));
+ }
+ }
+ Ok(res)
+ } else {
+ match explist.last().cloned() {
+ Some(Atom(RestPat(name))) => {
+ for (exp, ty) in explist.clone()
+ .into_iter()
+ .rev().skip(1).rev()
+ .zip(typelist.clone())
+ {
+ for (e, t) in exp.matches_type_ctx(ty, ctx.clone())? {
+ res.push((e, t));
+ }
+ }
+ res.push((
+ name,
+ List(typelist
+ .into_iter()
+ .skip(explist.len() - 1)
+ .collect())
+ ));
+ Ok(res)
+ },
+ _ => Err(InvalidPattern(TypeMismatch {
+ pattern: scons(e1.clone(), e2.clone()),
+ expected: List(typelist),
+ found: scons(e1, e2).infer_list_type(ctx)?,
+ })),
+ }
+ }
+ },
+
+ (e, t) => {
+ let found_ty = e.infer_list_type(ctx)?;
+ Err(InvalidPattern(TypeMismatch {
+ pattern: e,
+ expected: t,
+ found: found_ty
+ }))
+ },
+
+ }
+
}
}
diff --git a/src/type/check.rs b/src/type/check.rs
index 9d7c52b..7a2d231 100644
--- a/src/type/check.rs
+++ b/src/type/check.rs
@@ -1,5 +1,5 @@
-use crate::r#type::{Type, TypeError, Type::*, TypeError::*, util::*};
+use crate::r#type::{Type, TypeError, Type::*, TypeError::*, PatFail::*, util::*};
use crate::sexp::{SExp, SExp::*, SLeaf::*, util::*};
use std::collections::HashMap;
@@ -156,7 +156,7 @@ impl SExp {
///
/// assert_eq!(
/// parse_to_ast(
- /// "(case (1 2 3 true false) ((x 2 3 false true) (+ x 1)) ((0 0 0 true true) 0) (_ 1))"
+ /// "+ (case (quote 1 2 3 true false) ((x 2 3 false true) (+ x 1)) ((0 0 0 true true) 0) (_ 1)) 0"
/// ).unwrap().type_check(),
/// Ok(Integer)
/// );
@@ -251,8 +251,8 @@ impl SExp {
Atom(Ty(_)) => Ok(TypeLit),
Atom(Arr) => Ok(arr(List(vec![TypeLit, TypeLit]), TypeLit)),
Atom(Fun) => Err(FunAsAtom),
- Atom(Case) => todo!(),
- Atom(RestPat(_)) => todo!(),
+ Atom(Case) => Err(CaseAsAtom),
+ Atom(RestPat(_)) => Err(RestAsAtom),
SCons(op, l) => {
@@ -274,6 +274,52 @@ impl SExp {
return scons(op.clone(), l.clone()).get_fun_type(ctx);
}
+ // Case expressions
+ if let Some((scrutinee, patarms)) = scons(op.clone(), l.clone()).check_case() {
+ let scruty = scrutinee.infer_type(ctx.clone())?;
+ let scruty = match scruty {
+ List(v) if (
+ v.get(0) == Some(&QuoteTy)
+ || v.get(0) == Some(&VecType)
+ ) && v.get(1).is_some() =>
+ v[1].clone(),
+ t => t,
+ };
+ let mut ty: Option<Type> = None;
+ let mut has_wildcard = false;
+ for patandarm in patarms {
+ let (pat, arm) = match patandarm {
+ SCons(pat, arm) => Ok((*pat, *arm)),
+ _ => Err(InvalidPattern(NoArm(patandarm))),
+ }?;
+ let arm = match arm {
+ SCons(x, n) if *n == Atom(Nil) => *x,
+ t => t,
+ };
+ if let Atom(Var(_)) = pat {
+ has_wildcard = true;
+ }
+ let mut newctx = ctx.clone();
+ for (name, ty) in pat.matches_type(&scruty)? {
+ newctx.insert(name, ty);
+ }
+ match &ty {
+ None => ty = Some(arm.infer_type(newctx)?),
+ Some(t) => if arm.infer_type(newctx.clone())? != *t {
+ println!("different types: {}, {}", t, arm.infer_list_type(newctx)?);
+ return Err(OtherError);
+ },
+ }
+ }
+ if !has_wildcard {
+ return Err(NoWildcardInCase(scrutinee));
+ }
+ match ty {
+ Some(t) => return Ok(t),
+ None => return Err(OtherError),
+ }
+ }
+
// Normal operation
let opertype = (*op).infer_type(ctx.clone())?;
let argstype = (*l).infer_list_type(ctx)?;
@@ -333,7 +379,7 @@ impl SExp {
}
- fn infer_list_type(&self, ctx: HashMap<String, Type>) -> Result<Type, TypeError> {
+ pub fn infer_list_type(&self, ctx: HashMap<String, Type>) -> Result<Type, TypeError> {
let mut res = vec![];
for exp in self.clone().parts() {
res.push(exp.infer_type(ctx.clone())?);
diff --git a/src/type/display.rs b/src/type/display.rs
index e945eba..290e24e 100644
--- a/src/type/display.rs
+++ b/src/type/display.rs
@@ -123,7 +123,10 @@ impl fmt::Display for TypeError {
)
},
FunAsAtom => write!(f, "'fn' used as atom doesn't make sense"),
+ CaseAsAtom => write!(f, "'case' used as atom doesn't make sense"),
+ RestAsAtom => write!(f, "'..[name]' used as atom doesn't make sense"),
InvalidFunDef(exp, err) => write!(f, "invalid function definition '{exp}': {err}"),
+ NoWildcardInCase(exp) => write!(f, "no wildcard in cases: '{exp}'"),
OtherError => write!(f, "uncategorized error"),
}
}
@@ -133,6 +136,10 @@ impl fmt::Display for TypeError {
impl fmt::Display for PatFail {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
+ RestNotAtEnd(exp) =>
+ write!(f, "rest pattern should be at end in '{exp}'"),
+ NoArm(exp) =>
+ write!(f, "'{exp}' should consist of a pattern and a respective arm"),
RepeatedVariable(name, exp_in) =>
write!(f, "repeated pattern variable '{name}' in '{exp_in}'"),
TypeMismatch { pattern, expected, found } =>
diff --git a/src/type/mod.rs b/src/type/mod.rs
index 34506d0..e979595 100644
--- a/src/type/mod.rs
+++ b/src/type/mod.rs
@@ -71,8 +71,14 @@ pub enum TypeError {
FunAsAtom,
+ CaseAsAtom,
+
+ RestAsAtom,
+
InvalidFunDef(SExp, FunDefError),
+ NoWildcardInCase(SExp),
+
OtherError
}
@@ -93,6 +99,8 @@ pub enum FunDefError {
#[derive(Debug,PartialEq)]
pub enum PatFail {
RepeatedVariable(String, SExp),
+ NoArm(SExp),
+ RestNotAtEnd(SExp),
TypeMismatch {
pattern: SExp,
expected: Type,