Implement closures

What a pain.

Next steps: More happy path tests, then large values/mutabilty, then redo in bytecode.
closure
Alona EM 2021-12-28 01:42:17 +00:00
parent 93266a980e
commit bc00dbe69e
14 changed files with 816 additions and 825 deletions

View File

@ -22,3 +22,8 @@ lalrpop = "0.19.6"
[dev-dependencies] [dev-dependencies]
assert_cmd = "2.0.2" assert_cmd = "2.0.2"
insta = { version = "1.9.0", features = ["glob"] } insta = { version = "1.9.0", features = ["glob"] }
[profile.dev]
# https://github.com/rust-lang/rust/issues/92163
# https://github.com/rust-lang/rust/issues/92315
incremental=false

View File

@ -1,20 +1,24 @@
use std::rc::Rc; use std::rc::Rc;
#[derive(Debug, debug2::Debug, PartialEq)] #[derive(Debug /*/*, debug2::Debug*/*/, PartialEq)]
crate enum Tree { crate enum Tree {
Leaf(Literal), Leaf(Literal),
Define(String, Box<Tree>), Define(String, Box<Tree>),
If(Box<[Tree; 3]>), If(Box<[Tree; 3]>),
Lambda(Lambda), // Its easier to box the lambdas in the parser than the vm, as
// here we see all of them exactly once
Func(Rc<Func>),
Branch(Vec<Tree>), Branch(Vec<Tree>),
} }
#[derive(Debug, debug2::Debug, PartialEq, Clone)] #[derive(Debug /*/*, debug2::Debug*/*/, PartialEq)]
crate struct Func {
crate args: Vec<String>,
crate body: Vec<Tree>,
}
crate struct Lambda(crate Rc<[String]>, crate Rc<[Tree]>); #[derive(Debug /*/*, debug2::Debug*/*/, PartialEq)]
#[derive(Debug, debug2::Debug, PartialEq)]
crate enum Literal { crate enum Literal {
Sym(String), Sym(String),

View File

@ -1,58 +1,77 @@
use std::assert_matches::assert_matches; use std::assert_matches::assert_matches;
use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc;
use crate::value::Value; use crate::value::Value;
use crate::{ast, prims}; use crate::{ast, prims};
pub(crate) fn eval(t: &ast::Tree, env: &mut Env) -> Result<Value, RTError> { #[derive(/*debug2::Debug,*/ Debug, Clone)]
crate struct Lambda {
crate func: Rc<ast::Func>,
crate captures: Rc<RefCell<Env>>,
}
pub(crate) fn eval(t: &ast::Tree, env: Rc<RefCell<Env>>) -> Result<Value, RTError> {
Ok(match t { Ok(match t {
ast::Tree::Leaf(l) => match l { ast::Tree::Leaf(l) => match l {
ast::Literal::Sym(s) => match env.lookup(s) { ast::Literal::Sym(s) => match env.borrow().lookup(s) {
Some(v) => v, Some(v) => v,
None => return err(format!("Undefined variable `{}`", s)), None => return err(format!("Undefined variable `{}`", s)),
}, },
ast::Literal::Num(v) => Value::Num(*v), ast::Literal::Num(v) => Value::Num(*v),
ast::Literal::Bool(b) => Value::Bool(*b), ast::Literal::Bool(b) => Value::Bool(*b),
}, },
ast::Tree::Lambda(l) => Value::Lambda(l.clone()), ast::Tree::Func(l) => Value::Lambda(Lambda {
func: Rc::clone(l),
captures: Rc::clone(&env),
}),
ast::Tree::Branch(args) => { ast::Tree::Branch(args) => {
let Some(fun) = args.get(0) else { return err("No argument given".to_owned()) }; let Some(fun) = args.get(0) else { return err("No func given".to_owned()) };
let fun = eval(fun, env)?;
let fun = fun.as_func()?; let fun = eval(fun, Rc::clone(&env))?.as_func()?;
let args = args let args = args
.iter() .iter()
.skip(1) .skip(1)
.map(|a| eval(a, env)) .map(|a| eval(a, Rc::clone(&env)))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
// return fun(&args);
match fun { match fun {
Callable::Func(f) => f(&args)?, Callable::Func(f) => f(&args)?,
Callable::Lambda(l) => { Callable::Lambda(l) => {
if l.0.len() == args.len() { if l.func.args.len() == args.len() {
let mut env = env.child(); // let mut env = env.child();
for (x, y) in l.0.iter().zip(args) { let mut env = Env::child(Rc::clone(&l.captures));
for (x, y) in l.func.args.iter().zip(args) {
env.define(x.clone(), y) env.define(x.clone(), y)
} }
let env = Rc::new(RefCell::new(env));
let [main @ .., tail] = &l.1[..] else { unreachable!("Body has 1+ element by parser") }; let [main @ .., tail] = &l.func.body[..] else { unreachable!("Body has 1+ element by parser") };
for i in main { for i in main {
eval(i, &mut env)?; eval(i, Rc::clone(&env))?;
} }
eval(tail, &mut env)? eval(tail, env)?
} else { } else {
return err(format!("Need {} args, got {}", l.0.len(), args.len())); return err(format!(
"Need {} args, got {}",
l.func.args.len(),
args.len()
));
} }
} }
} }
} }
ast::Tree::Define(name, to) => { ast::Tree::Define(name, to) => {
let val = eval(to, env)?; let val = eval(to, Rc::clone(&env))?;
env.define(name.to_owned(), val); env.borrow_mut().define(name.to_owned(), val);
Value::Trap Value::Trap
} }
ast::Tree::If(box [cond, tcase, fcase]) => { ast::Tree::If(box [cond, tcase, fcase]) => {
let b = eval(cond, env)?.as_bool()?; let b = eval(cond, Rc::clone(&env))?.as_bool()?;
let body = if b { tcase } else { fcase }; let body = if b { tcase } else { fcase };
eval(body, env)? eval(body, env)?
} }
@ -66,40 +85,60 @@ pub(crate) fn err<T>(s: String) -> Result<T, RTError> {
Err(RTError(s)) Err(RTError(s))
} }
pub(crate) enum Callable<'a> { pub(crate) enum Callable {
Func(prims::Func), Func(prims::NativeFunc),
Lambda(&'a ast::Lambda), Lambda(Lambda),
} }
pub(crate) struct Env<'a> { #[derive(Clone /*, debug2::Debug*/)]
pub(crate) vars: HashMap<String, Value>, pub(crate) struct Env {
pub(crate) enclosing: Option<&'a Env<'a>>, vars: HashMap<String, Value>,
enclosing: Option<Rc<RefCell<Env>>>,
} }
impl<'a> Env<'a> { impl std::fmt::Debug for Env {
pub(crate) fn child(&'a self) -> Env<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
struct Vars<'a>(&'a HashMap<String, Value>);
impl std::fmt::Debug for Vars<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.0.keys()).finish()
}
}
f.debug_struct("Env")
.field("vars", &Vars(&self.vars))
.field("enclosing", &self.enclosing)
.finish()
}
}
impl Env {
pub(crate) fn child(this: Rc<RefCell<Self>>) -> Env {
Env { Env {
vars: HashMap::new(), vars: HashMap::new(),
enclosing: Some(self), enclosing: Some(this),
} }
} }
pub(crate) fn lookup(&self, s: &str) -> Option<Value> { pub(crate) fn lookup(&self, s: &str) -> Option<Value> {
if let Some(v) = self.vars.get(s) { if let Some(v) = self.vars.get(s) {
Some(v.clone()) Some(v.clone())
} else if let Some(parent) = self.enclosing { } else if let Some(parent) = &self.enclosing {
parent.lookup(s) parent.borrow().lookup(s)
} else { } else {
None None
} }
} }
pub(crate) fn define(&mut self, name: String, val: Value) { pub(crate) fn define(&mut self, name: String, val: Value) {
assert_ne!(val, Value::Trap); // TODO: Better error
self.vars.insert(name, val); self.vars.insert(name, val);
} }
} }
pub(crate) fn default_env() -> Env<'static> { pub(crate) fn default_env() -> Env {
let mut vars = HashMap::new(); let mut vars = HashMap::new();
for (name, fun) in prims::prims() { for (name, fun) in prims::prims() {

View File

@ -17,9 +17,10 @@ Trees = Tree+;
pub(crate) Tree: Tree = { pub(crate) Tree: Tree = {
"(" <Tree+> ")" => Tree::Branch(<>), "(" <Tree+> ")" => Tree::Branch(<>),
"(" "define" <Sym> <BTree> ")" => Tree::Define(<>), "(" "define" <Sym> <BTree> ")" => Tree::Define(<>),
"(" "define" "(" <name:Sym> <args:RcSlice<Sym>> ")" <body:RcSlice<Tree>> ")" => Tree::Define(name, Box::new(Tree::Lambda(Lambda(args, body)))), "(" "define" "(" <name:Sym> <args:Sym*> ")" <body:Trees> ")"
=> Tree::Define(name, Box::new(Tree::Func(Rc::new(Func{args, body})))),
"(" "if" <Tree> <Tree> <Tree> ")" => Tree::If(Box::new([<>])), "(" "if" <Tree> <Tree> <Tree> ")" => Tree::If(Box::new([<>])),
"(" "lambda (" <RcSlice<Sym>> ")" <RcSlice<Tree>> ")" => Tree::Lambda(Lambda(<>)), "(" "lambda (" <args:Sym*> ")" <body:Trees> ")" => Tree::Func(Rc::new(Func{<>})),
Literal => Tree::Leaf(<>), Literal => Tree::Leaf(<>),
} }

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,9 @@
#![feature(crate_visibility_modifier)] #![feature(crate_visibility_modifier)]
#![feature(let_else)] #![feature(let_else)]
use std::cell::RefCell;
use std::io; use std::io;
use std::rc::Rc;
use rustyline::validate::{ use rustyline::validate::{
MatchingBracketValidator, ValidationContext, ValidationResult, Validator, MatchingBracketValidator, ValidationContext, ValidationResult, Validator,
@ -33,9 +35,9 @@ fn main() {
fn run_file(file: &str) -> io::Result<()> { fn run_file(file: &str) -> io::Result<()> {
let src = std::fs::read_to_string(file)?; let src = std::fs::read_to_string(file)?;
let tree = grammar::FileParser::new().parse(&src).unwrap(); let tree = grammar::FileParser::new().parse(&src).unwrap();
let mut env = eval::default_env(); let env = Rc::new(RefCell::new(eval::default_env()));
for i in tree { for i in tree {
eval::eval(&i, &mut env).unwrap(); eval::eval(&i, Rc::clone(&env)).unwrap();
} }
Ok(()) Ok(())
} }
@ -46,14 +48,23 @@ fn repl() {
brackets: MatchingBracketValidator::new(), brackets: MatchingBracketValidator::new(),
})); }));
let mut env = eval::default_env(); let env = Rc::new(RefCell::new(eval::default_env()));
while let Ok(line) = rl.readline("> ") { while let Ok(line) = rl.readline("> ") {
rl.add_history_entry(&line); rl.add_history_entry(&line);
let tree = grammar::TreeParser::new().parse(&line).unwrap(); let tree = grammar::TreeParser::new().parse(&line).unwrap();
// dbg!(&tree); // dbg!(&tree);
println!("< {:?}", eval::eval(&tree, &mut env)) match eval::eval(&tree, Rc::clone(&env)) {
Ok(v) => {
// Cant use == because it will panic. This is the only
// valid comparison to trap
if !matches!(v, value::Value::Trap) {
println!("{}", v)
}
}
Err(e) => println!("! {}", e.0),
}
} }
} }

View File

@ -1,9 +1,12 @@
use crate::eval::{err, RTError}; use crate::eval::{err, RTError};
use crate::value::Value; use crate::value::Value;
crate type Func = fn(&[Value]) -> Result<Value, RTError>;
crate fn prims() -> &'static [(&'static str, Func)] { type Result = std::result::Result<Value, RTError>;
crate type NativeFunc = for<'a> fn(&'a [Value]) -> Result;
crate fn prims() -> &'static [(&'static str, NativeFunc)] {
&[ &[
("*", mul), ("*", mul),
("+", add), ("+", add),
@ -21,25 +24,27 @@ crate fn prims() -> &'static [(&'static str, Func)] {
(">", gt), (">", gt),
("<=", le), ("<=", le),
(">=", ge), (">=", ge),
("_Z_debug", z_debug)
] ]
} }
// TODO: DRY +-/* // TODO: DRY +-/*
fn add(args: &[Value]) -> Result<Value, RTError> { fn add(args: &[Value]) -> Result {
args.iter() args.iter()
.map(Value::as_num) .map(Value::as_num)
.try_fold(0.0, |a, b| Ok(a + b?)) .try_fold(0.0, |a, b| Ok(a + b?))
.map(Value::Num) .map(Value::Num)
} }
fn mul(args: &[Value]) -> Result<Value, RTError> { fn mul(args: &[Value]) -> Result {
args.iter() args.iter()
.map(Value::as_num) .map(Value::as_num)
.try_fold(1.0, |a, b| Ok(a * b?)) .try_fold(1.0, |a, b| Ok(a * b?))
.map(Value::Num) .map(Value::Num)
} }
fn div(args: &[Value]) -> Result<Value, RTError> { fn div(args: &[Value]) -> Result {
let init = args let init = args
.get(0) .get(0)
.ok_or_else(|| RTError("`div` needs at least one argument".to_owned()))? .ok_or_else(|| RTError("`div` needs at least one argument".to_owned()))?
@ -53,7 +58,7 @@ fn div(args: &[Value]) -> Result<Value, RTError> {
})) }))
} }
fn sub(args: &[Value]) -> Result<Value, RTError> { fn sub(args: &[Value]) -> Result {
let init = args let init = args
.get(0) .get(0)
.ok_or_else(|| RTError("`sub` needs at least one argument".to_owned()))? .ok_or_else(|| RTError("`sub` needs at least one argument".to_owned()))?
@ -67,17 +72,17 @@ fn sub(args: &[Value]) -> Result<Value, RTError> {
})) }))
} }
fn equals(args: &[Value]) -> Result<Value, RTError> { fn equals(args: &[Value]) -> Result {
Ok(Value::Bool(args.array_windows().all(|[l, r]| l == r))) Ok(Value::Bool(args.array_windows().all(|[l, r]| l == r)))
} }
fn display(args: &[Value]) -> Result<Value, RTError> { fn display(args: &[Value]) -> Result {
let [arg] = args else {return err("To many args to `display`".to_owned())}; let [arg] = args else {return err("To many args to `display`".to_owned())};
print!("{:?}", arg); print!("{}", arg);
Ok(Value::Trap) Ok(Value::Trap)
} }
fn newline(args: &[Value]) -> Result<Value, RTError> { fn newline(args: &[Value]) -> Result {
if !args.is_empty() { if !args.is_empty() {
return err("Newline takes no args".to_owned()); return err("Newline takes no args".to_owned());
} }
@ -85,7 +90,7 @@ fn newline(args: &[Value]) -> Result<Value, RTError> {
Ok(Value::Trap) Ok(Value::Trap)
} }
fn abs(args: &[Value]) -> Result<Value, RTError> { fn abs(args: &[Value]) -> Result {
let [v] = args else { return err("abs takes 1 arg".to_owned()) }; let [v] = args else { return err("abs takes 1 arg".to_owned()) };
let ans = v.as_num()?.abs(); let ans = v.as_num()?.abs();
Ok(Value::Num(ans)) Ok(Value::Num(ans))
@ -95,7 +100,7 @@ crate fn compare_core(
args: &[Value], args: &[Value],
f: fn(f64, f64) -> bool, f: fn(f64, f64) -> bool,
name: &'static str, name: &'static str,
) -> Result<Value, RTError> { ) -> Result {
for [l, r] in args.array_windows() { for [l, r] in args.array_windows() {
let (Value::Num(l), Value::Num(r)) = (l,r) else let (Value::Num(l), Value::Num(r)) = (l,r) else
{ {
@ -114,11 +119,16 @@ crate fn compare_core(
macro_rules! cmps { macro_rules! cmps {
($(($name:ident $op:tt))*) => { ($(($name:ident $op:tt))*) => {
$( $(
fn $name(args: &[Value]) -> Result<Value, RTError> { fn $name(args: &[Value]) -> Result {
compare_core(args, |l, r| l $op r, stringify!($op)) compare_core(args, |l, r| l $op r, stringify!($op))
} }
)* )*
}; };
} }
cmps! { (lt <) (gt >) (le <=) (ge >=) } cmps! { (lt <) (gt >) (le <=) (ge >=) }
fn z_debug(args: &[Value]) -> Result {
eprintln!("{:?}", args);
Ok(Value::Trap)
}

View File

@ -0,0 +1,13 @@
---
source: src/tests.rs
assertion_line: 41
expression: run-pass ambig-scope.scm
---
42.0
1.0
2.0
3.0
2.0
1.0

View File

@ -0,0 +1,9 @@
---
source: src/tests.rs
assertion_line: 41
expression: run-pass capture.scm
---
#<procedure>
2.0

View File

@ -0,0 +1,8 @@
---
source: src/tests.rs
assertion_line: 41
expression: run-pass curry.scm
---
3.0

View File

@ -0,0 +1,18 @@
#lang scheme
(define (displayln x) (newline) (display x))
(define (ambig1 x) (lambda (x) x))
(displayln ((ambig1 13) 42))
(define (ambig2 a)
(displayln a)
(define a- a)
(lambda (b) (displayln b)
(lambda (a)
(displayln a)
(displayln b)
(displayln a-))))
(((ambig2 1) 2) 3)

View File

@ -0,0 +1,12 @@
#lang scheme
(define (displayln x) (display x) (newline))
(define (const x) (lambda () x))
(define two-thunk (const 2))
(displayln two-thunk)
(displayln (two-thunk))

View File

@ -0,0 +1,7 @@
#lang scheme
(define (displayln x) (display x) (newline))
(define (curry2 f) (lambda (l) (lambda (r) (f l r))))
(displayln (((curry2 +) 1) 2))

View File

@ -1,16 +1,32 @@
use crate::{ast, eval, prims}; use crate::{eval, prims};
use std::rc::Rc; use std::rc::Rc;
#[derive(Clone)] #[derive(Clone /*, debug2::Debug*/)]
crate enum Value { crate enum Value {
Num(f64), Num(f64),
Func(prims::Func), // TODO: implement debug2::Debug for `fn` type, and do it right
Lambda(ast::Lambda), // https://godbolt.org/z/vr9erGeKq
Func(prims::NativeFunc),
Lambda(eval::Lambda),
Bool(bool), Bool(bool),
/// Result of things that shouldnt have values, like (define x 3) /// Result of things that shouldnt have values, like (define x 3)
/// TODO: Figure this out
Trap, Trap,
} }
impl std::fmt::Debug for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Num(arg0) => f.debug_tuple("Num").field(arg0).finish(),
// We cant just pass the field because of livetime bs, see https://godbolt.org/z/vr9erGeKq
Self::Func(_) => f.debug_struct("Func").finish_non_exhaustive(),
Self::Lambda(arg0) => f.debug_tuple("Lambda").field(arg0).finish(),
Self::Bool(arg0) => f.debug_tuple("Bool").field(arg0).finish(),
Self::Trap => write!(f, "Trap"),
}
}
}
impl PartialEq for Value { impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
use Value::*; use Value::*;
@ -18,7 +34,9 @@ impl PartialEq for Value {
(Num(l), Num(r)) => l == r, (Num(l), Num(r)) => l == r,
(Func(l), Func(r)) => *l as usize == *r as usize, (Func(l), Func(r)) => *l as usize == *r as usize,
(Bool(l), Bool(r)) => l == r, (Bool(l), Bool(r)) => l == r,
(Lambda(l), Lambda(r)) => Rc::ptr_eq(&l.0, &r.0) && Rc::ptr_eq(&l.1, &r.1), (Lambda(l), Lambda(r)) => {
Rc::ptr_eq(&l.func, &r.func) && Rc::ptr_eq(&l.captures, &r.captures)
}
(Num(_), _) => false, (Num(_), _) => false,
(Func(_), _) => false, (Func(_), _) => false,
@ -30,10 +48,10 @@ impl PartialEq for Value {
} }
} }
impl std::fmt::Debug for Value { impl std::fmt::Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::Num(n) => n.fmt(f), Self::Num(n) => std::fmt::Debug::fmt(n, f), // TODO: Change to displays
Self::Func(_) => f.write_str("#<procedure>"), Self::Func(_) => f.write_str("#<procedure>"),
Self::Lambda(_) => f.write_str("#<procedure>"), Self::Lambda(_) => f.write_str("#<procedure>"),
Self::Bool(b) => f.write_str(if *b { "#t" } else { "#f" }), Self::Bool(b) => f.write_str(if *b { "#t" } else { "#f" }),
@ -51,14 +69,11 @@ impl Value {
} }
} }
crate fn as_func(&self) -> Result<eval::Callable, eval::RTError> { crate fn as_func(self) -> Result<eval::Callable, eval::RTError> {
match self { match self {
Self::Func(f) => Ok(eval::Callable::Func(*f)), Self::Func(f) => Ok(eval::Callable::Func(f)),
Self::Lambda(l) => Ok(eval::Callable::Lambda(l)), Self::Lambda(l) => Ok(eval::Callable::Lambda(l)),
_ => Err(eval::RTError(format!( _ => Err(eval::RTError(format!("Expected a function, got {}", self))),
"Expected a function, got {:?}",
self
))),
} }
} }