import java.util.ArrayList;
import java.util.List;

public class Euler971 {

    static long modPow(long base, long exp, long mod) {
        long result = 1 % mod;
        base %= mod;
        while (exp > 0) {
            if ((exp & 1) != 0) {
                result = (result * base) % mod;
            }
            base = (base * base) % mod;
            exp >>= 1;
        }
        return result;
    }

    static List<Integer> primesUpTo(int n) {
        boolean[] isPrime = new boolean[n + 1];
        for (int i = 2; i <= n; i++)
            isPrime[i] = true;

        for (int i = 2; i * i <= n; ++i) {
            if (isPrime[i]) {
                for (int j = i * i; j <= n; j += i) {
                    isPrime[j] = false;
                }
            }
        }

        List<Integer> primes = new ArrayList<>();
        for (int i = 2; i <= n; ++i) {
            if (isPrime[i]) {
                primes.add(i);
            }
        }
        return primes;
    }

    static int countCycleNodes(int[] next) {
        int[] state = new int[5]; // 0=unseen, 1=done, 2=in-stack
        int cycleNodes = 0;

        for (int start = 0; start < 5; ++start) {
            if (state[start] != 0) {
                continue;
            }

            int[] stack = new int[5];
            int[] pos = new int[5];
            for (int i = 0; i < 5; i++)
                pos[i] = -1;
            int top = 0;

            int cur = start;
            while (state[cur] == 0) {
                state[cur] = 2;
                pos[cur] = top;
                stack[top++] = cur;
                cur = next[cur];
            }

            if (state[cur] == 2) {
                cycleNodes += top - pos[cur];
            }

            for (int i = 0; i < top; ++i) {
                state[stack[i]] = 1;
            }
        }

        return cycleNodes;
    }

    static long countPeriodicPointsForPrime(int p) {
        long t = (p - 1) / 5;

        long omega = 1;
        for (long a = 2; a < p; ++a) {
            omega = modPow(a, t, p);
            if (omega != 1) {
                break;
            }
        }

        long[] powOmega = new long[5];
        powOmega[0] = 1;
        for (int i = 1; i < 5; ++i) {
            powOmega[i] = (powOmega[i - 1] * omega) % p;
        }

        int[] next = new int[5];
        for (int r = 0; r < 5; ++r) {
            long value = modPow((1 + powOmega[r]) % p, t, p);
            int shift = -1;
            for (int i = 0; i < 5; ++i) {
                if (powOmega[i] == value) {
                    shift = i;
                    break;
                }
            }
            if (shift < 0) {
                return 0;
            }
            next[r] = (r + shift) % 5;
        }

        int cycleNodes = countCycleNodes(next);
        return 1 + t * cycleNodes;
    }

    static long bruteCountPeriodicPointsForPrime(int p) {
        int k = (p + 4) / 5;
        int[] next = new int[p];
        for (int x = 0; x < p; ++x) {
            next[x] = (int) ((modPow(x, k, p) + x) % p);
        }

        int[] state = new int[p];
        int periodic = 0;

        for (int start = 0; start < p; ++start) {
            if (state[start] != 0) {
                continue;
            }

            int[] stack = new int[p];
            int[] pos = new int[p];
            for (int i = 0; i < p; i++)
                pos[i] = -1;
            int top = 0;

            int cur = start;
            while (state[cur] == 0) {
                state[cur] = 2;
                pos[cur] = top;
                stack[top++] = cur;
                cur = next[cur];
            }

            if (state[cur] == 2) {
                periodic += top - pos[cur];
            }

            for (int i = 0; i < top; ++i) {
                state[stack[i]] = 1;
            }
        }

        return periodic;
    }

    static long solveImpl(int limit) {
        List<Integer> primes = primesUpTo(limit);

        long total = 0;
        for (int p : primes) {
            if (p % 5 != 1) {
                continue;
            }
            total += countPeriodicPointsForPrime(p);
        }
        return total;
    }

    public static String solve() {
        return Long.toString(solveImpl(100_000_000));
    }

    public static void main(String[] args) {
        if (countPeriodicPointsForPrime(11) != 7) {
            System.out.println("Validation failed");
            return;
        }

        int[] testPrimes = { 11, 31, 41, 61 };
        for (int p : testPrimes) {
            if (countPeriodicPointsForPrime(p) != bruteCountPeriodicPointsForPrime(p)) {
                System.out.println("Validation failed");
                return;
            }
        }

        if (solveImpl(100) != 127) {
            System.out.println("Validation failed");
            return;
        }

        System.out.println(solve());
    }
}
