/*	pbpdemo.cpp

	This is a trivial demo of the reference C++ library
	implementation of the new Parallel Bit Pattern
	computation model.
*/

// This program generates a lot of output to stderr
// if REWAYS has a large value; 8 is sufficient for
// all the demos to run correctly
#define	REWAYS	8

#include "pbp.hpp"

void
pbitripple()
{
	// 4-bit wide pbitripple-carry adder
	// per Cuccaro et al
	// arXiv:quant-ph/0410184v1
	pbit a0(0), a1(0), a2(0), a3(0);
	pbit b0(1), b1(0), b2(0), b3(0);
	pbit z(0), x(0);
	H(a0, 0);
	H(a1, 1);
	H(a2, 2);
	H(a3, 3);
	CNOT(a1,b1); CNOT(a2,b2);
	CNOT(a3,b3); CNOT(a1,x);
	CCNOT(a0,b0,x); CNOT(a2,a1);
	CCNOT(x,b1,a1); CNOT(a3,a2);
	CCNOT(a1,b2,a2); CNOT(a3,z);
	CCNOT(a2,b3,z); NOT(b1);
	NOT(b2); CNOT(x,b1);
	CNOT(a1,b2); CNOT(a2,b3);
	CCNOT(a1,b2,a2);
	CCNOT(x,b1,a1);
	CNOT(a3,a2); NOT(b2);
	CCNOT(a0,b0,x); CNOT(a2,a1);
	NOT(b1); CNOT(a1,x);
	CNOT(a0,b0); CNOT(a1,b1);
	CNOT(a2,b2); CNOT(a3,b3);
	SETMEAS();
	printf("a=%d b=%d\n",
	       MEAS(a0)+(MEAS(a1)<<1)+(MEAS(a2)<<2)+(MEAS(a3)<<3),
	       MEAS(b0)+(MEAS(b1)<<1)+(MEAS(b2)<<2)+(MEAS(b3)<<3));
}

void
pintsqrt(int val)
{
	// Compute square root of val
	pint a(val); // 8-bit number
	pint b = pint(0).Had(4); // 4-bit possible square roots
	pint c = (b * b); // square them
	pint d = (c == a); // which were 169?
	int pos = d.First(); // first non-0 is answer
	printf("Square root of %d is %d\n", val, pos);
}

void
pintfactor(int val)
{
	// Factor val
	pint a(val); // 8-bit number
	pint b = pint(0).Had(4); // 4-bit possible 1st factor
	pint c = pint(0).Had(4,4); // 4-bit possible 2nd factor
	pint d = b * c; // multiply 'em
	pint e = (d == a); // which were 143?
	int spot = e.First(); // factors
	int one = c.Meas(spot);
	int two = b.Meas(spot);
	printf("%d, %d are factors of %d\n", one, two, val);
}

void
pintpi(int bits)
{
	pint intervals(1 << bits); // intervals in pbits
	pint w(1 << (2 * bits)); // int scaling factor
	pint x = pint(0).Had(bits); // all x values
	pint y = pint(0).Had(bits,bits); // all y values
	pint h = w / (((x * x) >> bits) + intervals);
	pint r = (h > y); // r is 1 where below curve
	// count 1s; quantum would sample probability
	double pi = (4.0 * r.Pop()) / (1 << REWAYS);
	printf("Pi is roughly %f\n", pi);
}

int
main(int argc, char **argv)
{
	srand(time(0)); // for random measurements...

	pbitripple(); // 4-bit ripple-carry adder

	pintsqrt(169); // computes sqrt(169)

	pintfactor(143); // factors 143

	pintpi(2); // needs larger REWAYS for more bits

	re.Stats(); // show lower-level statistics
}
