问题描述
小蓝随手写出了含有 n
个正整数的数组 {a1, a2, …, an}
,他发现可以轻松地算出有多少个有序二元组 (i, j)
满足 aj
是 ai
的一个因数。因此,他定义一个整数对 (x1, y1)
是一个整数对 (x2, y2)
的 因数,当且仅当 x1
和 y1
分别是 x2
和 y2
的因数。
他想知道有多少个有序四元组 (i, j, k, l)
满足 (ai, aj)
是 (ak, al)
的因数,其中 i, j, k, l
互不相等。
输入格式
- 第一行包含一个正整数
n
。 - 第二行包含
n
个正整数a1, a2, …, an
,相邻整数之间使用一个空格分隔。
输出格式
- 输出一行包含一个整数,表示满足条件的
(i, j, k, l)
的个数。
样例输入
5
3 6 2 2 7
样例输出
4
样例说明
有效的四元组 (i, j, k, l)
:
(1, 4, 2, 3)
:(3, 2)
是(6, 2)
的因子。(1, 3, 2, 4)
:(3, 2)
是(6, 2)
的因子。(4, 1, 3, 2)
:(2, 3)
是(2, 6)
的因子。(3, 1, 4, 2)
:(2, 3)
是(2, 6)
的因子。
约束
- 对于 20% 的测试用例,
n ≤ 50
- 对于 40% 的测试用例,
n ≤ 10^4
- 对于 100% 的测试用例,
1 ≤ n ≤ 10^5, 1 ≤ ai ≤ 10^5
c++代码
#include<bits/stdc++.h>
#include<stdio.h>using namespace std;typedef long long ll;vector<ll> myleft, myright, arr;
unordered_map<ll, vector<ll>> mps;
ll maxval, ans, n;int main() {scanf("%lld", &n);arr = vector<ll>(n);myleft = vector<ll>(100001, 0);myright = vector<ll>(100001, 0);for (ll i = 0; i < n; i++) {scanf("%lld", &arr[i]);mps[arr[i]].push_back(i);maxval = max(maxval, arr[i]);}for (ll i = 0; i < n; i++) {if (mps[arr[i]].size() > 1) {ans += mps[arr[i]].size() - 1;myleft[i] += mps[arr[i]].size() - 1;for (int x : mps[arr[i]]) {if (x != i) {myright[x]++;}}}for (int k = arr[i] * 2; k <= maxval; k += arr[i]) {if (mps.find(k) != mps.end()) {ans += mps[k].size();myleft[i] += mps[k].size();for (int x : mps[k]) {myright[x]++;}}}}ans = ans * (ans - 1);for (ll i = 0; i < n; i++) {if (myleft[i] - 1 > 0) ans -= myleft[i] * (myleft[i] - 1);if (myright[i] - 1 > 0) ans -= myright[i] * (myright[i] - 1);if (myleft[i] > 0 && myright[i] > 0) ans -= (myleft[i] * myright[i] * 2);}for (auto it = mps.begin(); it != mps.end(); it++) {if (it->second.size() > 1) ans += it->second.size() * (it->second.size() - 1);}printf("%lld\n", ans);return 0;
}//by wqs
算法解析
本题目用容斥原理
合法方案=总方案-不合法方案
总方案
ans = ans * (ans - 1);
ai = ak
if (myleft[i] - 1 > 0) ans -= myleft[i] * (myleft[i] - 1);
aj = al
if (myright[i] - 1 > 0) ans -= myright[i] * (myright[i] - 1);
ai = al 或者 aj = ak
if (myleft[i] > 0 && myright[i] > 0) ans -= (myleft[i] * myright[i] * 2);
如果ai = al 并且 aj = ak 多减了一次,加回来
for (auto it = mps.begin(); it != mps.end(); it++) {if (it->second.size() > 1) ans += it->second.size() * (it->second.size() - 1);
}