microkanren.libsonnet (8973B)
1 local uKc = import 'microkanren_checks.libsonnet'; 2 3 // custom types 4 5 local type(value) = 6 local t = std.type(value); 7 if t == 'object' then 8 if std.objectHas(value, 'µK:var') then 9 assert uKc.Variable(value); 10 'variable' 11 else if std.objectHas(value, 'µK:goal') then 12 assert uKc.Goal(value); 13 'goal' 14 else if std.objectHas(value, 'µK:stream') then 15 assert uKc.Stream(value); 16 'stream' 17 else 18 t 19 else 20 t; 21 22 // stream functions 23 24 local baseStream(streamType) = { 25 ['µK:stream']: streamType, 26 pull():: 27 if self['µK:stream'] == 'immature' then 28 self.call().pull() 29 else 30 self, 31 takeAll():: 32 local takeRecursive(accumulator, stream) = 33 local mature = stream.pull(); 34 if mature['µK:stream'] == 'empty' then 35 accumulator 36 else 37 takeRecursive(accumulator + [mature.state], mature.next); 38 takeRecursive([], self), 39 take(n):: 40 local takeRecursive(accumulator, stream) = 41 if std.length(accumulator) >= n then 42 accumulator 43 else 44 local mature = stream.pull(); 45 if mature['µK:stream'] == 'empty' then 46 accumulator 47 else 48 takeRecursive(accumulator + [mature.state], mature.next); 49 takeRecursive([], self), 50 }; 51 52 local matureStream(state, next) = baseStream('mature') + { 53 state: state, 54 next: next, 55 }; 56 57 local immatureStream(func) = baseStream('immature') + { 58 call: func, 59 }; 60 61 local mplus(stream1, stream2) = 62 local t1 = stream1['µK:stream']; 63 if t1 == 'empty' then 64 immatureStream(function() stream2) 65 else if t1 == 'immature' then 66 immatureStream(function() mplus(stream2, stream1.call())) 67 else if t1 == 'mature' then 68 immatureStream(function() matureStream(stream1.state, mplus(stream2, stream1.next))) 69 else 70 error 'Invalid stream'; 71 72 local bind(stream, goal) = 73 local t = stream['µK:stream']; 74 if t == 'empty' then 75 stream 76 else if t == 'immature' then 77 immatureStream(function() bind(stream.call(), goal)) 78 else if t == 'mature' then 79 immatureStream(function() mplus(goal.pursue(stream.state), bind(stream.next, goal))) 80 else 81 error 'Invalid stream'; 82 83 // substitution functions 84 local walk(variable, substitution) = 85 if type(variable) == 'variable' then 86 if std.objectHas(substitution, variable['µK:var']) then 87 substitution.walk(substitution[variable['µK:var']]) 88 else 89 variable 90 else 91 variable; 92 93 local unify(value1, value2, substitution) = 94 local w1 = substitution.walk(value1); 95 local w2 = substitution.walk(value2); 96 local t1 = type(w1); 97 local t2 = type(w2); 98 assert uKc.trace('unify walked', [t1, w1, t2, w2], true); 99 if t1 == 'variable' && t2 == 'variable' && w1 == w2 then 100 substitution 101 else if t1 == 'variable' then 102 substitution.extend(w1, w2) 103 else if t2 == 'variable' then 104 substitution.extend(w2, w1) 105 else if t1 == 'array' && t2 == 'array' then 106 if std.length(w1) == std.length(w2) then 107 if std.length(w1) == 0 then 108 substitution 109 else 110 local s1 = substitution.unify(w1[0], w2[0]); 111 if s1 == null then 112 null 113 else 114 s1.unify(w1[1::], w2[1::]) 115 else 116 null 117 else if t1 == 'object' && t2 == 'object' then 118 if std.objectFields(w1) == std.objectFields(w2) then 119 assert uKc.trace("unifying objects", [w1, w2], true); 120 std.foldl( 121 function(field, prev_subst) 122 if prev_subst == null then 123 null 124 else 125 prev_subst.unify(w1[field], w2[field]), 126 std.objectFields(w1), 127 substitution 128 ) 129 else 130 null 131 else if w1 == w2 then 132 substitution 133 else 134 null; 135 136 // templates for objects with methods 137 local baseObjects = { 138 // state 139 emptyState: { 140 variableCount: 0, 141 142 substitution: { 143 extend(variable, value):: 144 assert uKc.Variable(variable); 145 self + {[variable['µK:var']]: value}, 146 walk(var):: walk(var, self), 147 unify(value1, value2):: unify(value1, value2, self), 148 }, 149 }, 150 151 // streams 152 emptyStream: baseStream('empty'), 153 unitStream(state): 154 assert uKc.State(state); 155 matureStream(state, $.emptyStream), 156 157 }; 158 159 // goal functions 160 161 local makeGoal(callable) = { 162 ['µK:goal']: callable, 163 pursue(state):: 164 assert uKc.trace('pursuing goal in state', state, true); 165 assert uKc.State(state); 166 uKc.traceValue('goal returned', self['µK:goal'](state)), 167 }; 168 169 local conj(goal1, goal2) = 170 assert uKc.Goal(goal1); 171 assert uKc.Goal(goal2); 172 local _conj(state) = 173 assert uKc.State(state); 174 bind(goal1.pursue(state), goal2); 175 makeGoal(_conj); 176 177 local disj(goal1, goal2) = 178 assert uKc.Goal(goal1); 179 assert uKc.Goal(goal2); 180 local _disj(state) = 181 assert uKc.State(state); 182 mplus(goal1.pursue(state), goal2.pursue(state)); 183 makeGoal(_disj); 184 185 local eq(value1, value2) = 186 local _eq(state) = 187 assert uKc.State(state); 188 assert uKc.trace('eq', [type(value1), value1, type(value2), value2], true); 189 local newSubst = state.substitution.unify(value1, value2); 190 assert uKc.trace('unify result', newSubst, true); 191 if newSubst == null then 192 baseObjects.emptyStream 193 else 194 baseObjects.unitStream(state + {substitution: newSubst}); 195 makeGoal(_eq); 196 197 local maybeSExpGoal(sexp) = 198 local t = type(sexp); 199 if t == 'goal' then 200 sexp 201 else if t != 'array' then 202 error 'Invalid goal type: %s' % [t] 203 else 204 assert std.assertEqual(std.type(sexp), 'array'); 205 assert std.length(sexp) >= 3; 206 local head = sexp[0]; 207 if head == 'eq' then 208 assert std.length(sexp) == 3; 209 eq(sexp[1], sexp[2]) 210 else 211 local subgoals = std.map(maybeSExpGoal, std.reverse(sexp[1::])); 212 std.foldl( 213 if head == 'and' then conj 214 else if head == 'or' then disj 215 else error 'Invalid s-exp head: %s' % [head], 216 subgoals[1::], 217 subgoals[0] 218 ); 219 220 local sExpGoal(sexp) = 221 assert std.assertEqual(std.type(sexp), 'array'); 222 maybeSExpGoal(sexp); 223 224 // variable creation 225 226 local makeVariable(number) = { 227 ['µK:var']: '%d' % number, 228 eq(value):: eq(self, value), 229 }; 230 231 local makeFreshVariable(state) = 232 assert uKc.State(state); 233 { 234 variable: makeVariable(state.variableCount), 235 state: state + {variableCount: state.variableCount + 1}, 236 }; 237 238 local callFresh(func) = 239 assert std.assertEqual(std.type(func), 'function'); 240 local _callFresh(state) = 241 assert uKc.State(state); 242 local fresh = makeFreshVariable(state); 243 assert uKc.Variable(fresh.variable); 244 assert uKc.State(fresh.state); 245 local newGoal = func(fresh.variable); 246 assert uKc.trace('callFresh newGoal', newGoal, true); 247 local adaptedGoal = maybeSExpGoal(newGoal); 248 assert uKc.trace('callFresh adaptedGoal', adaptedGoal, true); 249 adaptedGoal.pursue(fresh.state); 250 makeGoal(_callFresh); 251 252 // resolution helpers 253 local takeAll(goal) = 254 assert uKc.Goal(goal); 255 local stream = goal.pursue(baseObjects.emptyState); 256 // assert uKc.trace('runAll stream', stream, true); 257 assert uKc.Stream(stream); 258 stream.takeAll(); 259 260 local take(count, goal) = 261 assert uKc.Goal(goal); 262 local stream = goal.pursue(baseObjects.emptyState); 263 assert uKc.Stream(stream); 264 stream.take(count); 265 266 local runSingleVar(func, count=null, state=null) = 267 assert std.assertEqual(std.type(func), 'function'); 268 local fresh = makeFreshVariable(if state == null then baseObjects.emptyState else state); 269 local goal = maybeSExpGoal(func(fresh.variable)); 270 local stream = goal.pursue(fresh.state); 271 local states = 272 if count == null then 273 stream.takeAll() 274 else 275 stream.take(count); 276 [state.substitution.walk(fresh.variable) for state in states]; 277 278 local runWithVars(variableNames, func, count=null, state=null) = 279 assert std.assertEqual(std.type(func), 'function'); 280 local baseState = 281 if state == null then baseObjects.emptyState else state; 282 assert uKc.State(baseState); 283 284 assert std.assertEqual(std.type(variableNames), 'array'); 285 assert std.length(variableNames) >= 1; 286 287 local named = std.foldl( 288 function(curr, name) 289 local fresh = makeFreshVariable(curr.state); 290 { 291 state: fresh.state, 292 vars: curr.vars + {[name]: fresh.variable}, 293 }, 294 variableNames, 295 {state: baseState, vars: {}} 296 ); 297 298 local goal = maybeSExpGoal(func(named.vars)); 299 local stream = goal.pursue(named.state); 300 local states = 301 if count == null then 302 stream.takeAll() 303 else 304 stream.take(count); 305 [std.mapWithKey(function(name, var) state.substitution.walk(var), named.vars) for state in states]; 306 307 308 // public interface 309 baseObjects + { 310 type: type, 311 312 // goal constructors 313 eq: eq, 314 conj: conj, 315 disj: disj, 316 sExpGoal: sExpGoal, 317 maybeSExpGoal: maybeSExpGoal, 318 callFresh: callFresh, 319 320 // resolution helpers 321 takeAll: takeAll, 322 take: take, 323 runSingleVar: runSingleVar, 324 runWithVars: runWithVars, 325 } 326 // vim: sts=2 ts=2 sw=2 et