백준 25323 - 수 정렬하기, 근데 이제 제곱수를 곁들인

Updated:

Categories:

Tags: , , ,

백준 25323 - 수 정렬하기, 근데 이제 제곱수를 곁들인

풀이

$a \times b$ 가 제곱수이고 $b \times c$가 제곱수이면, 이 2개를 곱한 $a \times b^2 \times c$는 제곱수이다.
근데 $b^2$가 제곱수이므로 $a \times c$는 제곱수이어야 한다.

위 사실을 생각해보면 우리는 서로 곱하면 제곱수가 되는 애들끼리 그룹으로 묶을 수 있다.
그룹 내의 임의의 숫자 2개를 골라서 곱하면 그 곱은 제곱수가 된다.

기존 수열을 a, a를 정렬한 수열을 srt라 하자.
i번째 위치에는 결국 srt[i]의 수가 와야 한다. 즉, a[i]와 srt[i]의 값이 같은 그룹에 있다면 swap가능하다는 의미다.

따라서 모든 i에 대해서 $a[i] \times srt[i]$가 제곱수인지를 확인하면 풀리는 문제다.
하지만 a[i]의 값이 $10^{18}$이기 때문에 곱하면 long long의 범위를 넘어간다. 이를 c++의 int128을 이용해서 해결해줘야 한다.
또한 수의 범위가 크기 때문에 sqrt 함수를 사용하지 못 하는데 이를 이분탐색으로 찾아주었다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>

#define endl "\n"
#define all(v) (v).begin(), (v).end()
#define For_IMPL(condition, i, a, b, increment, decrement) \
    for (int i = (a); condition; (a < b ? increment : decrement))
#define For(i, a, b) For_IMPL((a < b ? i < b : i > b), i, a, b, ++i, --i)
#define For_(i, a, b) For_IMPL((a < b ? i <= b : i >= b), i, a, b, ++i, --i)
#define ft first
#define sd second

using namespace std;
using ll = long long;
using lll = __int128_t;
using ulll = __uint128_t;
using ull = unsigned long long;
using ld = long double;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using ti3 = tuple<int, int, int>;
using tl3 = tuple<ll, ll, ll>;

template<class T> bool ckmin(T& a, const T& b) { return b < a ? a = b, 1 : 0; }
template<class T> bool ckmax(T& a, const T& b) { return a < b ? a = b, 1 : 0; }

const int INF = 987654321;
const int INF0 = numeric_limits<int>::max();
const ll LNF = 987654321987654321;
const ll LNF0 = numeric_limits<ll>::max();

ulll sqrt(ulll x) {
    ulll lo=1, hi=1e18;
    while(lo < hi) {
        ulll mid = lo + (hi-lo)/2;
        ulll res = mid * mid;
        if(res >= x) hi = mid;
        else lo = mid+1;
    }
    return hi;
}

void solve() {
    int n; cin >> n;
    vector<ulll> a(n), srt(n);
    For(i,0,n) {
        ll x; cin >> x;
        a[i] = x;
        srt[i] = x;
    }
    sort(all(srt));
    For(i,0,n) {
        ulll res = a[i]*srt[i];
        ulll rt = sqrt(res);
        if(rt*rt != res) {
            cout << "NO\n";
            return;
        }
    }
    cout << "YES\n";
}

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int tc = 1;
//    cin >> tc;
    while(tc--) {
        solve();
//        cout << solve() << endl;
    }


    return 0;
}

Leave a comment