1 /**
2 Copyright: Copyright (c) 2017-2019 Andrey Penechko.
3 License: $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
4 Authors: Andrey Penechko.
5 */
6 module sandbox;
7 
8 import std.stdio;
9 import amd64asm;
10 import utils;
11 
12 enum PAGE_SIZE = 4096;
13 
14 // for printing
15 enum Reg8  : ubyte {AL, CL, DL, BL, SPL,BPL,SIL,DIL,R8B,R9B,R10B,R11B,R12B,R13B,R14B,R15B}
16 enum Reg16 : ubyte {AX, CX, DX, BX, SP, BP, SI, DI, R8W,R9W,R10W,R11W,R12W,R13W,R14W,R15W}
17 enum Reg32 : ubyte {EAX,ECX,EDX,EBX,ESP,EBP,ESI,EDI,R8D,R9D,R10D,R11D,R12D,R13D,R14D,R15D}
18 enum Reg64 : ubyte {RAX,RCX,RDX,RBX,RSP,RBP,RSI,RDI,R8, R9, R10, R11, R12, R13, R14, R15 }
19 
20 extern(C) int printNum(int param){ write(param); return 0; }
21 extern(C) int printNumLn(int param){ writeln(param); return 0; }
22 
23 void main()
24 {
25 	//run_from_rwx();
26 	//testPrintMemAddress();
27 	//testVMs();
28 	//testLang();
29 	//testLang2();
30 
31 	{
32 		CodeGen_x86_64 codeGen;
33 		codeGen.encoder.setBuffer(alloc_executable_memory(PAGE_SIZE * 1024));
34 		scope(exit) free_executable_memory(codeGen.encoder.freeBuffer);
35 
36 		codeGen.addq(Register.CX, Register.AX);
37 		printHex(codeGen.encoder.code, 10);
38 	}
39 	testAll();
40 	/*
41 	writefln("main() == %s", runScript(input, "main"));
42 	writefln("main() == %s", runScript(q{i32 main(){return 42;}}, "main"));
43 	writefln("main(20) == %s", runScript(q{i32 main(i32 par){return par;}}, "main", 20));
44 	writefln("main(20) == %s", runScript(q{i32 main(i32 par){return par+par;}}, "main", 20));
45 	writefln("main(20) == %s", runScript(q{i32 main(i32 par){return sub(par)+10;} i32 sub(i32 a){return a+a;}}, "main", 20));
46 
47 	{
48 		LangVM vm;
49 		vm.setup();
50 		scope(exit) vm.free;
51 
52 		writefln("%s", cast(void*)&printNum);
53 		vm.registerFunction("print", &printNum);
54 		vm.registerFunction("println", &printNumLn);
55 		vm.compileModule(q{ void test(i32 i) { print(i); println(i+i); } });
56 		if (vm.valid)
57 		{
58 			printHex(vm.codeGen.code, 16);
59 			writefln("fun table %s", vm.codeGen.functionTable);
60 			writefln("%s", vm.run!int("test", __LINE__, __FILE__, 20));
61 		}
62 	}
63 
64 	{
65 		CodeGen_x86_64 codeGen;
66 		codeGen.encoder.setBuffer(alloc_executable_memory(PAGE_SIZE * 1024));
67 		scope(exit) free_executable_memory(codeGen.encoder.freeBuffer);
68 
69 		//writefln("MOV byte ptr %s, 0x%X", memAddrDisp32(0x55667788), 0xAA);
70 
71 		alias R = Reg64;
72 		enum regMax = cast(R)(R.max+1);
73 		foreach (R regB; R.min..regMax)
74 		{
75 			//codeGen.addq(memAddrBase(cast(Register)regB), Imm8(1));
76 			//writefln("mov %s, qword ptr %s", regB, memAddrBaseIndexDisp8(cast(Register)regB, cast(Register)regB, SibScale(3), 0xFE));
77 			//codeGen.movq(cast(Register)regB, Imm64(0x24364758AABBCCDD));
78 		}
79 
80 		codeGen.call(memAddrRipDisp32(100));
81 		codeGen.call(codeGen.stubPC+100);
82 	}
83 
84 	//printHex(codeGen.encoder.code, 10);
85 	*/
86 }
87 
88 void testAll()
89 {
90 	import asmtest.utils;
91 	CodegenTester tester;
92 
93 	tester.setup();
94 	scope(exit) tester.free();
95 
96 	import asmtest.add;
97 	import asmtest.mov;
98 	import asmtest.not;
99 	import asmtest.mul;
100 	import asmtest.inc;
101 	import asmtest.pop;
102 	import asmtest.push;
103 	import asmtest.cmp;
104 	import asmtest.jmp_jcc_setcc;
105 	import asmtest.imul;
106 
107 	testAdd(tester);
108 	testMov(tester);
109 	testNot(tester);
110 	testMul(tester);
111 	testInc(tester);
112 	testPop(tester);
113 	testPush(tester);
114 	testCmp(tester);
115 	testJmpJccSetcc(tester);
116 	testImul(tester);
117 }
118 
119 void testPrintMemAddress()
120 {
121 	writeln(memAddrDisp32(0x11223344));
122 	writeln(memAddrIndexDisp32(Register.AX, SibScale(0), 0x11223344));
123 	writeln(memAddrBase(Register.AX));
124 	writeln(memAddrBaseDisp32(Register.AX, 0x11223344));
125 	writeln(memAddrBaseIndex(Register.AX, Register.BX, SibScale(1)));
126 	writeln(memAddrBaseIndexDisp32(Register.AX, Register.BX, SibScale(2), 0x11223344));
127 	writeln(memAddrBaseDisp8(Register.AX, 0xFE));
128 	writeln(memAddrBaseIndexDisp8(Register.AX, Register.BX, SibScale(3), 0xFE));
129 }
130 
131 void run_from_rwx()
132 {
133 	const size_t SIZE = 4096;
134 	ubyte[] mem = alloc_executable_memory(SIZE);
135 	scope(exit) free_executable_memory(mem);
136 	//writefln("alloc %s bytes at %s", mem.length, mem.ptr);
137 
138 	emit_code_into_memory(mem);
139 
140 	alias JittedFunc = long function(long);
141 	JittedFunc func = cast(JittedFunc)mem.ptr;
142 
143 	long result = func(2);
144 	assert(result == 42);
145 
146 	//writefln("func(2) == %s", result);
147 }
148 
149 void emit_code_into_memory(ubyte[] mem)
150 {
151 	CodeGen_x86_64 codeGen;
152 	codeGen.encoder.setBuffer(mem);
153 
154 	version(Windows)
155 	{
156 		// main
157 		codeGen.beginFunction();
158 		auto sub_call = codeGen.saveFixup();
159 		codeGen.call(codeGen.pc);
160 		codeGen.endFunction();
161 
162 		sub_call.call(codeGen.pc);
163 
164 		// sub_fun
165 		codeGen.beginFunction();
166 		codeGen.movq(Register.AX, Imm32(42));
167 		codeGen.endFunction();
168 	}
169 	else version(Posix)
170 	{
171 		//codeGen.movq(Register.AX, Register.DI);
172 		//codeGen.addq(Register.AX, Imm8(4));
173 		//codeGen.movq(memAddrBaseDisp32(Register.AX, 0x55), Imm32(0xAABBCCDD));
174 		codeGen.ret();
175 	}
176 	//printHex(codeGen.encoder.code, 16);
177 }
178 
179 /*---------------------------------------------------------------------------*/
180 // Tiny C
181 
182 string[] testSources = [
183 `{ i=1; while (i<100) { while (j < 100) j=j+1; i=i+1;} }`,
184 `{ i=1; while (i<100) { j=0; while (j < 100) {j=j+1;a=a+1;} i=i+1;} }`,
185 `{ i=i+1; }`,
186 `{ i;i;i; }`,
187 `{ i=125; j=100; while (i-j) if (i<j) j=j-i; else i=i-j; }`,
188 `{ if (i) j=1; else j=2; }`,
189 "a=b=c=2<3;",
190 "{ i=1; do i=i+10; while (i<50); }",
191 "{ i=1; while ((i=i+10)<50) ; }",
192 "{ i=7; if (i<5) n=1; if (i<10) y=2; }",
193 ];
194 
195 struct Source
196 {
197 	const(char)[] slice;
198 
199 	char stdinGetter() {
200 		return cast(char)getchar();
201 	}
202 
203 	void reset()
204 	{
205 		slice = testSources[3];
206 	}
207 
208 	char testGetter() {
209 		if (slice.length == 0) return 255;
210 
211 		char ch = slice[0];
212 		slice = slice[1..$];
213 		return ch;
214 	}
215 }
216 
217 void testVMs()
218 {
219 	import tinyc;
220 	import utils;
221 
222 	auto time0 = currTime;
223 	enum times = 1_000;
224 
225 	Source source;
226 	Lexer lexer;
227 	Parser parser = Parser(&lexer);
228 	Node* rootNode;
229 	foreach (_; 0..times)
230 	{
231 		source.reset();
232 		lexer = Lexer(&source.testGetter);
233 		rootNode = parser.program();
234 	}
235 
236 	auto time1 = currTime;
237 
238 	byte[1000] _object; // executable
239 
240 	CodeGenerator codeGen;
241 	foreach (_; 0..times)
242 	{
243 		codeGen = CodeGenerator(_object.ptr);
244 		codeGen.compile(rootNode);
245 	}
246 
247 	auto time2 = currTime;
248 
249 	VM vm;
250 	foreach (_; 0..times)
251 	{
252 		vm = VM();
253 		vm.run(_object);
254 	}
255 
256 	auto time3 = currTime;
257 
258 	foreach(i, v; vm.globals)
259 		if (v != 0) writefln("%s = %s", cast(char)('a'+i), v);
260 
261 	writefln("Parse: %ss, compile: %ss, run: %ss",
262 		scaledNumberFmt(time1 - time0, 1.0/times),
263 		scaledNumberFmt(time2 - time1, 1.0/times),
264 		scaledNumberFmt(time3 - time2, 1.0/times));
265 
266 	JitVM jit_vm;
267 	scope(exit) jit_vm.free;
268 	time1 = currTime;
269 	foreach (_; 0..times) {
270 		jit_vm.compile(rootNode);
271 	}
272 	time2 = currTime;
273 	foreach (_; 0..times) {
274 		jit_vm.reset();
275 		jit_vm.run();
276 	}
277 	time3 = currTime;
278 	foreach(i, v; jit_vm.globals)
279 		if (v != 0) writefln("%s = %s", cast(char)('a'+i), v);
280 
281 	writefln("Compile: %ss, run: %ss",
282 		scaledNumberFmt(time2 - time1, 1.0/times),
283 		scaledNumberFmt(time3 - time2, 1.0/times));
284 
285 	//printHex(jit_vm.code, 8);
286 	//printAST(rootNode);
287 
288 	writefln("Total %ss", scaledNumberFmt(time3 - time0));
289 }
290 
291 auto runScript(int line = __LINE__, string file = __FILE__, Args...)(string input, string funcName, Args args)
292 {
293 	LangVM vm; vm.setup;
294 	scope(exit) vm.free;
295 	vm.compileModule(input);
296 	return vm.run!(int)(funcName, line, file, args);
297 }
298 
299 struct LangVM
300 {
301 	import lang;
302 
303 	private IdentifierMap idMap;
304 	private Lexer2 lexer;
305 	private Parser parser;
306 	private LangCodeGen codeGen;
307 	private Module moduleDecl;
308 	private ModuleSemantics moduleSemantics;
309 	private NativeFunction[] nativeFunctions;
310 	private bool valid;
311 
312 	void setup()
313 	{
314 		idMap = new IdentifierMap();
315 		parser = Parser(&lexer, idMap);
316 		parser.setup();
317 		codeGen.setup();
318 	}
319 
320 	void free()
321 	{
322 		codeGen.free;
323 	}
324 
325 	void registerFunction(string name, NativeFunPtr fun)
326 	{
327 		nativeFunctions ~= NativeFunction(idMap.getOrReg(name), 1, fun);
328 	}
329 
330 	void compileModule(string source)
331 	{
332 		valid = true;
333 		try {
334 			lexer = Lexer2(source);
335 			moduleDecl = parser.parseModule();
336 			moduleSemantics = analyzeModule(moduleDecl, idMap, nativeFunctions);
337 			codeGen.compileModule(moduleSemantics);
338 		} catch(CompilationException e) {
339 			writefln("[ERROR] %s: %s", e.loc, e.msg);
340 			writeln(e);
341 			valid = false;
342 		}
343 	}
344 
345 	ResultType run(ResultType, Args...)(string funcName, int line, string file, Args args)
346 	{
347 		if (!valid) throw runtime_error("Cannot start '%s'. Module compiled with errors", funcName, line, file);
348 
349 		foreach(arg; Args)
350 		{
351 			static assert(is(arg == int), "parameter must be int");
352 		}
353 		static assert(is(ResultType == int), "return type must be int");
354 		alias JittedFunc = extern(C) ResultType function(Args);
355 		auto id = idMap.find(funcName);
356 		if (id == Identifier.max) throw runtime_error("Unknown function name '%s'", funcName, line, file);
357 		auto fun = moduleSemantics.tryGetFunction(id);
358 		if (fun is null) throw runtime_error("'%s' is not a function name", funcName, line, file);
359 
360 		auto numArgs = Args.length;
361 		auto numParams = fun.node.parameters.length;
362 		if (numArgs < numParams)
363 			throw runtime_error("Insufficient parameters to '%s', got %s, expected %s",
364 				funcName, numArgs, numParams, line, file);
365 		else if (numArgs > numParams)
366 			throw runtime_error("Too much parameters to '%s', got %s, expected %s",
367 				funcName, numArgs, numParams, line, file);
368 
369 		JittedFunc func = cast(JittedFunc)fun.funcPtr;
370 		//assert(false);
371 		auto result = func(args);
372 		return result;
373 	}
374 }
375 
376 string input = q{
377 	i32 main() {
378 		i32 localVar;
379 		struct nestedStruct {}
380 		//fn i32 nestedFunc(i32 a) { return a + 1; }
381 		return localVar;
382 		//return sub(1, 2, 3, 4, 5, 6); // returns 21 as expected
383 	}
384 	//fn i32 sub(i32 a, i32 b, i32 c, i32 d, i32 e, i32 f) {
385 	//	return a + b + c + d + e + f;
386 	//}
387 	i32 globalVar;
388 	struct structWIP {}
389 };
390 
391 void testLang()
392 {
393 	import lang;
394 
395 	enum times = 10_000;
396 	auto time0 = currTime;
397 
398 	auto idMap = new IdentifierMap();
399 	Lexer2 lexer = Lexer2(input);
400 
401 	Parser parser = Parser(&lexer, idMap);
402 	Module moduleDecl;
403 	try
404 	{
405 		foreach (_; 0..times)
406 		{
407 			parser.setup();
408 			lexer = Lexer2(input);
409 			moduleDecl = parser.parseModule();
410 		}
411 	}
412 	catch(ParsingException e)
413 	{
414 		auto loc = e.loc;
415 		writefln("%s: [ERROR] %s", loc, e.msg);
416 		return;
417 	}
418 
419 	auto time1 = currTime;
420 
421 	ModuleSemantics moduleSemantics;
422 	try
423 	{
424 		foreach (_; 0..times)
425 		{
426 			moduleSemantics = analyzeModule(moduleDecl, idMap);
427 		}
428 	}
429 	catch(SemanticsException e)
430 	{
431 		auto loc = e.loc;
432 		writefln("[ERROR] %s: %s", loc, e.msg);
433 		return;
434 	}
435 
436 	auto time2 = currTime;
437 
438 	LangCodeGen codeGen;
439 	scope(exit) codeGen.free;
440 	foreach (_; 0..times)
441 	{
442 		codeGen.setup();
443 		codeGen.compileModule(moduleSemantics);
444 	}
445 
446 	auto time3 = currTime;
447 
448 	int res;
449 	foreach (_; 0..times)
450 	{
451 		alias JittedFunc = extern(C) int function(int, int, int, int, int, int);
452 		JittedFunc func = cast(JittedFunc)moduleSemantics.functions[0].funcPtr; // main
453 		res = func(1, 2, 3, 4, 5, 6);
454 	}
455 
456 	auto time4 = currTime;
457 
458 	writefln("Lang: parse %ss, semantics %ss, compile %ss, run %ss",
459 		scaledNumberFmt(time1 - time0, 1.0/times),
460 		scaledNumberFmt(time2 - time1, 1.0/times),
461 		scaledNumberFmt(time3 - time2, 1.0/times),
462 		scaledNumberFmt(time4 - time3, 1.0/times));
463 	writeln(input);
464 	printAST(moduleDecl, idMap);
465 	printHex(codeGen.code, 16);
466 	writefln("func() == %s", res);
467 }
468 
469 string input2 = q{
470 	i32 isNegative(i32 number) {
471 		i32 result;
472 		if (number < 0) result = 1;
473 		else result = 0;
474 		return result;
475 	}
476 };
477 
478 string input3 = q{
479 i32 sign(i32 number) {
480 	i32 result;
481 	if (number < 0) result = 0-1;
482 	else if (number < 0) result = 1;
483 	else result = 0
484 	return result;
485 }
486 };
487 
488 void testLang2()
489 {
490 	LangVM vm;
491 	vm.setup();
492 	scope(exit) vm.free;
493 
494 	writefln("%s", cast(void*)&printNum);
495 	vm.compileModule(input3);
496 	if (vm.valid)
497 	{
498 		printHex(vm.codeGen.code, 16);
499 		writefln("fun table %s", vm.codeGen.functionTable);
500 		writefln("%s", vm.run!int("sign", __LINE__, __FILE__, 10));
501 	}
502 }