학사 나부랭이

Strassen function 본문

塵箱/코드 삽질

Strassen function

태양왕 해킹 (14세) 2021. 4. 4. 01:09
#include <stdio.h>
#include<stdlib.h>
#include <math.h>
void free(int** mtx, int n) {
	for (int i = 0; i < n; i++) free(*(mtx + i));
	free(mtx);
}
int **addmtx(int** mtx1, int** mtx2, int rn, int chk) {
	int** res;
	res = (int**)malloc(sizeof(int*) * rn);
	for (int i = 0; i < rn; i++) {
		res[i] = (int*)malloc(sizeof(int*) * rn);
		for (int j = 0; j < rn; j++) {
			if (mtx2[i][j] && chk) res[i][j] = mtx1[i][j] + mtx2[i][j];
			else if(mtx2[i][j] && !chk) res[i][j] = mtx1[i][j] - mtx2[i][j];
			else res[i][j] = mtx1[i][j];
		}
	}
	return res;
}
int **strassen(int n, int **mtx1, int **mtx2) {
	int** A11, ** A12, ** A21, ** A22;
	int** B11, ** B12, ** B21, ** B22;
	int** res;
	int rn = sqrt(n); //mtx1, mtx2의 한 행/열 당 원소 개수
	int sosrn2 = rn / 2; //A11~B22의 한 행/열 당 원소 개수
	int sosrn1 = rn - sosrn2; //혹시 9X9 같은거면 곤란하니까
	if (!(mtx1[rn - 1][rn - 1]) || !(mtx2[rn - 1][rn - 1])) {
		printf("정방행렬X || 같은 크기X");
		return NULL;
	}
	A11 = (int**)malloc(sizeof(int*) * sosrn1);
	A12 = (int**)malloc(sizeof(int*) * sosrn1);
	A21 = (int**)malloc(sizeof(int*) * sosrn2);
	A22 = (int**)malloc(sizeof(int*) * sosrn2);
	B11 = (int**)malloc(sizeof(int*) * sosrn1);
	B12 = (int**)malloc(sizeof(int*) * sosrn1);
	B21 = (int**)malloc(sizeof(int*) * sosrn2);
	B22 = (int**)malloc(sizeof(int*) * sosrn2);
	res = (int**)malloc(sizeof(int*) * rn);
	for (int i = 0; i < rn; i++) res[i] = (int*)malloc(sizeof(int*) * rn);
	for (int i = 0; i < sosrn1; i++) {
		A11[i] = (int*)malloc(sizeof(int*) * sosrn1);
		A12[i] = (int*)malloc(sizeof(int*) * sosrn2);
		B11[i] = (int*)malloc(sizeof(int*) * sosrn1);
		B12[i] = (int*)malloc(sizeof(int*) * sosrn2);
		for (int j = 0; j < sosrn1; j++) {
			A11[i][j] = mtx1[i][j];
			B11[i][j] = mtx2[i][j];
		}
		for (int j = 0; j < sosrn2; j++) {
			A12[i][j] = mtx1[i][j + sosrn1];
			B12[i][j] = mtx2[i][j + sosrn1];
		}
	}
	for (int i = 0; i < sosrn2; i++) {
		A21[i] = (int*)malloc(sizeof(int*) * sosrn1);
		A22[i] = (int*)malloc(sizeof(int*) * sosrn2);
		B21[i] = (int*)malloc(sizeof(int*) * sosrn1);
		B22[i] = (int*)malloc(sizeof(int*) * sosrn2);
		for (int j = 0; j < sosrn1; j++) {
			A21[i][j] = mtx1[i + sosrn1][j];
			B21[i][j] = mtx2[i + sosrn1][j];
		}
		for (int j = 0; j < sosrn2; j++) {
			A22[i][j] = mtx1[i + sosrn1][j + sosrn1];
			B22[i][j] = mtx2[i + sosrn1][j + sosrn1];
		}
	}
	//for (int i = 0; i < sosrn1; i++) {
	//	for (int j = 0; j < sosrn1; j++) {
	//		printf("%d ", A11[i][j]);
	//	}
	//	for (int j = 0; j < sosrn2; j++) {
	//		printf("%d ", A12[i][j]);
	//	}
	//	printf("\n");
	//}
	//for (int i = 0; i < sosrn2; i++) {
	//	for (int j = 0; j < sosrn1; j++) {
	//		printf("%d ", A21[i][j]);
	//	}
	//	for (int j = 0; j < sosrn2; j++) {
	//		printf("%d ", A22[i][j]);
	//	}
	//	printf("\n");
	//}
	//printf("\n");
	//for (int i = 0; i < sosrn1; i++) {
	//	for (int j = 0; j < sosrn1; j++) {
	//		printf("%d ", B11[i][j]);
	//	}
	//	for (int j = 0; j < sosrn2; j++) {
	//		printf("%d ", B12[i][j]);
	//	}
	//	printf("\n");
	//}
	//for (int i = 0; i < sosrn2; i++) {
	//	for (int j = 0; j < sosrn1; j++) {
	//		printf("%d ", B21[i][j]);
	//	}
	//	for (int j = 0; j < sosrn2; j++) {
	//		printf("%d ", B22[i][j]);
	//	}
	//	printf("\n");
	//}
	//printf("\n");
	if (sosrn1 == 1 && sosrn2 == 1) { //다 나눴다! 곱셈 가능!
		int m[8];
		m[1] = (**A11 + **A22) * (**B11 + **B22);
		m[2] = (**A21 + **A22) * **B11;
		m[3] = **A11 * (**B12 - **B22);
		m[4] = **A22 * (**B21 - **B11);
		m[5] = (**A11 + **A12) * **B22;
		m[6] = (**A21 - **A11) * (**B11 + **B12);
		m[7] = (**A12 - **A22) * (**B21 + **B22);
		res[0][0] = m[1] + m[4] - m[5] + m[7]; res[0][1] = m[3] + m[5];
		res[1][0] = m[2] + m[4]; res[1][1] = m[1] + m[3] - m[2] + m[6];
		printf("\nres start\n");
		for (int i = 0; i < 2; i++) {
			for (int j = 0; j < 2; j++) printf("%d ", res[i][j]);
			printf("\n");
		}
		printf("res end\n");
	}
	else {
		for (int i = 0; i < sosrn1; i++) {
			for (int j = 0; j < sosrn1; j++) {
				res[i][j] = strassen(sosrn1 * sosrn1, addmtx(A11, A22, sosrn1, 1), addmtx(B11, B22, sosrn1, 1))[i][j] + strassen(sosrn1 * sosrn1, addmtx(B21, B11, sosrn1, 0), A22)[i][j] - strassen(sosrn1 * sosrn1, addmtx(A11, A12, sosrn1, 1), B22)[i][j] + strassen(sosrn1 * sosrn1, addmtx(A12, A22, sosrn1, 0), addmtx(B21, B22, sosrn1, 1))[i][j];
			}
		}
		for (int i = 0; i < sosrn1; i++) {
			for (int j = 0; j < sosrn1; j++) {
				res[i][j + sosrn1] = strassen(sosrn1 * sosrn1, addmtx(B12, B22, sosrn1, 0), A11)[i][j] + strassen(sosrn1 * sosrn1, addmtx(A11, A12, sosrn1, 1), B22)[i][j];
			}
		}
		for (int i = 0; i < sosrn1; i++) {
			for (int j = 0; j < sosrn1; j++) {
				res[i+sosrn1][j] = strassen(sosrn1 * sosrn1, addmtx(A21, A22, sosrn1, 1), B11)[i][j] + strassen(sosrn1 * sosrn1, addmtx(B21, B11, sosrn1, 0), A22)[i][j];
			}
		}
		for (int i = 0; i < sosrn1; i++) {
			for (int j = 0; j < sosrn1; j++) {
				res[i+sosrn1][j+sosrn1] = strassen(sosrn1 * sosrn1, addmtx(A11, A22, sosrn1, 1), addmtx(B11, B22, sosrn1, 1))[i][j] + strassen(sosrn1 * sosrn1, addmtx(B12, B22, sosrn1, 0), A11)[i][j] - strassen(sosrn1 * sosrn1, addmtx(A21, A22, sosrn1, 1), B11)[i][j] + strassen(sosrn1 * sosrn1, addmtx(A21, A11, sosrn1, 0), addmtx(B11, B12, sosrn1, 1))[i][j];
			}
		}
	}
	free(A11, sosrn1);
	free(A12, sosrn1);
	free(A21, sosrn2);
	free(A22, sosrn2);

	free(B11, sosrn1);
	free(B12, sosrn1);
	free(B21, sosrn2);
	free(B22, sosrn2);
	return res;
}
int main() {
	int** a, **b, **res;
	int tmp = 1;
	a = (int**)malloc(sizeof(int*) * 4);
	b = (int**)malloc(sizeof(int*) * 4);
	res = (int**)malloc(sizeof(int*) * 4);
	for (int i = 0; i < 4; i++) {
		a[i] = (int*)malloc(sizeof(int*) * 4);
		b[i] = (int*)malloc(sizeof(int*) * 4);
		res[i] = (int*)malloc(sizeof(int*) * 4);
		for (int j = 0; j < 4; j++) {
			a[i][j] = tmp;
			b[i][j] = tmp + 7;
			if (b[i][j] > 9) b[i][j] -= 9;
			tmp++;
			if (tmp > 9) tmp = 1;
		}//나폴리탄씨가 쓴 알고리즘 기초 70페이지 예제 2.5의 행렬
	}
	res = strassen(16, a, b);
	printf("\n==================================================================\n");
	for (int i = 0; i < 4; i++) {
		for (int j = 0; j < 4; j++) printf("%d ", res[i][j]);
		printf("\n");
	}
	free(a, 4);
	free(b, 4);
	free(res, 4);
}

결과
올바른 결과

+3 +5 -3 +13
-5 -3 +40 -3
+3 +5 -3 +13
-5 -3 +40 -3

흐아ㅏㅏ아아아가닛비ㅏ 거의 다 온거 같은데ㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠ

'塵箱 > 코드 삽질' 카테고리의 다른 글

백준 15829 - 모듈러 연산  (0) 2021.04.20
백준 10866 - 덱  (0) 2021.04.19
백준 1494 - 다이나믹 프로그래밍, 이진 탐색  (0) 2021.04.17
Floyd algorithm  (0) 2021.04.16
백준 2164 - 큐 구현  (0) 2021.04.16
Comments