💻 백준 17626: Four Squares
- 문제 링크: https://www.acmicpc.net/problem/17626
- 알고리즘 분류: 다이나믹 프로그래밍, 정수론, BFS
문제 소개 🧐
모든 자연수는 넷 혹은 그 이하의 제곱수의 합으로 표현할 수 있다는 라그랑주의 네 제곱수 정리를 바탕으로, 주어진 자연수 n
을 최소 개수의 제곱수 합으로 표현하는 문제입니다.
입력 📝
- 자연수
n
(1 ≤ n ≤ 50,000)
출력 📤
n
을 표현하는 최소 제곱수의 개수를 출력합니다.
제한 ❌
- 시간 제한: 0.5 초
- 메모리 제한: 512 MB
첫 번째 시도: 브루트포스 👊
Idea
가장 큰 제곱수부터 빼나가면서 최소 개수를 찾으려고 시도했다. 예를 들어 n
이 주어지면, n
보다 작거나 같은 가장 큰 제곱수 s^2
를 찾고, n - s^2
에 대해 이 과정을 반복하는 탐욕적인 접근이다.
Code
import sys
import math
def get_cnt_from_integer(x_sq, n_val):
cnt = 0
li = [i*i for i in range(1, math.ceil(math.sqrt(n_val))+1)]
while n_val > 0:
while n_val < x_sq:
x_sq = li.pop()
cnt += 1
n_val -= x_sq
return cnt
# ... (main logic)
Result
- 실패 (57%): 이 방식은 항상 최적의 해를 보장하지 않는다는 것을 깨달았다. 예를 들어 25987은 159^2 + 26^2 + 3^2로 3개의 제곱수 합으로 표현되지만, 가장 큰 제곱수인 161^2 (25921)부터 빼면 더 많은 개수가 필요할 수 있다.
두 번째 시도: BFS ✨
Idea
BFS를 사용하여 최소 개수를 찾는 방법이다. 큐에 (현재 숫자, 사용된 제곱수 개수)
를 저장하고, 각 단계에서 모든 제곱수를 더해보며 n
에 도달하는 가장 빠른 경로(최소 개수)를 찾는 방식이다.
Code
from collections import deque
import math
n = int(sys.stdin.readline().rstrip())
sqrt_list = [i*i for i in range(1, math.ceil(math.sqrt(n))+1)]
counted_num = [False] * 50001
que = deque()
for sqr in sqrt_list:
if sqr == n:
print(1)
# ...
que.append((sqr, 1))
counted_num[sqr] = True
while que:
cur_num, cur_cnt = que.popleft()
for sqr in sqrt_list:
next_num = cur_num + sqr
if next_num == n:
print(cur_cnt + 1)
# ...
if 0 < next_num < 50001 and not counted_num[next_num]:
que.append((next_num, cur_cnt + 1))
counted_num[next_num] = True
Result
- 시간 초과: 모든 경우를 탐색하기 때문에
n
이 커질수록 탐색 공간이 너무 넓어져 시간 내에 해결할 수 없는 것 같다.
세 번째 시도: 다이나믹 프로그래밍 (DP) ✅
Idea
DP를 사용하여 문제를 해결하는 방법이다. dp[i]
를 i
를 만드는 데 필요한 최소 제곱수의 개수라고 정의하고, dp[i] = min(dp[i - j*j]) + 1
이라는 점화식을 세울 수 있다. 여기서 j*j
는 i
보다 작은 모든 제곱수이다.
Code
import sys
import math
n = int(sys.stdin.readline().rstrip())
dp = [5] * (n+1)
for i in range(1, n+1):
if math.sqrt(i).is_integer():
dp[i] = 1
continue
for j in range(1, int(math.sqrt(i)) + 1):
if dp[i] > dp[i - j*j] + 1:
dp[i] = dp[i - j*j] + 1
print(dp[n])
Result
- PyPy3 통과 (시간 초과): Python 3에서는 시간 초과가 발생했지만, PyPy3에서는 통과했다. 시간 복잡도는 O(N√N)으로, 여전히 빠듯한 것 같다.
네 번째 시도: 정수론 (라그랑주 & 르장드르 정리) ✅
Idea
수학적 정리를 활용하여 더 효율적으로 해결하는 방법이다.
- 1개:
n
이 제곱수인 경우 - 2개:
n
이 두 제곱수의 합인 경우 (n = a^2 + b^2
) - 4개:
n
이4^k(8m+7)
꼴인 경우 (르장드르의 세 제곱수 정리) - 3개: 위의 경우에 해당하지 않는 모든 경우
Code
import sys
import math
n = int(sys.stdin.readline().rstrip())
# 1. n이 제곱수인 경우
if int(math.sqrt(n)) ** 2 == n:
print(1)
sys.exit()
# 2. n이 두 제곱수의 합인 경우
for i in range(1, int(math.sqrt(n)) + 1):
if int(math.sqrt(n - i*i)) ** 2 == (n - i*i):
print(2)
sys.exit()
# 4. 4^k(8m+7) 꼴인지 확인
temp_n = n
while temp_n % 4 == 0:
temp_n //= 4
if temp_n % 8 == 7:
print(4)
sys.exit()
# 3. 그 외의 모든 경우
print(3)
Result
- 성공! 수학적 정리를 활용하니 매우 빠르고 효율적으로 문제를 해결할 수 있었다. 이 풀이가 이 문제의 핵심 아이디어인 것 같다.
개선할 부분 🤔
- DP 풀이의 경우, 시간 복잡도를 줄이기 위한 추가적인 최적화 방법을 고민해볼 수 있을 것 같다.
- 정수론 풀이가 가장 효율적이므로, 비슷한 유형의 문제가 나왔을 때 수학적 접근이 가능한지 먼저 생각해보는 습관을 들이면 좋을 것 같다.