1
0
Fork 0
mirror of https://github.com/NixOS/nix synced 2025-07-07 14:21:48 +02:00

Merge environments of nested functions

Previously an expression like 'x: y: ...' would create two
environments with one value. Now it creates one environment with two
values. This reduces the number of allocations and the distance in the
environment chain that variable lookups need to traverse.

On

  $ nix-instantiate --dry-run '<nixpkgs/nixos/release-combined.nix>' -A nixos.tests.simple.x86_64-linux

this gives a ~30% reduction in the number of Env allocations.
This commit is contained in:
Eelco Dolstra 2021-11-05 13:05:03 +01:00
parent a1c1b0e553
commit 904d0ec5c0
9 changed files with 276 additions and 136 deletions

View file

@ -127,6 +127,7 @@ void printValue(std::ostream & str, std::set<const Value *> & active, const Valu
break; break;
case tThunk: case tThunk:
case tApp: case tApp:
case tPartialApp:
str << "<CODE>"; str << "<CODE>";
break; break;
case tLambda: case tLambda:
@ -1275,35 +1276,28 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value &
} }
}; };
while (nrArgs > 0) { auto callLambda = [&](Env * env, ExprLambda & lambda, Value * * args)
{
Env & env2(allocEnv(lambda.envSize));
env2.up = env;
if (vCur.isLambda()) { Displacement displ = 0;
ExprLambda & lambda(*vCur.lambda.fun); for (auto & arg : lambda.args) {
auto vArg = *args++;
auto size = if (arg.arg != sEpsilon)
(lambda.arg.empty() ? 0 : 1) + env2.values[displ++] = vArg;
(lambda.hasFormals() ? lambda.formals->formals.size() : 0);
Env & env2(allocEnv(size));
env2.up = vCur.lambda.env;
Displacement displ = 0; if (arg.formals) {
forceAttrs(*vArg, pos);
if (!lambda.hasFormals())
env2.values[displ++] = args[0];
else {
forceAttrs(*args[0], pos);
if (!lambda.arg.empty())
env2.values[displ++] = args[0];
/* For each formal argument, get the actual argument. If /* For each formal argument, get the actual argument. If
there is no matching actual argument but the formal there is no matching actual argument but the formal
argument has a default, use the default. */ argument has a default, use the default. */
size_t attrsUsed = 0; size_t attrsUsed = 0;
for (auto & i : lambda.formals->formals) { for (auto & i : arg.formals->formals) {
auto j = args[0]->attrs->get(i.name); auto j = vArg->attrs->get(i.name);
if (!j) { if (!j) {
if (!i.def) throwTypeError(pos, "%1% called without required argument '%2%'", if (!i.def) throwTypeError(pos, "%1% called without required argument '%2%'",
lambda, i.name); lambda, i.name);
@ -1316,35 +1310,96 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value &
/* Check that each actual argument is listed as a formal /* Check that each actual argument is listed as a formal
argument (unless the attribute match specifies a `...'). */ argument (unless the attribute match specifies a `...'). */
if (!lambda.formals->ellipsis && attrsUsed != args[0]->attrs->size()) { if (!arg.formals->ellipsis && attrsUsed != vArg->attrs->size()) {
/* Nope, so show the first unexpected argument to the /* Nope, so show the first unexpected argument to the
user. */ user. */
for (auto & i : *args[0]->attrs) for (auto & i : *vArg->attrs)
if (lambda.formals->argNames.find(i.name) == lambda.formals->argNames.end()) if (arg.formals->argNames.find(i.name) == arg.formals->argNames.end())
throwTypeError(pos, "%1% called with unexpected argument '%2%'", lambda, i.name); throwTypeError(pos, "%1% called with unexpected argument '%2%'", lambda, i.name);
abort(); // can't happen abort(); // can't happen
} }
} }
}
nrFunctionCalls++; assert(displ == lambda.envSize);
if (countCalls) incrFunctionCall(&lambda);
/* Evaluate the body. */ nrFunctionCalls++;
try { if (countCalls) incrFunctionCall(&lambda);
lambda.body->eval(*this, env2, vCur);
} catch (Error & e) { /* Evaluate the body. */
if (loggerSettings.showTrace.get()) { try {
addErrorTrace(e, lambda.pos, "while evaluating %s", lambda.body->eval(*this, env2, vCur);
(lambda.name.set() } catch (Error & e) {
? "'" + (string) lambda.name + "'" if (loggerSettings.showTrace) {
: "anonymous lambda")); addErrorTrace(e, lambda.pos, "while evaluating %s",
addErrorTrace(e, pos, "from call site%s", ""); (lambda.name.set()
} ? "'" + (string) lambda.name + "'"
throw; : "anonymous lambda"));
addErrorTrace(e, pos, "from call site%s", "");
} }
throw;
}
};
nrArgs--; while (nrArgs > 0) {
args += 1;
if (vCur.isLambda()) {
ExprLambda & lambda(*vCur.lambda.fun);
if (nrArgs < lambda.args.size()) {
vRes = vCur;
for (size_t i = 0; i < nrArgs; ++i) {
auto fun2 = allocValue();
*fun2 = vRes;
vRes.mkPartialApp(fun2, args[i]);
}
return;
} else {
callLambda(vCur.lambda.env, lambda, args);
nrArgs -= lambda.args.size();
args += lambda.args.size();
}
}
else if (vCur.isPartialApp()) {
/* Figure out the number of arguments still needed. */
size_t argsDone = 0;
Value * lambda = &vCur;
while (lambda->isPartialApp()) {
argsDone++;
lambda = lambda->app.left;
}
assert(lambda->isLambda());
auto arity = lambda->lambda.fun->args.size();
auto argsLeft = arity - argsDone;
if (nrArgs < argsLeft) {
/* We still don't have enough arguments, so extend the tPartialApp chain. */
vRes = vCur;
for (size_t i = 0; i < nrArgs; ++i) {
auto fun2 = allocValue();
*fun2 = vRes;
vRes.mkPartialApp(fun2, args[i]);
}
return;
} else {
/* We have all the arguments, so call the function
with the previous and new arguments. */
Value * vArgs[arity];
auto n = argsDone;
for (Value * arg = &vCur; arg->isPartialApp(); arg = arg->app.left)
vArgs[--n] = arg->app.right;
for (size_t i = 0; i < argsLeft; ++i)
vArgs[argsDone + i] = args[i];
nrArgs -= argsLeft;
args += argsLeft;
callLambda(lambda->lambda.env, *lambda->lambda.fun, vArgs);
}
} }
else if (vCur.isPrimOp()) { else if (vCur.isPrimOp()) {
@ -1458,42 +1513,48 @@ void EvalState::autoCallFunction(Bindings & args, Value & fun, Value & res)
} }
} }
if (!fun.isLambda() || !fun.lambda.fun->hasFormals()) { if (!fun.isLambda()) {
res = fun; res = fun;
return; return;
} }
Value * actualArgs = allocValue(); Value * actualArgs[fun.lambda.fun->args.size()];
mkAttrs(*actualArgs, std::max(static_cast<uint32_t>(fun.lambda.fun->formals->formals.size()), args.size()));
if (fun.lambda.fun->formals->ellipsis) { for (const auto & [i, arg] : enumerate(fun.lambda.fun->args)) {
// If the formals have an ellipsis (eg the function accepts extra args) pass if (!arg.formals) {
// all available automatic arguments (which includes arguments specified on res = fun;
// the command line via --arg/--argstr) return;
for (auto& v : args) {
actualArgs->attrs->push_back(v);
} }
} else {
// Otherwise, only pass the arguments that the function accepts actualArgs[i] = allocValue();
for (auto & i : fun.lambda.fun->formals->formals) { mkAttrs(*actualArgs[i], std::max(arg.formals->formals.size(), static_cast<size_t>(args.size())));
Bindings::iterator j = args.find(i.name);
if (j != args.end()) { if (arg.formals->ellipsis) {
actualArgs->attrs->push_back(*j); /* If the formals have an ellipsis (i.e. the function
} else if (!i.def) { accepts extra args), pass all available automatic
throwMissingArgumentError(i.pos, R"(cannot evaluate a function that has an argument without a value ('%1%') arguments. */
for (auto & v : args)
actualArgs[i]->attrs->push_back(v);
} else {
/* Otherwise, only pass the arguments that the function
accepts. */
for (auto & j : arg.formals->formals) {
if (auto attr = args.get(j.name))
actualArgs[i]->attrs->push_back(*attr);
else if (!j.def)
throwMissingArgumentError(j.pos, R"(cannot evaluate a function that has an argument without a value ('%1%')
Nix attempted to evaluate a function as a top level expression; in Nix attempted to evaluate a function as a top level expression; in
this case it must have its arguments supplied either by default this case it must have its arguments supplied either by default
values, or passed explicitly with '--arg' or '--argstr'. See values, or passed explicitly with '--arg' or '--argstr'. See
https://nixos.org/manual/nix/stable/#ss-functions.)", i.name); https://nixos.org/manual/nix/stable/#ss-functions.)", j.name);
} }
} }
actualArgs[i]->attrs->sort();
} }
actualArgs->attrs->sort(); callFunction(fun, fun.lambda.fun->args.size(), actualArgs, res, noPos);
callFunction(fun, *actualArgs, res, noPos);
} }

View file

@ -230,8 +230,13 @@ static Flake getFlake(
if (auto outputs = vInfo.attrs->get(sOutputs)) { if (auto outputs = vInfo.attrs->get(sOutputs)) {
expectType(state, nFunction, *outputs->value, *outputs->pos); expectType(state, nFunction, *outputs->value, *outputs->pos);
if (outputs->value->isLambda() && outputs->value->lambda.fun->hasFormals()) { if (outputs->value->lambda.fun->args.size() != 1)
for (auto & formal : outputs->value->lambda.fun->formals->formals) { throw Error("the 'outputs' attribute of flake '%s' is not a unary function", lockedRef);
auto & arg = outputs->value->lambda.fun->args[0];
if (arg.formals) {
for (auto & formal : arg.formals->formals) {
if (formal.name != state.sSelf) if (formal.name != state.sSelf)
flake.inputs.emplace(formal.name, FlakeInput { flake.inputs.emplace(formal.name, FlakeInput {
.ref = parseFlakeRef(formal.name) .ref = parseFlakeRef(formal.name)

View file

@ -124,23 +124,26 @@ void ExprList::show(std::ostream & str) const
void ExprLambda::show(std::ostream & str) const void ExprLambda::show(std::ostream & str) const
{ {
str << "("; str << "(";
if (hasFormals()) { for (auto & arg : args) {
str << "{ "; if (arg.formals) {
bool first = true; str << "{ ";
for (auto & i : formals->formals) { bool first = true;
if (first) first = false; else str << ", "; for (auto & i : arg.formals->formals) {
str << i.name; if (first) first = false; else str << ", ";
if (i.def) str << " ? " << *i.def; str << i.name;
if (i.def) str << " ? " << *i.def;
}
if (arg.formals->ellipsis) {
if (!first) str << ", ";
str << "...";
}
str << " }";
if (!arg.arg.empty()) str << " @ ";
} }
if (formals->ellipsis) { if (!arg.arg.empty()) str << arg.arg;
if (!first) str << ", "; str << ": ";
str << "...";
}
str << " }";
if (!arg.empty()) str << " @ ";
} }
if (!arg.empty()) str << arg; str << *body << ")";
str << ": " << *body << ")";
} }
void ExprCall::show(std::ostream & str) const void ExprCall::show(std::ostream & str) const
@ -279,8 +282,7 @@ void ExprVar::bindVars(const StaticEnv & env)
if (curEnv->isWith) { if (curEnv->isWith) {
if (withLevel == -1) withLevel = level; if (withLevel == -1) withLevel = level;
} else { } else {
auto i = curEnv->find(name); if (auto i = curEnv->get(name)) {
if (i != curEnv->vars.end()) {
fromWith = false; fromWith = false;
this->level = level; this->level = level;
displ = i->second; displ = i->second;
@ -354,25 +356,48 @@ void ExprList::bindVars(const StaticEnv & env)
void ExprLambda::bindVars(const StaticEnv & env) void ExprLambda::bindVars(const StaticEnv & env)
{ {
StaticEnv newEnv( /* The parser adds arguments in reverse order. Let's fix that
false, &env, now. */
(hasFormals() ? formals->formals.size() : 0) + std::reverse(args.begin(), args.end());
(arg.empty() ? 0 : 1));
envSize = 0;
for (auto & arg :args) {
if (!arg.arg.empty()) envSize++;
if (arg.formals) envSize += arg.formals->formals.size();
}
StaticEnv newEnv(false, &env, envSize);
Displacement displ = 0; Displacement displ = 0;
if (!arg.empty()) newEnv.vars.emplace_back(arg, displ++); for (auto & arg : args) {
if (!arg.arg.empty()) {
if (auto i = const_cast<StaticEnv::Vars::value_type *>(newEnv.get(arg.arg)))
i->second = displ++;
else
newEnv.vars.emplace_back(arg.arg, displ++);
}
if (hasFormals()) { if (arg.formals) {
for (auto & i : formals->formals) for (auto & i : arg.formals->formals) {
newEnv.vars.emplace_back(i.name, displ++); if (auto j = const_cast<StaticEnv::Vars::value_type *>(newEnv.get(i.name)))
j->second = displ++;
else
newEnv.vars.emplace_back(i.name, displ++);
}
newEnv.sort(); newEnv.sort();
for (auto & i : formals->formals) for (auto & i : arg.formals->formals)
if (i.def) i.def->bindVars(newEnv); if (i.def) i.def->bindVars(newEnv);
}
} }
assert(displ == envSize);
newEnv.sort();
body->bindVars(newEnv); body->bindVars(newEnv);
} }

View file

@ -233,21 +233,24 @@ struct ExprLambda : Expr
{ {
Pos pos; Pos pos;
Symbol name; Symbol name;
Symbol arg;
Formals * formals; struct Arg
Expr * body;
ExprLambda(const Pos & pos, const Symbol & arg, Formals * formals, Expr * body)
: pos(pos), arg(arg), formals(formals), body(body)
{ {
if (!arg.empty() && formals && formals->argNames.find(arg) != formals->argNames.end()) Symbol arg;
throw ParseError({ Formals * formals;
.msg = hintfmt("duplicate formal function argument '%1%'", arg),
.errPos = pos
});
}; };
std::vector<Arg> args;
Expr * body;
Displacement envSize = 0; // initialized by bindVars()
ExprLambda(const Pos & pos, Expr * body)
: pos(pos), body(body)
{ };
void setName(Symbol & name); void setName(Symbol & name);
string showNamePos() const; string showNamePos() const;
inline bool hasFormals() const { return formals != nullptr; }
COMMON_METHODS COMMON_METHODS
}; };
@ -368,12 +371,12 @@ struct StaticEnv
[](const Vars::value_type & a, const Vars::value_type & b) { return a.first < b.first; }); [](const Vars::value_type & a, const Vars::value_type & b) { return a.first < b.first; });
} }
Vars::const_iterator find(const Symbol & name) const const Vars::value_type * get(const Symbol & name) const
{ {
Vars::value_type key(name, 0); Vars::value_type key(name, 0);
auto i = std::lower_bound(vars.begin(), vars.end(), key); auto i = std::lower_bound(vars.begin(), vars.end(), key);
if (i != vars.end() && i->first == name) return i; if (i != vars.end() && i->first == name) return &*i;
return vars.end(); return {};
} }
}; };

View file

@ -160,6 +160,24 @@ static void addFormal(const Pos & pos, Formals * formals, const Formal & formal)
} }
static Expr * addArg(const Pos & pos, Expr * e, ExprLambda::Arg && arg)
{
if (!arg.arg.empty() && arg.formals && arg.formals->argNames.count(arg.arg))
throw ParseError({
.msg = hintfmt("duplicate formal function argument '%1%'", arg.arg),
.errPos = pos
});
auto e2 = dynamic_cast<ExprLambda *>(e); // FIXME: slow?
if (!e2)
e2 = new ExprLambda(pos, e);
else
e2->pos = pos;
e2->args.emplace_back(std::move(arg));
return e2;
}
static Expr * stripIndentation(const Pos & pos, SymbolTable & symbols, vector<Expr *> & es) static Expr * stripIndentation(const Pos & pos, SymbolTable & symbols, vector<Expr *> & es)
{ {
if (es.empty()) return new ExprString(symbols.create("")); if (es.empty()) return new ExprString(symbols.create(""));
@ -332,13 +350,13 @@ expr: expr_function;
expr_function expr_function
: ID ':' expr_function : ID ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create($1), 0, $3); } { $$ = addArg(CUR_POS, $3, {data->symbols.create($1), nullptr}); }
| '{' formals '}' ':' expr_function | '{' formals '}' ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create(""), $2, $5); } { $$ = addArg(CUR_POS, $5, {data->state.sEpsilon, $2}); }
| '{' formals '}' '@' ID ':' expr_function | '{' formals '}' '@' ID ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create($5), $2, $7); } { $$ = addArg(CUR_POS, $7, {data->symbols.create($5), $2}); }
| ID '@' '{' formals '}' ':' expr_function | ID '@' '{' formals '}' ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create($1), $4, $7); } { $$ = addArg(CUR_POS, $7, {data->symbols.create($1), $4}); }
| ASSERT expr ';' expr_function | ASSERT expr ';' expr_function
{ $$ = new ExprAssert(CUR_POS, $2, $4); } { $$ = new ExprAssert(CUR_POS, $2, $4); }
| WITH expr ';' expr_function | WITH expr ';' expr_function
@ -456,7 +474,7 @@ expr_simple
string_parts string_parts
: STR : STR
| string_parts_interpolated { $$ = new ExprConcatStrings(CUR_POS, true, $1); } | string_parts_interpolated { $$ = new ExprConcatStrings(CUR_POS, true, $1); }
| { $$ = new ExprString(data->symbols.create("")); } | { $$ = new ExprString(data->state.sEpsilon); }
; ;
string_parts_interpolated string_parts_interpolated

View file

@ -2386,23 +2386,38 @@ static RegisterPrimOp primop_catAttrs({
static void prim_functionArgs(EvalState & state, const Pos & pos, Value * * args, Value & v) static void prim_functionArgs(EvalState & state, const Pos & pos, Value * * args, Value & v)
{ {
state.forceValue(*args[0], pos); state.forceValue(*args[0], pos);
if (args[0]->isPrimOpApp() || args[0]->isPrimOp()) { if (args[0]->isPrimOpApp() || args[0]->isPrimOp()) {
state.mkAttrs(v, 0); state.mkAttrs(v, 0);
return; return;
} }
if (!args[0]->isLambda())
if (!args[0]->isLambda() && !args[0]->isPartialApp())
throw TypeError({ throw TypeError({
.msg = hintfmt("'functionArgs' requires a function"), .msg = hintfmt("'functionArgs' requires a function"),
.errPos = pos .errPos = pos
}); });
if (!args[0]->lambda.fun->hasFormals()) { size_t argsDone = 0;
auto lambda = args[0];
while (lambda->isPartialApp()) {
argsDone++;
lambda = lambda->app.left;
}
assert(lambda->isLambda());
assert(argsDone < lambda->lambda.fun->args.size());
// FIXME: handle partially applied functions
auto formals = lambda->lambda.fun->args[argsDone].formals;
if (!formals) {
state.mkAttrs(v, 0); state.mkAttrs(v, 0);
return; return;
} }
state.mkAttrs(v, args[0]->lambda.fun->formals->formals.size()); state.mkAttrs(v, formals->formals.size());
for (auto & i : args[0]->lambda.fun->formals->formals) { for (auto & i : formals->formals) {
// !!! should optimise booleans (allocate only once) // !!! should optimise booleans (allocate only once)
Value * value = state.allocValue(); Value * value = state.allocValue();
v.attrs->push_back(Attr(i.name, value, ptr(&i.pos))); v.attrs->push_back(Attr(i.name, value, ptr(&i.pos)));

View file

@ -126,24 +126,28 @@ static void printValueAsXML(EvalState & state, bool strict, bool location,
} }
case nFunction: { case nFunction: {
if (!v.isLambda()) { if (!v.isLambda()) {
// FIXME: Serialize primops and primopapps // FIXME: Serialize primops and partial apps
doc.writeEmptyElement("unevaluated"); doc.writeEmptyElement("unevaluated");
break; break;
} }
XMLAttrs xmlAttrs; XMLAttrs xmlAttrs;
if (location) posToXML(xmlAttrs, v.lambda.fun->pos); if (location) posToXML(xmlAttrs, v.lambda.fun->pos);
XMLOpenElement _(doc, "function", xmlAttrs); XMLOpenElement _(doc, "function", xmlAttrs);
if (v.lambda.fun->hasFormals()) { auto & arg = v.lambda.fun->args[0];
if (arg.formals) {
XMLAttrs attrs; XMLAttrs attrs;
if (!v.lambda.fun->arg.empty()) attrs["name"] = v.lambda.fun->arg; if (arg.arg != state.sEpsilon) attrs["name"] = arg.arg;
if (v.lambda.fun->formals->ellipsis) attrs["ellipsis"] = "1"; if (arg.formals->ellipsis) attrs["ellipsis"] = "1";
XMLOpenElement _(doc, "attrspat", attrs); XMLOpenElement _(doc, "attrspat", attrs);
for (auto & i : v.lambda.fun->formals->formals) for (auto & i : arg.formals->formals)
doc.writeEmptyElement("attr", singletonAttrs("name", i.name)); doc.writeEmptyElement("attr", singletonAttrs("name", i.name));
} else } else
doc.writeEmptyElement("varpat", singletonAttrs("name", v.lambda.fun->arg)); doc.writeEmptyElement("varpat", singletonAttrs("name", arg.arg));
break; break;
} }

View file

@ -21,6 +21,7 @@ typedef enum {
tListN, tListN,
tThunk, tThunk,
tApp, tApp,
tPartialApp,
tLambda, tLambda,
tBlackhole, tBlackhole,
tPrimOp, tPrimOp,
@ -125,6 +126,7 @@ public:
// type() == nFunction // type() == nFunction
inline bool isLambda() const { return internalType == tLambda; }; inline bool isLambda() const { return internalType == tLambda; };
inline bool isPartialApp() const { return internalType == tPartialApp; };
inline bool isPrimOp() const { return internalType == tPrimOp; }; inline bool isPrimOp() const { return internalType == tPrimOp; };
inline bool isPrimOpApp() const { return internalType == tPrimOpApp; }; inline bool isPrimOpApp() const { return internalType == tPrimOpApp; };
@ -196,7 +198,7 @@ public:
case tNull: return nNull; case tNull: return nNull;
case tAttrs: return nAttrs; case tAttrs: return nAttrs;
case tList1: case tList2: case tListN: return nList; case tList1: case tList2: case tListN: return nList;
case tLambda: case tPrimOp: case tPrimOpApp: return nFunction; case tLambda: case tPartialApp: case tPrimOp: case tPrimOpApp: return nFunction;
case tExternal: return nExternal; case tExternal: return nExternal;
case tFloat: return nFloat; case tFloat: return nFloat;
case tThunk: case tApp: case tBlackhole: return nThunk; case tThunk: case tApp: case tBlackhole: return nThunk;
@ -307,6 +309,13 @@ public:
app.right = r; app.right = r;
} }
inline void mkPartialApp(Value * l, Value * r)
{
internalType = tPartialApp;
app.left = l;
app.right = r;
}
inline void mkExternal(ExternalValueBase * e) inline void mkExternal(ExternalValueBase * e)
{ {
clearValue(); clearValue();

View file

@ -355,14 +355,12 @@ struct CmdFlakeCheck : FlakeCommand
try { try {
state->forceValue(v, pos); state->forceValue(v, pos);
if (!v.isLambda() if (!v.isLambda()
|| v.lambda.fun->hasFormals() || v.lambda.fun->args.size() != 2
|| !argHasName(v.lambda.fun->arg, "final")) || v.lambda.fun->args[0].formals
throw Error("overlay does not take an argument named 'final'"); || !argHasName(v.lambda.fun->args[0].arg, "final")
auto body = dynamic_cast<ExprLambda *>(v.lambda.fun->body); || v.lambda.fun->args[1].formals
if (!body || !argHasName(v.lambda.fun->args[1].arg, "prev"))
|| body->hasFormals() throw Error("overlay is not a binary function with arguments 'final' and 'prev'");
|| !argHasName(body->arg, "prev"))
throw Error("overlay does not take an argument named 'prev'");
// FIXME: if we have a 'nixpkgs' input, use it to // FIXME: if we have a 'nixpkgs' input, use it to
// evaluate the overlay. // evaluate the overlay.
} catch (Error & e) { } catch (Error & e) {
@ -375,7 +373,9 @@ struct CmdFlakeCheck : FlakeCommand
try { try {
state->forceValue(v, pos); state->forceValue(v, pos);
if (v.isLambda()) { if (v.isLambda()) {
if (!v.lambda.fun->hasFormals() || !v.lambda.fun->formals->ellipsis) if (v.lambda.fun->args.size() != 1
|| !v.lambda.fun->args[0].formals
|| !v.lambda.fun->args[0].formals->ellipsis)
throw Error("module must match an open attribute set ('{ config, ... }')"); throw Error("module must match an open attribute set ('{ config, ... }')");
} else if (v.type() == nAttrs) { } else if (v.type() == nAttrs) {
for (auto & attr : *v.attrs) for (auto & attr : *v.attrs)
@ -473,12 +473,12 @@ struct CmdFlakeCheck : FlakeCommand
auto checkBundler = [&](const std::string & attrPath, Value & v, const Pos & pos) { auto checkBundler = [&](const std::string & attrPath, Value & v, const Pos & pos) {
try { try {
state->forceValue(v, pos); state->forceValue(v, pos);
if (!v.isLambda()) if (!v.isLambda()
throw Error("bundler must be a function"); || v.lambda.fun->args.size() != 1
if (!v.lambda.fun->formals || || !v.lambda.fun->args[0].formals
!v.lambda.fun->formals->argNames.count(state->symbols.create("program")) || || !v.lambda.fun->args[0].formals->argNames.count(state->symbols.create("program"))
!v.lambda.fun->formals->argNames.count(state->symbols.create("system"))) || !v.lambda.fun->args[0].formals->argNames.count(state->symbols.create("system")))
throw Error("bundler must take formal arguments 'program' and 'system'"); throw Error("bundler must be a function that takes take arguments 'program' and 'system'");
} catch (Error & e) { } catch (Error & e) {
e.addTrace(pos, hintfmt("while checking the template '%s'", attrPath)); e.addTrace(pos, hintfmt("while checking the template '%s'", attrPath));
reportError(e); reportError(e);