commit 2866b340daec3a4dbf5570f205adabbceaeab701
parent ef15df36ea3b11c3a01be57108d9bdc98018efa6
Author: Jan Pobříslo <ccx@te2000.cz>
Date: Wed, 29 Nov 2023 14:04:53 +0000
major refactor
Diffstat:
3 files changed, 265 insertions(+), 129 deletions(-)
diff --git a/example.jsonnet b/example.jsonnet
@@ -7,9 +7,20 @@ local uKc = import 'microkanren_checks.libsonnet';
assert uKc.Goal(goal);
assert uKc.trace('goal', goal, true);
local stream = goal.pursue(uK.emptyState);
- assert std.type(stream);
- true, /*
assert uKc.Stream(stream);
- stream.takeAll(), */
+ stream.takeAll(),
+
+ anotherFive: uK.runSingleVar(function(q) uK.eq(q, 5)),
+
+ a_and_b: uK.takeAll(uK.conj(
+ uK.callFresh(function(a) ['eq', a, 7]),
+ uK.callFresh(function(b) ['or', b.eq(5), ['eq', b, 6]]),
+ )),
+
+/*
+ another_a_and_b: uK.runWithVars(['a', 'b'], function(vars)
+ ['and', ['eq', vars.a, 7], ['or', vars.b.eq(5), ['eq', vars.b, 6]]]
+ ),
+ */
}
// vim: sts=2 ts=2 sw=2 et
diff --git a/microkanren.libsonnet b/microkanren.libsonnet
@@ -1,5 +1,25 @@
local uKc = import 'microkanren_checks.libsonnet';
-// helper functions
+
+// custom types
+
+local type(value) =
+ local t = std.type(value);
+ if t == 'object' then
+ if std.objectHas(value, 'µK:var') then
+ assert uKc.Variable(value);
+ 'variable'
+ else if std.objectHas(value, 'µK:goal') then
+ assert uKc.Goal(value);
+ 'goal'
+ else if std.objectHas(value, 'µK:stream') then
+ assert uKc.Stream(value);
+ 'stream'
+ else
+ t
+ else
+ t;
+
+// stream functions
local baseStream(streamType) = {
['µK:stream']: streamType,
@@ -8,7 +28,7 @@ local baseStream(streamType) = {
self.call().pull()
else
self,
- takeAll()::
+ takeAll()::
local takeRecursive(accumulator, stream) =
local mature = stream.pull();
if mature['µK:stream'] == 'empty' then
@@ -38,19 +58,6 @@ local immatureStream(func) = baseStream('immature') + {
call: func,
};
-local makeVariable(module, number) = {
- ['µK:var']: '%d' % number,
- eq(value):: module.eq(self, value),
-};
-
-local makeGoal(callable) = {
- ['µK:goal']: callable,
- pursue(state)::
- assert uKc.trace('pursuing goal in state', state, true);
- assert uKc.State(state);
- uKc.traceValue('goal returned', self['µK:goal'](state)),
-};
-
local mplus(stream1, stream2) =
local t1 = stream1['µK:stream'];
if t1 == 'empty' then
@@ -73,139 +80,247 @@ local bind(stream, goal) =
else
error 'Invalid stream';
-// public interface
-{
- type(value):
- local t = std.type(value);
- if t == 'object' then
- if std.type(std.get(value, 'µK:var')) == 'number' then
- 'variable'
- else if std.type(std.get(value, 'µK:goal')) == 'function' then
- 'goal'
- else if std.type(std.get(value, 'µK:stream')) == 'string' then
- 'stream'
+// substitution functions
+local walk(variable, substitution) =
+ if type(variable) == 'variable' then
+ if std.objectHas(substitution, variable['µK:var']) then
+ substitution.walk(substitution[variable['µK:var']])
+ else
+ variable
+ else
+ variable;
+
+local unify(value1, value2, substitution) =
+ local w1 = substitution.walk(value1);
+ local w2 = substitution.walk(value2);
+ local t1 = type(w1);
+ local t2 = type(w2);
+ assert uKc.trace('unify walked', [t1, w1, t2, w2], true);
+ if t1 == 'variable' && t2 == 'variable' && w1 == w2 then
+ substitution
+ else if t1 == 'variable' then
+ substitution.extend(w1, w2)
+ else if t2 == 'variable' then
+ substitution.extend(w2, w1)
+ else if t1 == 'array' && t2 == 'array' then
+ if std.length(w1) == std.length(w2) then
+ if std.length(w1) == 0 then
+ substitution
else
- t
+ local s1 = substitution.unify(w1[0], w2[0]);
+ if s1 == null then
+ null
+ else
+ s1.unify(w1[1::], w2[1::])
else
- t,
+ null
+ else if t1 == 'object' && t2 == 'object' then
+ if std.objectFields(w1) == std.objectFields(w2) then
+ assert uKc.trace("unifying objects", [w1, w2], true);
+ std.foldl(
+ function(field, prev_subst)
+ if prev_subst == null then
+ null
+ else
+ prev_subst.unify(w1[field], w2[field]),
+ std.objectFields(w1),
+ substitution
+ )
+ else
+ null
+ else if w1 == w2 then
+ substitution
+ else
+ null;
+// templates for objects with methods
+local baseObjects = {
// state
emptyState: {
variableCount: 0,
+
substitution: {
- walk(var)::
- if $.type(var) == 'variable' then
- local bound = std.get(self, '%d' % var['µK:var']);
- if bound == null then
- var
- else
- self.walk(bound)
- else
- var,
- unify(value1, value2)::
- local w1 = self.walk(value1);
- local w2 = self.walk(value1);
- local t1 = $.type(w1);
- local t2 = $.type(w2);
- if t1 == 'variable' && t2 == 'variable' && w1 == w2 then
- self
- else if t1 == 'variable' then
- self.extend(w1, w2)
- else if t2 == 'variable' then
- self.extend(w2, w1)
- else if t1 == 'array' && t2 == 'array' then
- if std.length(w1) == std.length(w2) then
- if std.length(w1) == 0 then
- self
- else
- local s1 = self.unify(w1[0], w2[0]);
- if s1 == null then
- null
- else
- s1.unify(w1[1::], w2[1::])
- else
- null
- else if t1 == 'object' && t2 == 'object' then
- if std.objectFields(w1) == std.objectFields(w2) then
- std.foldl(
- function(field, prev_subst)
- if prev_subst == null then
- null
- else
- prev_subst.unify(w1[field], w2[field]),
- std.objectFields(w1),
- self
- )
- else
- null
- else if w1 == w2 then
- self
- else
- null,
+ extend(variable, value)::
+ assert uKc.Variable(variable);
+ self + {[variable['µK:var']]: value},
+ walk(var):: walk(var, self),
+ unify(value1, value2):: unify(value1, value2, self),
},
- freshVar()::
- local current = self;
- {
- variable: makeVariable($, current.variableCount),
- newState: uKc.checkState(current + {variableCount: current.variableCount + 1}),
- },
},
// streams
emptyStream: baseStream('empty'),
+ unitStream(state):
+ assert uKc.State(state);
+ matureStream(state, $.emptyStream),
- unitStream(state): matureStream(state, $.emptyStream),
+};
- // goal constructors
- eq(value1, value2): makeGoal(function(state)
- assert uKc.State(state);
- local newSubst = state.substitution.unify(value1, value2);
- if newSubst == null then
- $.emptyStream
- else
- $.unitStream(state + {substitution: newSubst})
- ),
- conj(goal1, goal2):
+// goal functions
+
+local makeGoal(callable) = {
+ ['µK:goal']: callable,
+ pursue(state)::
+ assert uKc.trace('pursuing goal in state', state, true);
+ assert uKc.State(state);
+ uKc.traceValue('goal returned', self['µK:goal'](state)),
+};
+
+local conj(goal1, goal2) =
assert uKc.Goal(goal1);
assert uKc.Goal(goal2);
- makeGoal(function(state)
+ local _conj(state) =
assert uKc.State(state);
- $.bind(goal1.pursue(state), goal2)
- ),
+ bind(goal1.pursue(state), goal2);
+ makeGoal(_conj);
- disj(goal1, goal2):
+local disj(goal1, goal2) =
assert uKc.Goal(goal1);
assert uKc.Goal(goal2);
- makeGoal(function(state)
+ local _disj(state) =
assert uKc.State(state);
- $.mplus(goal1.pursue(state), goal2.pursue(state))
- ),
+ mplus(goal1.pursue(state), goal2.pursue(state));
+ makeGoal(_disj);
+
+local eq(value1, value2) =
+ local _eq(state) =
+ assert uKc.State(state);
+ assert uKc.trace('eq', [type(value1), value1, type(value2), value2], true);
+ local newSubst = state.substitution.unify(value1, value2);
+ assert uKc.trace('unify result', newSubst, true);
+ if newSubst == null then
+ baseObjects.emptyStream
+ else
+ baseObjects.unitStream(state + {substitution: newSubst});
+ makeGoal(_eq);
+
+local maybeSExpGoal(sexp) =
+ local t = type(sexp);
+ if t == 'goal' then
+ sexp
+ else if t != 'array' then
+ error 'Invalid goal type: %s' % [t]
+ else
+ assert std.assertEqual(std.type(sexp), 'array');
+ assert std.length(sexp) >= 3;
+ local head = sexp[0];
+ if head == 'eq' then
+ assert std.length(sexp) == 3;
+ eq(sexp[1], sexp[2])
+ else
+ local subgoals = std.map(maybeSExpGoal, std.reverse(sexp[1::]));
+ std.foldl(
+ if head == 'and' then conj
+ else if head == 'or' then disj
+ else error 'Invalid s-exp head: %s' % [head],
+ subgoals[1::],
+ subgoals[0]
+ );
+
+local sExpGoal(sexp) =
+ assert std.assertEqual(std.type(sexp), 'array');
+ maybeSExpGoal(sexp);
+
+// variable creation
+
+local makeVariable(number) = {
+ ['µK:var']: '%d' % number,
+ eq(value):: eq(self, value),
+};
+
+local makeFreshVariable(state) =
+ assert uKc.State(state);
+ {
+ variable: makeVariable(state.variableCount),
+ state: state + {variableCount: state.variableCount + 1},
+ };
- callFresh(func):
+local callFresh(func) =
assert std.assertEqual(std.type(func), 'function');
- makeGoal(function(state)
+ local _callFresh(state) =
assert uKc.State(state);
- local fresh = state.freshVar();
+ local fresh = makeFreshVariable(state);
assert uKc.Variable(fresh.variable);
- assert uKc.State(fresh.newState);
+ assert uKc.State(fresh.state);
local newGoal = func(fresh.variable);
- local t = $.type(newGoal);
- (if t == 'goal' then
- newGoal
- else if t == 'array' then
- local subgoals = std.reverse(newGoal[1::]);
- (if newGoal[0] == 'and' then
- std.foldl($.conj, subgoals[1::], subgoals[0])
- else if newGoal[0] == 'or' then
- std.foldl($.disj, subgoals[1::], subgoals[0])
- else if newGoal[0] == 'eq' then
- std.foldl($.disj, subgoals[1::], subgoals[0])
- else
- error 'Invalid goal'
- )
- else
- error 'Invalid goal'
- ).pursue(fresh.newState)
- ),
+ assert uKc.trace('callFresh newGoal', newGoal, true);
+ local adaptedGoal = maybeSExpGoal(newGoal);
+ assert uKc.trace('callFresh adaptedGoal', adaptedGoal, true);
+ adaptedGoal.pursue(fresh.state);
+ makeGoal(_callFresh);
+
+// resolution helpers
+local takeAll(goal) =
+ assert uKc.Goal(goal);
+ local stream = goal.pursue(baseObjects.emptyState);
+ // assert uKc.trace('runAll stream', stream, true);
+ assert uKc.Stream(stream);
+ stream.takeAll();
+
+local take(count, goal) =
+ assert uKc.Goal(goal);
+ local stream = goal.pursue(baseObjects.emptyState);
+ assert uKc.Stream(stream);
+ stream.take(count);
+
+local runSingleVar(func, count=null, state=null) =
+ assert std.assertEqual(std.type(func), 'function');
+ local fresh = makeFreshVariable(if state == null then baseObjects.emptyState else state);
+ local goal = maybeSExpGoal(func(fresh.variable));
+ local stream = goal.pursue(fresh.state);
+ local states =
+ if count == null then
+ stream.takeAll()
+ else
+ stream.take(count);
+ [state.substitution.walk(fresh.variable) for state in states];
+
+local runWithVars(variableNames, func, count=null, state=null) =
+ assert std.assertEqual(std.type(func), 'function');
+ local baseState =
+ if state == null then baseObjects.emptyState else state;
+ assert uKc.State(baseState);
+
+ assert std.assertEqual(std.type(variableNames), 'array');
+ assert std.length(variableNames) >= 1;
+
+ local named = std.foldl(
+ function(curr, name)
+ local fresh = makeFreshVariable(curr.state);
+ {
+ state: curr.state,
+ vars: curr.vars + {[name]: fresh.variable},
+ },
+ variableNames,
+ {state: baseState, vars: {}}
+ );
+
+ local goal = maybeSExpGoal(func(named.vars));
+ local stream = goal.pursue(named.state);
+ local states =
+ if count == null then
+ stream.takeAll()
+ else
+ stream.take(count);
+ [std.mapWithKey(function(name, var) state.substitution.walk(var), named.vars) for state in states];
+
+
+// public interface
+baseObjects + {
+ type: type,
+
+ // goal constructors
+ eq: eq,
+ conj: conj,
+ disj: disj,
+ sExpGoal: sExpGoal,
+ maybeSExpGoal: maybeSExpGoal,
+ callFresh: callFresh,
+
+ // resolution helpers
+ takeAll: takeAll,
+ take: take,
+ runSingleVar: runSingleVar,
+ runWithVars: runWithVars,
}
// vim: sts=2 ts=2 sw=2 et
diff --git a/microkanren_checks.libsonnet b/microkanren_checks.libsonnet
@@ -47,7 +47,9 @@ local checkType(type) = function(value) std.assertEqual(std.type(value), type);
Substitution(subst):
assert std.assertEqual(std.type(subst), 'object');
std.all(std.map(
- function(f) std.assertEqual(std.type(f.name), 'number'),
+ function(f)
+ assert std.parseInt(f.key) >= 0;
+ true,
std.objectKeysValues(subst)
)),
checkSubstitution(subst):
@@ -55,10 +57,17 @@ local checkType(type) = function(value) std.assertEqual(std.type(value), type);
subst,
State(state):
+ local checkVariableMaximum = std.all(std.map(
+ function(f)
+ assert std.parseInt(f.key) >= 0;
+ assert std.parseInt(f.key) < state.variableCount;
+ true,
+ std.objectKeysValues(state.substitution)
+ ));
$.objectFields(state, {
variableCount: $.VariableCount,
substitution: $.Substitution,
- }),
+ }) && checkVariableMaximum,
checkState(state):
assert $.State(state);
state,
@@ -77,7 +86,8 @@ local checkType(type) = function(value) std.assertEqual(std.type(value), type);
else if t == 'mature' then
$.objectFields(stream, {
['µK:stream']: checkType('string'),
- call: checkType('function'),
+ state: $.State,
+ next: $.Stream,
})
else
error "Incorrect stream type",