jsonnet-microkanren

microKanren implementation in Jsonnet
git clone https://ccx.te2000.cz/git/jsonnet-microkanren
Log | Files | Refs

commit 2866b340daec3a4dbf5570f205adabbceaeab701
parent ef15df36ea3b11c3a01be57108d9bdc98018efa6
Author: Jan Pobříslo <ccx@te2000.cz>
Date:   Wed, 29 Nov 2023 14:04:53 +0000

major refactor

Diffstat:
Mexample.jsonnet | 17++++++++++++++---
Mmicrokanren.libsonnet | 361++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------
Mmicrokanren_checks.libsonnet | 16+++++++++++++---
3 files changed, 265 insertions(+), 129 deletions(-)

diff --git a/example.jsonnet b/example.jsonnet @@ -7,9 +7,20 @@ local uKc = import 'microkanren_checks.libsonnet'; assert uKc.Goal(goal); assert uKc.trace('goal', goal, true); local stream = goal.pursue(uK.emptyState); - assert std.type(stream); - true, /* assert uKc.Stream(stream); - stream.takeAll(), */ + stream.takeAll(), + + anotherFive: uK.runSingleVar(function(q) uK.eq(q, 5)), + + a_and_b: uK.takeAll(uK.conj( + uK.callFresh(function(a) ['eq', a, 7]), + uK.callFresh(function(b) ['or', b.eq(5), ['eq', b, 6]]), + )), + +/* + another_a_and_b: uK.runWithVars(['a', 'b'], function(vars) + ['and', ['eq', vars.a, 7], ['or', vars.b.eq(5), ['eq', vars.b, 6]]] + ), + */ } // vim: sts=2 ts=2 sw=2 et diff --git a/microkanren.libsonnet b/microkanren.libsonnet @@ -1,5 +1,25 @@ local uKc = import 'microkanren_checks.libsonnet'; -// helper functions + +// custom types + +local type(value) = + local t = std.type(value); + if t == 'object' then + if std.objectHas(value, 'µK:var') then + assert uKc.Variable(value); + 'variable' + else if std.objectHas(value, 'µK:goal') then + assert uKc.Goal(value); + 'goal' + else if std.objectHas(value, 'µK:stream') then + assert uKc.Stream(value); + 'stream' + else + t + else + t; + +// stream functions local baseStream(streamType) = { ['µK:stream']: streamType, @@ -8,7 +28,7 @@ local baseStream(streamType) = { self.call().pull() else self, - takeAll():: + takeAll():: local takeRecursive(accumulator, stream) = local mature = stream.pull(); if mature['µK:stream'] == 'empty' then @@ -38,19 +58,6 @@ local immatureStream(func) = baseStream('immature') + { call: func, }; -local makeVariable(module, number) = { - ['µK:var']: '%d' % number, - eq(value):: module.eq(self, value), -}; - -local makeGoal(callable) = { - ['µK:goal']: callable, - pursue(state):: - assert uKc.trace('pursuing goal in state', state, true); - assert uKc.State(state); - uKc.traceValue('goal returned', self['µK:goal'](state)), -}; - local mplus(stream1, stream2) = local t1 = stream1['µK:stream']; if t1 == 'empty' then @@ -73,139 +80,247 @@ local bind(stream, goal) = else error 'Invalid stream'; -// public interface -{ - type(value): - local t = std.type(value); - if t == 'object' then - if std.type(std.get(value, 'µK:var')) == 'number' then - 'variable' - else if std.type(std.get(value, 'µK:goal')) == 'function' then - 'goal' - else if std.type(std.get(value, 'µK:stream')) == 'string' then - 'stream' +// substitution functions +local walk(variable, substitution) = + if type(variable) == 'variable' then + if std.objectHas(substitution, variable['µK:var']) then + substitution.walk(substitution[variable['µK:var']]) + else + variable + else + variable; + +local unify(value1, value2, substitution) = + local w1 = substitution.walk(value1); + local w2 = substitution.walk(value2); + local t1 = type(w1); + local t2 = type(w2); + assert uKc.trace('unify walked', [t1, w1, t2, w2], true); + if t1 == 'variable' && t2 == 'variable' && w1 == w2 then + substitution + else if t1 == 'variable' then + substitution.extend(w1, w2) + else if t2 == 'variable' then + substitution.extend(w2, w1) + else if t1 == 'array' && t2 == 'array' then + if std.length(w1) == std.length(w2) then + if std.length(w1) == 0 then + substitution else - t + local s1 = substitution.unify(w1[0], w2[0]); + if s1 == null then + null + else + s1.unify(w1[1::], w2[1::]) else - t, + null + else if t1 == 'object' && t2 == 'object' then + if std.objectFields(w1) == std.objectFields(w2) then + assert uKc.trace("unifying objects", [w1, w2], true); + std.foldl( + function(field, prev_subst) + if prev_subst == null then + null + else + prev_subst.unify(w1[field], w2[field]), + std.objectFields(w1), + substitution + ) + else + null + else if w1 == w2 then + substitution + else + null; +// templates for objects with methods +local baseObjects = { // state emptyState: { variableCount: 0, + substitution: { - walk(var):: - if $.type(var) == 'variable' then - local bound = std.get(self, '%d' % var['µK:var']); - if bound == null then - var - else - self.walk(bound) - else - var, - unify(value1, value2):: - local w1 = self.walk(value1); - local w2 = self.walk(value1); - local t1 = $.type(w1); - local t2 = $.type(w2); - if t1 == 'variable' && t2 == 'variable' && w1 == w2 then - self - else if t1 == 'variable' then - self.extend(w1, w2) - else if t2 == 'variable' then - self.extend(w2, w1) - else if t1 == 'array' && t2 == 'array' then - if std.length(w1) == std.length(w2) then - if std.length(w1) == 0 then - self - else - local s1 = self.unify(w1[0], w2[0]); - if s1 == null then - null - else - s1.unify(w1[1::], w2[1::]) - else - null - else if t1 == 'object' && t2 == 'object' then - if std.objectFields(w1) == std.objectFields(w2) then - std.foldl( - function(field, prev_subst) - if prev_subst == null then - null - else - prev_subst.unify(w1[field], w2[field]), - std.objectFields(w1), - self - ) - else - null - else if w1 == w2 then - self - else - null, + extend(variable, value):: + assert uKc.Variable(variable); + self + {[variable['µK:var']]: value}, + walk(var):: walk(var, self), + unify(value1, value2):: unify(value1, value2, self), }, - freshVar():: - local current = self; - { - variable: makeVariable($, current.variableCount), - newState: uKc.checkState(current + {variableCount: current.variableCount + 1}), - }, }, // streams emptyStream: baseStream('empty'), + unitStream(state): + assert uKc.State(state); + matureStream(state, $.emptyStream), - unitStream(state): matureStream(state, $.emptyStream), +}; - // goal constructors - eq(value1, value2): makeGoal(function(state) - assert uKc.State(state); - local newSubst = state.substitution.unify(value1, value2); - if newSubst == null then - $.emptyStream - else - $.unitStream(state + {substitution: newSubst}) - ), - conj(goal1, goal2): +// goal functions + +local makeGoal(callable) = { + ['µK:goal']: callable, + pursue(state):: + assert uKc.trace('pursuing goal in state', state, true); + assert uKc.State(state); + uKc.traceValue('goal returned', self['µK:goal'](state)), +}; + +local conj(goal1, goal2) = assert uKc.Goal(goal1); assert uKc.Goal(goal2); - makeGoal(function(state) + local _conj(state) = assert uKc.State(state); - $.bind(goal1.pursue(state), goal2) - ), + bind(goal1.pursue(state), goal2); + makeGoal(_conj); - disj(goal1, goal2): +local disj(goal1, goal2) = assert uKc.Goal(goal1); assert uKc.Goal(goal2); - makeGoal(function(state) + local _disj(state) = assert uKc.State(state); - $.mplus(goal1.pursue(state), goal2.pursue(state)) - ), + mplus(goal1.pursue(state), goal2.pursue(state)); + makeGoal(_disj); + +local eq(value1, value2) = + local _eq(state) = + assert uKc.State(state); + assert uKc.trace('eq', [type(value1), value1, type(value2), value2], true); + local newSubst = state.substitution.unify(value1, value2); + assert uKc.trace('unify result', newSubst, true); + if newSubst == null then + baseObjects.emptyStream + else + baseObjects.unitStream(state + {substitution: newSubst}); + makeGoal(_eq); + +local maybeSExpGoal(sexp) = + local t = type(sexp); + if t == 'goal' then + sexp + else if t != 'array' then + error 'Invalid goal type: %s' % [t] + else + assert std.assertEqual(std.type(sexp), 'array'); + assert std.length(sexp) >= 3; + local head = sexp[0]; + if head == 'eq' then + assert std.length(sexp) == 3; + eq(sexp[1], sexp[2]) + else + local subgoals = std.map(maybeSExpGoal, std.reverse(sexp[1::])); + std.foldl( + if head == 'and' then conj + else if head == 'or' then disj + else error 'Invalid s-exp head: %s' % [head], + subgoals[1::], + subgoals[0] + ); + +local sExpGoal(sexp) = + assert std.assertEqual(std.type(sexp), 'array'); + maybeSExpGoal(sexp); + +// variable creation + +local makeVariable(number) = { + ['µK:var']: '%d' % number, + eq(value):: eq(self, value), +}; + +local makeFreshVariable(state) = + assert uKc.State(state); + { + variable: makeVariable(state.variableCount), + state: state + {variableCount: state.variableCount + 1}, + }; - callFresh(func): +local callFresh(func) = assert std.assertEqual(std.type(func), 'function'); - makeGoal(function(state) + local _callFresh(state) = assert uKc.State(state); - local fresh = state.freshVar(); + local fresh = makeFreshVariable(state); assert uKc.Variable(fresh.variable); - assert uKc.State(fresh.newState); + assert uKc.State(fresh.state); local newGoal = func(fresh.variable); - local t = $.type(newGoal); - (if t == 'goal' then - newGoal - else if t == 'array' then - local subgoals = std.reverse(newGoal[1::]); - (if newGoal[0] == 'and' then - std.foldl($.conj, subgoals[1::], subgoals[0]) - else if newGoal[0] == 'or' then - std.foldl($.disj, subgoals[1::], subgoals[0]) - else if newGoal[0] == 'eq' then - std.foldl($.disj, subgoals[1::], subgoals[0]) - else - error 'Invalid goal' - ) - else - error 'Invalid goal' - ).pursue(fresh.newState) - ), + assert uKc.trace('callFresh newGoal', newGoal, true); + local adaptedGoal = maybeSExpGoal(newGoal); + assert uKc.trace('callFresh adaptedGoal', adaptedGoal, true); + adaptedGoal.pursue(fresh.state); + makeGoal(_callFresh); + +// resolution helpers +local takeAll(goal) = + assert uKc.Goal(goal); + local stream = goal.pursue(baseObjects.emptyState); + // assert uKc.trace('runAll stream', stream, true); + assert uKc.Stream(stream); + stream.takeAll(); + +local take(count, goal) = + assert uKc.Goal(goal); + local stream = goal.pursue(baseObjects.emptyState); + assert uKc.Stream(stream); + stream.take(count); + +local runSingleVar(func, count=null, state=null) = + assert std.assertEqual(std.type(func), 'function'); + local fresh = makeFreshVariable(if state == null then baseObjects.emptyState else state); + local goal = maybeSExpGoal(func(fresh.variable)); + local stream = goal.pursue(fresh.state); + local states = + if count == null then + stream.takeAll() + else + stream.take(count); + [state.substitution.walk(fresh.variable) for state in states]; + +local runWithVars(variableNames, func, count=null, state=null) = + assert std.assertEqual(std.type(func), 'function'); + local baseState = + if state == null then baseObjects.emptyState else state; + assert uKc.State(baseState); + + assert std.assertEqual(std.type(variableNames), 'array'); + assert std.length(variableNames) >= 1; + + local named = std.foldl( + function(curr, name) + local fresh = makeFreshVariable(curr.state); + { + state: curr.state, + vars: curr.vars + {[name]: fresh.variable}, + }, + variableNames, + {state: baseState, vars: {}} + ); + + local goal = maybeSExpGoal(func(named.vars)); + local stream = goal.pursue(named.state); + local states = + if count == null then + stream.takeAll() + else + stream.take(count); + [std.mapWithKey(function(name, var) state.substitution.walk(var), named.vars) for state in states]; + + +// public interface +baseObjects + { + type: type, + + // goal constructors + eq: eq, + conj: conj, + disj: disj, + sExpGoal: sExpGoal, + maybeSExpGoal: maybeSExpGoal, + callFresh: callFresh, + + // resolution helpers + takeAll: takeAll, + take: take, + runSingleVar: runSingleVar, + runWithVars: runWithVars, } // vim: sts=2 ts=2 sw=2 et diff --git a/microkanren_checks.libsonnet b/microkanren_checks.libsonnet @@ -47,7 +47,9 @@ local checkType(type) = function(value) std.assertEqual(std.type(value), type); Substitution(subst): assert std.assertEqual(std.type(subst), 'object'); std.all(std.map( - function(f) std.assertEqual(std.type(f.name), 'number'), + function(f) + assert std.parseInt(f.key) >= 0; + true, std.objectKeysValues(subst) )), checkSubstitution(subst): @@ -55,10 +57,17 @@ local checkType(type) = function(value) std.assertEqual(std.type(value), type); subst, State(state): + local checkVariableMaximum = std.all(std.map( + function(f) + assert std.parseInt(f.key) >= 0; + assert std.parseInt(f.key) < state.variableCount; + true, + std.objectKeysValues(state.substitution) + )); $.objectFields(state, { variableCount: $.VariableCount, substitution: $.Substitution, - }), + }) && checkVariableMaximum, checkState(state): assert $.State(state); state, @@ -77,7 +86,8 @@ local checkType(type) = function(value) std.assertEqual(std.type(value), type); else if t == 'mature' then $.objectFields(stream, { ['µK:stream']: checkType('string'), - call: checkType('function'), + state: $.State, + next: $.Stream, }) else error "Incorrect stream type",