use std::rc::Rc;
use std::time::{SystemTime, UNIX_EPOCH};
use super::{fast_power, PCG32};
#[derive(Debug)]
struct CustomFiniteFiled {
modulus: u64,
i_square: u64,
}
impl CustomFiniteFiled {
pub fn new(modulus: u64, i_square: u64) -> Self {
Self { modulus, i_square }
}
}
#[derive(Clone, Debug)]
struct CustomComplexNumber {
real: u64,
imag: u64,
f: Rc<CustomFiniteFiled>,
}
impl CustomComplexNumber {
pub fn new(real: u64, imag: u64, f: Rc<CustomFiniteFiled>) -> Self {
Self { real, imag, f }
}
pub fn mult_other(&mut self, rhs: &Self) {
let tmp = (self.imag * rhs.real + self.real * rhs.imag) % self.f.modulus;
self.real = (self.real * rhs.real
+ ((self.imag * rhs.imag) % self.f.modulus) * self.f.i_square)
% self.f.modulus;
self.imag = tmp;
}
pub fn mult_self(&mut self) {
let tmp = (self.imag * self.real + self.real * self.imag) % self.f.modulus;
self.real = (self.real * self.real
+ ((self.imag * self.imag) % self.f.modulus) * self.f.i_square)
% self.f.modulus;
self.imag = tmp;
}
pub fn fast_power(mut base: Self, mut power: u64) -> Self {
let mut result = CustomComplexNumber::new(1, 0, base.f.clone());
while power != 0 {
if (power & 1) != 0 {
result.mult_other(&base);
}
base.mult_self();
power >>= 1;
}
result
}
}
fn is_residue(x: u64, modulus: u64) -> bool {
let power = (modulus - 1) >> 1;
x != 0 && fast_power(x as usize, power as usize, modulus as usize) == 1
}
pub fn cipolla(a: u32, p: u32, seed: Option<u64>) -> Option<(u32, u32)> {
let a = a as u64;
let p = p as u64;
if a == 0 {
return Some((0, 0));
}
if !is_residue(a, p) {
return None;
}
let seed = match seed {
Some(seed) => seed,
None => SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
};
let mut rng = PCG32::new_default(seed);
let r = loop {
let r = rng.get_u64() % p;
if r == 0 || !is_residue((p + r * r - a) % p, p) {
break r;
}
};
let filed = Rc::new(CustomFiniteFiled::new(p, (p + r * r - a) % p));
let comp = CustomComplexNumber::new(r, 1, filed);
let power = (p + 1) >> 1;
let x0 = CustomComplexNumber::fast_power(comp, power).real as u32;
let x1 = p as u32 - x0 as u32;
if x0 < x1 {
Some((x0, x1))
} else {
Some((x1, x0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn small_numbers() {
assert_eq!(cipolla(1, 43, None), Some((1, 42)));
assert_eq!(cipolla(2, 23, None), Some((5, 18)));
assert_eq!(cipolla(17, 83, Some(42)), Some((10, 73)));
}
#[test]
fn random_numbers() {
assert_eq!(cipolla(392203, 852167, None), Some((413252, 438915)));
assert_eq!(
cipolla(379606557, 425172197, None),
Some((143417827, 281754370))
);
assert_eq!(
cipolla(585251669, 892950901, None),
Some((192354555, 700596346))
);
assert_eq!(
cipolla(404690348, 430183399, Some(19260817)),
Some((57227138, 372956261))
);
assert_eq!(
cipolla(210205747, 625380647, Some(998244353)),
Some((76810367, 548570280))
);
}
#[test]
fn no_answer() {
assert_eq!(cipolla(650927, 852167, None), None);
}
}