diff --git a/src/libexpr/eval.cc b/src/libexpr/eval.cc index 402de78ad..ec6e6e263 100644 --- a/src/libexpr/eval.cc +++ b/src/libexpr/eval.cc @@ -127,6 +127,7 @@ void printValue(std::ostream & str, std::set & active, const Valu break; case tThunk: case tApp: + case tPartialApp: str << ""; break; case tLambda: @@ -1275,35 +1276,28 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value & } }; - while (nrArgs > 0) { + auto callLambda = [&](Env * env, ExprLambda & lambda, Value * * args) + { + Env & env2(allocEnv(lambda.envSize)); + env2.up = env; - if (vCur.isLambda()) { + Displacement displ = 0; - ExprLambda & lambda(*vCur.lambda.fun); + for (auto & arg : lambda.args) { + auto vArg = *args++; - auto size = - (lambda.arg.empty() ? 0 : 1) + - (lambda.hasFormals() ? lambda.formals->formals.size() : 0); - Env & env2(allocEnv(size)); - env2.up = vCur.lambda.env; + if (arg.arg != sEpsilon) + env2.values[displ++] = vArg; - Displacement displ = 0; - - if (!lambda.hasFormals()) - env2.values[displ++] = args[0]; - - else { - forceAttrs(*args[0], pos); - - if (!lambda.arg.empty()) - env2.values[displ++] = args[0]; + if (arg.formals) { + forceAttrs(*vArg, pos); /* For each formal argument, get the actual argument. If there is no matching actual argument but the formal argument has a default, use the default. */ size_t attrsUsed = 0; - for (auto & i : lambda.formals->formals) { - auto j = args[0]->attrs->get(i.name); + for (auto & i : arg.formals->formals) { + auto j = vArg->attrs->get(i.name); if (!j) { if (!i.def) throwTypeError(pos, "%1% called without required argument '%2%'", lambda, i.name); @@ -1316,35 +1310,96 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value & /* Check that each actual argument is listed as a formal argument (unless the attribute match specifies a `...'). */ - if (!lambda.formals->ellipsis && attrsUsed != args[0]->attrs->size()) { + if (!arg.formals->ellipsis && attrsUsed != vArg->attrs->size()) { /* Nope, so show the first unexpected argument to the user. */ - for (auto & i : *args[0]->attrs) - if (lambda.formals->argNames.find(i.name) == lambda.formals->argNames.end()) + for (auto & i : *vArg->attrs) + if (arg.formals->argNames.find(i.name) == arg.formals->argNames.end()) throwTypeError(pos, "%1% called with unexpected argument '%2%'", lambda, i.name); abort(); // can't happen } } + } - nrFunctionCalls++; - if (countCalls) incrFunctionCall(&lambda); + assert(displ == lambda.envSize); - /* Evaluate the body. */ - try { - lambda.body->eval(*this, env2, vCur); - } catch (Error & e) { - if (loggerSettings.showTrace.get()) { - addErrorTrace(e, lambda.pos, "while evaluating %s", - (lambda.name.set() - ? "'" + (string) lambda.name + "'" - : "anonymous lambda")); - addErrorTrace(e, pos, "from call site%s", ""); - } - throw; + nrFunctionCalls++; + if (countCalls) incrFunctionCall(&lambda); + + /* Evaluate the body. */ + try { + lambda.body->eval(*this, env2, vCur); + } catch (Error & e) { + if (loggerSettings.showTrace) { + addErrorTrace(e, lambda.pos, "while evaluating %s", + (lambda.name.set() + ? "'" + (string) lambda.name + "'" + : "anonymous lambda")); + addErrorTrace(e, pos, "from call site%s", ""); } + throw; + } + }; - nrArgs--; - args += 1; + while (nrArgs > 0) { + + if (vCur.isLambda()) { + + ExprLambda & lambda(*vCur.lambda.fun); + + if (nrArgs < lambda.args.size()) { + vRes = vCur; + for (size_t i = 0; i < nrArgs; ++i) { + auto fun2 = allocValue(); + *fun2 = vRes; + vRes.mkPartialApp(fun2, args[i]); + } + return; + } else { + callLambda(vCur.lambda.env, lambda, args); + nrArgs -= lambda.args.size(); + args += lambda.args.size(); + } + } + + else if (vCur.isPartialApp()) { + /* Figure out the number of arguments still needed. */ + size_t argsDone = 0; + Value * lambda = &vCur; + while (lambda->isPartialApp()) { + argsDone++; + lambda = lambda->app.left; + } + assert(lambda->isLambda()); + auto arity = lambda->lambda.fun->args.size(); + auto argsLeft = arity - argsDone; + + if (nrArgs < argsLeft) { + /* We still don't have enough arguments, so extend the tPartialApp chain. */ + vRes = vCur; + for (size_t i = 0; i < nrArgs; ++i) { + auto fun2 = allocValue(); + *fun2 = vRes; + vRes.mkPartialApp(fun2, args[i]); + } + return; + } else { + /* We have all the arguments, so call the function + with the previous and new arguments. */ + + Value * vArgs[arity]; + auto n = argsDone; + for (Value * arg = &vCur; arg->isPartialApp(); arg = arg->app.left) + vArgs[--n] = arg->app.right; + + for (size_t i = 0; i < argsLeft; ++i) + vArgs[argsDone + i] = args[i]; + + nrArgs -= argsLeft; + args += argsLeft; + + callLambda(lambda->lambda.env, *lambda->lambda.fun, vArgs); + } } else if (vCur.isPrimOp()) { @@ -1458,42 +1513,48 @@ void EvalState::autoCallFunction(Bindings & args, Value & fun, Value & res) } } - if (!fun.isLambda() || !fun.lambda.fun->hasFormals()) { + if (!fun.isLambda()) { res = fun; return; } - Value * actualArgs = allocValue(); - mkAttrs(*actualArgs, std::max(static_cast(fun.lambda.fun->formals->formals.size()), args.size())); + Value * actualArgs[fun.lambda.fun->args.size()]; - if (fun.lambda.fun->formals->ellipsis) { - // If the formals have an ellipsis (eg the function accepts extra args) pass - // all available automatic arguments (which includes arguments specified on - // the command line via --arg/--argstr) - for (auto& v : args) { - actualArgs->attrs->push_back(v); + for (const auto & [i, arg] : enumerate(fun.lambda.fun->args)) { + if (!arg.formals) { + res = fun; + return; } - } else { - // Otherwise, only pass the arguments that the function accepts - for (auto & i : fun.lambda.fun->formals->formals) { - Bindings::iterator j = args.find(i.name); - if (j != args.end()) { - actualArgs->attrs->push_back(*j); - } else if (!i.def) { - throwMissingArgumentError(i.pos, R"(cannot evaluate a function that has an argument without a value ('%1%') + + actualArgs[i] = allocValue(); + mkAttrs(*actualArgs[i], std::max(arg.formals->formals.size(), static_cast(args.size()))); + + if (arg.formals->ellipsis) { + /* If the formals have an ellipsis (i.e. the function + accepts extra args), pass all available automatic + arguments. */ + for (auto & v : args) + actualArgs[i]->attrs->push_back(v); + } else { + /* Otherwise, only pass the arguments that the function + accepts. */ + for (auto & j : arg.formals->formals) { + if (auto attr = args.get(j.name)) + actualArgs[i]->attrs->push_back(*attr); + else if (!j.def) + throwMissingArgumentError(j.pos, R"(cannot evaluate a function that has an argument without a value ('%1%') Nix attempted to evaluate a function as a top level expression; in this case it must have its arguments supplied either by default values, or passed explicitly with '--arg' or '--argstr'. See -https://nixos.org/manual/nix/stable/#ss-functions.)", i.name); - +https://nixos.org/manual/nix/stable/#ss-functions.)", j.name); } } + + actualArgs[i]->attrs->sort(); } - actualArgs->attrs->sort(); - - callFunction(fun, *actualArgs, res, noPos); + callFunction(fun, fun.lambda.fun->args.size(), actualArgs, res, noPos); } diff --git a/src/libexpr/flake/flake.cc b/src/libexpr/flake/flake.cc index c9d848495..0d24e90a9 100644 --- a/src/libexpr/flake/flake.cc +++ b/src/libexpr/flake/flake.cc @@ -230,8 +230,13 @@ static Flake getFlake( if (auto outputs = vInfo.attrs->get(sOutputs)) { expectType(state, nFunction, *outputs->value, *outputs->pos); - if (outputs->value->isLambda() && outputs->value->lambda.fun->hasFormals()) { - for (auto & formal : outputs->value->lambda.fun->formals->formals) { + if (outputs->value->lambda.fun->args.size() != 1) + throw Error("the 'outputs' attribute of flake '%s' is not a unary function", lockedRef); + + auto & arg = outputs->value->lambda.fun->args[0]; + + if (arg.formals) { + for (auto & formal : arg.formals->formals) { if (formal.name != state.sSelf) flake.inputs.emplace(formal.name, FlakeInput { .ref = parseFlakeRef(formal.name) diff --git a/src/libexpr/nixexpr.cc b/src/libexpr/nixexpr.cc index 57c2f6e44..70f2ddb1c 100644 --- a/src/libexpr/nixexpr.cc +++ b/src/libexpr/nixexpr.cc @@ -124,23 +124,26 @@ void ExprList::show(std::ostream & str) const void ExprLambda::show(std::ostream & str) const { str << "("; - if (hasFormals()) { - str << "{ "; - bool first = true; - for (auto & i : formals->formals) { - if (first) first = false; else str << ", "; - str << i.name; - if (i.def) str << " ? " << *i.def; + for (auto & arg : args) { + if (arg.formals) { + str << "{ "; + bool first = true; + for (auto & i : arg.formals->formals) { + if (first) first = false; else str << ", "; + str << i.name; + if (i.def) str << " ? " << *i.def; + } + if (arg.formals->ellipsis) { + if (!first) str << ", "; + str << "..."; + } + str << " }"; + if (!arg.arg.empty()) str << " @ "; } - if (formals->ellipsis) { - if (!first) str << ", "; - str << "..."; - } - str << " }"; - if (!arg.empty()) str << " @ "; + if (!arg.arg.empty()) str << arg.arg; + str << ": "; } - if (!arg.empty()) str << arg; - str << ": " << *body << ")"; + str << *body << ")"; } void ExprCall::show(std::ostream & str) const @@ -279,8 +282,7 @@ void ExprVar::bindVars(const StaticEnv & env) if (curEnv->isWith) { if (withLevel == -1) withLevel = level; } else { - auto i = curEnv->find(name); - if (i != curEnv->vars.end()) { + if (auto i = curEnv->get(name)) { fromWith = false; this->level = level; displ = i->second; @@ -354,25 +356,48 @@ void ExprList::bindVars(const StaticEnv & env) void ExprLambda::bindVars(const StaticEnv & env) { - StaticEnv newEnv( - false, &env, - (hasFormals() ? formals->formals.size() : 0) + - (arg.empty() ? 0 : 1)); + /* The parser adds arguments in reverse order. Let's fix that + now. */ + std::reverse(args.begin(), args.end()); + + envSize = 0; + + for (auto & arg :args) { + if (!arg.arg.empty()) envSize++; + if (arg.formals) envSize += arg.formals->formals.size(); + } + + StaticEnv newEnv(false, &env, envSize); Displacement displ = 0; - if (!arg.empty()) newEnv.vars.emplace_back(arg, displ++); + for (auto & arg : args) { + if (!arg.arg.empty()) { + if (auto i = const_cast(newEnv.get(arg.arg))) + i->second = displ++; + else + newEnv.vars.emplace_back(arg.arg, displ++); + } - if (hasFormals()) { - for (auto & i : formals->formals) - newEnv.vars.emplace_back(i.name, displ++); + if (arg.formals) { + for (auto & i : arg.formals->formals) { + if (auto j = const_cast(newEnv.get(i.name))) + j->second = displ++; + else + newEnv.vars.emplace_back(i.name, displ++); + } - newEnv.sort(); + newEnv.sort(); - for (auto & i : formals->formals) - if (i.def) i.def->bindVars(newEnv); + for (auto & i : arg.formals->formals) + if (i.def) i.def->bindVars(newEnv); + } } + assert(displ == envSize); + + newEnv.sort(); + body->bindVars(newEnv); } diff --git a/src/libexpr/nixexpr.hh b/src/libexpr/nixexpr.hh index 13256272c..48ee13739 100644 --- a/src/libexpr/nixexpr.hh +++ b/src/libexpr/nixexpr.hh @@ -233,21 +233,24 @@ struct ExprLambda : Expr { Pos pos; Symbol name; - Symbol arg; - Formals * formals; - Expr * body; - ExprLambda(const Pos & pos, const Symbol & arg, Formals * formals, Expr * body) - : pos(pos), arg(arg), formals(formals), body(body) + + struct Arg { - if (!arg.empty() && formals && formals->argNames.find(arg) != formals->argNames.end()) - throw ParseError({ - .msg = hintfmt("duplicate formal function argument '%1%'", arg), - .errPos = pos - }); + Symbol arg; + Formals * formals; }; + + std::vector args; + + Expr * body; + + Displacement envSize = 0; // initialized by bindVars() + + ExprLambda(const Pos & pos, Expr * body) + : pos(pos), body(body) + { }; void setName(Symbol & name); string showNamePos() const; - inline bool hasFormals() const { return formals != nullptr; } COMMON_METHODS }; @@ -368,12 +371,12 @@ struct StaticEnv [](const Vars::value_type & a, const Vars::value_type & b) { return a.first < b.first; }); } - Vars::const_iterator find(const Symbol & name) const + const Vars::value_type * get(const Symbol & name) const { Vars::value_type key(name, 0); auto i = std::lower_bound(vars.begin(), vars.end(), key); - if (i != vars.end() && i->first == name) return i; - return vars.end(); + if (i != vars.end() && i->first == name) return &*i; + return {}; } }; diff --git a/src/libexpr/parser.y b/src/libexpr/parser.y index 2e8a04143..4f002996e 100644 --- a/src/libexpr/parser.y +++ b/src/libexpr/parser.y @@ -160,6 +160,24 @@ static void addFormal(const Pos & pos, Formals * formals, const Formal & formal) } +static Expr * addArg(const Pos & pos, Expr * e, ExprLambda::Arg && arg) +{ + if (!arg.arg.empty() && arg.formals && arg.formals->argNames.count(arg.arg)) + throw ParseError({ + .msg = hintfmt("duplicate formal function argument '%1%'", arg.arg), + .errPos = pos + }); + + auto e2 = dynamic_cast(e); // FIXME: slow? + if (!e2) + e2 = new ExprLambda(pos, e); + else + e2->pos = pos; + e2->args.emplace_back(std::move(arg)); + return e2; +} + + static Expr * stripIndentation(const Pos & pos, SymbolTable & symbols, vector & es) { if (es.empty()) return new ExprString(symbols.create("")); @@ -332,13 +350,13 @@ expr: expr_function; expr_function : ID ':' expr_function - { $$ = new ExprLambda(CUR_POS, data->symbols.create($1), 0, $3); } + { $$ = addArg(CUR_POS, $3, {data->symbols.create($1), nullptr}); } | '{' formals '}' ':' expr_function - { $$ = new ExprLambda(CUR_POS, data->symbols.create(""), $2, $5); } + { $$ = addArg(CUR_POS, $5, {data->state.sEpsilon, $2}); } | '{' formals '}' '@' ID ':' expr_function - { $$ = new ExprLambda(CUR_POS, data->symbols.create($5), $2, $7); } + { $$ = addArg(CUR_POS, $7, {data->symbols.create($5), $2}); } | ID '@' '{' formals '}' ':' expr_function - { $$ = new ExprLambda(CUR_POS, data->symbols.create($1), $4, $7); } + { $$ = addArg(CUR_POS, $7, {data->symbols.create($1), $4}); } | ASSERT expr ';' expr_function { $$ = new ExprAssert(CUR_POS, $2, $4); } | WITH expr ';' expr_function @@ -456,7 +474,7 @@ expr_simple string_parts : STR | string_parts_interpolated { $$ = new ExprConcatStrings(CUR_POS, true, $1); } - | { $$ = new ExprString(data->symbols.create("")); } + | { $$ = new ExprString(data->state.sEpsilon); } ; string_parts_interpolated diff --git a/src/libexpr/primops.cc b/src/libexpr/primops.cc index e4107dbe1..1dec12395 100644 --- a/src/libexpr/primops.cc +++ b/src/libexpr/primops.cc @@ -2386,23 +2386,38 @@ static RegisterPrimOp primop_catAttrs({ static void prim_functionArgs(EvalState & state, const Pos & pos, Value * * args, Value & v) { state.forceValue(*args[0], pos); + if (args[0]->isPrimOpApp() || args[0]->isPrimOp()) { state.mkAttrs(v, 0); return; } - if (!args[0]->isLambda()) + + if (!args[0]->isLambda() && !args[0]->isPartialApp()) throw TypeError({ .msg = hintfmt("'functionArgs' requires a function"), .errPos = pos }); - if (!args[0]->lambda.fun->hasFormals()) { + size_t argsDone = 0; + auto lambda = args[0]; + while (lambda->isPartialApp()) { + argsDone++; + lambda = lambda->app.left; + } + assert(lambda->isLambda()); + + assert(argsDone < lambda->lambda.fun->args.size()); + + // FIXME: handle partially applied functions + auto formals = lambda->lambda.fun->args[argsDone].formals; + + if (!formals) { state.mkAttrs(v, 0); return; } - state.mkAttrs(v, args[0]->lambda.fun->formals->formals.size()); - for (auto & i : args[0]->lambda.fun->formals->formals) { + state.mkAttrs(v, formals->formals.size()); + for (auto & i : formals->formals) { // !!! should optimise booleans (allocate only once) Value * value = state.allocValue(); v.attrs->push_back(Attr(i.name, value, ptr(&i.pos))); diff --git a/src/libexpr/value-to-xml.cc b/src/libexpr/value-to-xml.cc index b44455f5f..1b7135324 100644 --- a/src/libexpr/value-to-xml.cc +++ b/src/libexpr/value-to-xml.cc @@ -126,24 +126,28 @@ static void printValueAsXML(EvalState & state, bool strict, bool location, } case nFunction: { + if (!v.isLambda()) { - // FIXME: Serialize primops and primopapps + // FIXME: Serialize primops and partial apps doc.writeEmptyElement("unevaluated"); break; } + XMLAttrs xmlAttrs; if (location) posToXML(xmlAttrs, v.lambda.fun->pos); XMLOpenElement _(doc, "function", xmlAttrs); - if (v.lambda.fun->hasFormals()) { + auto & arg = v.lambda.fun->args[0]; + + if (arg.formals) { XMLAttrs attrs; - if (!v.lambda.fun->arg.empty()) attrs["name"] = v.lambda.fun->arg; - if (v.lambda.fun->formals->ellipsis) attrs["ellipsis"] = "1"; + if (arg.arg != state.sEpsilon) attrs["name"] = arg.arg; + if (arg.formals->ellipsis) attrs["ellipsis"] = "1"; XMLOpenElement _(doc, "attrspat", attrs); - for (auto & i : v.lambda.fun->formals->formals) + for (auto & i : arg.formals->formals) doc.writeEmptyElement("attr", singletonAttrs("name", i.name)); } else - doc.writeEmptyElement("varpat", singletonAttrs("name", v.lambda.fun->arg)); + doc.writeEmptyElement("varpat", singletonAttrs("name", arg.arg)); break; } diff --git a/src/libexpr/value.hh b/src/libexpr/value.hh index a1f131f9e..e48c4555b 100644 --- a/src/libexpr/value.hh +++ b/src/libexpr/value.hh @@ -21,6 +21,7 @@ typedef enum { tListN, tThunk, tApp, + tPartialApp, tLambda, tBlackhole, tPrimOp, @@ -125,6 +126,7 @@ public: // type() == nFunction inline bool isLambda() const { return internalType == tLambda; }; + inline bool isPartialApp() const { return internalType == tPartialApp; }; inline bool isPrimOp() const { return internalType == tPrimOp; }; inline bool isPrimOpApp() const { return internalType == tPrimOpApp; }; @@ -196,7 +198,7 @@ public: case tNull: return nNull; case tAttrs: return nAttrs; case tList1: case tList2: case tListN: return nList; - case tLambda: case tPrimOp: case tPrimOpApp: return nFunction; + case tLambda: case tPartialApp: case tPrimOp: case tPrimOpApp: return nFunction; case tExternal: return nExternal; case tFloat: return nFloat; case tThunk: case tApp: case tBlackhole: return nThunk; @@ -307,6 +309,13 @@ public: app.right = r; } + inline void mkPartialApp(Value * l, Value * r) + { + internalType = tPartialApp; + app.left = l; + app.right = r; + } + inline void mkExternal(ExternalValueBase * e) { clearValue(); diff --git a/src/nix/flake.cc b/src/nix/flake.cc index 5eeb5498a..1e5be06ba 100644 --- a/src/nix/flake.cc +++ b/src/nix/flake.cc @@ -355,14 +355,12 @@ struct CmdFlakeCheck : FlakeCommand try { state->forceValue(v, pos); if (!v.isLambda() - || v.lambda.fun->hasFormals() - || !argHasName(v.lambda.fun->arg, "final")) - throw Error("overlay does not take an argument named 'final'"); - auto body = dynamic_cast(v.lambda.fun->body); - if (!body - || body->hasFormals() - || !argHasName(body->arg, "prev")) - throw Error("overlay does not take an argument named 'prev'"); + || v.lambda.fun->args.size() != 2 + || v.lambda.fun->args[0].formals + || !argHasName(v.lambda.fun->args[0].arg, "final") + || v.lambda.fun->args[1].formals + || !argHasName(v.lambda.fun->args[1].arg, "prev")) + throw Error("overlay is not a binary function with arguments 'final' and 'prev'"); // FIXME: if we have a 'nixpkgs' input, use it to // evaluate the overlay. } catch (Error & e) { @@ -375,7 +373,9 @@ struct CmdFlakeCheck : FlakeCommand try { state->forceValue(v, pos); if (v.isLambda()) { - if (!v.lambda.fun->hasFormals() || !v.lambda.fun->formals->ellipsis) + if (v.lambda.fun->args.size() != 1 + || !v.lambda.fun->args[0].formals + || !v.lambda.fun->args[0].formals->ellipsis) throw Error("module must match an open attribute set ('{ config, ... }')"); } else if (v.type() == nAttrs) { for (auto & attr : *v.attrs) @@ -473,12 +473,12 @@ struct CmdFlakeCheck : FlakeCommand auto checkBundler = [&](const std::string & attrPath, Value & v, const Pos & pos) { try { state->forceValue(v, pos); - if (!v.isLambda()) - throw Error("bundler must be a function"); - if (!v.lambda.fun->formals || - !v.lambda.fun->formals->argNames.count(state->symbols.create("program")) || - !v.lambda.fun->formals->argNames.count(state->symbols.create("system"))) - throw Error("bundler must take formal arguments 'program' and 'system'"); + if (!v.isLambda() + || v.lambda.fun->args.size() != 1 + || !v.lambda.fun->args[0].formals + || !v.lambda.fun->args[0].formals->argNames.count(state->symbols.create("program")) + || !v.lambda.fun->args[0].formals->argNames.count(state->symbols.create("system"))) + throw Error("bundler must be a function that takes take arguments 'program' and 'system'"); } catch (Error & e) { e.addTrace(pos, hintfmt("while checking the template '%s'", attrPath)); reportError(e);