1 /// Copyright: Copyright (c) 2017-2019 Andrey Penechko.
2 /// License: $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
3 /// Authors: Andrey Penechko.
4 module vox.fe.ast.expr.member_access;
5 
6 import vox.all;
7 
8 enum MemberSubType
9 {
10 	unresolved,
11 	builtin_member, // member is decl_builtin
12 	static_struct_member,
13 	struct_member,
14 	struct_method,
15 	struct_templ_method,
16 	enum_member,
17 	slice_member,
18 	alias_array_length,
19 }
20 
21 enum MemberExprFlags : ushort {
22 	needsDeref = AstFlags.userFlag << 0,
23 }
24 
25 // member access of aggregate.member form
26 // loc points to member identifier location
27 @(AstType.expr_member)
28 struct MemberExprNode {
29 	mixin ExpressionNodeData!(AstType.expr_member);
30 
31 	ScopeIndex parentScope; // set in parser
32 	AstIndex aggregate;
33 	union
34 	{
35 		// when unresolved
36 		struct {
37 			private Identifier _memberId; // member name before resolution
38 		}
39 		// when resolved
40 		struct {
41 			private AstIndex _member; // member node after resolution
42 			private uint _memberIndex; // resolved index of member being accessed. TODO: remove
43 		}
44 	}
45 
46 	bool isSymResolved() { return subType != MemberSubType.unresolved; }
47 	bool needsDeref() { return cast(bool)(flags & MemberExprFlags.needsDeref); }
48 
49 	void resolve(MemberSubType subType, AstIndex member, uint memberIndex, CompilationContext* c)
50 	{
51 		assert(subType != MemberSubType.unresolved);
52 		assert(member);
53 		this.subType = subType;
54 		_member = member;
55 		_memberIndex = memberIndex;
56 	}
57 
58 	ref AstIndex member(CompilationContext* c) return {
59 		c.assertf(isSymResolved, loc, "Member access is %s, %s", cast(MemberSubType)subType, state);
60 		return _member;
61 	}
62 	uint memberIndex(CompilationContext* c) {
63 		c.assertf(isSymResolved, loc, "Member access is %s, %s", cast(MemberSubType)subType, state);
64 		return _memberIndex;
65 	}
66 	ref Identifier memberId(CompilationContext* c) return {
67 		return isSymResolved ? _member.get_node_id(c) : _memberId;
68 	}
69 
70 	this(TokenIndex loc, ScopeIndex parentScope, AstIndex aggregate, Identifier memberId, AstIndex type = AstIndex.init)
71 	{
72 		this.loc = loc;
73 		this.astType = AstType.expr_member;
74 		this.state = AstNodeState.name_register_self_done;
75 		this.parentScope = parentScope;
76 		this.aggregate = aggregate;
77 		this._memberId = memberId;
78 		this.type = type;
79 	}
80 
81 	// produce already resolved node
82 	this(TokenIndex loc, ScopeIndex parentScope, AstIndex aggregate, AstIndex member, uint memberIndex, MemberSubType subType)
83 	{
84 		this.loc = loc;
85 		this.astType = AstType.expr_member;
86 		this.state = AstNodeState.name_register_self_done;
87 		this.parentScope = parentScope;
88 		this.aggregate = aggregate;
89 		this._member = member;
90 		this.subType = subType;
91 		this._memberIndex = memberIndex;
92 	}
93 }
94 
95 void print_member(MemberExprNode* node, ref AstPrintState state)
96 {
97 	state.print("MEMBER ", node.type.printer(state.context), " ", state.context.idString(node.memberId(state.context)), " ", cast(MemberSubType)node.subType);
98 	print_ast(node.aggregate, state);
99 }
100 
101 void post_clone_member(MemberExprNode* node, ref CloneState state)
102 {
103 	assert(!node.isSymResolved);
104 	state.fixScope(node.parentScope);
105 	state.fixAstIndex(node.aggregate);
106 }
107 
108 void name_register_nested_member(MemberExprNode* node, ref NameRegisterState state) {
109 	node.state = AstNodeState.name_register_nested;
110 	require_name_register(node.aggregate, state);
111 	node.state = AstNodeState.name_register_nested_done;
112 }
113 
114 void name_resolve_member(MemberExprNode* node, ref NameResolveState state) {
115 	node.state = AstNodeState.name_resolve;
116 	assert(!node.isSymResolved);
117 	// name resolution is done in type check pass, because we need to know type of aggregate expression
118 	require_name_resolve(node.aggregate, state);
119 	node.state = AstNodeState.name_resolve_done;
120 }
121 
122 void type_check_member(ref AstIndex nodeIndex, MemberExprNode* node, ref TypeCheckState state)
123 {
124 	CompilationContext* c = state.context;
125 
126 	node.state = AstNodeState.type_check;
127 
128 	// try member
129 	// performs require_type_check on aggregate
130 	LookupResult res = lookupMember(nodeIndex, node, state);
131 
132 	if (res == LookupResult.success) {
133 		lowerMember(nodeIndex, node, state);
134 		nodeIndex.get_node(c).state = AstNodeState.type_check_done;
135 		return;
136 	}
137 
138 	// try UFCS
139 	AstIndex callIndex;
140 	LookupResult ufcsRes = tryUFCSCall(callIndex, node, state);
141 
142 	if (ufcsRes == LookupResult.success) {
143 		nodeIndex = callIndex;
144 		return;
145 	}
146 
147 	// nothing found
148 	node.type = CommonAstNodes.type_error;
149 	AstIndex objType = node.aggregate.get_node_type(c);
150 	c.error(node.loc, "`%s` has no member `%s`", objType.printer(c), c.idString(node.memberId(c)));
151 	nodeIndex.get_node(c).state = AstNodeState.type_check_done;
152 }
153 
154 // Creates call node if it is undefined (only creates when lookup is successfull)
155 LookupResult tryUFCSCall(ref AstIndex callIndex, MemberExprNode* memberNode, ref TypeCheckState state)
156 {
157 	CompilationContext* c = state.context;
158 
159 	AstIndex ufcsNodeIndex = lookupScopeIdRecursive(memberNode.parentScope.get_scope(c), memberNode.memberId(c), memberNode.loc, c);
160 	if (ufcsNodeIndex == CommonAstNodes.node_error) return LookupResult.failure;
161 
162 	AstType ufcsAstType = ufcsNodeIndex.astType(c);
163 
164 	if (ufcsAstType == AstType.decl_function)
165 	{
166 		// rewrite as call
167 		createMethodCall(callIndex, memberNode, ufcsNodeIndex, state);
168 		return LookupResult.success;
169 	}
170 
171 	return LookupResult.failure;
172 }
173 
174 void createMethodCall(ref AstIndex callIndex, MemberExprNode* memberNode, AstIndex member, ref TypeCheckState state)
175 {
176 	CompilationContext* c = state.context;
177 	if (callIndex.isUndefined)
178 		callIndex = c.appendAst!CallExprNode(memberNode.loc, AstIndex(), memberNode.parentScope);
179 
180 	auto call = callIndex.get!CallExprNode(c);
181 	call.state = AstNodeState.name_resolve_done;
182 	call.callee = member;
183 	auto method = call.callee.get!FunctionDeclNode(c);
184 	auto signature = method.signature.get!FunctionSignatureNode(c);
185 	AstIndex aggregate = memberNode.aggregate;
186 	if (method.isMember) lowerThisArgument(signature, aggregate, memberNode.loc, c);
187 	call.args.putFront(c.arrayArena, aggregate);
188 
189 	// type check call
190 	type_check_func_call(call, signature, memberNode.memberId(c), state);
191 }
192 
193 /// Makes sure that aggregate is of pointer type
194 void lowerThisArgument(FunctionSignatureNode* signature, ref AstIndex aggregate, TokenIndex loc, CompilationContext* c)
195 {
196 	auto thisType = signature.parameters[0].get_node_type(c); // Struct*
197 	c.assertf(thisType.isDefined, "null");
198 	auto structType = thisType.get!PtrTypeNode(c).base.get_node_type(c); // Struct
199 	if (aggregate.get_node_type(c) == structType) // rewrite Struct as Struct*
200 	{
201 		aggregate.flags(c) |= AstFlags.isLvalue;
202 		aggregate = c.appendAst!UnaryExprNode(loc, AstIndex.init, UnOp.addrOf, aggregate);
203 	}
204 	aggregate.get_node(c).state = AstNodeState.name_resolve_done;
205 }
206 
207 /// Look up member by Identifier. Searches aggregate scope for identifier.
208 LookupResult lookupMember(ref AstIndex nodeIndex, MemberExprNode* expr, ref TypeCheckState state)
209 {
210 	CompilationContext* c = state.context;
211 	if (expr.isSymResolved) {
212 		require_type_check(expr.aggregate, c, IsNested.no);
213 		if (expr.type.isUndefined)
214 			expr.type = expr._member.get_expr_type(c);
215 		return LookupResult.success;
216 	}
217 
218 	if (expr.aggregate.astType(c) == AstType.decl_alias_array)
219 	{
220 		expr.resolve(MemberSubType.alias_array_length, c.builtinNodes(BuiltinId.array_length), 0, c);
221 		expr.type = CommonAstNodes.type_u64;
222 		return LookupResult.success;
223 	}
224 
225 	require_type_check(expr.aggregate, c, IsNested.no);
226 	TypeNode* objType = expr.aggregate.get_type(c);
227 
228 	Identifier memberId = expr.memberId(c);
229 	if (memberId == CommonIds.id_sizeof)
230 	{
231 		expr.resolve(MemberSubType.builtin_member, c.builtinNodes(BuiltinId.type_sizeof), 0, c);
232 		expr.type = CommonAstNodes.type_u64;
233 		return LookupResult.success;
234 	}
235 
236 	if (memberId == CommonIds.id_offsetof)
237 	{
238 		if (expr.aggregate.astType(c) == AstType.expr_member)
239 		{
240 			auto innerMember = expr.aggregate.get!MemberExprNode(c);
241 			TypeNode* objType2 = innerMember.aggregate.get_type(c);
242 			if (objType2.isStruct)
243 			{
244 				if (cast(MemberSubType)innerMember.subType == MemberSubType.struct_member)
245 				{
246 					expr.resolve(MemberSubType.builtin_member, c.builtinNodes(BuiltinId.type_offsetof), 0, c);
247 					expr.type = CommonAstNodes.type_u64;
248 					return LookupResult.success;
249 				}
250 			}
251 		}
252 		return LookupResult.error;
253 	}
254 
255 	// Allow member access for pointers to structs
256 	if (objType.isPointer)
257 	{
258 		auto baseType = objType.as_ptr.base.get_type(c);
259 		if (baseType.isStruct)
260 		{
261 			objType = baseType;
262 			expr.flags |= MemberExprFlags.needsDeref;
263 		}
264 	}
265 
266 	switch(objType.astType)
267 	{
268 		case AstType.type_slice: return lookupSliceMember(expr, objType.as_slice, memberId, c);
269 		case AstType.type_static_array: return lookupStaticArrayMember(expr, objType.as_static_array, memberId, c);
270 		case AstType.decl_struct: return lookupStructMember(nodeIndex, expr, objType.as_struct, memberId, state);
271 		case AstType.decl_enum: return lookupEnumMember(expr, objType.as_enum, memberId, c);
272 		case AstType.type_basic: return lookupBasicMember(expr, objType.as_basic, memberId, c);
273 		default: return LookupResult.error;
274 	}
275 }
276 
277 LookupResult lookupEnumMember(MemberExprNode* expr, EnumDeclaration* enumDecl, Identifier id, CompilationContext* c)
278 {
279 	c.assertf(!enumDecl.isAnonymous, expr.loc,
280 		"Trying to get member from anonymous enum defined at %s",
281 		c.tokenLoc(enumDecl.loc));
282 
283 	AstIndex memberIndex = enumDecl.memberScope.lookup_scope(id, c);
284 	if (!memberIndex) return LookupResult.failure;
285 
286 	EnumMemberDecl* enumMember = memberIndex.get!EnumMemberDecl(c);
287 	expr.resolve(MemberSubType.enum_member, memberIndex, enumMember.scopeIndex, c);
288 	expr.type = enumMember.type;
289 
290 	return LookupResult.success;
291 }
292 
293 LookupResult lookupBasicMember(MemberExprNode* expr, BasicTypeNode* basicType, Identifier id, CompilationContext* c)
294 {
295 	if (basicType.isInteger)
296 	{
297 		if (id == CommonIds.id_min)
298 		{
299 			expr.resolve(MemberSubType.builtin_member, c.builtinNodes(BuiltinId.int_min), 0, c);
300 			expr.type = basicType.get_ast_index(c);
301 			return LookupResult.success;
302 		}
303 		else if (id == CommonIds.id_max)
304 		{
305 			expr.resolve(MemberSubType.builtin_member, c.builtinNodes(BuiltinId.int_max), 0, c);
306 			expr.type = basicType.get_ast_index(c);
307 			return LookupResult.success;
308 		}
309 	}
310 
311 	return LookupResult.failure;
312 }
313 
314 LookupResult lookupSliceMember(MemberExprNode* expr, SliceTypeNode* sliceType, Identifier id, CompilationContext* c)
315 {
316 	// use integer indices, because slice is a struct
317 	if (id == CommonIds.id_ptr)
318 	{
319 		expr.resolve(MemberSubType.slice_member, c.builtinNodes(BuiltinId.slice_ptr), 1, c);
320 		expr.type = c.appendAst!PtrTypeNode(sliceType.loc, CommonAstNodes.type_type, sliceType.base);
321 		expr.type.setState(c, AstNodeState.type_check_done);
322 		return LookupResult.success;
323 	}
324 	else if (id == CommonIds.id_length)
325 	{
326 		expr.resolve(MemberSubType.slice_member, c.builtinNodes(BuiltinId.slice_length), 0, c);
327 		expr.type = CommonAstNodes.type_u64;
328 		expr.type.setState(c, AstNodeState.type_check_done);
329 		return LookupResult.success;
330 	}
331 
332 	return LookupResult.failure;
333 }
334 
335 LookupResult lookupStaticArrayMember(MemberExprNode* expr, StaticArrayTypeNode* arrType, Identifier id, CompilationContext* c)
336 {
337 	if (id == CommonIds.id_ptr)
338 	{
339 		expr.resolve(MemberSubType.builtin_member, c.builtinNodes(BuiltinId.array_ptr), 0, c);
340 		expr.type = c.appendAst!PtrTypeNode(arrType.loc, CommonAstNodes.type_type, arrType.base);
341 		expr.type.setState(c, AstNodeState.type_check_done);
342 		return LookupResult.success;
343 	}
344 	else if (id == CommonIds.id_length)
345 	{
346 		expr.resolve(MemberSubType.builtin_member, c.builtinNodes(BuiltinId.array_length), 0, c);
347 		expr.type = CommonAstNodes.type_u64;
348 		expr.type.setState(c, AstNodeState.type_check_done);
349 		return LookupResult.success;
350 	}
351 
352 	return LookupResult.failure;
353 }
354 
355 LookupResult lookupStructMember(ref AstIndex nodeIndex, MemberExprNode* node, StructDeclNode* structDecl, Identifier id, ref TypeCheckState state)
356 {
357 	CompilationContext* c = state.context;
358 	AstIndex entity = c.getAstScope(structDecl.memberScope).symbols.get(id, AstIndex.init);
359 	if (!entity) {
360 		return LookupResult.failure;
361 	}
362 	AstType entityAstType = entity.astType(c);
363 
364 	switch(entityAstType)
365 	{
366 		case AstType.decl_function:
367 			node.resolve(MemberSubType.struct_method, entity, 0, c);
368 			node.type = entity.get_node_type(c);
369 			return LookupResult.success;
370 
371 		case AstType.decl_var:
372 			auto memberVar = entity.get!VariableDeclNode(c);
373 			if (memberVar.isMember)
374 				node.resolve(MemberSubType.struct_member, entity, memberVar.scopeIndex, c);
375 			else
376 				node.resolve(MemberSubType.static_struct_member, entity, memberVar.scopeIndex, c);
377 			node.type = entity.get_node_type(c);
378 			return LookupResult.success;
379 
380 		case AstType.decl_struct:
381 			node.resolve(MemberSubType.static_struct_member, entity, 0, c);
382 			node.type = entity.get_node_type(c);
383 			c.internal_error("member structs are not implemented");
384 
385 		case AstType.decl_enum:
386 			node.resolve(MemberSubType.static_struct_member, entity, 0, c);
387 			node.type = entity.get_node_type(c);
388 			return LookupResult.success;
389 
390 		case AstType.decl_alias:
391 			node.resolve(MemberSubType.static_struct_member, entity, 0, c);
392 			node.type = entity.get_node_type(c);
393 			return LookupResult.success;
394 
395 		case AstType.decl_enum_member:
396 			node.resolve(MemberSubType.enum_member, entity, 0, c);
397 			node.type = entity.get_node_type(c);
398 			return LookupResult.success;
399 
400 		case AstType.decl_template:
401 			node.resolve(MemberSubType.struct_templ_method, entity, 0, c);
402 			node.type = CommonAstNodes.type_alias;
403 			auto templ = entity.get!TemplateDeclNode(c);
404 			if (templ.body.astType(c) != AstType.decl_function) {
405 				c.unrecoverable_error(node.loc, "Cannot call template of %s", templ.body.astType(c));
406 			}
407 			return LookupResult.success;
408 
409 		default:
410 			c.internal_error("Unexpected struct member %s", entityAstType);
411 	}
412 }
413 
414 void lowerMember(ref AstIndex nodeIndex, MemberExprNode* node, ref TypeCheckState state)
415 {
416 	CompilationContext* c = state.context;
417 	switch(node.subType) with(MemberSubType)
418 	{
419 		case MemberSubType.struct_method:
420 			AstIndex effectiveMember = node.member(c).get_effective_node(c);
421 			// parentheses-less method call
422 			AstIndex callIndex;
423 			createMethodCall(callIndex, node, effectiveMember, state);
424 			nodeIndex = callIndex;
425 			return;
426 		case MemberSubType.struct_templ_method:
427 			AstIndex callee = node.member(c).get_effective_node(c);
428 			AstNodes types;
429 			callee = get_template_instance(callee, node.loc, types, state);
430 
431 			if (callee == CommonAstNodes.node_error) {
432 				node.type = CommonAstNodes.type_error;
433 				return;
434 			}
435 
436 			AstIndex callIndex = c.appendAst!CallExprNode(node.loc, AstIndex(), node.parentScope, callee);
437 			auto call = callIndex.get!CallExprNode(c);
438 			call.state = AstNodeState.name_resolve_done;
439 			nodeIndex = callIndex;
440 
441 			auto method = callee.get!FunctionDeclNode(c);
442 			auto signature = method.signature.get!FunctionSignatureNode(c);
443 			AstIndex aggregate = node.aggregate;
444 			if (method.isMember) lowerThisArgument(signature, aggregate, node.loc, c);
445 			call.args.putFront(c.arrayArena, aggregate);
446 
447 			// type check call
448 			type_check_func_call(call, signature, node.memberId(c), state);
449 			return;
450 		default: break;
451 	}
452 
453 	if (node.needsDeref) {
454 		TypeNode* objType = node.aggregate.get_type(c);
455 		auto baseType = objType.as_ptr.base.get_type(c);
456 		node.aggregate = c.appendAst!UnaryExprNode(node.loc, c.getAstNodeIndex(baseType), UnOp.deref, node.aggregate);
457 		node.aggregate.setState(c, AstNodeState.type_check_done);
458 	}
459 }
460 
461 ExprValue ir_gen_member(ref IrGenState gen, IrIndex currentBlock, ref IrLabel nextStmt, MemberExprNode* m)
462 {
463 	CompilationContext* c = gen.context;
464 
465 	switch(m.subType) with(MemberSubType)
466 	{
467 		case struct_member, slice_member:
468 			IrLabel afterAggr = IrLabel(currentBlock);
469 			ExprValue aggr = ir_gen_expr(gen, m.aggregate, currentBlock, afterAggr);
470 			TypeNode* objType = m.aggregate.get_type(c);
471 			currentBlock = afterAggr.blockIndex;
472 
473 			IrIndex memberIndex = c.constants.add(makeIrType(IrBasicType.i32), m.memberIndex(c));
474 			ExprValue result = aggr.member(gen, m.loc, currentBlock, memberIndex);
475 			gen.builder.addJumpToLabel(currentBlock, nextStmt);
476 			return result;
477 		case static_struct_member:
478 			auto v = m.member(c).get!VariableDeclNode(c);
479 			ir_gen_decl_var(c, v);
480 			ExprValue result = v.irValue;
481 			gen.builder.addJumpToLabel(currentBlock, nextStmt);
482 			return result;
483 		case struct_method:
484 			c.unreachable("Not implemented");
485 		case enum_member:
486 			IrIndex result = m.member(c).get!EnumMemberDecl(c).gen_init_value_enum_member(c);
487 			gen.builder.addJumpToLabel(currentBlock, nextStmt);
488 			return ExprValue(result);
489 		case builtin_member:
490 			BuiltinId builtin = m.member(c).get!BuiltinNode(c).builtin;
491 			switch(builtin) with(BuiltinId)
492 			{
493 				case array_ptr:
494 					IrLabel afterAggr = IrLabel(currentBlock);
495 					ExprValue aggr = ir_gen_expr(gen, m.aggregate, currentBlock, afterAggr);
496 					currentBlock = afterAggr.blockIndex;
497 					IrIndex ZERO = c.constants.addZeroConstant(makeIrType(IrBasicType.i32));
498 					IrIndex ptr = buildGEPEx(gen, m.loc, currentBlock, aggr, ZERO, ZERO);
499 					gen.builder.addJumpToLabel(currentBlock, nextStmt);
500 					return ExprValue(ptr);
501 				default:
502 					return ExprValue(eval_builtin_member(builtin, m.aggregate, m.loc, c));
503 			}
504 		default:
505 			c.internal_error(m.loc, "Unexpected node type %s", m.astType);
506 	}
507 }