feat: fetch problem by pwsh script.

Remove old C++ fetcher.
This commit is contained in:
2026-03-11 22:31:31 +08:00
parent b3dba32504
commit 9333e5e144
16 changed files with 367 additions and 837 deletions

View File

@@ -1,163 +1 @@
# public, protected, private 修饰符对齐
# AccessModifierOffset: 2
# 长函数调用时,参数对齐, 括号形式
# someLongFunction(
# argument1, argument2
# )
#
AlignAfterOpenBracket: BlockIndent
AlignArrayOfStructures: None
# 连续的赋值语句对齐
AlignConsecutiveAssignments:
Enabled: true
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: true
PadOperators: true
#AlignConsecutiveBitFields: false
AlignConsecutiveDeclarations: None
AlignConsecutiveMacros: None
AlignEscapedNewlines: Right
# Align
# x = aaaaaaaa +
# bbbbbbbb
#
# when BreakBeforeBinaryOperators is set
#
# x = aaaaaaaa +
# bbbbbbbb
#
# AlignAfterOperator
# x = aaaaaaaa
# + bbbbbbbb
#AlignOperands: AlignAfterOperator
AlignTrailingComments:
Kind: Always
OverEmptyLines: 2
# true:
# callFunction(
# a, b, c, d);
#
# false:
# callFunction(a,
# b,
# c,
# d);
AllowAllArgumentsOnNextLine: false
AllowAllConstructorInitializersOnNextLine: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: true
#AllowShortEnumsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Inline
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Empty
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
# 在template声明时是否换行
# template <typename T>
# T foo() {
# }
# template <typename T>
# T foo(int aaaaaaaaaaaaaaaaaaaaa,
# int bbbbbbbbbbbbbbbbbbbbb) {
# }
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: false
BinPackParameters: false
#BitFieldColonSpacing: Both
BreakBeforeBraces: "Allman"
# true:
# veryVeryVeryVeryVeryVeryVeryVeryVeryVeryVeryLongDescription
# ? firstValue
# : SecondValueVeryVeryVeryVeryLong;
#
# false:
# veryVeryVeryVeryVeryVeryVeryVeryVeryVeryVeryLongDescription ?
# firstValue :
# SecondValueVeryVeryVeryVeryLong;
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeComma
BreakStringLiterals: false
ColumnLimit: 0 # 0: no limit
CompactNamespaces: false
ConstructorInitializerIndentWidth: 4
Cpp11BracedListStyle: true
FixNamespaceComments: true # 加上丢失的namespace注释
IncludeBlocks: Preserve
#IndentCaseBlocks: false
IndentCaseLabels: true
IndentGotoLabels: false
# #if FOO
# #if BAR
# #include <foo>
# #endif
# #endif
IndentPPDirectives: BeforeHash
IndentWidth: 4
AccessModifierOffset: -4
KeepEmptyLinesAtTheStartOfBlocks: false
MaxEmptyLinesToKeep: 3
NamespaceIndentation: None
# Left:
# int* a;
# Right:
# int *a;
# Middle:
# int * a;
PointerAlignment: Right
# QualifierOrder
ReferenceAlignment: Right
# 按照列数限制, 将注释进行换行
ReflowComments: false
SortIncludes: CaseSensitive
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCaseColon: false
SpaceBeforeCtorInitializerColon: false
SpaceBeforeInheritanceColon: false
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 4
SpacesInAngles: Leave
SpacesInCStyleCastParentheses: false
SpacesInConditionalStatement: false
SpacesInContainerLiterals: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
# Constructor
ConstructorInitializerAllOnOneLineOrOnePerLine: true
Standard: c++20
TabWidth: 4
UseTab: Never
BasedOnStyle: LLVM

View File

@@ -4,14 +4,6 @@ project(leetcode-cpp)
set(CMAKE_CXX_STANDARD 20)
include_directories(include)
find_package(CURL REQUIRED)
find_package(nlohmann_json REQUIRED)
find_package(GTest REQUIRED)
add_subdirectory(src)
add_executable(leetcode-fetcher main.cpp
src/fetcher.cpp)
target_link_libraries(leetcode-fetcher PRIVATE CURL::libcurl)
target_link_libraries(leetcode-fetcher PRIVATE nlohmann_json::nlohmann_json)

89
build.ps1 Executable file
View File

@@ -0,0 +1,89 @@
#!pwsh
[CmdletBinding()]
param(
[Parameter(Mandatory = $true, Position = 0)]
[string]$problemId
)
begin
{
$ErrorActionPreference = 'Stop'
}
end
{
Write-Host "Try to fetch problem $($problemId)"
$graphQuery = @'
query questionData($titleSlug: String!) {
question(titleSlug: $titleSlug) {
content
stats
codeDefinition
sampleTestCase
metaData
}
}
'@
$response = Invoke-RestMethod -Uri "https://leetcode.cn/api/problems/algorithms/" -Method Get
$problem = @($response.stat_status_pairs | Where-Object { $_.stat.frontend_question_id -eq $problemId })
if ($problem.Count -ne 1)
{
Write-Error "Failed to find target problem $($problemId)"
}
$problem = $problem[0]
$problemNumber = "{0:D4}" -f ($problem.stat.frontend_question_id -as [int])
$filename = "$problemNumber-$($problem.stat.question__title_slug).cpp"
$existedFile = @(Get-ChildItem ./src/problems | Where-Object { $_.Name -eq $filename})
if ($existedFile.Count -gt 0)
{
Write-Error "Problem $($problem.stat.question__title_slug) has been fetched, see src/problems/$filename."
}
$variables = @{
titleSlug = $problem.stat.question__title_slug
} | ConvertTo-Json
$query = @{
operationName = "questionData"
variables = $variables
query = $graphQuery
}
Write-Host "Try to fetch details of $($problem.stat.question__title_slug)"
$response = Invoke-RestMethod -Uri "https://leetcode.cn/graphql" -Body $query -Method Post
$codeDefinition = @($response.data.question.codeDefinition | ConvertFrom-Json | Where-Object { $_.value -eq "cpp" })
if ($codeDefinition.Count -ne 1)
{
Write-Error "Failed to find C++ code definition"
}
$testcaseName = "P$($problem.stat.frontend_question_id)"
$outputCode = @"
/**
* [$($problem.stat.frontend_question_id)] $($problem.stat.question__title_slug)
*/
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
$($codeDefinition[0].defaultCode)
// submission codes end
TEST($testcaseName, Test1)
{
}
"@
Set-Content -Path "./src/problems/$($filename)" -Value $outputCode
Write-Host "Saved as $filename."
}

View File

@@ -26,6 +26,3 @@ commit: test
git add -A
git commit -m "$message"
git push
pull id: build
./cmake-build-debug-clang/leetcode-fetcher {{ id }}

View File

@@ -1,15 +0,0 @@
#include "fetcher.h"
int main(int argc, char **argv)
{
Fetcher fetcher;
if (argc != 2)
{
throw std::invalid_argument("The fetcher expect the program id.");
}
fetcher.fetchProblem(argv[1]);
return 0;
}

View File

@@ -1,276 +0,0 @@
//
// Created by ricardo on 12/06/25.
//
#include <fstream>
#include "fetcher.h"
std::vector<LeetCodeProblem> Fetcher::getProblems() const
{
std::string body;
curl_easy_setopt(client.get(), CURLOPT_URL, kProblemsUrl.c_str());
curl_easy_setopt(client.get(), CURLOPT_WRITEFUNCTION, Fetcher::httpWriteCallback);
curl_easy_setopt(client.get(), CURLOPT_WRITEDATA, &body);
const CURLcode code = curl_easy_perform(client.get());
if (code != CURLE_OK)
{
std::cout << "Failed to fetch problems." << std::endl;
return {};
}
std::vector<LeetCodeProblem> problems;
try
{
nlohmann::json jsonResponse = nlohmann::json::parse(body);
for (const auto &item : jsonResponse["stat_status_pairs"])
{
bool paidOnly = item["paid_only"];
const auto &stat_item = item["stat"];
std::string frontendQuestionId = stat_item["frontend_question_id"];
int questionId = stat_item["question_id"];
std::string questionTitle = stat_item["question__title"];
std::string questionTitleSlug = stat_item["question__title_slug"];
problems.emplace_back(paidOnly, frontendQuestionId, questionId, questionTitle, questionTitleSlug);
}
}
catch (const nlohmann::json::parse_error &e)
{
std::cout << "JSON parse error: " << e.what() << std::endl;
std::cout << "Raw response: " << body << std::endl;
}
return problems;
}
std::unique_ptr<ProblemContent> Fetcher::fetchProblemContent(const LeetCodeProblem &problem) const
{
curl_easy_setopt(client.get(), CURLOPT_URL, kGraphQlUrl.c_str());
curl_easy_setopt(client.get(), CURLOPT_POST, 1L);
curl_slist *headers = nullptr;
headers = curl_slist_append(headers, "Content-Type: application/json");
curl_easy_setopt(client.get(), CURLOPT_HTTPHEADER, headers);
std::string requestBody = formatQueryJson(problem.questionTitleSlug).dump();
curl_easy_setopt(client.get(), CURLOPT_POSTFIELDS, requestBody.c_str());
curl_easy_setopt(client.get(), CURLOPT_POSTFIELDSIZE, requestBody.size());
std::string responseBody;
curl_easy_setopt(client.get(), CURLOPT_WRITEFUNCTION, Fetcher::httpWriteCallback);
curl_easy_setopt(client.get(), CURLOPT_WRITEDATA, &responseBody);
CURLcode code = curl_easy_perform(client.get());
if (code != CURLE_OK)
{
throw std::runtime_error("Failed to fetch problem.");
}
const nlohmann::json jsonResponse = nlohmann::json::parse(responseBody);
return extractContentFromJson(jsonResponse, problem);
}
static std::string replaceString(const std::string &original,
const std::string_view old_sub,
const std::string_view new_sub)
{
if (old_sub.empty())
return original; // 防止空字符串导致死循环
std::string result;
size_t pos = 0;
size_t old_len = old_sub.size();
size_t new_len = new_sub.size();
// 预计算总长度以减少内存分配
size_t total_length = original.size();
size_t count = 0;
size_t start = 0;
while ((pos = original.find(old_sub, start)) != std::string::npos)
{
total_length += (new_len - old_len); // 更新总长度
count++;
start = pos + old_len; // 下一次查找起始位置
}
result.reserve(total_length); // 预分配内存
start = 0;
while ((pos = original.find(old_sub, start)) != std::string::npos)
{
// 添加旧子串之前的部分
result.append(original.data() + start, pos - start);
// 添加新子串
result.append(new_sub.data(), new_sub.size());
// 更新起始位置
start = pos + old_len;
}
// 添加剩余部分
result.append(original.data() + start, original.size() - start);
return result;
}
std::string ProblemContent::formatTemplate(const std::string &templateContent) const
{
const auto it = std::ranges::find_if(codeDefinitions,
[](const CodeDefinition &definition)
{
return definition.value == "cpp";
});
if (it == codeDefinitions.end())
{
throw std::runtime_error("The target problem has no C++ template.");
}
std::string result = replaceString(templateContent, "__PROBLEM_ID__", std::to_string(questionId));
result = replaceString(result, "__PROBLEM_TITLE__", title);
result = replaceString(result, "__PROBLEM_DEFAULT_CODE__", it->defaultCode);
std::string testCaseName = "P" + std::to_string(questionId);
result = replaceString(result, "__TEST_CASE_NAME__", testCaseName);
return result;
}
std::string ProblemContent::formatFilename() const
{
std::ostringstream stream("p");
stream << questionId;
stream << "-" << titleSlug << ".cpp";
return stream.str();
}
void Fetcher::fetchProblem(const std::string &idString) const
{
std::vector<LeetCodeProblem> problems = getProblems();
const auto it = std::ranges::find_if(problems,
[idString](const LeetCodeProblem &problem)
{
return problem.frontendQuestionId == idString;
});
if (it == problems.end())
{
throw std::runtime_error("The target problem does not exist.");
}
std::unique_ptr<ProblemContent> problemContent = fetchProblemContent(*it);
const std::string templateFile = readTemplateFile();
std::string problemFileContent = problemContent->formatTemplate(templateFile);
std::string problemFilename = problemContent->formatFilename();
auto problemFile = std::ofstream("src/problems/" + problemFilename);
if (!problemFile.is_open())
{
throw std::runtime_error("Failed to open problem file.");
}
problemFile << problemFileContent;
problemFile.close();
}
nlohmann::json Fetcher::formatQueryJson(const std::string &title)
{
nlohmann::json result;
result["operationName"] = "questionData";
result["query"] = R"(query questionData($titleSlug: String!) {
question(titleSlug: $titleSlug) {
content
stats
codeDefinition
sampleTestCase
}
})";
nlohmann::json variables;
variables["titleSlug"] = title;
result["variables"] = variables;
return result;
}
std::unique_ptr<ProblemContent> Fetcher::extractContentFromJson(const nlohmann::json &json, const LeetCodeProblem &problem)
{
const auto questionJson = json["data"]["question"];
std::string content = questionJson["content"];
std::vector<CodeDefinition> codeDefinitions;
const std::string codeDefinitionString = questionJson["codeDefinition"];
for (nlohmann::json codeDefinitionJson = nlohmann::json::parse(codeDefinitionString);
const auto &codeDefinitionItem : codeDefinitionJson)
{
std::string value = codeDefinitionItem["value"];
std::string text = codeDefinitionItem["text"];
std::string defaultCode = codeDefinitionItem["defaultCode"];
codeDefinitions.emplace_back(value, text, defaultCode);
}
return std::make_unique<ProblemContent>(
problem.questionTitle,
problem.questionTitleSlug,
content,
codeDefinitions,
std::stoi(problem.frontendQuestionId));
}
std::string Fetcher::readTemplateFile()
{
std::ifstream file;
file.open("src/template.cpp");
std::stringstream buffer;
if (file.is_open())
{
buffer << file.rdbuf();
}
else
{
throw std::runtime_error("Failed to read template file.");
}
file.close();
return buffer.str();
}
bool Fetcher::validateExistedProblem(const ProblemContent &problem)
{
std::filesystem::path problemDirectory = "src/problems";
std::string defaultFilename = problem.formatFilename();
if (!std::filesystem::exists(problemDirectory))
{
throw std::runtime_error("The problem directory is not exitsed.");
}
return std::ranges::any_of(std::filesystem::directory_iterator(problemDirectory),
[&](const std::filesystem::directory_entry &entry)
{
if (!entry.is_regular_file())
{
return false;
}
return entry.path().filename() == defaultFilename;
});
}

View File

@@ -0,0 +1,34 @@
/**
* [1] two-sum
*/
#include <gtest/gtest.h>
#include <unordered_map>
#include <vector>
using namespace std;
// submission codes start here
class Solution {
public:
vector<int> twoSum(vector<int> &nums, int target) {
unordered_map<int, int> map;
int index = 0;
for (const auto &i : nums) {
const auto &it = map.find(i);
if (it != map.end()) {
return {index, it->second};
}
map.insert({target - i, index});
++index;
}
return {};
}
};
// submission codes end
TEST(P1, Test1) {}

View File

@@ -1,40 +0,0 @@
/**
* [1] Two Sum
*/
#include <bits/stdc++.h>
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
class Solution
{
public:
vector<int> twoSum(vector<int> &nums, int target)
{
unordered_map<int, int> map;
for (int i = 0; i < nums.size(); ++i)
{
if (const auto &it = map.find(target - nums[i]); it != map.end())
{
return {it->second, i};
}
map.insert({nums[i], i});
}
return {};
}
};
// submission codes endo
TEST(P1, Test1)
{
Solution s;
vector nums = {2, 7, 11, 15};
vector result = {0, 1};
ASSERT_EQ(s.twoSum(nums, 9), result);
}

View File

@@ -5,22 +5,17 @@
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
class Solution
{
class Solution {
public:
int maxDiff(int num)
{
int maxDiff(int num) {
string numString = to_string(num);
// Select the first, not 9 number.
int targetPos = 0;
while (targetPos < numString.size() - 1)
{
if (numString[targetPos] != '9')
{
while (targetPos < numString.size() - 1) {
if (numString[targetPos] != '9') {
break;
}
@@ -33,21 +28,17 @@ public:
char minimumChar = '1';
// If the first number is not 1, select the first number.
// If the first number is 1, select next, not zero number.
if (numString[targetPos] == '1')
{
if (numString[targetPos] == '1') {
targetPos = 1;
minimumChar = '0';
while (targetPos < numString.size())
{
if (numString[targetPos] == '1')
{
while (targetPos < numString.size()) {
if (numString[targetPos] == '1') {
// Can not replace 1 when the first number is 1.
targetPos += 1;
continue;
}
if (numString[targetPos] != '0')
{
if (numString[targetPos] != '0') {
break;
}
@@ -56,25 +47,19 @@ public:
}
int minNumber = num;
if (targetPos != numString.size())
{
if (targetPos != numString.size()) {
minNumber = replaceDigit(numString, numString[targetPos], minimumChar);
}
return maxNumber - minNumber;
}
static auto replaceDigit(const string &num, char source, char target) -> int
{
static auto replaceDigit(const string &num, char source, char target) -> int {
int result = 0;
for (const char i : num)
{
if (i == source)
{
for (const char i : num) {
if (i == source) {
result = result * 10 + target - '0';
}
else
{
} else {
result = result * 10 + i - '0';
}
}
@@ -85,8 +70,7 @@ public:
// submission codes end
TEST(P1432, Test1)
{
TEST(P1432, Test1) {
ASSERT_EQ(888, Solution().maxDiff(555));
ASSERT_EQ(8, Solution().maxDiff(9));
ASSERT_EQ(888, Solution().maxDiff(111));

View File

@@ -5,27 +5,22 @@
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
class Solution
{
class Solution {
public:
int maximumDifference(vector<int> &nums)
{
int maximumDifference(vector<int> &nums) {
vector<int> heap;
heap.reserve(nums.size());
ranges::make_heap(heap, greater());
int result = -1;
for (int i = 1; i < nums.size(); ++i)
{
for (int i = 1; i < nums.size(); ++i) {
heap.push_back(nums[i - 1]);
ranges::push_heap(heap, greater());
if (heap[0] < nums[i])
{
if (heap[0] < nums[i]) {
result = max(result, nums[i] - heap[0]);
}
}
@@ -36,8 +31,7 @@ public:
// submission codes end
TEST(P2016, Test1)
{
TEST(P2016, Test1) {
vector nums1 = {7, 1, 5, 4};
ASSERT_EQ(4, Solution().maximumDifference(nums1));

View File

@@ -5,14 +5,11 @@
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
class Solution
{
class Solution {
public:
int partitionArray(vector<int> &nums, int k)
{
int partitionArray(vector<int> &nums, int k) {
ranges::sort(nums);
// At least to split into one segment.
@@ -20,15 +17,11 @@ public:
int pos = 1;
int minValue = nums[0];
while (pos < nums.size())
{
if (nums[pos] > minValue + k)
{
while (pos < nums.size()) {
if (nums[pos] > minValue + k) {
result += 1;
minValue = nums[pos];
}
else
{
} else {
pos += 1;
}
}
@@ -39,8 +32,7 @@ public:
// submission codes end
TEST(P2294, Test1)
{
TEST(P2294, Test1) {
vector nums1 = {3, 6, 1, 2, 5};
ASSERT_EQ(2, Solution().partitionArray(nums1, 2));
vector nums2 = {1, 2, 3};

View File

@@ -5,22 +5,17 @@
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
class Solution
{
class Solution {
public:
int minMaxDifference(int num)
{
int minMaxDifference(int num) {
const string numString = to_string(num);
// Select the first not 9 number when converting to maximum value.
int targetPos = 0;
while (targetPos < numString.size() - 1)
{
if (numString[targetPos] != '9')
{
while (targetPos < numString.size() - 1) {
if (numString[targetPos] != '9') {
break;
}
@@ -30,14 +25,10 @@ public:
int maxNum = 0;
char targetChar = numString[targetPos];
for (const auto c : numString)
{
if (c == targetChar)
{
for (const auto c : numString) {
if (c == targetChar) {
maxNum = maxNum * 10 + 9;
}
else
{
} else {
maxNum = maxNum * 10 + c - '0';
}
}
@@ -46,14 +37,10 @@ public:
targetChar = numString[0];
int minNum = 0;
for (const auto c : numString)
{
if (c == targetChar)
{
for (const auto c : numString) {
if (c == targetChar) {
minNum = minNum * 10;
}
else
{
} else {
minNum = minNum * 10 + c - '0';
}
}
@@ -64,8 +51,7 @@ public:
// submission codes end
TEST(P2566, Test1)
{
TEST(P2566, Test1) {
ASSERT_EQ(99009, Solution().minMaxDifference(11891));
ASSERT_EQ(99, Solution().minMaxDifference(90));
}

View File

@@ -5,24 +5,18 @@
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
class Solution
{
class Solution {
public:
int minimizeMax(vector<int> &nums, int p)
{
int minimizeMax(vector<int> &nums, int p) {
ranges::sort(nums);
auto check = [&](int value) -> bool
{
auto check = [&](int value) -> bool {
int count = 0;
for (int i = 0; i < nums.size() - 1; ++i)
{
if (nums[i + 1] - nums[i] <= value)
{
for (int i = 0; i < nums.size() - 1; ++i) {
if (nums[i + 1] - nums[i] <= value) {
count += 1;
i += 1;
}
@@ -33,15 +27,11 @@ public:
int left = 0, right = nums.back() - nums[0];
while (left < right)
{
while (left < right) {
int middle = (left + right) >> 1;
if (check(middle))
{
if (check(middle)) {
right = middle;
}
else
{
} else {
left = middle + 1;
}
}
@@ -52,8 +42,7 @@ public:
// submission codes end.
TEST(P2616, Test1)
{
TEST(P2616, Test1) {
vector nums = {10, 1, 2, 7, 1, 3};
Solution s;
ASSERT_EQ(s.minimizeMax(nums, 2), 1);

View File

@@ -5,24 +5,20 @@
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
class Solution
{
class Solution {
public:
vector<vector<int>> divideArray(vector<int> &nums, int k)
{
vector<vector<int>> divideArray(vector<int> &nums, int k) {
ranges::sort(nums);
vector<vector<int>> result;
bool flag = true;
result.reserve(nums.size() / 3);
for (int i = 0; i < nums.size(); i += 3)
{
if (nums[i + 2] - nums[i + 1] > k || nums[i + 1] - nums[i] > k || nums[i + 2] - nums[i] > k)
{
for (int i = 0; i < nums.size(); i += 3) {
if (nums[i + 2] - nums[i + 1] > k || nums[i + 1] - nums[i] > k ||
nums[i + 2] - nums[i] > k) {
flag = false;
break;
}
@@ -31,8 +27,7 @@ public:
result.push_back(move(array));
}
if (flag)
{
if (flag) {
return result;
}
@@ -42,8 +37,7 @@ public:
// submission codes end
TEST(P2966, Test1)
{
TEST(P2966, Test1) {
vector nums = {1, 3, 4, 8, 7, 9, 3, 5, 1};
vector<vector<int>> result = {{1, 1, 3}, {3, 4, 5}, {7, 8, 9}};

View File

@@ -5,7 +5,6 @@
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
constexpr long long MOD = 1e9 + 7;
@@ -14,16 +13,12 @@ constexpr long long UPPER_BOUND = 1e5;
static long long fact[UPPER_BOUND];
static long long inverseFact[UPPER_BOUND];
class Solution
{
static long long quickPower(long long x, int n)
{
class Solution {
static long long quickPower(long long x, int n) {
long long result = 1;
while (n > 0)
{
if ((n & 1) == 1)
{
while (n > 0) {
if ((n & 1) == 1) {
result = result * x % MOD;
}
@@ -34,45 +29,39 @@ class Solution
return result;
}
static long long combine(int n, int m)
{
static long long combine(int n, int m) {
return fact[n] * inverseFact[m] % MOD * inverseFact[n - m] % MOD;
}
static void init()
{
if (fact[0] != 0)
{
static void init() {
if (fact[0] != 0) {
return;
}
fact[0] = 1;
for (int i = 1; i < UPPER_BOUND; ++i)
{
for (int i = 1; i < UPPER_BOUND; ++i) {
fact[i] = fact[i - 1] * i % MOD;
}
// Modular Multiplicative Inverse is calculated by the quick power.
inverseFact[UPPER_BOUND - 1] = quickPower(fact[UPPER_BOUND - 1], MOD - 2);
for (int i = UPPER_BOUND - 1; i > 0; --i)
{
for (int i = UPPER_BOUND - 1; i > 0; --i) {
inverseFact[i - 1] = inverseFact[i] * i % MOD;
}
}
public:
int countGoodArrays(int n, int m, int k)
{
int countGoodArrays(int n, int m, int k) {
init();
const long long result = combine(n - 1, k) * m % MOD * quickPower(m - 1, n - k - 1) % MOD;
const long long result =
combine(n - 1, k) * m % MOD * quickPower(m - 1, n - k - 1) % MOD;
return result;
}
};
// submission codes end
TEST(P3405, Test1)
{
TEST(P3405, Test1) {
Solution s;
ASSERT_EQ(4, s.countGoodArrays(3, 2, 1));
ASSERT_EQ(6, s.countGoodArrays(4, 2, 2));

View File

@@ -1,17 +0,0 @@
/**
* [__PROBLEM_ID__] __PROBLEM_TITLE__
*/
#include <bits/stdc++.h>
#include <gtest/gtest.h>
using namespace std;
// submission codes start here
__PROBLEM_DEFAULT_CODE__
// submission codes end
TEST(__TEST_CASE_NAME__, Test1)
{
}