1 /// Copyright: Copyright (c) 2017-2019 Andrey Penechko.
2 /// License: $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
3 /// Authors: Andrey Penechko.
4 
5 /// Constant folding and Compile-time function evaluation (CTFE)
6 /// Requires nodes that are evaluated to be type checked
7 module vox.fe.passes.eval;
8 
9 import vox.all;
10 
11 /// Eval expression
12 /// Returns a constant
13 IrIndex eval_static_expr(AstIndex nodeIndex, CompilationContext* context)
14 {
15 	AstNode* node = context.getAstNode(nodeIndex);
16 
17 	switch(node.state) with(AstNodeState)
18 	{
19 		case name_register_self_done:
20 			require_name_register(nodeIndex, context);
21 			context.throwOnErrors;
22 			goto case;
23 		case name_register_nested_done:
24 			require_name_resolve(nodeIndex, context);
25 			context.throwOnErrors;
26 			goto case;
27 		case name_resolve_done:
28 			// perform type checking of forward referenced node
29 			require_type_check(nodeIndex, context);
30 			context.throwOnErrors;
31 			break;
32 		case type_check_done: break; // all requirement are done
33 		default: context.internal_error(node.loc, "Node %s in %s state", node.astType, node.state);
34 	}
35 
36 	context.assertf(node !is null, "null node");
37 
38 	switch (node.astType) with(AstType)
39 	{
40 		case decl_enum_member: return node.as!EnumMemberDecl(context).gen_init_value_enum_member(context);
41 		case expr_name_use: return eval_static_expr_name_use(cast(NameUseExprNode*)node, context);
42 		case expr_member: return eval_static_expr_member(cast(MemberExprNode*)node, context);
43 		case expr_bin_op: return eval_static_expr_bin_op(cast(BinaryExprNode*)node, context);
44 		case expr_un_op: return eval_static_expr_un_op(cast(UnaryExprNode*)node, context);
45 		case expr_type_conv: return eval_type_conv(cast(TypeConvExprNode*)node, eval_static_expr((cast(TypeConvExprNode*)node).expr, context), context);
46 		case expr_call: return eval_static_expr_call(cast(CallExprNode*)node, context);
47 		case literal_int: return ir_gen_literal_int(context, cast(IntLiteralExprNode*)node);
48 		case literal_float: return ir_gen_literal_float(context, cast(FloatLiteralExprNode*)node);
49 		case literal_string: return ir_gen_literal_string(context, cast(StringLiteralExprNode*)node);
50 		case literal_null: return ir_gen_literal_null(context, cast(NullLiteralExprNode*)node);
51 		case literal_bool: return ir_gen_literal_bool(context, cast(BoolLiteralExprNode*)node);
52 		case literal_special: return ir_gen_literal_special(context, cast(SpecialLiteralExprNode*)node);
53 
54 		case type_basic, type_ptr, type_slice, type_static_array, decl_function, decl_struct:
55 			return context.constants.add(makeIrType(IrBasicType.i32), nodeIndex.storageIndex);
56 		default:
57 			context.internal_error(node.loc, "Cannot evaluate static expression %s", node.astType);
58 	}
59 }
60 
61 /// Evaluates expression that results in $alias and returns it as AstIndex
62 AstIndex eval_static_expr_alias(AstIndex nodeIndex, CompilationContext* c)
63 {
64 	IrIndex val = eval_static_expr(nodeIndex, c);
65 	AstNode* node = c.getAstNode(nodeIndex);
66 	AstIndex retType = nodeIndex.get_node_alias(c);
67 	if (!(retType == CommonAstNodes.type_alias || retType == CommonAstNodes.type_type)) {
68 		c.internal_error(node.loc, "Cannot evaluate static expression %s as $alias", node.astType);
69 	}
70 	if (!val.isSomeConstant)
71 		c.internal_error(node.loc, "Cannot obtain $alias from %s", val);
72 	IrConstant con = c.constants.get(val);
73 	c.assertf(con.type == makeIrType(IrBasicType.i32), node.loc, "Cannot obtain $alias from %s", val);
74 	return AstIndex(con.i32);
75 }
76 
77 /// Evaluates expression that results in $type and returns it as AstIndex
78 AstIndex eval_static_expr_type(AstIndex nodeIndex, CompilationContext* c)
79 {
80 	IrIndex val = eval_static_expr(nodeIndex, c);
81 	AstNode* node = c.getAstNode(nodeIndex);
82 	if (nodeIndex.get_node_alias(c) != CommonAstNodes.type_type)
83 		c.internal_error(node.loc, "Cannot evaluate static expression %s as $type", node.astType);
84 	if (!val.isSomeConstant)
85 		c.internal_error(node.loc, "Cannot obtain $type from %s", val);
86 	IrConstant con = c.constants.get(val);
87 	c.assertf(con.type == makeIrType(IrBasicType.i32), node.loc, "Cannot obtain $type from %s", val);
88 	return AstIndex(con.i32);
89 }
90 
91 IrIndex eval_static_expr_name_use(NameUseExprNode* node, CompilationContext* context)
92 {
93 	return eval_static_expr(node.entity, context);
94 }
95 
96 IrIndex eval_static_expr_member(MemberExprNode* node, CompilationContext* c)
97 {
98 	switch(node.subType) with(MemberSubType)
99 	{
100 		case enum_member:
101 			return eval_static_expr(node.member(c), c);
102 		case builtin_member:
103 			return eval_builtin_member(node.member(c).get!BuiltinNode(c).builtin, node.aggregate, node.loc, c);
104 		case alias_array_length:
105 			auto ctParam = node.aggregate.get_effective_node(c).get!AliasArrayDeclNode(c);
106 			return c.constants.add(makeIrType(IrBasicType.i64), ctParam.items.length);
107 		default:
108 			AstIndex nodeIndex = get_ast_index(node, c);
109 			c.unrecoverable_error(node.loc,
110 				"Cannot access .%s member of %s while in CTFE (%s)",
111 				c.idString(get_node_id(nodeIndex, c)),
112 				get_node_kind_name(nodeIndex, c),
113 				cast(MemberSubType)node.subType);
114 	}
115 }
116 
117 IrIndex eval_builtin_member(BuiltinId builtin, AstIndex obj, TokenIndex loc, CompilationContext* c)
118 {
119 	AstIndex objType = obj.get_node_type(c);
120 	switch(builtin) with(BuiltinId)
121 	{
122 		case int_min:
123 			auto b = objType.get!BasicTypeNode(c);
124 			return c.constants.add(gen_ir_type_basic(b, c), b.minValue);
125 		case int_max:
126 			auto b = objType.get!BasicTypeNode(c);
127 			return c.constants.add(gen_ir_type_basic(b, c), b.maxValue);
128 		case array_length:
129 			require_type_check(objType, c);
130 			return c.constants.add(makeIrType(IrBasicType.i64), objType.get!StaticArrayTypeNode(c).length);
131 		case type_sizeof:
132 			SizeAndAlignment sizealign = objType.require_type_size(c);
133 			return c.constants.add(makeIrType(IrBasicType.i64), sizealign.size);
134 		case type_offsetof:
135 			auto member = obj.get!MemberExprNode(c);
136 			IrIndex irType = member.aggregate.gen_ir_type(c);
137 			c.assertf(irType.isTypeStruct, "%s", irType);
138 			IrTypeStructMember[] members = c.types.get!IrTypeStruct(irType).members;
139 			uint memberIndex = member.memberIndex(c);
140 			c.assertf(members.length > memberIndex, "member index (%s) out of bounds (%s)", memberIndex, members.length);
141 			return c.constants.add(makeIrType(IrBasicType.i64), members[memberIndex].offset, );
142 		default:
143 			c.unrecoverable_error(loc,
144 				"Cannot access .%s member of %s while in CTFE",
145 				builtinIdStrings[builtin],
146 				get_node_kind_name(objType, c));
147 	}
148 }
149 
150 IrIndex eval_static_expr_bin_op(BinaryExprNode* node, CompilationContext* c)
151 {
152 	switch (node.op) {
153 		case BinOp.LOGIC_AND:
154 			IrIndex leftVal = eval_static_expr(node.left, c);
155 			IrConstant leftCon = c.constants.get(leftVal);
156 			if (!leftCon.i64) return c.constants.addZeroConstant(makeIrType(IrBasicType.i8));
157 			return eval_static_expr(node.right, c);
158 		case BinOp.LOGIC_OR:
159 			IrIndex leftVal = eval_static_expr(node.left, c);
160 			IrConstant leftCon = c.constants.get(leftVal);
161 			if (leftCon.i64) return c.constants.add(makeIrType(IrBasicType.i8), 1);
162 			return eval_static_expr(node.right, c);
163 		default:
164 			IrIndex leftVal = eval_static_expr(node.left, c);
165 			IrIndex rightVal = eval_static_expr(node.right, c);
166 			return calcBinOp(node.op, leftVal, rightVal, c);
167 	}
168 }
169 
170 IrIndex eval_static_expr_un_op(UnaryExprNode* node, CompilationContext* c)
171 {
172 	ExpressionNode* child = node.child.get_expr(c);
173 	switch (node.op) with(UnOp)
174 	{
175 		case addrOf:
176 			switch(child.astType)
177 			{
178 				case AstType.expr_name_use:
179 					AstNode* entity = child.as!NameUseExprNode(c).entity.get_node(c);
180 
181 					switch (entity.astType)
182 					{
183 						case AstType.decl_function:
184 							// type is not pointer to function sig, but sig itself
185 							return entity.as!FunctionDeclNode(c).getIrIndex(c);
186 						case AstType.decl_var:
187 							// must be global
188 							auto v = entity.as!VariableDeclNode(c);
189 							if (v.isGlobal) {
190 								ir_gen_decl_var(c, v);
191 								return v.getIrIndex(c);
192 							}
193 							else
194 								c.unrecoverable_error(node.loc, "Can only take address of global variable while in CTFE");
195 						default:
196 							c.unrecoverable_error(node.loc, "Cannot take address of %s while in CTFE", entity.astType);
197 					}
198 				default:
199 					c.unrecoverable_error(node.loc, "Cannot take address of %s while in CTFE", child.astType);
200 			}
201 		default:
202 			IrIndex childVal = eval_static_expr(node.child, c);
203 			return calcUnOp(node.op, childVal, c);
204 	}
205 }
206 
207 IrIndex eval_static_expr_call(CallExprNode* node, CompilationContext* c)
208 {
209 	AstIndex callee = node.callee.get_effective_node(c);
210 
211 	switch (callee.astType(c))
212 	{
213 		case AstType.decl_struct:
214 			return eval_constructor(node, callee, c);
215 		case AstType.decl_function:
216 			return eval_call(node, callee, c);
217 		default:
218 			c.internal_error(node.loc, "Cannot call %s at compile-time", callee.get_node_type(c).get_type(c).printer(c));
219 	}
220 }
221 
222 IrIndex eval_constructor(CallExprNode* node, AstIndex callee, CompilationContext* c)
223 {
224 	StructDeclNode* s = callee.get!StructDeclNode(c);
225 
226 	if (node.args.length == 0) {
227 		return s.gen_init_value_struct(c);
228 	}
229 
230 	IrIndex structType = s.gen_ir_type_struct(c);
231 	uint numStructMembers = c.types.get!IrTypeStruct(structType).numMembers;
232 	IrIndex[] args = c.allocateTempArray!IrIndex(numStructMembers);
233 	scope(exit) c.freeTempArray(args);
234 
235 	bool allZeroes = true;
236 	uint memberIndex;
237 	foreach(AstIndex member; s.declarations)
238 	{
239 		AstNode* memberVarNode = member.get_node(c);
240 		if (memberVarNode.astType != AstType.decl_var) continue;
241 		VariableDeclNode* memberVar = memberVarNode.as!VariableDeclNode(c);
242 
243 		if (node.args.length > memberIndex) { // init from constructor argument
244 			IrIndex memberValue = eval_static_expr(node.args[memberIndex], c);
245 			args[memberIndex] = memberValue;
246 			if (!memberValue.isConstantZero) allZeroes = false;
247 		} else { // init with initializer from struct definition
248 			args[memberIndex] = memberVar.gen_init_value_var(c);
249 		}
250 
251 		++memberIndex;
252 	}
253 
254 	return c.constants.addAggrecateConstant(structType, args);
255 }
256 
257 void force_callee_ir_gen(FunctionDeclNode* callee, AstIndex calleeIndex, CompilationContext* c)
258 {
259 	switch(callee.state) with(AstNodeState)
260 	{
261 		case name_register_self, name_register_nested, name_resolve, type_check, ir_gen:
262 			c.circular_dependency(calleeIndex, NodeProperty.ir_body);
263 		case parse_done:
264 			auto name_state = NameRegisterState(c);
265 			require_name_register_self(0, calleeIndex, name_state);
266 			c.throwOnErrors;
267 			goto case;
268 		case name_register_self_done:
269 			auto name_state = NameRegisterState(c);
270 			require_name_register(calleeIndex, name_state);
271 			c.throwOnErrors;
272 			goto case;
273 		case name_register_nested_done:
274 			require_name_resolve(calleeIndex, c);
275 			c.throwOnErrors;
276 			goto case;
277 		case name_resolve_done:
278 			require_type_check(calleeIndex, c);
279 			c.throwOnErrors;
280 			goto case;
281 		case type_check_done:
282 			break; // all requirement are done
283 		case ir_gen_done: return; // already has IR
284 		default: c.internal_error(callee.loc, "Node %s in %s state", callee.astType, callee.state);
285 	}
286 
287 	c.push_analized_node(AnalysedNode(calleeIndex, NodeProperty.ir_body));
288 	scope(success) c.pop_analized_node;
289 
290 	IrGenState state = IrGenState(c);
291 	ir_gen_function(state, callee);
292 }
293 
294 IrIndex eval_call(CallExprNode* node, AstIndex callee, CompilationContext* c)
295 {
296 	auto func = callee.get!FunctionDeclNode(c);
297 
298 	force_callee_ir_gen(func, callee, c);
299 
300 	if (func.state != AstNodeState.ir_gen_done)
301 		c.internal_error(node.loc,
302 			"Function's IR is not yet generated");
303 
304 	auto signature = func.signature.get!FunctionSignatureNode(c);
305 	uint numArgs = node.args.length;
306 	uint numParams = signature.parameters.length;
307 	IrIndex[] args = c.allocateTempArray!IrIndex(numParams);
308 	scope(exit) c.freeTempArray(args);
309 
310 	foreach (i, AstIndex arg; node.args)
311 	{
312 		args[i] = eval_static_expr(arg, c);
313 	}
314 
315 	foreach(i; numArgs..numParams)
316 	{
317 		// use default argument value
318 		VariableDeclNode* param = c.getAst!VariableDeclNode(signature.parameters[i]);
319 		c.assertf(param.initializer.isDefined, param.loc, "Undefined default arg %s", c.idString(param.id));
320 		args[i] = param.gen_init_value_var(c);
321 	}
322 
323 	if (func.isBuiltin) {
324 		return eval_call_builtin(node.loc, null, callee, args, c);
325 	}
326 
327 	IrFunction* irData = c.getAst!IrFunction(func.backendData.irData);
328 	ubyte* vmBuffer = c.vmBuffer.bufPtr;
329 
330 	IrIndex retType = c.types.getReturnType(irData.type, c);
331 	c.assertf(!retType.isTypeVoid, node.loc, "Cannot eval call to function returning void");
332 
333 	uint retSize = c.types.typeSize(retType);
334 	IrVmSlotInfo returnMem = c.pushVmStack(retSize);
335 
336 	IrVm vm = IrVm(c, irData);
337 	vm.pushFrame;
338 	foreach(uint index, IrVmSlotInfo slot; vm.parameters)
339 	{
340 		//writefln("param %s %s", index, slot);
341 		ubyte[] mem = vmBuffer[slot.offset..slot.offset+slot.length];
342 		constantToMem(mem, args[index], c);
343 	}
344 	vm.run(returnMem);
345 
346 	ubyte[] returnSlice = vmBuffer[returnMem.offset..returnMem.offset+returnMem.length];
347 	IrIndex result = memToConstant(returnSlice, retType, c);
348 	vm.popFrame;
349 	c.popVmStack(returnMem);
350 
351 	return result;
352 }
353 
354 IrIndex eval_call_builtin(TokenIndex loc, IrVm* vm, AstIndex callee, IrIndex[] args, CompilationContext* c)
355 {
356 	switch (callee.storageIndex) {
357 		case CommonAstNodes.compile_error.storageIndex:
358 			c.unrecoverable_error(loc, "%s", stringFromIrValue(vm, args[0], c));
359 		case CommonAstNodes.is_slice.storageIndex:
360 			AstIndex nodeIndex = astIndexFromIrValue(vm, args[0], c);
361 			if (nodeIndex.isUndefined) return c.constants.addZeroConstant(makeIrType(IrBasicType.i8));
362 			AstNode* node = c.getAstNode(nodeIndex);
363 			if (node.astType != AstType.type_slice) return c.constants.addZeroConstant(makeIrType(IrBasicType.i8));
364 			return c.constants.add(makeIrType(IrBasicType.i8), 1);
365 		case CommonAstNodes.is_integer.storageIndex:
366 			AstIndex nodeIndex = astIndexFromIrValue(vm, args[0], c);
367 			if (nodeIndex.isUndefined) return c.constants.addZeroConstant(makeIrType(IrBasicType.i8));
368 			AstNode* node = c.getAstNode(nodeIndex);
369 			if (node.astType != AstType.type_basic) return c.constants.addZeroConstant(makeIrType(IrBasicType.i8));
370 			return c.constants.add(makeIrType(IrBasicType.i8), cast(ubyte)node.as!BasicTypeNode(c).isInteger);
371 		case CommonAstNodes.is_pointer.storageIndex:
372 			AstIndex nodeIndex = astIndexFromIrValue(vm, args[0], c);
373 			if (nodeIndex.isUndefined) return c.constants.addZeroConstant(makeIrType(IrBasicType.i8));
374 			return c.constants.add(makeIrType(IrBasicType.i8), nodeIndex.astType(c) == AstType.type_ptr);
375 		case CommonAstNodes.base_of.storageIndex:
376 			AstIndex nodeIndex = astIndexFromIrValue(vm, args[0], c);
377 			if (nodeIndex.isUndefined) return c.constants.addZeroConstant(makeIrType(IrBasicType.i32));
378 			AstNode* node = c.getAstNode(nodeIndex);
379 			AstIndex baseType;
380 			switch(node.astType) {
381 				case AstType.type_ptr: baseType = node.as!PtrTypeNode(c).base; break;
382 				case AstType.type_slice: baseType = node.as!SliceTypeNode(c).base; break;
383 				case AstType.type_static_array: baseType = node.as!StaticArrayTypeNode(c).base; break;
384 				default: return c.constants.addZeroConstant(makeIrType(IrBasicType.i32));
385 			}
386 			baseType = baseType.get_effective_node(c);
387 			return c.constants.add(makeIrType(IrBasicType.i32), baseType.storageIndex);
388 		default:
389 			c.internal_error("Unknown builtin function %s", c.idString(callee.get_node_id(c)));
390 	}
391 }