1 /**
2 Copyright: Copyright (c) 2017-2019 Andrey Penechko.
3 License: $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
4 Authors: Andrey Penechko.
5 */
6 
7 module vox.be.optimize;
8 
9 import std.stdio;
10 import vox.all;
11 
12 alias FuncPassIr = void function(CompilationContext*, IrFunction*, IrIndex, ref IrBuilder);
13 alias FuncPass = void function(CompilationContext*, IrFunction*);
14 
15 void apply_lir_func_pass(CompilationContext* context, FuncPass pass, string passName)
16 {
17 	foreach (ref SourceFileInfo file; context.files.data)
18 	foreach (IrFunction* lir; file.mod.lirModule.functions) {
19 		pass(context, lir);
20 		if (context.validateIr)
21 			validateIrFunction(context, lir, passName);
22 	}
23 }
24 
25 void pass_optimize_ir(ref CompilationContext c, ref ModuleDeclNode mod, ref FunctionDeclNode func)
26 {
27 	if (func.isExternal) return;
28 
29 	FuncPassIr dcePass = &func_pass_remove_dead_code;
30 	if (c.disableDCE) dcePass = null;
31 
32 	FuncPassIr[3] passes = [
33 		c.disableInline ? null : &func_pass_inline,
34 		&func_pass_invert_conditions,
35 		dcePass,
36 	];
37 
38 	IrBuilder builder;
39 
40 	IrFunction* irData = c.getAst!IrFunction(func.backendData.irData);
41 	func.backendData.optimizedIrData = c.appendAst!IrFunction;
42 	IrFunction* optimizedIrData = c.getAst!IrFunction(func.backendData.optimizedIrData);
43 	*optimizedIrData = *irData; // copy
44 
45 	builder.beginDup(optimizedIrData, &c);
46 
47 	IrIndex funcIndex = func.getIrIndex(&c);
48 
49 	foreach (FuncPassIr pass; passes) {
50 		if (pass is null) continue;
51 		pass(&c, optimizedIrData, funcIndex, builder);
52 		if (c.validateIr)
53 			validateIrFunction(&c, optimizedIrData);
54 		if (c.printIrOptEach && c.printDumpOf(&func)) dumpFunction(&c, optimizedIrData, "IR opt");
55 	}
56 	if (!c.printIrOptEach && c.printIrOpt && c.printDumpOf(&func)) dumpFunction(&c, optimizedIrData, "IR opt all");
57 	builder.finalizeIr;
58 }
59 
60 void func_pass_inline(CompilationContext* c, IrFunction* ir, IrIndex funcIndex, ref IrBuilder builder)
61 {
62 	IrIndex* inlineStack = cast(IrIndex*)c.tempBuffer.nextPtr;
63 	uint inlineStackLen = 0;
64 	void pushFunc(IrIndex index) {
65 		c.tempBuffer.put(index.asUint);
66 		++inlineStackLen;
67 	}
68 	void popFunc() {
69 		c.tempBuffer.unput(1);
70 		--inlineStackLen;
71 	}
72 	bool isOnStack(IrIndex index)
73 	{
74 		foreach(IrIndex slot; inlineStack[0..inlineStackLen])
75 			if (slot == index) return true;
76 		return false;
77 	}
78 	pushFunc(funcIndex);
79 
80 	IrIndex blockIndex = ir.entryBasicBlock;
81 	while (blockIndex.isDefined)
82 	{
83 		IrBasicBlock* block = ir.getBlock(blockIndex);
84 		IrIndex instrIndex = block.firstInstr;
85 
86 		while (instrIndex.isInstruction)
87 		{
88 			IrIndex nextInstr = ir.nextInstr(instrIndex);
89 			IrInstrHeader* instrHeader = ir.getInstr(instrIndex);
90 
91 			void try_inline()
92 			{
93 				if (!instrHeader.alwaysInline) return; // inlining is not requested for this call
94 
95 				IrIndex calleeIndex = instrHeader.arg(ir, 0);
96 				if (!calleeIndex.isFunction) return; // cannot inline indirect calls
97 
98 				FunctionDeclNode* callee = c.getFunction(calleeIndex);
99 				if (callee.isExternal) return; // cannot inline external functions
100 
101 				if (isOnStack(calleeIndex)) return; // recursive call
102 
103 				IrFunction* calleeIr = c.getAst!IrFunction(callee.backendData.irData);
104 
105 				// Inliner returns the next instruction to visit
106 				// We will visit inlined code next
107 				nextInstr = inline_call(&builder, calleeIr, instrIndex, blockIndex);
108 				pushFunc(calleeIndex);
109 			}
110 
111 			switch(cast(IrOpcode)instrHeader.op)
112 			{
113 				case IrOpcode.call: try_inline(); break;
114 				case IrOpcode.inline_marker:
115 					removeInstruction(ir, instrIndex);
116 					popFunc;
117 					break;
118 				default:
119 					break;
120 			}
121 
122 			instrIndex = nextInstr;
123 		}
124 
125 		blockIndex = block.nextBlock;
126 	}
127 	popFunc;
128 }
129 
130 void func_pass_invert_conditions(CompilationContext* context, IrFunction* ir, IrIndex funcIndex, ref IrBuilder builder)
131 {
132 	ir.assignSequentialBlockIndices();
133 
134 	foreach (IrIndex blockIndex, ref IrBasicBlock block; ir.blocks)
135 	{
136 		if (!block.lastInstr.isDefined) continue;
137 
138 		IrInstrHeader* instrHeader = ir.getInstr(block.lastInstr);
139 		ubyte invertedCond;
140 
141 		switch(instrHeader.op) with(IrOpcode)
142 		{
143 			case branch_unary:
144 				invertedCond = invertUnaryCond(cast(IrUnaryCondition)instrHeader.cond);
145 				break;
146 			case branch_binary:
147 				invertedCond = invertBinaryCond(cast(IrBinaryCondition)instrHeader.cond);
148 				break;
149 
150 			default: continue;
151 		}
152 
153 		uint seqIndex0 = ir.getBlock(block.successors[0, ir]).seqIndex;
154 		uint seqIndex1 = ir.getBlock(block.successors[1, ir]).seqIndex;
155 		if (block.seqIndex + 1 == seqIndex0)
156 		{
157 			instrHeader.cond = invertedCond;
158 			IrIndex succIndex0 = block.successors[0, ir];
159 			IrIndex succIndex1 = block.successors[1, ir];
160 			block.successors[0, ir] = succIndex1;
161 			block.successors[1, ir] = succIndex0;
162 		}
163 	}
164 }
165 
166 void func_pass_remove_dead_code(CompilationContext* context, IrFunction* ir, IrIndex funcIndex, ref IrBuilder builder)
167 {
168 	auto funcInstrInfos = allInstrInfos[ir.instructionSet];
169 	foreach (IrIndex blockIndex, ref IrBasicBlock block; ir.blocksReverse)
170 	{
171 		foreach (IrIndex phiIndex, ref IrPhi phi; block.phis(ir))
172 		{
173 			if (ir.getVirtReg(phi.result).users.length == 0) {
174 				removePhi(context, ir, phiIndex);
175 				//writefln("removed dead %s", phiIndex);
176 			}
177 		}
178 
179 		foreach(IrIndex instrIndex, ref IrInstrHeader instrHeader; block.instructionsReverse(ir))
180 		{
181 			if (funcInstrInfos[instrHeader.op].hasSideEffects) continue;
182 
183 			if (instrHeader.hasResult) {
184 				if (!instrHeader.result(ir).isVirtReg) continue;
185 				if (ir.getVirtReg(instrHeader.result(ir)).users.length > 0) continue;
186 			}
187 
188 			// instruction without side effects
189 			// instruction's result is unused or has no result
190 			// remove that instruction
191 			foreach(ref IrIndex arg; instrHeader.args(ir)) {
192 				removeUser(context, ir, instrIndex, arg);
193 			}
194 			removeInstruction(ir, instrIndex);
195 			//writefln("removed dead %s", instrIndex);
196 		}
197 	}
198 }
199 
200 /*
201 void lir_func_pass_simplify(ref CompilationContext context, ref IrFunction ir)
202 {
203 	foreach (IrIndex blockIndex, ref IrBasicBlock block; ir.blocksReverse)
204 	{
205 		foreach(IrIndex instrIndex, ref IrInstrHeader instrHeader; block.instructionsReverse(ir))
206 		{
207 			switch(cast(Amd64Opcode)instrHeader.op) with(Amd64Opcode)
208 			{
209 				case mov:
210 					static assert(LirAmd64Instr_xor.sizeof == LirAmd64Instr_mov.sizeof);
211 					// replace 'mov reg, 0' with xor reg reg
212 					IrIndex dst = instrHeader.result;
213 					IrIndex src = instrHeader.args[0];
214 					if (src.isSimpleConstant && context.constants.get(src).i64 == 0)
215 					{
216 
217 					}
218 				default: break;
219 			}
220 		}
221 	}
222 }
223 */
224 void pass_optimize_lir(CompilationContext* context)
225 {
226 	apply_lir_func_pass(context, &pass_optimize_lir_func, "Optimize LIR");
227 }
228 
229 void pass_optimize_lir_func(CompilationContext* context, IrFunction* ir)
230 {
231 	ir.assignSequentialBlockIndices();
232 
233 	foreach (IrIndex blockIndex, ref IrBasicBlock block; ir.blocks)
234 	{
235 		if (!block.lastInstr.isDefined) continue;
236 
237 		IrInstrHeader* instrHeader = ir.getInstr(block.lastInstr);
238 		auto isJump = context.machineInfo.instrInfo[instrHeader.op].isJump;
239 
240 		if (isJump)
241 		{
242 			uint seqIndex0 = ir.getBlock(block.successors[0, ir]).seqIndex;
243 			// successor is the next instruction after current block
244 			if (block.seqIndex + 1 == seqIndex0)
245 			{
246 				removeInstruction(ir, block.lastInstr);
247 			}
248 		}
249 	}
250 }