/*	fe_optimizer.c

	Front end optimizations.
*/


#include "stdpccts.h"
#include "swartypes.h"
#include "oputils.h"
#include "messages.h"

static int
foldbinary(int op,
int a,
int b)
{
	/* Constant fold operations on two integers */

	switch (op) {
	case ADD:	return(a + b);
	case AND:	return(a & b);
	case AVG:	return((a + b) / 2);
	case DIV:	return(a / b);
	case EQ:	return(-(a == b));
	case GE:	return(-(a >= b));
	case GT:	return(-(a > b));
	case LAND:	return(-(a && b));
	case LE:	return(-(a <= b));
	case LOR:	return(-(a || b));
	case LT:	return(-(a < b));
	case MAX:	return((a > b) ? a : b);
	case MIN:	return((a > b) ? b : a);
	case MOD:	return(a % b);
	case MUL:	return(a * b);
	case NE:	return(-(a != b));
	case OR:	return(a | b);
	case SHL:	return(a << b);
	case SHR:	return(a >> b);
	case SUB:	return(a - b);
	case XOR:	return(a ^ b);
	}

	return(a);
}

static float
iasf(int i)
{
	/* Return an int as a float */

	union { float f; int i; } t;
	t.i = i;
	return(t.f);
}

static int
fasi(float f)
{
	/* Return a float as an int */

	union { float f; int i; } t;
	t.f = f;
	return(t.i);
}

static float
ffoldbinary(int op,
float a,
float b)
{
	/* Constant fold operations on two floats */

	switch (op) {
	case ADD:	return(a + b);
	case AVG:	return((a + b) / 2.0);
	case DIV:	return(a / b);
	case EQ:	return(-(a == b));
	case GE:	return(-(a >= b));
	case GT:	return(-(a > b));
	case LAND:	return(-(a && b));
	case LE:	return(-(a <= b));
	case LOR:	return(-(a || b));
	case LT:	return(-(a < b));
	case MAX:	return((a > b) ? a : b);
	case MIN:	return((a > b) ? b : a);
	case MUL:	return(a * b);
	case NE:	return(-(a != b));
	case SUB:	return(a - b);
	}

	return((float) foldbinary(op, ((int) a), ((int) b)));
}

static int
foldreduce(int op,
int a,
int dim)
{
	/* Constant fold reduction operations on integer vectors */

	register int r;
	register int i;

	/* Get reduction result for first element */
	switch (op) {
	case ALL:
	case ANY:	r = (numbuf[a] != 0); break;
	default:	r = numbuf[a];
	}

	/* Then use binary op to reduce the rest */
	for (i=1; i<dim; ++i) {
		switch (op) {
		case ALL:	r = (r && numbuf[i + a]); break;
		case ANY:	r = (r || numbuf[i + a]); break;
		case REDUCEAVG:
		case REDUCEADD:	r = (r + numbuf[i + a]); break;
		case REDUCEAND:	r = (r & numbuf[i + a]); break;
		case REDUCEMAX:	if (r < numbuf[i + a]) {
					r = numbuf[i + a];
				}
				break;
		case REDUCEMIN:	if (r > numbuf[i + a]) {
					r = numbuf[i + a];
				}
				break;
		case REDUCEMUL:	r = (r * numbuf[i + a]); break;
		case REDUCEOR:	r = (r | numbuf[i + a]); break;
		case REDUCEXOR:	r = (r ^ numbuf[i + a]); break;
		}
	}

	/* AVeraGe was done as ADD, now do final divide */
	if (op == REDUCEAVG) r /= dim;

	return(r);
}

static int
co_isconst(tree *t)
{
	/* Return 1 if t is a constant, 0 otherwise.
	   "t" is expected to exist */

	/* Is tree node t a constant? */
	return((t->op == VNUM) || (t->op == NUM));
}

tree *
co_fold(tree *t)
{
	/* Fold any constant-valued expressions */
	if (optnofecf) return t;

	if (t) {
		register tree *n = t;

		if (t->down) {
			t->down = co_fold(t->down);
			t = t->down;
			while (t->right) {
				t->right = co_fold(t->right);
				if (t->right) t = t->right;
			}
		}
		t = n;

		switch (t->op) {
		case ADD:
		case AND:
		case DIV:
		case EQ:
		case GE:
		case GT:
		case LAND:
		case LE:
		case LOR:
		case LT:
		case MAX:
		case MIN:
		case MOD:
		case MUL:
		case NE:
		case OR:
		case SHL:
		case SHR:
		case SUB:
		case XOR:
		{
			/* Binary operation... */
			if (((t->down)->op == NUM) &&
			    (((t->down)->right)->op == NUM)) {
				if (((t->down)->type.attr & TYP_FLOAT) ||
				    (((t->down)->right)->type.attr & TYP_FLOAT)) {
					/* Float stuff */
					register float f;
					if (!((t->down)->type.attr & TYP_FLOAT)) {
						f = (t->down)->num;
						f = ffoldbinary(t->op, f,
						 iasf(((t->down)->right)->num));
					} else if (!(((t->down)->right)->type.attr & TYP_FLOAT)) {
						f = ((t->down)->right)->num;
						f = ffoldbinary(t->op, f,
						 iasf((t->down)->num));
					} else {
						/* convert to int?  ###  */
						f = ffoldbinary(t->op,
						 iasf((t->down)->num),
						 iasf(((t->down)->right)->num));
					}
					t->num = fasi(f);
					t->type = typfloat;
				} else {
					/* Fold simple ints */
					t->num = foldbinary(t->op,
						    (t->down)->num,
						    ((t->down)->right)->num);
				}
				t->op = NUM;
				t->down = 0;
			} else if (((t->down)->op == VNUM) &&
				   (((t->down)->right)->op == VNUM)) {
				/* Fold in vectors */
				register int i;

				t->num = numsp;
				for (i=0; i<(t->type.dim); ++i) {
					numbuf[numsp++] = foldbinary(t->op,
					 numbuf[i + (t->down)->num],
					 numbuf[i + ((t->down)->right)->num]);
				}
				t->op = VNUM;
				t->down = 0;
			}
			break;
		}
		case LNOT:
		{
			switch ((t->down)->op) {
			case NUM:
			{
				t->num = !((t->down)->num);
				t->op = NUM;
				t->down = 0;
				break;
			}
			case VNUM:
			{
				register int i;
				t->num = numsp;
				for (i=0; i<(t->type.dim); ++i) {
					numbuf[numsp++] =
					 !((t->down)->num);
				}
				t->op = VNUM;
				t->down = 0;
				break;
			}
			}
			break;
		}
		case NEG:
		{
			switch ((t->down)->op) {
			case NUM:
			{
				if ((t->down)->type.attr & TYP_FLOAT) {
					union { float f; int i; } tmp;

					tmp.i = (t->down)->num;
					tmp.f = -tmp.f;
					t->num = tmp.i;
				} else {
					t->num = -((t->down)->num);
				}
				t->op = NUM;
				t->down = 0;
				break;
			}
			case VNUM:
			{
				register int i;
				t->num = numsp;
				for (i=0; i<(t->type.dim); ++i) {
				    if ((t->down)->type.attr & TYP_FLOAT) {
					union { float f; int i; } tmp;

					tmp.i = (t->down)->num;
					tmp.f = -tmp.f;
					numbuf[numsp++] = tmp.i;
				    } else {
					numbuf[numsp++] = -((t->down)->num);
				    }
				}
				t->op = VNUM;
				t->down = 0;
				break;
			}
			}
			break;
		}
		case NOT:
		{
			switch ((t->down)->op) {
			case NUM:
			{
				t->num = ~((t->down)->num);
				t->op = NUM;
				t->down = 0;
				break;
			}
			case VNUM:
			{
				register int i;
				t->num = numsp;
				for (i=0; i<(t->type.dim); ++i) {
					numbuf[numsp++] =
					 ~((t->down)->num);
				}
				t->op = VNUM;
				t->down = 0;
				break;
			}
			}
			break;
		}
		case CAST:
		{
			switch ((t->down)->op) {
			case NUM:
			{
				if (t->type.dim > 1) {
					/* Need to widen number */
					register int i;
					t->num = numsp;
					for (i=0; i<(t->type.dim); ++i) {
						numbuf[numsp++] =
						 (t->down)->num;
					}
					t->op = VNUM;
				} else {
					t->num = (t->down)->num;
					t->op = NUM;
				}
				t->down = 0;
				break;
			}
			case VNUM:
			{
				t->num = (t->down)->num;
				t->op = VNUM;
				t->down = 0;
				break;
			}
			}
			break;
		}
		case ALL:
		case ANY:
		case REDUCEADD:
		case REDUCEAND:
		case REDUCEAVG:
		case REDUCEMAX:
		case REDUCEMIN:
		case REDUCEMUL:
		case REDUCEOR:
		case REDUCEXOR:
		{
			/* Reduction operation... */
			if ((t->down)->op == NUM) {
				/* Only one thing to reduce */
				if ((t->op == ANY) ||
				    (t->op == ALL)) {
					t->num = ((t->down)->num != 0);
				} else {
					t->num = (t->down)->num;
				}
				t->op = NUM;
				t->down = 0;
			} else if ((t->down)->op == VNUM) {
				/* Do the vector reduce */
				t->num = foldreduce(t->op,
						    (t->down)->num,
						    t->type.dim);
				t->op = NUM;
				t->down = 0;
			}
			break;
		}

		/* Even some control flow can fold away... */
		case IF:
		{
			if ((t->down)->op == NUM) {
				/* Constant condition */
				if ((t->down)->num == 0) {
					/* Is there an else? */
					if (((t->down)->right)->right) {
						/* Replace with else */
						(((t->down)->right)->
						 right)->right = t->right;
						t = ((t->down)->right)->
						 right;
					} else {
						/* Replace with ; */
						t = mk_leaf(SEMI,
						 ((sym *) 0),
						 0,
						 typnull);
					}
				} else {
					/* Remove condition and else */
					((t->down)->right)->right =
					 t->right;
					t = (t->down)->right;
				}
			}
			break;
		}

		/* Other stuff does not change... */
		}
	}

	return(t);
}

tree *
bvt_vector(tree *t)
{
	/* Performs BVT as per sect. 4.2 in Fisher&Dietz LCPC99 */

	register tree *n;		/* Temporary for walking tree */
	register int Case = 0;
	register tree *op1, *op2=NULL, *op3=NULL, *clone;

	/* Don't walk off tree */
	if (!t) return (t);

	/* Apply in depth-first order */
	n = t;
	if (n->down) {
		n->down = bvt_vector(n->down);
		n = n->down;
		while (n->right) {
			n->right = bvt_vector(n->right);
			if (n->right) n = n->right;
		}
	}


	/* Test the conditions that must be met in order to apply BVT */

	op1 = t;
	/* Condition 1a: op1 must be a binary op */
	if ((op1->down) && (op1->down->right) && (!op1->down->right->right)) {
		/* Condition 1b: op1 must have exactly one constant operand */
		info (9, "Testing condition 1b for BVT");
		if (co_isconst(op1->down)) {
			/* Case A or C */
			Case = 0x1;
			op2 = op1->down->right;
		}
		if (co_isconst(op1->down->right)) {
			/* Case B or D */
			Case |= 0x2;
			op2 = op1->down;
		}
		if ((!Case) || (Case==0x3)) {
			/* Either neither of op1's operands is constant,
			   or both are, so ... */
			return(t);
		}

		info (9, "Passed condition 1 for BVT");
		info (9, "Testing condition 2a for BVT");
		/* Condition 2a: op2 must be a binary op and have exactly one
		   constant operand */
		if ((op2->down) && (op2->down->right) &&
		    (!op2->down->right->right)) {
			if (co_isconst(op2->down)) {
				/* Case A or B */
				Case |= 0x4;
				op3 = op2->down->right;
			}
			if (co_isconst(op2->down->right)) {
				/* Case C or D */
				Case |= 0x8;
				op3 = op2->down;
			}
			if ( (!(Case&0xC)) || ((Case&0xC)==0xC) ) {
			/* Either neither of op2's operands is constant,
			   or both are, so ... */
				return(t);
			}
		}

		info (9, "Passed condition 2a for BVT");
		info (9, "Testing condition 2b for BVT");
		/* Condition 2b: op2 must be distributive over op3 */
		if ( !distributive(op2->op, op3->op) ) return(t);

		info (9, "Passed condition 2b for BVT");
		info (9, "Testing condition 3a for BVT");
		if ((op3->down) && (op3->down->right) &&
		    (!op3->down->right->right)) {
			/* Condition 3a: op3 must be a binary op and have at
			   least one constant operand */
			if (co_isconst(op3->down)) {
				/* Primed */
				Case |= 0x10;
				op3 = op2->down->right;
			}
			if (co_isconst(op2->down->right)) {
				/* Unprimed */
				Case |= 0x20;
				op3 = op2->down;
			}
			if ( !(Case&0x30) ) {
				/* Neither of op3's operands is constant,
				   so ... */
				return(t);
			}
		}

		info (9, "Passed condition 3a for BVT");
		info (9, "Testing condition 3b for BVT");
		/* Condition 3b: op3 must be associative with op1 */
		if ( !associative(op1->op, op3->op) ) return(t);


		/* Extra Conditions:
		   Depending on the case, some of the operations may need to
		   be commutative; otherwise, the optimization cannot be done
		   (or I don't know how to do it at this time), and we should
		   just return now.
		*/
		info (9, "Passed condition 3b for BVT");
		info (9, "Testing extra conditions for BVT");
		switch (Case) {
		case 0x25:
		case 0x35:
			/* Case A */
		case 0x16:
			/* Case B' */
		case 0x29:
		case 0x39:
			/* Case C */
		case 0x1A:
			/* Case D' */
			if ( (!commutative(op1->op)) &&
			     (!commutative(op3->op)) ) {
				return (t);
			}
			break;
		case 0x15:
			/* Case A' */
			break;
		case 0x26:
		case 0x36:
			/* Case B */
			break;
		case 0x19:
			/* Case C' */
			break;
		case 0x2A:
		case 0x3A:
			/* Case D */
			break;
		}


		/* If we are here, all the conditions have been met, so we
		   can perform the optimization... */

		/* I have this broken into the steps as per our LCPC99 paper,
		   but because we do so much restructuring, and step 1 looks
		   different for so many cases, we may as well combine steps
		   1 and 3 if step 2 isn't actually necessary.
		*/

		info (9, "Passed extra conditions for BVT");
		info (9, "Performing BVT step 1");
		/* Step 1: Distribute op2 over op3 */
		clone = mk_cloned(op2);
		switch (Case) {
		case 0x25:
		case 0x35:
			/* Case A */
			op2->right = clone;
			clone->down = mk_cloned(op2->down);
			op2->down->right = op3->down;
			clone->down->right = op3->down->right;
			op1->down->right = op3;
			op3->down->right = (tree *)0;
			op3->down = op2;
			break;
		case 0x15:
			/* Case A' */
			op2->right = clone;
			clone->down = mk_cloned(op2->down);
			op2->down->right = op3->down;
			clone->down->right = op3->down->right;
			op1->down->right = op3;
			op3->down->right = (tree *)0;
			op3->down = op2;
			break;
		case 0x26:
		case 0x36:
			/* Case B */
			clone->right = op2;
			clone->down = mk_cloned(op2->down);
			clone->down->right = op3->down;
			op2->down->right = op3->down->right;
			op3->right = op2->right;
			op3->down->right = (tree *)0;
			op3->down = clone;
			op2->right = (tree *)0;
			op1->down = op3;
			break;
		case 0x16:
			/* Case B' */
			clone->right = op2;
			clone->down = mk_cloned(op2->down);
			clone->down->right = op3->down;
			op2->down->right = op3->down->right;
			op3->right = op2->right;
			op3->down->right = (tree *)0;
			op3->down = clone;
			op2->right = (tree *)0;
			op1->down = op3;
			break;
		case 0x29:
		case 0x39:
			/* Case C */
			op2->right = clone;
			clone->down = op3->down->right;
			clone->down->right = mk_cloned(op3->right);
			op2->down = op3->down;
			op2->down->right = op3->right;
			op3->right = (tree *)0;
			op1->down->right = op3;
			op3->down = op2;
			break;
		case 0x19:
			/* Case C' */
			op2->right = clone;
			clone->down = op3->down->right;
			clone->down->right = mk_cloned(op3->right);
			op2->down = op3->down;
			op2->down->right = op3->right;
			op3->right = (tree *)0;
			op1->down->right = op3;
			op3->down = op2;
			break;
		case 0x2A:
		case 0x3A:
			/* Case D */
			clone->right = op2;
			clone->down = op3->down;
			op2->down = op3->down->right;
			clone->down->right = mk_cloned(op3->right);
			op2->down->right = op3->right;
			op3->right = op2->right;
			op2->right = (tree *)0;
			op3->down = clone;
			op1->down = op3;
			break;
		case 0x1A:
			/* Case D' */
			clone->right = op2;
			clone->down = op3->down;
			op2->down = op3->down->right;
			clone->down->right = mk_cloned(op3->right);
			op2->down->right = op3->right;
			op3->right = op2->right;
			op2->right = (tree *)0;
			op3->down = clone;
			op1->down = op3;
			break;
		}

		/* Step 2: Constant fold tree which replaces constant operand
			   of op2 in the original tree.  Unless we explicitly
			   check for constant values in later steps, this
			   step is unnecessary, and should be removed.
		*/
		info (9, "Performing BVT step 2");
		switch (Case) {
		case 0x25:
		case 0x35:
			/* Case A */
		case 0x16:
			/* Case B' */
		case 0x29:
		case 0x39:
			/* Case C */
		case 0x1A:
			/* Case D' */
			clone = co_fold(clone);
			break;
		case 0x15:
			/* Case A' */
		case 0x26:
		case 0x36:
			/* Case B */
		case 0x19:
			/* Case C' */
		case 0x2A:
		case 0x3A:
			/* Case D */
			op2 = co_fold(op2);
			break;
		}

		/* Step 3: Combine op1 and op3. */
		info (9, "Performing BVT step 3");
		switch (Case) {
		case 0x25:
		case 0x35:
			/* Case A */
		case 0x29:
		case 0x39:
			/* Case C */
			if ( commutative(op1->op) || commutative(op3->op) ) {
				op1->down->right = op2->right;
				op2->right = op1;
				op3->right = op1->right;
				op1->right = (tree *) 0;
				t = op3;
			} else {
				/* This is blocked above for now, but should
				   handle things like op1==op3=="<<" */
			}
			break;
		case 0x15:
			/* Case A' */
		case 0x19:
			/* Case C' */
			op1->down->right = op3->down;
			op1->down->right->right = (tree *) 0;
			op3->right = op1->right;
			op1->right = op2;
			op3->down = op1;
			t = op3;
			break;
		case 0x26:
		case 0x36:
			/* Case B */
		case 0x2A:
		case 0x3A:
			/* Case D */
			op1->down = op2->right;
			op1->down->right = op3->right;
			op3->right = op1->right;
			op1->right = (tree *) 0;
			op2->right = op1;
			t = op3;
			break;
		case 0x16:
			/* Case B' */
		case 0x1A:
			/* Case D' */
			if ( commutative(op1->op) || commutative(op3->op) ) {
				op3->down->right = op3->right;
				op3->right = op1->right;
				op1->down = op3->down;
				op1->right = op2;
				op3->down = op1;
				t = op3;
			} else {
				/* This is blocked above for now, but should
				   handle things like op1==op3=="<<" */
			}
			break;
		}

		/* Step 4: Constant fold the tree.  */
		info (9, "Performing BVT step 4");
		return ( t=co_fold(t) );

	} else {
		/* op1 doesn't have a binary structure, so ... */
		return(t);
	}
}

