/*	parser.c

	Parser for BitC programs
*/

#include <stdio.h>
#include <ctype.h>

#include "main.h"
#include "ite.h"
#include "val.h"
#include "fsm.h"
#include "mesg.h"
#include "parser.h"

typedef struct _sym {
	char	*text;
	val	var;
} sym;

static	int	nextt;			/* next token */
static	int	nextc = ' ';		/* next char */
static	char	lexbuf[513];		/* lexical buffer */
static	int	lexnum;			/* lexical number value */
static	sym	symtab[513];		/* symbol table */
static	sym	*symptr = &(symtab[0]);	/* next free symbol */

static	val	type(void);

static sym *
enter(s)
register char *s;
{
	extern char *malloc();
	register char *p = malloc(strlen(s) + 1);
	strcpy(p, s);
	symptr->text = p;
	return(symptr++);
}

static sym *
lookup(s)
register char *s;
{
	register sym *p = (symptr - 1);

	while (p >= &(symtab[0])) {
		if (!strcmp(p->text, s)) {
			return(p);
		}
		--p;
	}
	return(NULL);
}

static	int nextlineno = 1;

static int
advance(void)
{
	if (nextc != EOF) nextc = fgetc(infile);
	if (nextc == '\n') ++nextlineno;
	return(nextc);
}

static int
lexhelp(void)
{
	/* Eat leading spaces */
	while (isspace(nextc)) advance();

	/* Try to get line numbers exact */
	lineno = nextlineno;

	/* recognize a word/keyword */
	if (isalpha(nextc) || (nextc == '_')) {
		register char *p = &(lexbuf[0]);

		do {
			*(p++) = nextc; advance();
		} while (isalnum(nextc) || (nextc == '_'));
		*p = '\000';

		if (!strcmp(&(lexbuf[0]), "int")) return(INT);
		if (!strcmp(&(lexbuf[0]), "unsigned")) return(UNSIGNED);
		if (!strcmp(&(lexbuf[0]), "if")) return(IF);
		if (!strcmp(&(lexbuf[0]), "else")) return(ELSE);
		if (!strcmp(&(lexbuf[0]), "while")) return(WHILE);
		if (!strcmp(&(lexbuf[0]), "do")) return(DO);
		if (!strcmp(&(lexbuf[0]), "for")) return(FOR);
		if (!strcmp(&(lexbuf[0]), "goto")) return(GOTO);
		if (!strcmp(&(lexbuf[0]), "break")) return(BREAK);
		if (!strcmp(&(lexbuf[0]), "continue")) return(CONTINUE);
		if (!strcmp(&(lexbuf[0]), "typeof")) return(TYPEOF);
		if (!strcmp(&(lexbuf[0]), "sizeof")) return(SIZEOF);
		return(WORD);
	}

	/* recognize a number */
	if (isdigit(nextc)) {
		lexnum = nextc - '0'; advance();
		while (isdigit(nextc)) {
			lexnum = (lexnum * 10) + nextc - '0'; advance();
		}
		return(NUM);
	}

	/* must be something else...  */
	switch (nextt = nextc) {
	case EOF:
		return(MYEOF);
	case '{':	case '}':	case ';':
	case ':':	case '^':
	case '*':	case '/':
	case '$':	case '~':	case '(':
	case ')':	case '@':	case ',':
	case '[':	case ']':
		advance();
		break;
	case '%':
		advance();
		switch (nextc) {
		case '*':
			advance();
			return(FIXMUL);
		case '/':
			advance();
			return(FIXDIV);
		case '%':
			advance();
			return(FIXMOD);
		}
		break;
	case '?':
		advance();
		switch (nextc) {
		case '<':
			advance();
			return(MIN);
		case '>':
			advance();
			return(MAX);
		}
		break;
	case '|':
		advance();
		if (nextc == '|') {
			advance();
			return(OROR);
		}
		break;
	case '&':
		advance();
		if (nextc == '&') {
			advance();
			return(ANDAND);
		}
		break;
	case '=':
		advance();
		if (nextc == '=') {
			advance();
			return(EQ);
		}
		break;
	case '!':
		advance();
		if (nextc == '=') {
			advance();
			return(NE);
		}
		break;
	case '.':
		advance();
		if (nextc == '.') {
			advance();
			return(DOTDOT);
		}
		break;
	case '+':
		advance();
		if (nextc == '+') {
			advance();
			return(PLUSPLUS);
		}
		break;
	case '-':
		advance();
		if (nextc == '-') {
			advance();
			return(MINUSMINUS);
		}
		break;
	case '<':
		advance();
		switch (nextc) {
		case '=':
			advance();
			return(LE);
		case '<':
			advance();
			return(SHL);
		}
		break;
	case '>':
		advance();
		switch (nextc) {
		case '=':
			advance();
			return(GE);
		case '>':
			advance();
			return(SHR);
		}
		break;
	default:
		sprintf(mesg,
			"illegal character 0x%02x ignored",
			nextc);
		warn(mesg);
		advance();
		return(lexhelp());
	}

	return(nextt);
}

static inline int
lex(void)
{
	nextt = lexhelp();
	return(nextt);
}


static val expr(void);

static val
e12(void)
{
	register sym *p;
	val l;
	register int i, j, k;

	switch (nextt) {
	case '(':
		lex();
		switch (nextt) {
		case TYPEOF:
		case UNSIGNED:
		case INT:
		case ':':
			/* Type cast */
			{
				val t = type();
				if (nextt != ')') {
					warn("type cast missing ) assumed");
				} else {
					lex();
				}
				l = expr();
				l = extend(l, t.bits);
				l.sign = t.sign;
				break;
			}
		}

		/* A normal parenthesized expression */
		l = expr();
		if (nextt != ')') {
			warn("missing ) assumed");
		} else {
			lex();
		}
		break;
	case '-':
		lex();
		l = (unop('-', e12()));
		break;
	case '!':
		lex();
		l = (unop('!', e12()));
		break;
	case '~':
		lex();
		l = (unop('~', e12()));
		break;
	case '+':
		/* Extension: allow unary to mean cast unsigned */
		lex();
		l = e12();
		l.sign = 0;
		break;
	case SIZEOF:
		/* Variance: sizeof counts bits in a variable */
		lex();
		p = lookup(&(lexbuf[0]));
		lex();
		l = konst(p->var.bits, 0);
		break;
	case WORD:
		p = lookup(&(lexbuf[0]));
		lex();
		l = (load(p->var));
		break;
	case NUM:
		l = konst(lexnum, 0);
		lex();
		break;
	default:
		error("ill-formed expression");
	}

	/* Suffix stuff */
	switch (nextt) {
	case '[':
		/* Extraction of a bit field */
		lex();
		k = (j = konstint(expr()));
		if (nextt == DOTDOT) {
			lex();
			k = konstint(expr());
		}
		if ((j < 0) || (k < j)) {
			error("invalid bitfield extraction specified");
		}
		l = extend(l, (k + 1));
		for (i=0; i<(1+k-j); ++i) {
			l.itev[i] = l.itev[i + j];
		}
		l.bits = (1 + k - j);
		l.sign = 0;
		if (nextt == ']') {
			lex();
		} else {
			warn("missing ']' assumed");
		}
	}

	return(l);
}

static val
e11(void)
{
	val l = e12();

	for (;;) {
		register int op;
		val r;

		switch (nextt) {
		case '*':
		case '/':
		case '%':
			op = nextt;
			lex();
			r = e12();
			if ((l.bits > 12) || (r.bits > 12)) {
				sprintf(mesg,
					"%d-bit %c may be too complex; consider using %%%c",
					((l.bits > r.bits) ? l.bits : r.bits),
					op,
					op);
				warn(mesg);
			}
			l = binop(op, l, r);
			break;
		case FIXMUL:
		case FIXDIV:
		case FIXMOD:
			op = nextt;
			lex();
			r = e12();
			if ((l.bits > 13) || (r.bits > 13)) {
				sprintf(mesg,
					"%d-bit multiplicative operation may be too complex",
					((l.bits > r.bits) ? l.bits : r.bits));
				warn(mesg);
			}
			l = binop(op, l, r);
			break;
		default:
			return(l);
		}
	}
}

static val
e10(void)
{
	val l = e11();

	for (;;) {
		register int op;
		val r;

		switch (nextt) {
		case '+':
		case '-':
			op = nextt;
			lex();
			r = e11();
			if ((l.bits > 12) || (r.bits > 12)) {
				sprintf(mesg,
					"%d-bit %c may be too complex",
					((l.bits > r.bits) ? l.bits : r.bits),
					op,
					op);
				warn(mesg);
			}
			l = binop(op, l, r);
			break;
		default:
			return(l);
		}
	}
}

static val
e9(void)
{
	val l = e10();

	for (;;) {
		register int op;

		switch (nextt) {
		case SHR:
		case SHL:
			op = nextt;
			lex();
			l = binop(op, l, e10());
			break;
		default:
			return(l);
		}
	}
}

static val
e8(void)
{
	val l = e9();

	for (;;) {
		register int op;

		switch (nextt) {
		case '<':
		case '>':
		case LE:
		case GE:
		case MIN:
		case MAX:
			op = nextt;
			lex();
			l = binop(op, l, e9());
			break;
		default:
			return(l);
		}
	}
}

static val
e7(void)
{
	val l = e8();

	for (;;) {
		register int op;

		switch (nextt) {
		case EQ:
		case NE:
			op = nextt;
			lex();
			l = binop(op, l, e8());
			break;
		default:
			return(l);
		}
	}
}

static val
e6(void)
{
	val l = e7();

	while (nextt == '&') {
		lex();
		l = binop('&', l, e6());
	}
	return(l);
}

static val
e5(void)
{
	val l = e6();

	while (nextt == '^') {
		lex();
		l = binop('^', l, e6());
	}
	return(l);
}

static val
e4(void)
{
	val l = e5();

	while (nextt == '|') {
		lex();
		l = binop('|', l, e5());
	}
	return(l);
}

static val
e3(void)
{
	val l = e4();

	while (nextt == ANDAND) {
		lex();
		l = binop(ANDAND, l, e4());
	}
	return(l);
}

static val
e2(void)
{
	val l = e3();

	while (nextt == OROR) {
		lex();
		l = binop(OROR, l, e3());
	}
	return(l);
}

static val
e1(void)
{
	val l = e2();

	if (nextt == '?') {
		val m;

		lex();
		m = e2();
		if (nextt == ':') {
			lex();
		} else {
			warn("trinary missing : assumed");
		}
		return(trinary(l, m, e2()));
	}

	return(l);
}

static val
expr(void)
{
	return(e1());
}

static val
type(void)
{
	/* Parse type; return default if none found */
	register int i;
	val t;

	t.bits = DEFAULTBITS;
	t.sign = 1;

	/* Parse type specifier */
	if (nextt == TYPEOF) {
		lex();
		if (nextt == '(') lex();
		t = expr();
		if (nextt == ')') lex();
		return(t);
	}

	if (nextt == UNSIGNED) {
		lex();
		t.sign = 0;
	}
	if (nextt == INT) {
		lex();
	}
	if (nextt == ':') {
		lex();
		t.bits = konstint(expr());
		if ((t.bits < 1) ||
		    (t.bits > MAXBITS)) {
			error("invalid bit precision");
		}
	}

	if (t.bits > 12) {
		sprintf(mesg,
			"complicated arithmetic on %d-bit values should be avoided",
			t.bits);
		warn(mesg);
	}

	return(t);
}

static void
decls(void)
{
	/* Process declarations */
	while ((nextt == TYPEOF) ||
	       (nextt == UNSIGNED) ||
	       (nextt == INT) ||
	       (nextt == ':')) {
		val t = type();
		register sym *s;
		register int i;

another:
		if (nextt != WORD) {
			error("declaration missing identifier");
		}
		s = enter(&(lexbuf[0]));
		lex();

		/* Allocate PE memory bits */
		if (nextt == '@') {
			/* Explicit allocation of I/O registers */
			register int num;

			lex();
			num = konstint(expr());
			for (i=0; i<t.bits; ++i) {
				if ((num <= ITE0) || (num >= pemem)) {
					sprintf(mesg,
						"there is no I/O register %d",
						num);
					error(mesg);
				}
				if (ites[num].op == ITEVAR) {
					sprintf(mesg,
						"register %d was previously allocated as a regular register",
						num);
					error(mesg);
				}
				t.itev[i] = num;
				ites[num].op = ITEIO;
				ites[num].a = ITEX;	/* Value unknown */
				ites[num].b = ITEX;	/* No previous store */
				++num;
			}
		} else {
			/* Normal (implicit) allocation of registers */
			--bitp;
			for (i=0; i<t.bits; ++i) {
				/* Skip explicitly allocated I/O registers */
				do {
					++bitp;
					if (bitp >= pemem) {
						sprintf(mesg,
							"not enough registers for variable %s",
							s->text);
						error(mesg);
					}
				} while (ites[bitp].op == ITEIO);

				t.itev[i] = bitp;
				ites[bitp].op = ITEVAR;
				ites[bitp].a = ITEX;	/* Value unknown */
				ites[bitp].b = ITEX;	/* No previous store */
			}
			++bitp;
		}

		s->var = t;

		if (nextt == ',') {
			lex();
			goto another;
		}

		if (nextt == ';') {
			lex();
		} else {
			warn("declaration missing ; assumed");
		}
	}
}

static int
stat(void)
{
	val t;

	switch (nextt) {
	case '{':
		lex();
		{
			register sym *symsave = symptr;
			register int savebitp = bitp;

			decls();

			while (nextt != '}') {
				stat();
			}
			lex();	/* gobble the '}' */

			if (symptr != symsave) {
				while (symptr > symsave) {
					--symptr;
					killstores(symptr->var);
					free((char *) symptr->text);
				}
			}
			bitp = savebitp;
		}
		break;
	case IF:
		{
			register int *thenplace, *elseplace,
				     *thenendplace, *elseendplace;

			lex();
			t = expr();
			t = reduceor(t);
			thenplace = addgo(t.itev[0]);
			elseplace = addgo(itenot(t.itev[0]));
			newblock();
			gohere(thenplace);
			stat();
			thenendplace = addgo(ITE1);
			newblock();
			if (nextt == ELSE) {
				lex();
				gohere(elseplace);
				stat();
				elseendplace = addgo(ITE1);
				newblock();
				gohere(elseendplace);
			} else{
				gohere(elseplace);
			}
			gohere(thenendplace);
		}
		break;
	case WHILE:
		{
			register int *preplace, *whileendplace,
				     *whilebodyplace, whiletop;

			lex();
			preplace = addgo(ITE1);
			newblock();
			gohere(preplace);
			whiletop = statep;
			t = expr();
			t = reduceor(t);
			whileendplace = addgo(itenot(t.itev[0]));
			whilebodyplace = addgo(t.itev[0]);
			newblock();
			gohere(whilebodyplace);
			stat();
			gothere(addgo(ITE1), whiletop);
			newblock();
			gohere(whileendplace);
		}
		break;
	case DO:
		{
			register int dotop, *preplace,
				     *doendplace;

			lex();
			preplace = addgo(ITE1);
			newblock();
			dotop = statep;
			gohere(preplace);
			stat();
			if (nextt == WHILE) {
				lex();
			} else {
				error("do loop missing while clause");
			}
			t = expr();
			t = reduceor(t);
			gothere(addgo(t.itev[0]), dotop);
			doendplace = addgo(itenot(t.itev[0]));
			newblock();
			gohere(doendplace);
			if (nextt == ';') {
				lex();
			} else {
				warn("do while missing ; assumed");
			}
		}
		break;
	case WORD:
		{
			register sym *p = lookup(&(lexbuf[0]));

			lex();
			if (nextt != '=') {
				warn("assignment missing = assumed");
			} else {
				lex();
			}

			t = expr();
			store(p->var, t);
		}
		if (nextt != ';') {
			warn("assignment missing ; assumed");
		} else {
			lex();
		}
		break;
	case ';':
		lex();
		break;
	default:
		return(0);
	}
	return(1);
}

void
prog(void)
{
	register int i;

	mkstate();
	ites[0].op = ITE0;
	ites[0].a = ITE0;
	ites[0].b = ITEX;
	ites[1].op = ITE1;
	ites[1].a = ITE1;
	ites[1].b = ITEX;
	for (i=bitp; i<pemem; ++i) {
		ites[i].op = ITENULL;
		ites[i].a = ITEX;
		ites[i].b = ITEX;
	}
	nextt = lex();

	decls();
	while (stat()) ;
	if (nextt != MYEOF) {
		error("parser stuck before EOF read");
	}

	newblock();
	codegen();
	dumpcode(0);
}
