/*	val.c

	Word-level value support
*/

#include "ite.h"
#include "val.h"
#include "mesg.h"

val
reduceor(val t)
{
	register int i, j, k, l;

	/* First, sort ites into increasing order */
	for (i=0; i<t.bits; ++i) {
		k = t.itev[i];
		for (j=i+1; j<t.bits; ++j) {
			if (POSITE(k) > POSITE(l = t.itev[j])) {
				t.itev[i] = l;
				t.itev[j] = k;
				k = l;
			}
		}
	}

	/* Now do the ite reduction */
	i = t.itev[0];
	for (j=1; j<t.bits; ++j) {
		/* Don't or-in the same thing twice */
		if (t.itev[j-1] != t.itev[j]) {
			i = iteor(i, t.itev[j]);
		}
	}

	/* Return the appropriate thing */
	t.itev[0]= i;
	t.bits = 1;
	t.sign = 0;
	return(t);
}

val
extend(val a,
register int bits)
{
	/* Extend a to bits length;
	   truncate if already bigger than that
	*/

	/* Clip bits precision, if needed */
	if (bits > MAXBITS) bits = MAXBITS;

	/* Signed/unsigned extend */
	if (a.bits < bits) {
		register int i;

		for (i=a.bits; i<bits; ++i) {
			a.itev[i] = (a.sign ? a.itev[i-1] : ITE0);
		}
	}
	a.bits = bits;

	return(a);
}

val
promote(val a,
val b)
{
	/* Promote a to covering type for a, b */
	register int bits = a.bits;

	/* Make bits the covering size */
	if (bits < b.bits) bits = b.bits;

	/* If the signs differ and the signed one isn't bigger
	   than the unsigned one, add a bit and make both signed
	*/
	if ((a.sign && (!b.sign)) && (a.bits <= b.bits)) ++bits;
	if ((b.sign && (!a.sign)) && (b.bits <= a.bits)) ++bits;

	/* Extend a's precision */
	a = extend(a, bits);

	/* Correct sign */
	a.sign = (a.sign || b.sign);

	return(a);
}

val
load(val var)
{
	register int i;

	/* Value propogate... */
	for (i=0; i<var.bits; ++i) {
		if (ites[var.itev[i]].a != ITEX) {
			var.itev[i] = ites[var.itev[i]].a;
		}
	}
	return(var);
}

void
store(val var,
val t)
{
	register int i;

	/* Extend value to size of store target */
	t = extend(t, var.bits);

	/* Store the bits */
	for (i=0; i<var.bits; ++i) {
		mkst(var.itev[i], t.itev[i]);
	}
}

void
killstores(val var)
{
	register int i;

	/* Remove info aboutprevious value stored */
	for (i=0; i<var.bits; ++i) {
		mkkill(var.itev[i]);
	}
}

void
minsize(register val *a)
{
	/* Minimize size of a without altering its value */
	register int i;

	if (a->sign) {
		for (i=a->bits-1; i>0; --i) {
			if (a->itev[i] == a->itev[i-1]) {
				--(a->bits);
			} else {
				return;
			}
		}
	} else {
		for (i=a->bits-1; i>0; --i) {
			if (a->itev[i] == ITE0) {
				--(a->bits);
			} else {
				return;
			}
		}
	}
}

val
konst(register int v,
register int sign)
{
	register int vsave = v;
	val t;

	t.sign = sign;
	t.bits = 0;
	do {
		t.itev[t.bits] = ((v & 1) ? ITE1 : ITE0);
		++(t.bits);
		v >>= 1;
	} while (v && (t.bits < MAXBITS));

	if (sign && (vsave >= 0) && (t.bits < MAXBITS)) {
		t.itev[t.bits] = 0;
		++(t.bits);
	}

	minsize(&t);
	return(t);
}

int
konstint(val v)
{
	/* Return int value of a const */
	register int r = 0;
	register int i;
	register int warned = 0;

	for (i=0; i<(8*sizeof(int)); ++i) {
		if (i >= v.bits) {
			if (v.sign && (v.itev[v.bits-1] != ITE0)) {
				r |= (1 << i);
			}
		} else {
			if (v.itev[i] == ITE1) {
				r |= (1 << i);
			} else if (v.itev[i] != ITE0) {
				warned = 1;
			}
		}
	}
	if (warned) {
		warn("value not a constant");
	}
	return(r);
}

val
trinary(val x,
val y,
val z)
{
	/* Implement (reduceor(x) ? y : z) */
	register int i, j;

	/* Test value */
	x = reduceor(x);
	i = x.itev[0];

	/* Constant cases */
	if (i == ITE0) return(z);
	if (i == ITE1) return(y);

	/* Promote values */
	y = promote(y, z);
	z = promote(z, y);

	/* Fill-in the result in y */
	for (j=0; j<y.bits; ++j) {
		y.itev[j] = mkite(i, y.itev[j], z.itev[j]);
	}

	/* Make result as few bits as possible */
	minsize(&y);

	return(y);
}


val
binop(int op,
val a,
val b)
{
	register int i, j, c, gt, eq, sign;
	val v, one;

	switch (op) {
	case OROR:
		a = reduceor(a);
		b = reduceor(b);
		a.itev[0] = iteor(a.itev[0], b.itev[0]);
		return(a);
	case ANDAND:
		a = reduceor(a);
		b = reduceor(b);
		a.itev[0] = iteand(a.itev[0], b.itev[0]);
		return(a);
	case '|':
		a = promote(a, b);
		b = promote(b, a);
		for (i=0; i<a.bits; ++i) {
			a.itev[i] = iteor(a.itev[i], b.itev[i]);
		}
		minsize(&a);
		return(a);
	case '^':
		a = promote(a, b);
		b = promote(b, a);
		for (i=0; i<a.bits; ++i) {
			a.itev[i] = itexor(a.itev[i], b.itev[i]);
		}
		minsize(&a);
		return(a);
	case '&':
		a = promote(a, b);
		b = promote(b, a);
		for (i=0; i<a.bits; ++i) {
			a.itev[i] = iteand(a.itev[i], b.itev[i]);
		}
		minsize(&a);
		return(a);
	case EQ:
		a = reduceor(binop('-', a, b));
		a.itev[0] = itenot(a.itev[0]);
		return(a);
	case NE:
		return(reduceor(binop('-', a, b)));
	case '>':
		a = promote(a, b);
		b = promote(b, a);

		if (a.sign) {
			/* Signed > */
			gt = iteand(itenot(a.itev[a.bits-1]), b.itev[b.bits-1]);
			eq = itenot(itexor(a.itev[a.bits-1], b.itev[b.bits-1]));
			sign = iteand(a.itev[a.bits-1], b.itev[b.bits-1]);
		} else {
			/* Unsigned > */
			gt = ITE0;
			eq = ITE1;
		}

		for (i=a.bits-1; i>=0; --i) {
			gt = iteor(gt, iteand(eq, iteand(a.itev[i], itenot(b.itev[i]))));
			eq = iteand(eq, itenot(itexor(a.itev[i], b.itev[i])));
		}

		v.itev[0] = gt;
		v.bits = 1;
		v.sign = 0;
		return(v);
	case '<':
		return(binop('>', b, a));
	case LE:
		a = binop(OROR, binop('<', a, b), binop(EQ, a, b));
		return(a);
	case GE:
		a = binop(OROR, binop('>', a, b), binop(EQ, a, b));
		return(a);
	case MIN:
		a = promote(a, b);
		b = promote(b, a);
		v = binop('<', a, b);
		c = v.itev[0];
		for (i=0; i<a.bits; ++i) {
			a.itev[i] = mkite(c, a.itev[i], b.itev[i]);
		}
		return(a);
	case MAX:
		a = promote(a, b);
		b = promote(b, a);
		v = binop('>', a, b);
		c = v.itev[0];
		for (i=0; i<a.bits; ++i) {
			a.itev[i] = mkite(c, a.itev[i], b.itev[i]);
		}
		return(a);
	case SHR:
		/* The following treats shifts as signed/unsigned,
		   but treats the shift count as always unsigned;
		   i.e., (a >> -5) is not (a << 5)
		*/
		v = a;
		for (i=(b.bits-1); i>=0; --i) {
			for (j=0; j<a.bits; ++j) {
				register int x;

				if ((j + (1 << i)) >= a.bits) {
					x = (a.sign ? a.itev[a.bits-1] : ITE0);
				} else {
					x = v.itev[j + (1 << i)];
				}
				v.itev[j] = mkite(b.itev[i],
						  x,
						  v.itev[j]);
			}
		}
		minsize(&v);
		return(v);
	case SHL:
		v = extend(a, (a.bits + (1 << b.bits) - 1));
		for (i=0; i<b.bits; ++i) {
			for (j=v.bits-1; j>=0; --j) {
				v.itev[j] = mkite(b.itev[i],
						  ((j >= (1 << i)) ?
						   v.itev[j - (1 << i)] :
						   ITE0),
						  v.itev[j]);
			}
		}
		minsize(&v);
		return(v);
	case '+':
		a = promote(a, b);
		a = extend(a, (a.bits + 1));
		b = promote(b, a);

		/* define CARRYSELECT 4 to eneable that method for
		   4-bit chunks...  but simple ripple seems faster
		*/
#ifdef	CARRYSELECT
		if (a.bits > (CARRYSELECT+1)) {
			val la = a, lb = b, hb = b, h0 = a, h1;

			/* Unsigned add CARRYSELECT lowest bits */
			lb.bits = (la.bits = CARRYSELECT);
			lb.sign = (la.sign = 0);
			la = binop('+', la, lb);

			/* Create high bits values */
			for (i=CARRYSELECT; i<a.bits; ++i) {
				h0.itev[i-CARRYSELECT] = h0.itev[i];
				hb.itev[i-CARRYSELECT] = hb.itev[i];
			}
			hb.bits = (h0.bits = (a.bits - CARRYSELECT));
			hb.sign = (h0.sign = 0);
			h1 = h0;

			/* Add high bits both ways */
			c = ITE0;
			for (i=0; i<h0.bits; ++i) {
				register int x = itexor(h0.itev[i], hb.itev[i]);
				register int n = iteand(h0.itev[i], hb.itev[i]);
				h0.itev[i] = itexor(x, c);
				c = iteor(n, iteand(x, c));
			}
			c = ITE1;
			for (i=0; i<h1.bits; ++i) {
				register int x = itexor(h1.itev[i], hb.itev[i]);
				register int n = iteand(h1.itev[i], hb.itev[i]);
				h1.itev[i] = itexor(x, c);
				c = iteor(n, iteand(x, c));
			}

			/* Merge the two high halves */
			lb.bits = 1;
			lb.sign = 0;
			lb.itev[0] = la.itev[CARRYSELECT];
			h0 = trinary(lb, h1, h0);

			/* Reassemble the result */
			for (i=0; i<CARRYSELECT; ++i) {
				a.itev[i] = la.itev[i];
			}
			for (i=CARRYSELECT; i<a.bits; ++i) {
				a.itev[i] = h0.itev[i-CARRYSELECT];
			}

			minsize(&a);
			return(a);
		}
#endif

		c = ITE0;
		for (i=0; i<a.bits; ++i) {
			register int x = itexor(a.itev[i], b.itev[i]);
			register int n = iteand(a.itev[i], b.itev[i]);
			a.itev[i] = itexor(x, c);
			c = iteor(n, iteand(x, c));
		}
		minsize(&a);
		return(a);
	case '-':
		a = promote(a, b);
		a = extend(a, (a.bits + 1));
		b = promote(b, a);
		c = ITE1;
		for (i=0; i<a.bits; ++i) {
			register int n = itenot(b.itev[i]);
			register int x = itexor(a.itev[i], n);
			n = iteand(a.itev[i], n);
			a.itev[i] = itexor(x, c);
			c = iteor(n, iteand(x, c));
		}
		a.sign = 1;
		minsize(&a);
		return(a);
	case FIXMUL:
		/* Make b the smaller value (if possible) */
		if (b.bits > a.bits) {
			v = b;
			b = a;
			a = v;
		}

		v.itev[0] = ITE0;
		v.bits = 1;
		v.sign = 0;
		one.itev[0] = ITE1;
		one.bits = 1;
		one.sign = 0;
		for (i=0; i<b.bits; ++i) {
			val t = a;

			for (j=0; j<a.bits; ++j) {
				t.itev[j] = mkite(b.itev[i],
						  t.itev[j],
						  ITE0);
			}
			v = extend(binop('+', v, t), a.bits);
			a = extend(binop(SHL, a, one), a.bits);
		}
		minsize(&v);
		return(v);
	case '*':
		/* Make b the smaller value (if possible) */
		if (b.bits > a.bits) {
			v = b;
			b = a;
			a = v;
		}

		v.itev[0] = ITE0;
		v.bits = 1;
		v.sign = 0;
		one.itev[0] = ITE1;
		one.bits = 1;
		one.sign = 0;
		for (i=0; i<b.bits; ++i) {
			val t = a;

			for (j=0; j<a.bits; ++j) {
				t.itev[j] = mkite(b.itev[i],
						  t.itev[j],
						  ITE0);
			}
			v = binop('+', v, t);
			a = binop(SHL, a, one);
		}
		minsize(&v);
		return(v);
	case FIXDIV:
	case '/':
		/* By enumeration... */
		a = promote(a, b);
		b = promote(b, a);
		v = konst(0, 0);
		for (i=0; i<(1<<a.bits); ++i) {
			register int si = i;
			val it = konst(i, 0);
			val at = a;
			if (a.sign && (si & (1<<(a.bits-1)))) {
				si |= ~((1<<a.bits)-1);
			}
			at.sign = 0;
			at = binop(EQ, at, it);
			if (at.itev[0] != 0) {
				for (j=1; j<(1<<b.bits); ++j) {
					val jt = konst(j, 0);
					val bt = b;
					bt.sign = 0;
					bt = binop(EQ, bt, jt);
					bt = binop(ANDAND, at, bt);
					if (bt.itev[0] != 0) {
						register int sj = j;

						if (b.sign && (sj & (1<<(b.bits-1)))) {
							sj |= ~((1<<b.bits)-1);
						}
						c = si / sj;
						if (c != 0) {
							v = trinary(bt, binop('|', v, konst(c, 0)), v);
						}
					}
				}
			}
		}
		v = extend(v, a.bits);
		v.sign = (a.sign || b.sign);
		return(v);
	case FIXMOD:
	case '%':
		/* By enumeration... */
		a = promote(a, b);
		b = promote(b, a);
		v = konst(0, 0);
		for (i=0; i<(1<<a.bits); ++i) {
			register int si = i;
			val it = konst(i, 0);
			val at = a;
			if (a.sign && (si & (1<<(a.bits-1)))) {
				si |= ~((1<<a.bits)-1);
			}
			at.sign = 0;
			at = binop(EQ, at, it);
			if (at.itev[0] != 0) {
				for (j=1; j<(1<<b.bits); ++j) {
					val jt = konst(j, 0);
					val bt = b;
					bt.sign = 0;
					bt = binop(EQ, bt, jt);
					bt = binop(ANDAND, at, bt);
					if (bt.itev[0] != 0) {
						register int sj = j;

						if (b.sign && (sj & (1<<(b.bits-1)))) {
							sj |= ~((1<<b.bits)-1);
						}
						c = si % sj;
						if (c != 0) {
							v = trinary(bt, binop('|', v, konst(c, 0)), v);
						}
					}
				}
			}
		}
		v = extend(v, a.bits);
		v.sign = (a.sign || b.sign);
		return(v);
	}

	sprintf(mesg,
		"Unimplemented binary operation %c\n",
		op);
	error(mesg);
}

val
unop(register int op,
val a)
{
	register int i, j, k;
	val b, v;

	switch (op){
	case '-':
		v.itev[0] = ITE0;
		v.bits = 1;
		v.sign = 0;
		return(binop('-', v, a));
	case '!':
		v = reduceor(a);
		v.itev[0] = itenot(v.itev[0]);
		return(v);
	case '~':
		for (i=0; i<a.bits; ++i) {
			a.itev[i] = itenot(a.itev[i]);
		}
		return(a);
	case '$':
		/* Population count...
		   start by sorting bits so constants come first;
		   this allows the sum of constant bits to be
		   completely folded out, generating no tuples
		*/
		for (i=0; i<a.bits; ++i) {
			for (j=i+1; j<a.bits; ++j) {
				if (POSITE(a.itev[i]) > POSITE(a.itev[j])) {
					k = a.itev[i];
					a.itev[i] = a.itev[j];
					a.itev[j] = k;
				}
			}
		}
		b.bits = 1;
		b.sign = 0;
		v.itev[0] = a.itev[0];
		v.bits = 1;
		v.sign = 0;
		for (i=1; i<a.bits; ++i) {
			b.itev[0] = a.itev[i];
			v = binop('+', v, b);
		}
		return(v);
	}

	sprintf(mesg,
		"Unimplemented unary operation %c\n",
		op);
	error(mesg);
}
