/* bsgs.c -- (C) 2016 Mark Rodenkirch

   Implementation of a baby step giant step algorithm for finding all
   n in the range nmin <= n <= nmax satisfying b^n=d_i (mod p) where b
   and each d_i are relatively prime to p.
*/

#include <assert.h>
#include <inttypes.h>
#include <limits.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include "cksieve.h"
#include "arithmetic.h"
#include "hashtable.h"

// There are 4 sequences.  2 for the Carol form, 2 for the Kynea form
// All of them will be sieved concurrently.
#define ROOT_COUNT	4
#define NO_SOLUTION 	UINT32_MAX
#define USE_VEC_MULMOD

typedef struct {
   uint64_t root;
   int32_t  c;
} seq_t;

static uint32_t m;    /* Number of baby steps */
static uint32_t M;    /* Number of giant steps */

static uint32_t sieve_low;
static uint32_t sieve_range;

static seq_t    sequence[ROOT_COUNT];

static uint64_t D64[ROOT_COUNT];
static uint64_t bQ64;

void bsgs_small(uint64_t p);
void bsgs_big(uint64_t p);
uint32_t baby_steps32(uint32_t b, uint32_t bj0, uint32_t p);
uint32_t baby_steps64(uint64_t b, uint64_t bj0, uint64_t p);
 
/*
  Giant step baby step algorithm, from Wikipedea:

  input: A cyclic group G of order n, having a generator a and an element b.

  Output: A value x satisfying a^x = b (mod n).

  1. m <-- ceiling(sqrt(n))
  2. For all j where 0 <= j < m:
     1. Compute a^j mod n and store the pair (j, a^j) in a table.
  3. Compute a^(-m).
  4. g <-- b.
  5. For i = 0 to (m - 1):
     1. Check if g is the second component (a^j) of any pair in the table.
     2. If so, return im + j.
     3. If not, g <-- ga^(-m) mod n.
*/
void bsgs_small(uint64_t p)
{
   uint64_t  root1, root2;
   uint64_t  b, bm, inv_pb;
   uint32_t  i, j, k;
   uint32_t  order_of_b_mod_p;
   uint32_t  solutions = 0;
   uint32_t  first_solved = NO_SOLUTION;
   uint32_t  C32[ROOT_COUNT];
   uint32_t  D32[ROOT_COUNT];

   total_primes_tested++;

   // Exit the function if there are no values x such that x^2 = 2 (mod p)
   if (!isQuadraticResidue(2, p))
      return;
   
   mod64_init(p);
   
   // Find root of x^2 = 2 (mod p)
   root1 = findRoot(p);

   root2 = p - root1;

   // It is possible that findRoot returns where x^2 = -2 (mod p)
   if (mulmod64(root1, root1, p) != 2) { warning("%"PRIu64" is not a root (mod %"PRIu64")\n", root1, p); return; }
   if (mulmod64(root2, root2, p) != 2) { warning("%"PRIu64" is not a root (mod %"PRIu64")\n", root2, p); return; }
 
   sequence[0].root = root1 - 1;
   sequence[0].c    = +1;
   sequence[1].root = root2 - 1;
   sequence[1].c    = +1;
   sequence[2].root = root1 + 1;
   sequence[2].c    = -1;
   sequence[3].root = root2 + 1;
   sequence[3].c    = -1;

   b = b_term % p;
 
   if (b == 0) return;

   inv_pb = invmod32_64(b, p);

   bm = inv_pb;

   for (i = 0; i < ROOT_COUNT; i++)
   {
      D32[i] = sequence[i].root % p;
      C32[i] = NO_SOLUTION;
   }

   /* Baby steps. */
   order_of_b_mod_p = baby_steps32(b, powmod64(b, sieve_low, p), p);

   if (order_of_b_mod_p > 0)
   {
      for (k = 0; k < ROOT_COUNT; k++)
         for (j = lookup32(D32[k]); j < sieve_range; j += order_of_b_mod_p)
            eliminate_term(sieve_low+j, sequence[k].c, p);
   
      mod64_fini();
      return;
   }

   /* First giant step. */
   for (k = 0; k < ROOT_COUNT; k++)
      if ((j = lookup32(D32[k])) != HASH_NOT_FOUND)
      {
        solutions++;
        C32[k] = j;
        if (solutions == 1) /* first solution */
           first_solved = k;
      }

   bm = (uint32_t) powmod64(bm, m, p); /* bm <- 1/b^m (mod p) */
         
   for (i = 1; i < M && solutions <= ROOT_COUNT; i++)
      for (k = 0; k < ROOT_COUNT; k++)
         if (C32[k] == NO_SOLUTION || first_solved == k)
         {
            D32[k] = (uint32_t) mulmod64(D32[k], bm, p);
            j = lookup32(D32[k]);
               
            if (j != HASH_NOT_FOUND)
            {
               solutions++;
               if (first_solved == k) /* repeat solution */
               {
                  order_of_b_mod_p = i*m + j - C32[k];
                  first_solved = NO_SOLUTION;   /* no more repeats needed */
               }
               else
               {
                  C32[k] = i*m + j;
                  if (solutions == 1) /* first solution */
                     first_solved = k;
               }
            }
         }

   if (order_of_b_mod_p > 0)
      for (k = 0; k < ROOT_COUNT; k++)
         for (j = C32[k]; j < sieve_range; j += order_of_b_mod_p)
            eliminate_term(sieve_low+j, sequence[k].c, p);
   else
      for (k = 0; k < ROOT_COUNT; k++)
         if (C32[k] != NO_SOLUTION)
            eliminate_term(sieve_low+C32[k], sequence[k].c, p);
            
   mod64_fini();
}

uint32_t baby_steps32(uint32_t b, uint32_t bj0, uint32_t p)
{
   uint32_t j, bj;

   clear_hashtable();
   for (j = 0, bj = bj0; j < m; j++)
   {
      insert32(j, bj);
      bj = (uint32_t) mulmod64(bj, b, p);
      
      if (bj == bj0)
         return j+1;
   }

   return 0;
}

// Both 32-bit and 64-bit primes are handled here because very little time is
// spent in 32-bit sieving that it isn't worth the effort to code a 32-bit version.
void bsgs_big(uint64_t p)
{
   uint64_t  root1, root2;
   uint64_t  b, bj0;
   uint32_t  i, j, k;
   uint64_t  inv_pb;
   
   total_primes_tested++;

   // Exit the function if there are no values x such that x^2 = 2 (mod p)
   if (!isQuadraticResidue(2, p))
      return;
   
   mod64_init(p);
      
   // Find root of x^2 = 2 (mod p)
   root1 = findRoot(p);

   root2 = p - root1;

   // It is possible that findRoot returns where x^2 = -2 (mod p)
   if (mulmod64(root1, root1, p) != 2) { warning("%"PRIu64" is not a root (mod %"PRIu64")\n", root1, p); return; }
   if (mulmod64(root2, root2, p) != 2) { warning("%"PRIu64" is not a root (mod %"PRIu64")\n", root2, p); return; }
 
   sequence[0].root = root1 - 1;
   sequence[0].c    = +1;
   sequence[1].root = root2 - 1;
   sequence[1].c    = +1;
   sequence[2].root = root1 + 1;
   sequence[2].c    = -1;
   sequence[3].root = root2 + 1;
   sequence[3].c    = -1;
   
   b = b_term % p;
   
   if (b == 0) return;

   // Precompute 1/b^d (mod p) for 0 <= d <= Q.
   inv_pb = invmod32_64(b, p);

   bQ64 = inv_pb;

   for (i = 0; i < ROOT_COUNT; i++)
      D64[i] = sequence[i].root % p;

   b = b_term;
   bj0 = powmod64(b, n_min, p);
   
   if ((i = baby_steps64(b, bj0, p)) > 0)
   {
      // i is the order of b (mod p). This is all the information we need to
      // determine every solution for this p, no giant steps are needed.
      for (k = 0; k < ROOT_COUNT; k++)
         for (j = lookup64(D64[k]); j < sieve_range; j += i)
            eliminate_term(sieve_low+j, sequence[k].c, p);

       mod64_fini();
       return;
   }
   
   // First giant step
   for (k = 0; k < ROOT_COUNT; k++)
      if ((j = lookup64(D64[k])) != HASH_NOT_FOUND)
        eliminate_term(sieve_low+j, sequence[k].c, p);

   // Remaining giant steps
   b = powmod64(bQ64, m, p); /* b <- 1/b^m (mod p) */

 #ifdef USE_VEC_MULMOD
   vec_mulmod64_initp(p);
   vec_mulmod64_initb(b);
 #endif
 
   for (i = 1; i < M; i++)
   {
 #ifdef USE_VEC_MULMOD
      vec4_mulmod64(D64, D64, 1);
      
      if ((j = lookup64(D64[0])) != HASH_NOT_FOUND)
         eliminate_term(sieve_low+i*m+j, sequence[0].c, p);
         
      if ((j = lookup64(D64[1])) != HASH_NOT_FOUND)
         eliminate_term(sieve_low+i*m+j, sequence[1].c, p);
      
      if ((j = lookup64(D64[2])) != HASH_NOT_FOUND)
         eliminate_term(sieve_low+i*m+j, sequence[2].c, p);
      
      if ((j = lookup64(D64[3])) != HASH_NOT_FOUND)
         eliminate_term(sieve_low+i*m+j, sequence[3].c, p);
 #else
      for (k = 0; k < ROOT_COUNT; k++)
      {
         D64[k] = mulmod64(D64[k], b, p);
         
         if ((j = lookup64(D64[k])) != HASH_NOT_FOUND)
            eliminate_term(sieve_low+i*m+j, sequence[k].c, p);;
      }
 #endif
   }

 #ifdef USE_VEC_MULMOD
   vec_mulmod64_finib();
   vec_mulmod64_finip();
 #endif
 
   mod64_fini();
}

uint32_t baby_steps64(uint64_t b, uint64_t bj0, uint64_t p)
{
   uint64_t bj;
   uint32_t j;

   clear_hashtable();
  
   PRE2_MULMOD64_INIT(b);
  
   for (j = 0, bj = bj0; j < m; j++)
   {
      insert64(j, bj);
      
      bj = PRE2_MULMOD64(bj, b, p);
      
      if (bj == bj0)
      {
         PRE2_MULMOD64_FINI();
         return j+1;
      }
   }

   PRE2_MULMOD64_FINI();
   return 0;
}

void  init_bsgs(void)
{
   uint32_t r = n_max - n_min + 1;

   // In the worst case we will do do one table insertion and one mulmod
   // for m baby steps, then s table lookups and s mulmods for M giant
   // steps. The average case depends on how many solutions are found
   // and how early in the loop they are found, which I don't know how
   // to analyse. However for the worst case we just want to minimise
   // m + s*M subject to m*M >= r, which is when m = sqrt(s*r).
  
   M = MAX(1, sqrt((double) r/ROOT_COUNT));
   m = MIN(r, ceil((double) r/M));

   if (m > HASH_MAX_ELTS)
   {
#if SHORT_HASHTABLE
      report(1, "NOTE: cksieve was compiled with SHORT_HASHTABLE=1. It may be");
      report(1, "NOTE: worthwhile compiling with SHORT_HASHTABLE=0 for this job.");
#endif

      M = ceil((double)r/HASH_MAX_ELTS);
      m = ceil((double)r/M);
   }

   sieve_low = n_min;
   sieve_range = m*M;

   //report(1, "Sieve range is %"PRIu32" to %"PRIu32
   //      ", using %"PRIu32" baby steps, %"PRIu32" giant steps.",
   //      sieve_low, (sieve_low+sieve_range)-1, m, M);

   assert(sieve_low <= n_min);
   assert(n_max < sieve_low+sieve_range);

   init_hashtable(m);
}

void  fini_bsgs(void)
{
   fini_hashtable();
}

void sieve(void)
{
   uint64_t high_p;

   init_bsgs();

   p_min = MAX(3, p_min);

   init_prime_sieve(p_max);
   start_cksieve();

   high_p = MIN(p_max, MAX(b_term, 257));
   
   if (p_min <= high_p)
   {
      prime_sieve(p_min, high_p, bsgs_small);
      p_min = high_p + 1;
   }
  
   if (p_min <= p_max)
      prime_sieve(p_min, p_max, bsgs_big);

   fini_prime_sieve();
   fini_bsgs();

   finish_cksieve("--pmax was reached", p_max);
}
