/*
 * 	22C:021 Java Structures 
 *	Project 1 - MatrixMultiplication Implementaions 
 *	
 * 	By TA: Zhihong Wang (2004-10-25)
 * 
 * 	Methods implemented:
 *	- multiplySparseMatrix(SparseMatrix sm1,SparseMatrix sm2 )
 *	- multiplyMatrix(Matrix m1,Matrix m2)
 *	- randomInt(double p)
 * 	- createRandomSparseMatrix(int n)
 *	- createRandomMatrix(int n)
 */ 

import structure.*;
import java.util.Random;

public class MatrixMultiplication
{	
	protected static Random rand = new Random();;
	
	/**
	 * Multiply two SparseMatrix objects sm1 and sm2. 
	 * Pre (sm1!=null&&sm2!=null) &&(sm1.width()==sm2.height()) 
	 * Retrun a SparseMatrix object which is the product of sm1 and sm2.
	 */		
	public static SparseMatrix multiplySparseMatrix(SparseMatrix sm1,SparseMatrix sm2 )
	{		
		Assert.pre(sm1!=null&&sm2!=null, "Non-null SparseMatrixReference(s)");
		Assert.pre(sm1.width()==sm2.height(), "Muliplyable matrices");
		
		SparseMatrix sm3 = new SparseMatrix (sm1.height(), sm2.width());						
		Matrix row_mat, col_mat;
				
		for( int i=0;i<sm1.height();i++)
		{
			for (int j=0;j<sm2.width();j++)
			{
				//Get row i of sm1
				row_mat = sm1.getRow(i);
				
				//Get column j of sm1
				col_mat = sm2.getCol(j);
				
				int sum = 0;
				int row_val; 	//non-zero value in row i of sm1	
				int row_val_id;	//column index of row_val			
				int col_val; 	//non-zero value in column j of sm2
				int col_val_id;	//column index of col_val				
				
				int m,n;				
				//m: column index of row i in sm1
				//n: row index of column j in sm2				
				m=0;n=0; 
				
				//Compute the value at (i,j) in product matrix
				while(m<row_mat.width()&&n<col_mat.height())
				{
					row_val = ((Integer)row_mat.get(0,m)).intValue();
					row_val_id = ((Integer)row_mat.get(1,m)).intValue();
					
					col_val = ((Integer)col_mat.get(n,0)).intValue();
					col_val_id = ((Integer)col_mat.get(n,1)).intValue();					
					
					if(row_val_id==col_val_id)
					{
						sum += row_val*col_val;	
						m++;
						n++;											
					}
					else if(row_val_id < col_val_id)
					{
						m++;						
					}
					else 
					{
						n++;
					}								
				}								
			
				if(sum!=0)
					sm3.set(i, j, new Integer(sum));				
			}
		}
		
		return sm3;
	}
	
	
	/**
	 * Multiply two Matrix objects m1 and m2. 
	 * Pre (m1!=null&&m2!=null) &&(m1.width()==m2.height()) 
	 * Retrun a Matrix object which is the product of m1 and m2.
	 */	
	public static Matrix multiplyMatrix(Matrix m1,Matrix m2)
	{		
		Assert.pre(m1!=null&&m2!=null, "Non-null Matrix Reference(s)");
		Assert.pre(m1.width()==m2.height(), "Muliplyable matrices");
		
		Matrix m3 = new Matrix (m1.height(), m2.width());
				
		for( int i = 0; i < m1.height(); i++)
		{
			for (int j = 0; j < m2.width(); j++)
			{
				int sum = 0;
				int row_val,col_val;
				//Compute the value at (i,j) in product matrix				
				for (int k = 0; k < m1.width(); k++)
				{
					row_val = ((Integer)m1.get(i,k)).intValue();
					col_val = ((Integer)m2.get(k,j)).intValue();
					sum += ( row_val * col_val );
				}				
								
				m3.set(i, j, new Integer(sum));				
			}
		}
		
		return m3;
	}
	
	//generate a random integer which be non-zero with probability p
	public static int randomInt(double p)
	{
		//Decide if the generated integer is zero or not
		double rand_double = rand.nextDouble();
		if(rand_double<p)
		{
			//generate a random integer						
			int rand_int = rand.nextInt(100);
			if(rand_int==0)
			{	rand_int = 1;	}
				
			return rand_int;	
		}		
		else
		{
			return 0;
		}		
	}
	
	/**
	 * Construct a nxn Matrix, in which each entry is non-zero 
	 * with probability p
	 * Pre n>=0 && 0<= p <= 1
	 * Retrun the generated  Matrix object
	 */		
	public static Matrix createRandomMatrix(int n, double p)
	{		
		Assert.pre(n>=0, "Positive Matrix dimension");
		Assert.pre(p>=0 && p<=1, "Probability between 0.0 and 1.0");
		
		Matrix rand_mat = new Matrix(n,n);
				
		for(int i=0; i<n; i++)
		{
			for(int j=0;j<n;j++)
			{
				int rand_int = randomInt(p);				
				rand_mat.set(i,j,new Integer(rand_int));								
			}
		}
		
		return rand_mat;
	}	
		
	//Check if Matrix m1 and SparseMatrix m2 are identical.
	public static boolean isIdentical ( Matrix m1, SparseMatrix m2 )
	{
		int h1 = m1.height();
		int w1 = m1.width();
		
		int h2 = m2.height();
		int w2 = m2.width();
		
		//return false if m1,m2 have different sizes
		if(h1!=h2||w1!=w2) 
		{
			return false;
		}
		
		for(int i=0;i<h1;i++)
		{
			for(int j=0;j<w1;j++)
			{
				Integer v1_obj,v2_obj;
				int v1,v2;
				
				v1_obj = (Integer)m1.get(i,j);
				v1 = v1_obj.intValue();
				
				v2_obj = (Integer)m2.get(i,j);
				if(v2_obj!=null)
				{
					v2 = v2_obj.intValue();
				}
				else
				{
					v2 = 0;					
				}
				
				if(v1!=v2)
					return false;
			}
		}
		
		return true;
	}
	
	
    public static void main(String args[])
    {       	
    	Matrix m1,m2,m3;
    	SparseMatrix sm1,sm2,sm3;
    	
    	int n;	// matrix dimension: nxn
    	double p;	//Matrix entry is non-zero with probability p
				
		double construct_t1, construct_t2;
		double multiply_t1, multiply_t2;
		double time, newtime;
		
		n = 300;
		p = 10000/(n*n);
		
		//Timing construction of Matrix objects
		time = System.currentTimeMillis();		
		m1 = createRandomMatrix(n,p);
		m2 = createRandomMatrix(n,p);
		newtime = System.currentTimeMillis();		
		construct_t1 = newtime - time;
		
		System.out.println("Construction time of two Matrix objects = "+construct_t1+" milliseconds");
		
		//Timing multiplying time of Matrix objects				
		time = System.currentTimeMillis();
		m3 = multiplyMatrix(m1,m2);
		newtime = System.currentTimeMillis();		
		multiply_t1 = newtime-time;				
				
		System.out.println("Mulitplication time of Matrix objects = "+multiply_t1+" milliseconds");		
		
		//Timing construction of SparseMatrix objects
		time = System.currentTimeMillis();		
		sm1 = new SparseMatrix(m1);				
		sm2 = new SparseMatrix(m2);		
		newtime = System.currentTimeMillis();		
		construct_t2 = newtime - time;	
		System.out.println("Construction time of SparseMatrix objects = "+construct_t2+" milliseconds");
								
		//Timing multiplying time of SparseMatrix objects				
		time = System.currentTimeMillis();
		sm3=multiplySparseMatrix(sm1,sm2);
		newtime = System.currentTimeMillis();		
		multiply_t2 = newtime-time;				
		
		System.out.println("Mulitplication time of SparseMatrix objects = "+multiply_t2+" milliseconds");
		
					
		String s1 = new String("are identical.");
		String s2 = new String("are not identical.");		
		System.out.println("The Matrix product and SparseMatrix product "+(isIdentical(m3,sm3)?s1:s2));	
	
	}
		
}