💻 백준 17626: Four Squares

문제 소개 🧐

모든 자연수는 넷 혹은 그 이하의 제곱수의 합으로 표현할 수 있다는 라그랑주의 네 제곱수 정리를 바탕으로, 주어진 자연수 n을 최소 개수의 제곱수 합으로 표현하는 문제입니다.

입력 📝

  • 자연수 n (1 ≤ n ≤ 50,000)

출력 📤

  • n을 표현하는 최소 제곱수의 개수를 출력합니다.

제한 ❌

  • 시간 제한: 0.5 초
  • 메모리 제한: 512 MB

첫 번째 시도: 브루트포스 👊

Idea

가장 큰 제곱수부터 빼나가면서 최소 개수를 찾으려고 시도했다. 예를 들어 n이 주어지면, n보다 작거나 같은 가장 큰 제곱수 s^2를 찾고, n - s^2에 대해 이 과정을 반복하는 탐욕적인 접근이다.

Code

import sys
import math

def get_cnt_from_integer(x_sq, n_val):
    cnt = 0
    li = [i*i for i in range(1, math.ceil(math.sqrt(n_val))+1)]

    while n_val > 0:
        while n_val < x_sq:
            x_sq = li.pop()
        cnt += 1
        n_val -= x_sq

    return cnt

# ... (main logic)

Result

  • 실패 (57%): 이 방식은 항상 최적의 해를 보장하지 않는다는 것을 깨달았다. 예를 들어 25987은 159^2 + 26^2 + 3^2로 3개의 제곱수 합으로 표현되지만, 가장 큰 제곱수인 161^2 (25921)부터 빼면 더 많은 개수가 필요할 수 있다.

두 번째 시도: BFS ✨

Idea

BFS를 사용하여 최소 개수를 찾는 방법이다. 큐에 (현재 숫자, 사용된 제곱수 개수)를 저장하고, 각 단계에서 모든 제곱수를 더해보며 n에 도달하는 가장 빠른 경로(최소 개수)를 찾는 방식이다.

Code

from collections import deque
import math

n = int(sys.stdin.readline().rstrip())
sqrt_list = [i*i for i in range(1, math.ceil(math.sqrt(n))+1)]
counted_num = [False] * 50001
que = deque()

for sqr in sqrt_list:
    if sqr == n:
        print(1)
        # ...
    que.append((sqr, 1))
    counted_num[sqr] = True

while que:
    cur_num, cur_cnt = que.popleft()

    for sqr in sqrt_list:
        next_num = cur_num + sqr
        if next_num == n:
            print(cur_cnt + 1)
            # ...
        if 0 < next_num < 50001 and not counted_num[next_num]:
            que.append((next_num, cur_cnt + 1))
            counted_num[next_num] = True

Result

  • 시간 초과: 모든 경우를 탐색하기 때문에 n이 커질수록 탐색 공간이 너무 넓어져 시간 내에 해결할 수 없는 것 같다.

세 번째 시도: 다이나믹 프로그래밍 (DP) ✅

Idea

DP를 사용하여 문제를 해결하는 방법이다. dp[i]i를 만드는 데 필요한 최소 제곱수의 개수라고 정의하고, dp[i] = min(dp[i - j*j]) + 1 이라는 점화식을 세울 수 있다. 여기서 j*ji보다 작은 모든 제곱수이다.

Code

import sys
import math

n = int(sys.stdin.readline().rstrip())
dp = [5] * (n+1)

for i in range(1, n+1):
    if math.sqrt(i).is_integer():
        dp[i] = 1
        continue
    for j in range(1, int(math.sqrt(i)) + 1):
        if dp[i] > dp[i - j*j] + 1:
            dp[i] = dp[i - j*j] + 1

print(dp[n])

Result

  • PyPy3 통과 (시간 초과): Python 3에서는 시간 초과가 발생했지만, PyPy3에서는 통과했다. 시간 복잡도는 O(N√N)으로, 여전히 빠듯한 것 같다.

네 번째 시도: 정수론 (라그랑주 & 르장드르 정리) ✅

Idea

수학적 정리를 활용하여 더 효율적으로 해결하는 방법이다.

  1. 1개: n이 제곱수인 경우
  2. 2개: n이 두 제곱수의 합인 경우 (n = a^2 + b^2)
  3. 4개: n4^k(8m+7) 꼴인 경우 (르장드르의 세 제곱수 정리)
  4. 3개: 위의 경우에 해당하지 않는 모든 경우

Code

import sys
import math

n = int(sys.stdin.readline().rstrip())

# 1. n이 제곱수인 경우
if int(math.sqrt(n)) ** 2 == n:
    print(1)
    sys.exit()

# 2. n이 두 제곱수의 합인 경우
for i in range(1, int(math.sqrt(n)) + 1):
    if int(math.sqrt(n - i*i)) ** 2 == (n - i*i):
        print(2)
        sys.exit()

# 4. 4^k(8m+7) 꼴인지 확인
temp_n = n
while temp_n % 4 == 0:
    temp_n //= 4
if temp_n % 8 == 7:
    print(4)
    sys.exit()

# 3. 그 외의 모든 경우
print(3)

Result

  • 성공! 수학적 정리를 활용하니 매우 빠르고 효율적으로 문제를 해결할 수 있었다. 이 풀이가 이 문제의 핵심 아이디어인 것 같다.

개선할 부분 🤔

  • DP 풀이의 경우, 시간 복잡도를 줄이기 위한 추가적인 최적화 방법을 고민해볼 수 있을 것 같다.
  • 정수론 풀이가 가장 효율적이므로, 비슷한 유형의 문제가 나왔을 때 수학적 접근이 가능한지 먼저 생각해보는 습관을 들이면 좋을 것 같다.