import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Euler269 {

    static long pow10i(int exp) {
        long value = 1;
        for (int i = 0; i < exp; i++)
            value *= 10;
        return value;
    }

    static class StateKey {
        int[] carry;

        StateKey(int[] c) {
            this.carry = c.clone();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o)
                return true;
            if (o == null || getClass() != o.getClass())
                return false;
            StateKey stateKey = (StateKey) o;
            return Arrays.equals(carry, stateKey.carry);
        }

        @Override
        public int hashCode() {
            return Arrays.hashCode(carry);
        }
    }

    static long countSubsetForLength(int length, int lastDigit, List<Integer> roots) {
        if (roots.isEmpty())
            return 0;

        int evenSlots = (length + 1) / 2;
        int msdPos = length - 1;

        int[] zeroCarry = new int[roots.size()];
        StateKey zeroState = new StateKey(zeroCarry);
        Map<StateKey, Long> states = new HashMap<>();
        states.put(zeroState, 1L);

        for (int j = 0; j < evenSlots; j++) {
            int evenPos = 2 * j;
            int oddPos = 2 * j + 1;

            List<Integer> evenDigits = new ArrayList<>();
            if (j == 0) {
                evenDigits.add(lastDigit);
            } else {
                for (int d = 0; d <= 9; d++)
                    evenDigits.add(d);
            }
            if (evenPos == msdPos) {
                evenDigits.removeIf(d -> d == 0);
            }

            List<Integer> oddDigits = new ArrayList<>();
            if (oddPos >= length) {
                oddDigits.add(0);
            } else {
                for (int d = 0; d <= 9; d++)
                    oddDigits.add(d);
                if (oddPos == msdPos) {
                    oddDigits.removeIf(d -> d == 0);
                }
            }

            Map<StateKey, Long> nextStates = new HashMap<>();
            for (Map.Entry<StateKey, Long> entry : states.entrySet()) {
                int[] carry = entry.getKey().carry;
                long ways = entry.getValue();

                for (int ed : evenDigits) {
                    for (int od : oddDigits) {
                        int[] nextCarry = new int[roots.size()];
                        boolean ok = true;
                        for (int idx = 0; idx < roots.size(); idx++) {
                            int t = roots.get(idx);
                            int denom = t * t;
                            long num = (long) t * od + carry[idx] - ed;
                            if (num % denom != 0) {
                                ok = false;
                                break;
                            }
                            nextCarry[idx] = (int) (num / denom);
                        }
                        if (!ok)
                            continue;

                        StateKey nextKey = new StateKey(nextCarry);
                        nextStates.put(nextKey, nextStates.getOrDefault(nextKey, 0L) + ways);
                    }
                }
            }

            states = nextStates;
            if (states.isEmpty())
                return 0;
        }

        return states.getOrDefault(zeroState, 0L);
    }

    static long countForLength(int length) {
        long total = 0;
        if (length >= 2) {
            total += 9L * pow10i(length - 2);
        }

        for (int d0 = 1; d0 <= 9; d0++) {
            List<Integer> divisors = new ArrayList<>();
            for (int t = 1; t <= 9; t++) {
                if (d0 % t == 0)
                    divisors.add(t);
            }

            int m = divisors.size();
            long unionCount = 0;
            for (int mask = 1; mask < (1 << m); mask++) {
                List<Integer> roots = new ArrayList<>();
                for (int i = 0; i < m; i++) {
                    if (((mask >> i) & 1) != 0)
                        roots.add(divisors.get(i));
                }

                long cnt = countSubsetForLength(length, d0, roots);
                if ((Integer.bitCount(mask) & 1) != 0) {
                    unionCount += cnt;
                } else {
                    unionCount -= cnt;
                }
            }
            total += unionCount;
        }

        return total;
    }

    static long solve(int maxPower) {
        long total = 0;
        for (int len = 1; len <= maxPower; len++) {
            total += countForLength(len);
        }
        total += 1;
        return total;
    }

    public static void main(String[] args) {
        System.out.println(solve(16));
    }
}
