aboutsummaryrefslogtreecommitdiff
path: root/src/type/util.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/type/util.rs')
-rw-r--r--src/type/util.rs61
1 files changed, 60 insertions, 1 deletions
diff --git a/src/type/util.rs b/src/type/util.rs
index 85e64f1..23f8d1f 100644
--- a/src/type/util.rs
+++ b/src/type/util.rs
@@ -1,5 +1,6 @@
-use crate::r#type::{*, Type::*};
+use crate::r#type::{*, Type::*, TypeError::*, FunDefError::*};
+use std::collections::HashMap;
pub fn arr(a: impl Into<Box<Type>>, b: impl Into<Box<Type>>) -> Type {
Arrow(a.into(), b.into())
@@ -12,3 +13,61 @@ pub fn vt(name: &str) -> Type {
pub fn vecof(ty: impl Into<Box<Type>>) -> Type {
VecOf(ty.into())
}
+
+use crate::sexp::{SExp::*, SLeaf::*};
+impl SExp {
+ pub fn get_fun_type(self, mut ctx: HashMap<String, Type>) -> Result<Type, TypeError> {
+ let ls = self.clone().parts();
+ ls.get(0)
+ .filter(|t| **t == Atom(Fun))
+ .ok_or(InvalidFunDef(self.clone(), NoFunToken))?;
+ let argnames = ls.get(1)
+ .ok_or(InvalidFunDef(self.clone(), NoArgumentList))?
+ .clone().parts();
+ let argtypes = ls.get(2)
+ .ok_or(InvalidFunDef(self.clone(), NoTypeList))?
+ .clone();
+ let rettype = ls.get(3)
+ .ok_or(InvalidFunDef(self.clone(), NoReturnType))?;
+ let funbody = ls.get(4)
+ .ok_or(InvalidFunDef(self.clone(), NoFunctionBody))?;
+
+ let mut argnamevec = vec![];
+ for name in argnames {
+ argnamevec.push(match name {
+ Atom(Var(s)) => Ok(s),
+ _ => Err(InvalidFunDef(self.clone(), InvalidArgumentList)),
+ }?);
+ }
+
+ let argtypes = match argtypes.clone().multistep() {
+ Ok(Atom(Ty(List(v)))) => Ok(v),
+ Ok(Atom(Ty(t))) => Ok(vec![t]),
+ _ => {
+ Err(InvalidFunDef(self.clone(), InvalidArgumentList))
+ },
+ }?;
+
+ let rettype = match rettype.clone().multistep() {
+ Ok(Atom(Ty(t))) => Ok(t),
+ _ => Err(InvalidFunDef(self.clone(), InvalidReturnType))
+ }?;
+
+ let additional_ctx = argnamevec.into_iter().zip(argtypes.clone());
+ for (name, ty) in additional_ctx {
+ ctx.insert(name, ty);
+ }
+
+ let argtype = if argtypes.len() == 0 {
+ NilType
+ } else if argtypes.len() == 1 {
+ argtypes[0].clone()
+ } else {
+ List(argtypes)
+ };
+
+ funbody.infer_type(ctx)?;
+
+ Ok(arr(argtype, rettype))
+ }
+}