qbe

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

fold.c (11332B)


      1 #include "all.h"
      2 
      3 enum {
      4 	Bot = -1, /* lattice bottom */
      5 	Top = 0,  /* lattice top (matches UNDEF) */
      6 };
      7 
      8 typedef struct Edge Edge;
      9 
     10 struct Edge {
     11 	int dest;
     12 	int dead;
     13 	Edge *work;
     14 };
     15 
     16 static int *val;
     17 static Edge *flowrk, (*edge)[2];
     18 static Use **usewrk;
     19 static uint nuse;
     20 
     21 static int
     22 iscon(Con *c, int w, uint64_t k)
     23 {
     24 	if (c->type != CBits)
     25 		return 0;
     26 	if (w)
     27 		return (uint64_t)c->bits.i == k;
     28 	else
     29 		return (uint32_t)c->bits.i == (uint32_t)k;
     30 }
     31 
     32 static int
     33 latval(Ref r)
     34 {
     35 	switch (rtype(r)) {
     36 	case RTmp:
     37 		return val[r.val];
     38 	case RCon:
     39 		return r.val;
     40 	default:
     41 		die("unreachable");
     42 	}
     43 }
     44 
     45 static int
     46 latmerge(int v, int m)
     47 {
     48 	return m == Top ? v : (v == Top || v == m) ? m : Bot;
     49 }
     50 
     51 static void
     52 update(int t, int m, Fn *fn)
     53 {
     54 	Tmp *tmp;
     55 	uint u;
     56 
     57 	m = latmerge(val[t], m);
     58 	if (m != val[t]) {
     59 		tmp = &fn->tmp[t];
     60 		for (u=0; u<tmp->nuse; u++) {
     61 			vgrow(&usewrk, ++nuse);
     62 			usewrk[nuse-1] = &tmp->use[u];
     63 		}
     64 		val[t] = m;
     65 	}
     66 }
     67 
     68 static int
     69 deadedge(int s, int d)
     70 {
     71 	Edge *e;
     72 
     73 	e = edge[s];
     74 	if (e[0].dest == d && !e[0].dead)
     75 		return 0;
     76 	if (e[1].dest == d && !e[1].dead)
     77 		return 0;
     78 	return 1;
     79 }
     80 
     81 static void
     82 visitphi(Phi *p, int n, Fn *fn)
     83 {
     84 	int v;
     85 	uint a;
     86 
     87 	v = Top;
     88 	for (a=0; a<p->narg; a++)
     89 		if (!deadedge(p->blk[a]->id, n))
     90 			v = latmerge(v, latval(p->arg[a]));
     91 	update(p->to.val, v, fn);
     92 }
     93 
     94 static int opfold(int, int, Con *, Con *, Fn *);
     95 
     96 static void
     97 visitins(Ins *i, Fn *fn)
     98 {
     99 	int v, l, r;
    100 
    101 	if (rtype(i->to) != RTmp)
    102 		return;
    103 	if (optab[i->op].canfold) {
    104 		l = latval(i->arg[0]);
    105 		if (!req(i->arg[1], R))
    106 			r = latval(i->arg[1]);
    107 		else
    108 			r = CON_Z.val;
    109 		if (l == Bot || r == Bot)
    110 			v = Bot;
    111 		else if (l == Top || r == Top)
    112 			v = Top;
    113 		else
    114 			v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
    115 	} else
    116 		v = Bot;
    117 	/* fprintf(stderr, "\nvisiting %s (%p)", optab[i->op].name, (void *)i); */
    118 	update(i->to.val, v, fn);
    119 }
    120 
    121 static void
    122 visitjmp(Blk *b, int n, Fn *fn)
    123 {
    124 	int l;
    125 
    126 	switch (b->jmp.type) {
    127 	case Jjnz:
    128 		l = latval(b->jmp.arg);
    129 		if (l == Bot) {
    130 			edge[n][1].work = flowrk;
    131 			edge[n][0].work = &edge[n][1];
    132 			flowrk = &edge[n][0];
    133 		}
    134 		else if (iscon(&fn->con[l], 0, 0)) {
    135 			assert(edge[n][0].dead);
    136 			edge[n][1].work = flowrk;
    137 			flowrk = &edge[n][1];
    138 		}
    139 		else {
    140 			assert(edge[n][1].dead);
    141 			edge[n][0].work = flowrk;
    142 			flowrk = &edge[n][0];
    143 		}
    144 		break;
    145 	case Jjmp:
    146 		edge[n][0].work = flowrk;
    147 		flowrk = &edge[n][0];
    148 		break;
    149 	case Jhlt:
    150 		break;
    151 	default:
    152 		if (isret(b->jmp.type))
    153 			break;
    154 		die("unreachable");
    155 	}
    156 }
    157 
    158 static void
    159 initedge(Edge *e, Blk *s)
    160 {
    161 	if (s)
    162 		e->dest = s->id;
    163 	else
    164 		e->dest = -1;
    165 	e->dead = 1;
    166 	e->work = 0;
    167 }
    168 
    169 static int
    170 renref(Ref *r)
    171 {
    172 	int l;
    173 
    174 	if (rtype(*r) == RTmp)
    175 		if ((l=val[r->val]) != Bot) {
    176 			*r = CON(l);
    177 			return 1;
    178 		}
    179 	return 0;
    180 }
    181 
    182 /* require rpo, use, pred */
    183 void
    184 fold(Fn *fn)
    185 {
    186 	Edge *e, start;
    187 	Use *u;
    188 	Blk *b, **pb;
    189 	Phi *p, **pp;
    190 	Ins *i;
    191 	int t, d;
    192 	uint n, a;
    193 
    194 	val = emalloc(fn->ntmp * sizeof val[0]);
    195 	edge = emalloc(fn->nblk * sizeof edge[0]);
    196 	usewrk = vnew(0, sizeof usewrk[0], PHeap);
    197 
    198 	for (t=0; t<fn->ntmp; t++)
    199 		val[t] = Top;
    200 	for (n=0; n<fn->nblk; n++) {
    201 		b = fn->rpo[n];
    202 		b->visit = 0;
    203 		initedge(&edge[n][0], b->s1);
    204 		initedge(&edge[n][1], b->s2);
    205 	}
    206 	initedge(&start, fn->start);
    207 	flowrk = &start;
    208 	nuse = 0;
    209 
    210 	/* 1. find out constants and dead cfg edges */
    211 	for (;;) {
    212 		e = flowrk;
    213 		if (e) {
    214 			flowrk = e->work;
    215 			e->work = 0;
    216 			if (e->dest == -1 || !e->dead)
    217 				continue;
    218 			e->dead = 0;
    219 			n = e->dest;
    220 			b = fn->rpo[n];
    221 			for (p=b->phi; p; p=p->link)
    222 				visitphi(p, n, fn);
    223 			if (b->visit == 0) {
    224 				for (i=b->ins; i<&b->ins[b->nins]; i++)
    225 					visitins(i, fn);
    226 				visitjmp(b, n, fn);
    227 			}
    228 			b->visit++;
    229 			assert(b->jmp.type != Jjmp
    230 				|| !edge[n][0].dead
    231 				|| flowrk == &edge[n][0]);
    232 		}
    233 		else if (nuse) {
    234 			u = usewrk[--nuse];
    235 			n = u->bid;
    236 			b = fn->rpo[n];
    237 			if (b->visit == 0)
    238 				continue;
    239 			switch (u->type) {
    240 			case UPhi:
    241 				visitphi(u->u.phi, u->bid, fn);
    242 				break;
    243 			case UIns:
    244 				visitins(u->u.ins, fn);
    245 				break;
    246 			case UJmp:
    247 				visitjmp(b, n, fn);
    248 				break;
    249 			default:
    250 				die("unreachable");
    251 			}
    252 		}
    253 		else
    254 			break;
    255 	}
    256 
    257 	if (debug['F']) {
    258 		fprintf(stderr, "\n> SCCP findings:");
    259 		for (t=Tmp0; t<fn->ntmp; t++) {
    260 			if (val[t] == Bot)
    261 				continue;
    262 			fprintf(stderr, "\n%10s: ", fn->tmp[t].name);
    263 			if (val[t] == Top)
    264 				fprintf(stderr, "Top");
    265 			else
    266 				printref(CON(val[t]), fn, stderr);
    267 		}
    268 		fprintf(stderr, "\n dead code: ");
    269 	}
    270 
    271 	/* 2. trim dead code, replace constants */
    272 	d = 0;
    273 	for (pb=&fn->start; (b=*pb);) {
    274 		if (b->visit == 0) {
    275 			d = 1;
    276 			if (debug['F'])
    277 				fprintf(stderr, "%s ", b->name);
    278 			edgedel(b, &b->s1);
    279 			edgedel(b, &b->s2);
    280 			*pb = b->link;
    281 			continue;
    282 		}
    283 		for (pp=&b->phi; (p=*pp);)
    284 			if (val[p->to.val] != Bot)
    285 				*pp = p->link;
    286 			else {
    287 				for (a=0; a<p->narg; a++)
    288 					if (!deadedge(p->blk[a]->id, b->id))
    289 						renref(&p->arg[a]);
    290 				pp = &p->link;
    291 			}
    292 		for (i=b->ins; i<&b->ins[b->nins]; i++)
    293 			if (renref(&i->to))
    294 				*i = (Ins){.op = Onop};
    295 			else {
    296 				for (n=0; n<2; n++)
    297 					renref(&i->arg[n]);
    298 				if (isstore(i->op))
    299 				if (req(i->arg[0], UNDEF))
    300 					*i = (Ins){.op = Onop};
    301 			}
    302 		renref(&b->jmp.arg);
    303 		if (b->jmp.type == Jjnz && rtype(b->jmp.arg) == RCon) {
    304 				if (iscon(&fn->con[b->jmp.arg.val], 0, 0)) {
    305 					edgedel(b, &b->s1);
    306 					b->s1 = b->s2;
    307 					b->s2 = 0;
    308 				} else
    309 					edgedel(b, &b->s2);
    310 				b->jmp.type = Jjmp;
    311 				b->jmp.arg = R;
    312 		}
    313 		pb = &b->link;
    314 	}
    315 
    316 	if (debug['F']) {
    317 		if (!d)
    318 			fprintf(stderr, "(none)");
    319 		fprintf(stderr, "\n\n> After constant folding:\n");
    320 		printfn(fn, stderr);
    321 	}
    322 
    323 	free(val);
    324 	free(edge);
    325 	vfree(usewrk);
    326 }
    327 
    328 /* boring folding code */
    329 
    330 static int
    331 foldint(Con *res, int op, int w, Con *cl, Con *cr)
    332 {
    333 	union {
    334 		int64_t s;
    335 		uint64_t u;
    336 		float fs;
    337 		double fd;
    338 	} l, r;
    339 	uint64_t x;
    340 	Sym sym;
    341 	int typ;
    342 
    343 	memset(&sym, 0, sizeof sym);
    344 	typ = CBits;
    345 	l.s = cl->bits.i;
    346 	r.s = cr->bits.i;
    347 	if (op == Oadd) {
    348 		if (cl->type == CAddr) {
    349 			if (cr->type == CAddr)
    350 				return 1;
    351 			typ = CAddr;
    352 			sym = cl->sym;
    353 		}
    354 		else if (cr->type == CAddr) {
    355 			typ = CAddr;
    356 			sym = cr->sym;
    357 		}
    358 	}
    359 	else if (op == Osub) {
    360 		if (cl->type == CAddr) {
    361 			if (cr->type != CAddr) {
    362 				typ = CAddr;
    363 				sym = cl->sym;
    364 			} else if (!symeq(cl->sym, cr->sym))
    365 				return 1;
    366 		}
    367 		else if (cr->type == CAddr)
    368 			return 1;
    369 	}
    370 	else if (cl->type == CAddr || cr->type == CAddr)
    371 		return 1;
    372 	if (op == Odiv || op == Orem || op == Oudiv || op == Ourem) {
    373 		if (iscon(cr, w, 0))
    374 			return 1;
    375 		if (op == Odiv || op == Orem) {
    376 			x = w ? INT64_MIN : INT32_MIN;
    377 			if (iscon(cr, w, -1))
    378 			if (iscon(cl, w, x))
    379 				return 1;
    380 		}
    381 	}
    382 	switch (op) {
    383 	case Oadd:  x = l.u + r.u; break;
    384 	case Osub:  x = l.u - r.u; break;
    385 	case Oneg:  x = -l.u; break;
    386 	case Odiv:  x = w ? l.s / r.s : (int32_t)l.s / (int32_t)r.s; break;
    387 	case Orem:  x = w ? l.s % r.s : (int32_t)l.s % (int32_t)r.s; break;
    388 	case Oudiv: x = w ? l.u / r.u : (uint32_t)l.u / (uint32_t)r.u; break;
    389 	case Ourem: x = w ? l.u % r.u : (uint32_t)l.u % (uint32_t)r.u; break;
    390 	case Omul:  x = l.u * r.u; break;
    391 	case Oand:  x = l.u & r.u; break;
    392 	case Oor:   x = l.u | r.u; break;
    393 	case Oxor:  x = l.u ^ r.u; break;
    394 	case Osar:  x = (w ? l.s : (int32_t)l.s) >> (r.u & (31|w<<5)); break;
    395 	case Oshr:  x = (w ? l.u : (uint32_t)l.u) >> (r.u & (31|w<<5)); break;
    396 	case Oshl:  x = l.u << (r.u & (31|w<<5)); break;
    397 	case Oextsb: x = (int8_t)l.u;   break;
    398 	case Oextub: x = (uint8_t)l.u;  break;
    399 	case Oextsh: x = (int16_t)l.u;  break;
    400 	case Oextuh: x = (uint16_t)l.u; break;
    401 	case Oextsw: x = (int32_t)l.u;  break;
    402 	case Oextuw: x = (uint32_t)l.u; break;
    403 	case Ostosi: x = w ? (int64_t)cl->bits.s : (int32_t)cl->bits.s; break;
    404 	case Ostoui: x = w ? (uint64_t)cl->bits.s : (uint32_t)cl->bits.s; break;
    405 	case Odtosi: x = w ? (int64_t)cl->bits.d : (int32_t)cl->bits.d; break;
    406 	case Odtoui: x = w ? (uint64_t)cl->bits.d : (uint32_t)cl->bits.d; break;
    407 	case Ocast:
    408 		x = l.u;
    409 		if (cl->type == CAddr) {
    410 			typ = CAddr;
    411 			sym = cl->sym;
    412 		}
    413 		break;
    414 	default:
    415 		if (Ocmpw <= op && op <= Ocmpl1) {
    416 			if (op <= Ocmpw1) {
    417 				l.u = (int32_t)l.u;
    418 				r.u = (int32_t)r.u;
    419 			} else
    420 				op -= Ocmpl - Ocmpw;
    421 			switch (op - Ocmpw) {
    422 			case Ciule: x = l.u <= r.u; break;
    423 			case Ciult: x = l.u < r.u;  break;
    424 			case Cisle: x = l.s <= r.s; break;
    425 			case Cislt: x = l.s < r.s;  break;
    426 			case Cisgt: x = l.s > r.s;  break;
    427 			case Cisge: x = l.s >= r.s; break;
    428 			case Ciugt: x = l.u > r.u;  break;
    429 			case Ciuge: x = l.u >= r.u; break;
    430 			case Cieq:  x = l.u == r.u; break;
    431 			case Cine:  x = l.u != r.u; break;
    432 			default: die("unreachable");
    433 			}
    434 		}
    435 		else if (Ocmps <= op && op <= Ocmps1) {
    436 			switch (op - Ocmps) {
    437 			case Cfle: x = l.fs <= r.fs; break;
    438 			case Cflt: x = l.fs < r.fs;  break;
    439 			case Cfgt: x = l.fs > r.fs;  break;
    440 			case Cfge: x = l.fs >= r.fs; break;
    441 			case Cfne: x = l.fs != r.fs; break;
    442 			case Cfeq: x = l.fs == r.fs; break;
    443 			case Cfo: x = l.fs < r.fs || l.fs >= r.fs; break;
    444 			case Cfuo: x = !(l.fs < r.fs || l.fs >= r.fs); break;
    445 			default: die("unreachable");
    446 			}
    447 		}
    448 		else if (Ocmpd <= op && op <= Ocmpd1) {
    449 			switch (op - Ocmpd) {
    450 			case Cfle: x = l.fd <= r.fd; break;
    451 			case Cflt: x = l.fd < r.fd;  break;
    452 			case Cfgt: x = l.fd > r.fd;  break;
    453 			case Cfge: x = l.fd >= r.fd; break;
    454 			case Cfne: x = l.fd != r.fd; break;
    455 			case Cfeq: x = l.fd == r.fd; break;
    456 			case Cfo: x = l.fd < r.fd || l.fd >= r.fd; break;
    457 			case Cfuo: x = !(l.fd < r.fd || l.fd >= r.fd); break;
    458 			default: die("unreachable");
    459 			}
    460 		}
    461 		else
    462 			die("unreachable");
    463 	}
    464 	*res = (Con){.type=typ, .sym=sym, .bits={.i=x}};
    465 	return 0;
    466 }
    467 
    468 static void
    469 foldflt(Con *res, int op, int w, Con *cl, Con *cr)
    470 {
    471 	float xs, ls, rs;
    472 	double xd, ld, rd;
    473 
    474 	if (cl->type != CBits || cr->type != CBits)
    475 		err("invalid address operand for '%s'", optab[op].name);
    476 	*res = (Con){.type = CBits};
    477 	memset(&res->bits, 0, sizeof(res->bits));
    478 	if (w)  {
    479 		ld = cl->bits.d;
    480 		rd = cr->bits.d;
    481 		switch (op) {
    482 		case Oadd: xd = ld + rd; break;
    483 		case Osub: xd = ld - rd; break;
    484 		case Oneg: xd = -ld; break;
    485 		case Odiv: xd = ld / rd; break;
    486 		case Omul: xd = ld * rd; break;
    487 		case Oswtof: xd = (int32_t)cl->bits.i; break;
    488 		case Ouwtof: xd = (uint32_t)cl->bits.i; break;
    489 		case Osltof: xd = (int64_t)cl->bits.i; break;
    490 		case Oultof: xd = (uint64_t)cl->bits.i; break;
    491 		case Oexts: xd = cl->bits.s; break;
    492 		case Ocast: xd = ld; break;
    493 		default: die("unreachable");
    494 		}
    495 		res->bits.d = xd;
    496 		res->flt = 2;
    497 	} else {
    498 		ls = cl->bits.s;
    499 		rs = cr->bits.s;
    500 		switch (op) {
    501 		case Oadd: xs = ls + rs; break;
    502 		case Osub: xs = ls - rs; break;
    503 		case Oneg: xs = -ls; break;
    504 		case Odiv: xs = ls / rs; break;
    505 		case Omul: xs = ls * rs; break;
    506 		case Oswtof: xs = (int32_t)cl->bits.i; break;
    507 		case Ouwtof: xs = (uint32_t)cl->bits.i; break;
    508 		case Osltof: xs = (int64_t)cl->bits.i; break;
    509 		case Oultof: xs = (uint64_t)cl->bits.i; break;
    510 		case Otruncd: xs = cl->bits.d; break;
    511 		case Ocast: xs = ls; break;
    512 		default: die("unreachable");
    513 		}
    514 		res->bits.s = xs;
    515 		res->flt = 1;
    516 	}
    517 }
    518 
    519 static int
    520 opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
    521 {
    522 	Ref r;
    523 	Con c;
    524 
    525 	if (cls == Kw || cls == Kl) {
    526 		if (foldint(&c, op, cls == Kl, cl, cr))
    527 			return Bot;
    528 	} else
    529 		foldflt(&c, op, cls == Kd, cl, cr);
    530 	if (!KWIDE(cls))
    531 		c.bits.i &= 0xffffffff;
    532 	r = newcon(&c, fn);
    533 	assert(!(cls == Ks || cls == Kd) || c.flt);
    534 	return r.val;
    535 }