public class Euler391 {
    static class Solver {
        static final int SHIFT = 10;
        static final int STRIDE = 1 << SHIFT;

        int[] tableA = new int[STRIDE * STRIDE];
        int[] tableB = new int[STRIDE * STRIDE];

        Solver() {
            for (int v = 0; v <= 1000; v++) {
                int idx = v; // index(0, v)
                tableA[idx] = v;
                tableB[idx] = v;
            }
        }

        int computeM(int n) {
            int[] a = tableA;
            int[] b = tableB;
            int step = n + 1;

            for (int s = 1; s <= step; ++s) {
                for (int m = 1; m <= s; ++m) {
                    int add = s - m + 1;
                    int prevRow = (m - 1) << SHIFT;
                    int curRow = m << SHIFT;
                    for (int v = 0; v <= n; ++v) {
                        int value = b[prevRow + v] + add;
                        if (value >= step) {
                            value = 0;
                        }
                        b[curRow + v] = a[prevRow + value];
                    }
                }

                int row = s << SHIFT;
                int value = b[row];
                int v = 1;
                while (v <= n && b[row + v] == value) {
                    ++v;
                }
                if (v > n) {
                    tableA = a;
                    tableB = b;
                    return value;
                }

                int[] temp = a;
                a = b;
                b = temp;
            }

            tableA = a;
            tableB = b;
            return a[step << SHIFT];
        }
    }

    static String solve() {
        Solver solver = new Solver();
        long sum = 0;
        int m2 = -1;
        int m7 = -1;
        int m20 = -1;

        for (int n = 1; n <= 1000; n++) {
            int m = solver.computeM(n);
            if (n == 2)
                m2 = m;
            if (n == 7)
                m7 = m;
            if (n == 20)
                m20 = m;

            sum += (long) m * m * m;

            if (n == 20) {
                if (m2 != 2 || m7 != 1 || m20 != 4 || sum != 8150) {
                    throw new RuntimeException("Validation failed.");
                }
            }
        }
        return Long.toString(sum);
    }

    public static void main(String[] args) {
        System.out.println(solve());
    }
}
