#include <stdio.h>
#include <malloc.h>

// function to get each element of product matrix
// C row, column =   A row, 1 * B 1, column 
//                 + A row, 2 * B 2, column
//                 + ......
//                 + A row, n * B n, column
// n is the SIZE or the compatible dimension
// meaning, if A is m x n, and B is n x p
// A and B have to be compatible for multiplication

void multiply (int n, int row, int column, int** A, int** B, int** C)
{
    int iterator = 0;
    C[row][column] = 0;

    for (iterator = 0; iterator < n; ++iterator)
        {
	        C[row][column] += (A[row][iterator] * B[iterator][column]);
	    }

     return;

}


int main(int argc, char *argv[])
{
    int i, j, row, column, size; 
    int rowA, rowB, colA, colB;

	printf ("Enter dimensions of A, row first:\n");
	scanf ("%d", &rowA);
	scanf ("%d", &colA);
	printf ("Enter dimensions of B, row first:\n");
	scanf ("%d", &rowB);
	scanf ("%d", &colB);

	if (colA != rowB)
	{
		printf ("A and B are not product compatible. Try again.\n");
		exit(1);
	}

	if (rowA <= 0 || colA <= 0 || rowB <=0 || colB <=0)
	{
		printf ("Invalid Dimensions entered. Try again.\n");
		exit(2);
	}

    int **A, **B, **C;
    A = (int **)malloc (rowA * sizeof(int *));
	for (i = 0; i < rowA ; ++i )
	{
		A[i] = (int *)malloc (colA * sizeof(int));
	}

	B = (int **)malloc (rowB * sizeof(int *));
	for (i = 0; i < rowB ; ++i )
	{
		B[i] = (int *)malloc (colB * sizeof(int));
	}

    C = (int **)malloc (rowA * sizeof(int));
	for (i = 0; i < colB ; ++i )
	{
		C[i] = (int *)malloc (colB * sizeof(int));
	}

	size = colA;

	printf("Enter values of matrix A, row by row\n");
    for (i = 0; i < rowA; ++i)
      for (j = 0; j < size; ++j)
          scanf("%d", &A[i][j]);
 
    printf("Enter values of matrix B, row by row\n");
    for (i = 0; i < size; ++i)
      for (j = 0; j < colB; ++j)
          scanf("%d", &B[i][j]);


    for (row = 0; row < rowA; ++row )
    {
		for (column = 0; column < colB; ++column )
		{
			multiply (size, row, column, A, B, C);
		}
    }


	printf("This is the product matrix: \n");
    for (i = 0; i < rowA; ++i)
    {
      for (j = 0; j < colB; ++j)
        {
    	    printf("%d ", C[i][j]);
	    }
      printf("\n");
    }

	free(A);
	free(B);
	free(C);
	
	return 0;
}
