카라츠바의 빠른 곱셈 알고리즘은 수백 자리, 수만 자리 되는 큰 두 정수들을 곱하는 알고리즘이다.
두 자연수는 정수형 변수에 저장되는 것이 아니라, 각 자리의 십진수 표기가 배열에 저장된다.
두 정수를 곱한 결과를 계산하는 가장 기본적인 방법은 다음과 같다.
위 그림과 같이 각 정수 배열들은 각 자릿수를 맨 아래 자리부터 저장하고 있다.
이렇게 순서를 반대로 저장하면 a[i]에 주어진 자릿수의 크기를 10^i 로 쉽게 구할 수 있다는 장점이 있다.
따라서 위 그림과 같이 a[i] 와 b[j] 를 곱한 결과를 c[i+j] 에 저장할 수 있게 된다.
이를 코드로 구현하면 다음과 같다.
// 두 큰수를 곱하는 O(n^2) 시간 알고리즘
// 자릿수 올림을 처리하는 함수
void normalize(vector<int>& num) {
num.push_back(0); // 제일 윗자리로 올려질 경우를 대비해 0을 삽입
for (int i = 0; i < num.size(); i++) {
if (num[i] < 0) { // 자릿수가 음수일 때 윗 자리에서 빌려서 계산한다.
int borrow = (abs(num[i]) + 9) / 10;
num[i + 1] -= borrow;
num[i] += borrow * 10;
}
else { // 자릿수의 십의 자리는 위로 올리고 일의자리만 남긴다.
num[i + 1] += num[i] / 10;
num[i] = num[i] % 10;
}
}
if (num.back() == 0) num.pop_back(); // 제일 위로 올려진 값이 없을 때 앞에서 추가한 0을 삭제
}
// 배열 a와 배열 b의 각 자리를 곱해 다른 배열 c에 저장하는 함수
vector<int> multiply(vector<int>& a, vector<int>& b) {
vector<int> c(a.size() + b.size() + 1, 0);
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < b.size(); j++) {
c[i + j] += a[i] * b[j];
}
}
normalize(c); // 자릿수 올림을 실행한다.
return c;
}
normalize(...)는 매개변수로 받은 벡터에 저장된 수를 자릿수 올림 처리하는 함수이다.
예를 들어 위 그림에서 c[0]에 32가 저장되어 있으면 여기서 30은 십의 자리로 올림해주어야 하기 때문에
32를 10으로 나눈 몫을 c[1]에 더해주고, 나머지를 c[0]에 남겨둔다.
이 알고리즘의 시간 복잡도는 두 정수의 길이가 모두 n 이라고 할 때 O(n^2) 이다.
n번 실행되는 for문이 두번 겹쳐있기 때문이다.
이 알고리즘보다 조금 더 빠른 알고리즘이 바로 카라츠바 알고리즘이다.
카라츠바 알고리즘
카라츠바의 빠른 곱셈 알고리즘은 두 정수를 각각 절반으로 나눈다.
a와 b 모두 256자리 수라고 가정하고 두 수를 반으로 나누면,
a의 앞의 128자리 수는 a1, 뒤의 128자리 수는 a0로 나타낼 수 있고, b의 앞의 128자리 수는 b1, 뒤의 128자리 수는 b0로 나타낼 수 있다.
그렇게 나타낸 식으로 두 수를 곱하면,
큰 수를 1번 곱하는 식을 절반 크기로 나눈 수들을 4번 곱하는 식으로 바꿀 수 있다.
( a1*b1, a0*b1, a1*b0, a0*b0)
10의 거듭제곱은 시프트 연산으로 수행하면 되므로 곱셈 연산에 포함시키지 않도록 한다.
그러면 두 수를 곱하는 데 걸리는 시간은 시프트 연산을 하는데 걸리는 시간 O(n) 시간과 네번의 n/2자리 곱셈 연산 시간이다.
따라서 T(n) = O(n) + 4*T(n/2) 인데,
이를 마스터정리에 의해 정리하면 O(n^2) 의 시간이 걸리게 된다.
그러면 이전의 알고리즘에 비해서 더 나아진게 없게 된다.
알고리즘 수행시간을 줄이기 위해 곱셈의 횟수를 다음과 같이 세번으로 줄이도록한다.
위의 식에 의해 곱셈 연산, 즉 재귀호출을 세번만 실행하여 수행 시간을 줄일 수 있다.
(a0*b0, (a0+a1)*(b0+b1), a1*b1)
이를 코드로 구현하면 다음과 같다.
// 카라츠바의 빠른 곱셈 알고리즘
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
// 자릿수 올림을 처리하는 함수
void normalize(vector<int>& num) {
num.push_back(0); // 제일 윗자리로 올려질 경우를 대비해 0을 삽입
for (int i = 0; i < num.size(); i++) {
if (num[i] < 0) { // 자릿수가 음수일 때 윗 자리에서 빌려서 계산한다.
int borrow = (abs(num[i]) + 9) / 10;
num[i + 1] -= borrow;
num[i] += borrow * 10;
}
else { // 자릿수의 십의 자리는 위로 올리고 일의자리만 남긴다.
num[i + 1] += num[i] / 10;
num[i] = num[i] % 10;
}
}
if (num.back() == 0) num.pop_back(); // 제일 위로 올려진 값이 없을 때 앞에서 추가한 0을 삭제
}
// 배열 a와 배열 b의 각 자리를 곱해 다른 배열 c에 저장하는 함수
vector<int> multiply(vector<int>& a, vector<int>& b) {
vector<int> c(a.size() + b.size() + 1, 0);
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < b.size(); j++) {
c[i + j] += a[i] * b[j];
}
}
normalize(c); // 자릿수 올림을 실행한다.
return c;
}
// a = a + b * 10^k 식을 수행하는 함수
void addTo(vector<int>& a, vector<int>& b, int k) {
a.resize(max(a.size(), b.size() + k));
for (int i = 0; i < b.size(); i++) {
a[i + k] += b[i]; // a에서 k자리만큼 이동한 곳에 b의 값을 더한다.
}
normalize(a);
}
// a = a - b 식을 수행하는 함수
void subFrom(vector<int>& a, vector<int>& b) {
a.resize(max(a.size(), b.size()) + 1);
for (int i = 0; i < b.size(); i++) {
a[i] -= b[i];
}
normalize(a);
}
// karatsuba 곱셈을 수행하는 함수
vector<int> karatsuba(vector<int>& a, vector<int>& b) {
int an = a.size();
int bn = b.size();
if (an < bn) return karatsuba(b, a); // a보다 b가 자릿수가 더 크면 자리를 바꿔줌
if (an == 0 || bn == 0) return vector<int>(); // a 또는 b가 0이면 서로 곱하면 0이 된다.
if (an <= 50) multiply(a, b); // a와 b의 크기가 일정크기 이하로 작아지면 O(n^2) 곱셈으로 해결한다.
int half = an / 2;
// a와 b를 밑에서 half 자리와 나머지로 분리한다.
vector<int> a0(a.begin(), a.begin() + half);
vector<int> a1(a.begin() + half, a.end());
vector<int> b0(b.begin(), b.begin() + min(b.size(), half)); // a 자릿수의 절반보다 b의 자릿수가 작은경우가 있을 수 있으므로
vector<int> b1(b.begin() + min(b.size(), half), b.end());
vector<int> z2 = karatsuba(a1, b1); // z2 = a1 * b1 // 재귀호출 1번
vector<int> z0 = karatsuba(a0, b0); // z0 = a0 * b0 // 재귀호출 2번
addTo(a0, a1, 0); // a0 = a0 + a1
addTo(b0, b1, 0); // b0 = b0 + b1
// z1 = (a0 + a1) * (b0 + b1) - z0 - z2
vector<int> z1 = karatsuba(a0, b0); // 재귀호출 3번
subFrom(z1, z0);
subFrom(z1, z2);
// ret = z0 + z1 * 10^half + z2 * 10^(half*2)
vector<int> ret;
addTo(ret, z0, 0);
addTo(ret, z1, half);
addTo(ret, z2, half * 2);
return ret;
}
시간 복잡도 분석
카라츠바 알고리즘의 수행시간을 병합 단계와 base case의 두 부분으로 나누어서 볼 수 있다.
병합 단계의 수행시간은 addTo(...)와 subFrom(...) 의 수행시간에 의해 지배되고,
base case의 처리시간은 multiply(...)의 수행시간에 의해 지배된다.
먼저 base case 처리시간에 대해 생각해보자.
위 알고리즘에서는 50자리 이하일 때 multiply(...) 로 곱셈을 수행하지만 편의를 위해 한자리 숫자에 도달해야 multiply(...)로 수행한다고 한다.
그러면 자릿수 n이 2의 거듭제곱 2^k 라고 할 때, 재귀 호출의 깊이는 k가 된다.
한 단계마다 곱셈의 수가 3배씩 늘어나기 때문에 마지막 단계에서는 3^k개의 부분 문제가 있는데,
마지막 단계에서는 두 수 모두 한자리 수 이므로 곱셈 한번이면 충분하다.
따라서 곱셈의 수는 O(3^k) 가 된다.
n = 2^k 이므로 k = logn 이 되고 O(3^k) = O(3^logn) = O(n^log3) 이 된다.
다음 병합단계의 수행시간에 대해 생각해보자.
addTo(...)와 subFrom(...)은 숫자의 길이에 비례하는 시간이 걸린다. 즉, O(n)의 시간이 걸린다는 뜻이다.
그런데 단계가 내려갈 때마다 숫자의 길이는 절반으로 줄어들고, 부분 문제의 개수는 세 배 늘기 때문에,
i번째 단계에서 필요한 연산 수는 (3/2)^i * n 이 된다.
따라서 모든 단계에서 필요한 전체 연산의 수는 다음 식에 비례한다.
n * ∑(3/2)^i ( 0<=i<=logn)
그런데 위 함수는 3^logn보다 훨씬 느리게 증가하므로 무시 가능하다.
따라서 카라츠바의 최종 시간 복잡도는 O(n^log3) 이 된다.
이때 log3 = 1.585....이기 때문에 O(n^2)보다 훨씬 적은 곱셈을 필요로 하고 시간이 줄어든다는 것을 확인할 수 있다.
단, 카라츠바 알고리즘의 구현은 단순한 O(n^2) 알고리즘보다 훨씬 복잡하기 때문에
입력의 크기가 작을 경우 O(n^2) 알고리즘보다 느린 경우가 많게 된다.
따라서 위의 코드처럼 입력된 숫자가 짧을 경우 O(n^2) 알고리즘을 사용하도록 한다.
팬미팅 문제
이 문제는 멤버들과 팬들이 만나는 모든 경우의 수를 다 세기에는 입력값이 커지면 너무 오랜 시간이 걸리게 된다.
따라서 이 문제를 두 큰 수의 곱셈으로 변형시키도록 한다.
두 큰 수의 곱셈에서 multiply(...) 함수는 각 자릿수들의 곱을 계산하고, 각 세로줄의 합 C[i]를 계산한다.
이 때 C[i]를 구하기 위해 더하는 값들이 다음과 같은 형태라는 것을 볼 수 있다.
문제에서 각 멤버들의 성별을 배열 A, 각 팬들의 성별을 배열 B에 저장한다고 보면 된다.
그러면 각 멤버들과 팬들의 성별을 곱하여 더한값을 배열 C에서 얻을 수 있다.
그런데 기존 식은 A와 B의 순서가 반대가 되어서 서로 곱하므로, 곱셈을 하기 전에 A의 원소들을 순서를 앞뒤로 반전시켜주어야 한다.
그러면 B의 숫자들을 왼쪽으로 한칸씩 이동하면서 A의 숫자들과 곱한 결과를 얻을 수 있다.
위 식에서는 C[2] 부터 C[3] 까지 배열 C의 값을 확인하면 된다.
성별이 남성일 때는 1을, 여성일 때는 0을 배열에 저장하면 둘 다 남성일때만 곱셈의 결과가 1이 나오기 때문에
C의 값이 0일 때만 결과값(멤버 전체가 포옹하는 횟수)을 1씩 증가시켜주면 된다.
이를 코드로 구현하면 다음과 같다.
// 카라츠바의 빠른 곱셈을 이용한 팬미팅 문제 함수
int hugs(const string& members, const string& fans){
int n = members.size();
int m = fans.size();
vector<int> A(n), B(m);
for(int i = 0; i < n; i++) A[n-1-i] = (member[i] == 'M'); // 남성일 때는 1, 여성일 때는 0을 저장
for(int i = 0; i < m; i++) B[i] = (fans[i] == 'M');
vector<int> C = karatsuba(A,B);
int allHugs = 0;
for(int i = n - 1; i < m; i++) { // 멤버 수 -1 부터 팬 수 -1 까지
if(C[i] == 0) allHugs++;
}
return allHugs;
}
위의 코드에서 카라추바 알고리즘을 통해 구한 결과 벡터 C를 normalize, 즉 자리올림을 하지 않는 것을 확인할 수 있다.
이는 단순히 각 자릿수를 곱한 결과들을 더한 결과값을 확인하기 위함이다.
즉 C를 20만진수라고 생각하면 편하다.
'DEVELOP > 알고리즘' 카테고리의 다른 글
2. Divide & Conquer - 히스토그램에서 가장 큰 직사각형 문제 (6549) (0) | 2019.01.17 |
---|---|
2. Divide & Conquer - 쿼드트리(quadtree) 문제 (1992) (0) | 2019.01.16 |
2. Divide & Conquer (0) | 2019.01.15 |
1. Brute-Force - TSP 문제 알고리즘 (0) | 2019.01.15 |
1. Brute-Force - 게임판 덮기(Boardcover) 문제 (0) | 2019.01.14 |