From 929b90fbef8773c66c29208c1a596af9fee6139d Mon Sep 17 00:00:00 2001 From: jayson Date: Thu, 28 Sep 2023 12:58:40 +0800 Subject: [PATCH] speed up polynomial computation --- polynomials.go | 110 +++++++++----------------- polynomials_test.go | 189 ++++++++++++++++++++++++-------------------- 2 files changed, 142 insertions(+), 157 deletions(-) diff --git a/polynomials.go b/polynomials.go index f03730b..96ce3b2 100644 --- a/polynomials.go +++ b/polynomials.go @@ -20,38 +20,29 @@ func (x Pol) Add(y Pol) Pol { } // Mul returns x*y. When an overflow occurs, Mul panics. -func (x Pol) Mul(y Pol) Pol { - switch { - case x == 0 || y == 0: - return 0 - case x == 1: - return y - case y == 1: - return x - case y == 2: - return x.mul2() +func (x Pol) Mul(y Pol) (p Pol) { + if x == 0 || y == 0 { + return } - var res Pol - for i := 0; i <= y.Deg(); i++ { - if (y & (1 << uint(i))) > 0 { - res = res.Add(x << uint(i)) + if y&(y-1) == 0 { + if x.Deg()+y.Deg() >= 64 { + panic("multiplication would overflow uint64") } + return x << uint(y.Deg()) } - if res.Div(y) != x { - panic("multiplication would overflow uint64") + for i := 0; i <= y.Deg(); i++ { + if (y & (1 << uint(i))) != 0 { + p = p.Add(x << uint(i)) + } } - return res -} - -// 2*x. -func (x Pol) mul2() Pol { - if x&(1<<63) != 0 { + if p.Div(y) != x { panic("multiplication would overflow uint64") } - return x << 1 + + return p } // Deg returns the degree of the polynomial x. If x is zero, -1 is returned. @@ -90,31 +81,31 @@ func (x Pol) Expand() string { // DivMod returns x / d = q, and remainder r, // see https://en.wikipedia.org/wiki/Division_algorithm -func (x Pol) DivMod(d Pol) (Pol, Pol) { +func (x Pol) DivMod(d Pol) (q Pol, r Pol) { if x == 0 { - return 0, 0 + return q, r } if d == 0 { panic("division by zero") } + r = x D := d.Deg() diff := x.Deg() - D if diff < 0 { - return 0, x + return q, r } - var q Pol for diff >= 0 { m := d << uint(diff) - q |= (1 << uint(diff)) - x = x.Add(m) + q |= 1 << uint(diff) + r = r.Add(m) - diff = x.Deg() - D + diff = r.Deg() - D } - return q, x + return q, r } // Div returns the integer division result x / d. @@ -182,14 +173,6 @@ func (x Pol) GCD(f Pol) Pol { return x } - if x == 0 { - return f - } - - if x.Deg() < f.Deg() { - x, f = f, x - } - return f.GCD(x.Mod(f)) } @@ -200,7 +183,8 @@ func (x Pol) GCD(f Pol) Pol { // Finite Fields". func (x Pol) Irreducible() bool { for i := 1; i <= x.Deg()/2; i++ { - if x.GCD(qp(uint(i), x)) != 1 { + // computes the polynomial (x^(2^p)-x) mod g + if x.GCD(Pol(4).PowMod(1< 0 { - a := x - for j := 0; j < i; j++ { - a = a.Mul(2).Mod(g) - } - res = res.Add(a).Mod(g) +func (x Pol) MulMod(f, g Pol) (r Pol) { + for b := x; b != 0 && f != 0; f >>= 1 { + if f&1 != 0 { + r = r.Add(b).Mod(g) } + b = (b << 1).Mod(g) // f'(x) = f(x) * x } - - return res + return } -// qp computes the polynomial (x^(2^p)-x) mod g. This is needed for the -// reducibility test. -func qp(p uint, g Pol) Pol { - num := (1 << p) - i := 1 - - // start with x - res := Pol(2) - - for i < num { - // repeatedly square res - res = res.MulMod(res, g) - i *= 2 +// PowMod computes x^n mod g. This is needed for the reducibility test. +func (x Pol) PowMod(n uint, g Pol) (r Pol) { + var b Pol + for b, r = x, 1; n != 0; n >>= 1 { + if n&1 != 0 { + r = r.MulMod(b, g) + } + b = b.MulMod(b, g) } - - // add x - return res.Add(2).Mod(g) + return } // MarshalJSON returns the JSON representation of the Pol. diff --git a/polynomials_test.go b/polynomials_test.go index aca31c4..fd3f1a8 100644 --- a/polynomials_test.go +++ b/polynomials_test.go @@ -39,6 +39,9 @@ var polMulTests = []struct { x, y Pol res Pol }{ + {0, 0, 0}, + {0, 1, 0}, + {1, 0, 0}, {1, 2, 2}, { parseBin("1101"), @@ -92,6 +95,14 @@ func TestPolMul(t *testing.T) { } } +func BenchmarkPolMul(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tt := range polMulTests { + tt.x.Mul(tt.y) + } + } +} + func TestPolMulOverflow(t *testing.T) { defer func() { // try to recover overflow error @@ -111,41 +122,20 @@ func TestPolMulOverflow(t *testing.T) { t.Fatal("overflow test did not panic") } -var polDivTests = []struct { - x, y Pol - res Pol -}{ - {10, 50, 0}, - {0, 1, 0}, - { - parseBin("101101000"), // 0x168 - parseBin("1010"), // 0xa - parseBin("100100"), // 0x24 - }, - {2, 2, 1}, - { - 0x8000000000000000, - 0x8000000000000000, - 1, - }, - { - parseBin("1100"), - parseBin("100"), - parseBin("11"), - }, - { - parseBin("1100001111"), - parseBin("10011"), - parseBin("110101"), - }, -} - func TestPolDiv(t *testing.T) { - for i, test := range polDivTests { + for i, test := range polDivModTests { m := test.x.Div(test.y) - if test.res != m { - t.Errorf("TestPolDiv failed for test %d: %v * %v: want %v, got %v", - i, test.x, test.y, test.res, m) + if test.q != m { + t.Errorf("TestPolDiv failed for test %d: %v / %v: want %v, got %v", + i, test.x, test.y, test.q, m) + } + } +} + +func BenchmarkPolDiv(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tt := range polDivModTests { + tt.x.Div(tt.y) } } } @@ -169,85 +159,93 @@ func TestPolDeg(t *testing.T) { } } -var polModTests = []struct { +func BenchmarkPolDeg(t *testing.B) { + f := Pol(0x3af4b284899) + d := f.Deg() + if d != 41 { + t.Fatalf("BenchmalPolDeg: Wrong degree %d returned, expected %d", + d, 41) + } + + var sum int + for i := 0; i < t.N; i++ { + sum += f.Deg() + } + // Make sure Deg call isn't optimized away. + t.Log("sum of Deg:", sum) +} + +func TestPolMod(t *testing.T) { + for i, test := range polDivModTests { + res := test.x.Mod(test.y) + if test.r != res { + t.Errorf("test %d failed: want %v, got %v", i, test.r, res) + } + } +} + +func BenchmarkPolMod(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tt := range polDivModTests { + tt.x.Mod(tt.y) + } + } +} + +var polDivModTests = []struct { x, y Pol - res Pol + q, r Pol }{ - {10, 50, 10}, - {0, 1, 0}, + {10, 50, 0, 10}, + {0, 1, 0, 0}, { - parseBin("101101001"), - parseBin("1010"), - parseBin("1"), + parseBin("101101000"), // 0x168 + parseBin("1010"), // 0xa + parseBin("100100"), // 0x24 + parseBin("0"), // 0 }, - {2, 2, 0}, + {2, 2, 1, 0}, { 0x8000000000000000, 0x8000000000000000, + 1, 0, }, { parseBin("1100"), parseBin("100"), + parseBin("11"), parseBin("0"), }, { parseBin("1100001111"), parseBin("10011"), + parseBin("110101"), parseBin("0"), }, + { + 0x2482734cacca49, + 0x3af4b284899, + 0x1972, + 0x4229e6268b, + }, } -func TestPolModt(t *testing.T) { - for i, test := range polModTests { - res := test.x.Mod(test.y) - if test.res != res { - t.Errorf("test %d failed: want %v, got %v", i, test.res, res) +func TestPolDivMod(t *testing.T) { + for i, test := range polDivModTests { + q, r := test.x.DivMod(test.y) + if test.q != q || test.r != r { + t.Errorf("test %d failed: want (%v, %v), got (%v, %v)", i, test.q, test.r, q, r) } } } -func BenchmarkPolDivMod(t *testing.B) { - f := Pol(0x2482734cacca49) - g := Pol(0x3af4b284899) - - for i := 0; i < t.N; i++ { - g.DivMod(f) - } -} - -func BenchmarkPolDiv(t *testing.B) { - f := Pol(0x2482734cacca49) - g := Pol(0x3af4b284899) - - for i := 0; i < t.N; i++ { - g.Div(f) - } -} - -func BenchmarkPolMod(t *testing.B) { - f := Pol(0x2482734cacca49) - g := Pol(0x3af4b284899) - - for i := 0; i < t.N; i++ { - g.Mod(f) - } -} - -func BenchmarkPolDeg(t *testing.B) { - f := Pol(0x3af4b284899) - d := f.Deg() - if d != 41 { - t.Fatalf("BenchmalPolDeg: Wrong degree %d returned, expected %d", - d, 41) - } - - var sum int - for i := 0; i < t.N; i++ { - sum += f.Deg() +func BenchmarkPolDivMod(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tt := range polDivModTests { + tt.x.DivMod(tt.y) + } } - // Make sure Deg call isn't optimized away. - t.Log("sum of Deg:", sum) } func TestRandomPolynomial(t *testing.T) { @@ -396,12 +394,23 @@ func TestPolGCD(t *testing.T) { } } +func BenchmarkPolGCD(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tt := range polGCDTests { + tt.f1.GCD(tt.f2) + } + } +} + var polMulModTests = []struct { f1 Pol f2 Pol g Pol mod Pol }{ + {0, 0, 0x11111, 0}, + {0, 1, 0x11111, 0}, + {1, 0, 0x11111, 0}, { 0x1230, 0x230, @@ -425,3 +434,11 @@ func TestPolMulMod(t *testing.T) { } } } + +func BenchmarkPolMulMod(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tt := range polMulModTests { + tt.f1.MulMod(tt.f2, tt.g) + } + } +}