package jp.sourceforge.armadillo.lzh;

import java.io.*;

import jp.sourceforge.armadillo.io.*;

/**
 * LZH`A[JCũnt}@kR[hB
 * LH4,LH5,LH6,LH7ŎgpB
 */
final class LzhHuffmanDecoder implements LzssDecoderReadable {

    /**
     * Xgp[Ne[ũTCY (16bits + ԕ)
     */
    private static final int WORK_TABLE_BITLENGTH = 16 + 1;

    private BitInputStream bin;
    private int blockSize;
    private int symbolMaxBitLength;
    private short[] symbolLengthTable;
    private short[] symbolCodeTable;
    private int offsetMaxBitLength;
    private short[] offsetLengthTable;
    private short[] offsetCodeTable;

    /**
     * LzhHuffmanDecoder̐B
     * @param in InputStream
     */
    public LzhHuffmanDecoder(InputStream in) {
        this(in, -1);
    }

    /**
     * LzhHuffmanDecoder̐B
     * @param in InputStream
     * @param limit sf[^ʁioCgj
     *              ̐w肳ꂽꍇ͖(EOF܂)Ƃ
     */
    public LzhHuffmanDecoder(final InputStream in, final long limit) {
        this.bin = (limit < 0) ? new BitInputStream(in) : new BitInputStream(new InputStream() {

            private long remaining = limit;

            public int read() throws IOException {
                if (remaining <= 0) {
                    return -1;
                }
                int read = in.read();
                if (read != -1) {
                    --remaining;
                }
                return read;
            }

        });
        this.blockSize = 0;
    }

    /* @see jp.sourceforge.armadillo.lzh.LzssDecoderReadable#read() */
    public int read() throws IOException {
        try {
            assert blockSize >= 0;
            if (blockSize == 0) {
                if (bin.prefetch() == -1) {
                    return -1;
                }
                int n = bin.readBits(16);
                if (n == -1) {
                    return -1;
                }
                this.blockSize = n;
                assert blockSize > 0 : "block size = " + blockSize;
                createSymbolTables();
                createOffsetTables();
            }
            --blockSize;
            int b = bin.prefetchBits(symbolMaxBitLength);
            assert b != -1;
            int code = symbolCodeTable[b];
            bin.readBits(symbolLengthTable[code]);
            assert code >= 0 && code < 511;
            return code;
        } catch (RuntimeException ex) {
            throw new LzhException("decode error", ex);
        }
    }

    /* @see jp.sourceforge.armadillo.lzh.LzssDecoderReadable#readOffset() */
    public int readOffset() throws IOException {
        try {
            int b = bin.prefetchBits(offsetMaxBitLength);
            assert b != -1;
            int code = offsetCodeTable[b];
            int codeLength = offsetLengthTable[code];
            if (codeLength > 0) {
                bin.readBits(codeLength);
            }
            assert code >= 0 && codeLength >= 0;
            int offset;
            if (code > 1) {
                offset = (1 << (code - 1)) | bin.readBits(code - 1);
            } else {
                offset = code;
            }
            assert offset >= 0;
            return offset;
        } catch (RuntimeException ex) {
            throw new LzhException("decode error", ex);
        }
    }

    /**
     * L̕e[uƕe[u𐶐B
     * @throws IOException o̓G[ꍇ
     */
    private void createSymbolTables() throws IOException {
        short[] lengthList = readCodeLengthList(5, 3);
        final int blength = getMaxBitSize(lengthList);
        short[] table = createCodeTable(lengthList, blength);
        int n = bin.readBits(9);
        if (n < 1) {
            throw new LzhException("invalid compressed data: number of code lengths=" + n);
        }
        short[] codeLengthList = new short[n];
        for (int i = 0; i < codeLengthList.length;) {
            int code = bin.prefetchBits(blength);
            if (code == -1) {
                throw new LzhException("EOF appeared while reading symbol length list");
            }
            int length = table[code];
            int bitLength = lengthList[length];
            bin.readBits(bitLength);
            switch (length) {
                case 0:
                    ++i;
                    break;
                case 1:
                    i += bin.readBits(4) + 3;
                    break;
                case 2:
                    i += bin.readBits(9) + 20;
                    break;
                default:
                    codeLengthList[i++] = (short)(length - 2);
            }
        }
        final int maxBitLength = getMaxBitSize(codeLengthList);
        this.symbolMaxBitLength = maxBitLength;
        this.symbolLengthTable = codeLengthList;
        this.symbolCodeTable = createCodeTable(codeLengthList, maxBitLength);
    }

    /**
     * vʒu̕̕e[uƕe[u𐶐B
     * @throws IOException o̓G[ꍇ
     */
    private void createOffsetTables() throws IOException {
        short[] codeLengthList = readCodeLengthList(4, -1);
        if (codeLengthList.length == 0) {
            int offset = bin.readBits(4);
            codeLengthList = new short[offset + 1];
            short[] codeTable = new short[]{(short)offset, (short)offset};
            this.offsetMaxBitLength = 1;
            this.offsetLengthTable = codeLengthList;
            this.offsetCodeTable = codeTable;
        } else {
            final int maxBitLength = getMaxBitSize(codeLengthList);
            this.offsetMaxBitLength = maxBitLength;
            this.offsetLengthTable = codeLengthList;
            this.offsetCodeTable = createCodeTable(codeLengthList, maxBitLength);
        }
    }

    /**
     * XgǂݍށB
     * @param nBits vf̃rbg
     * @param special ʂȗvf̃CfbNX
     * @return Xg
     * @throws IOException o̓G[ꍇ
     */
    private short[] readCodeLengthList(int nBits, int special) throws IOException {
        final int n = bin.readBits(nBits);
        short[] list = new short[n];
        for (int i = 0; i < n; i++) {
            if (i == special) {
                i += bin.readBits(2);
            }
            int length = bin.readBits(3);
            if (length == 7) {
                while (bin.readBit() == 1) {
                    ++length;
                }
            }
            list[i] = (short)length;
        }
        return list;
    }

    /**
     * őrbg̎擾B
     * @param bitLengthList rbgXg
     * @return őrbg
     */
    private static int getMaxBitSize(short[] bitLengthList) {
        int max = 0;
        for (int i = 0; i < bitLengthList.length; i++) {
            if (bitLengthList[i] > max) {
                max = bitLengthList[i];
            }
        }
        return max;
    }

    /**
     * e[u𐶐B
     * @param lengthList Xg
     * @param maxBitLength őrbg
     * @param codeList Xg
     * @return e[u
     */
    private static short[] createCodeTable(short[] lengthList, int maxBitLength) {
        int[] codeList = createCodeList(lengthList);
        final int tableSize = (1 << maxBitLength);
        short[] table = new short[tableSize];
        for (int i = 0; i < lengthList.length; i++) {
            if (lengthList[i] > 0) {
                int rangeBits = maxBitLength - lengthList[i];
                int start = codeList[i] << rangeBits;
                int next = start + (1 << rangeBits);
                for (int index = start; index < next; index++) {
                    table[index] = (short)i;
                }
            }
        }
        return table;
    }

    /**
     * Xg𐶐B
     * @param codeLengthList Xg
     * @return Xg
     */
    private static int[] createCodeList(short[] codeLengthList) {
        assert codeLengthList.length > 0;
        if (codeLengthList.length == 1) {
            return new int[1];
        }
        int[] counts = new int[WORK_TABLE_BITLENGTH];
        for (int i = 0; i < codeLengthList.length; i++) {
            ++counts[codeLengthList[i]];
        }
        int[] baseCodes = new int[WORK_TABLE_BITLENGTH];
        for (int i = 0; i < WORK_TABLE_BITLENGTH - 1; i++) { // i = bit length - 1
            baseCodes[i + 1] = baseCodes[i] + counts[i + 1] << 1;
        }
        assert baseCodes[WORK_TABLE_BITLENGTH - 1] == 1 << WORK_TABLE_BITLENGTH : baseCodes[WORK_TABLE_BITLENGTH - 1];
        int[] codeList = new int[codeLengthList.length];
        for (int i = 0; i < codeList.length; i++) {
            int codeLength = codeLengthList[i];
            if (codeLength > 0) {
                codeList[i] = baseCodes[codeLength - 1]++;
            }
        }
        return codeList;
    }

    /* @see jp.sourceforge.armadillo.lzh.LzssDecoderReadable#close() */
    public void close() throws IOException {
        try {
            bin.close();
        } finally {
            bin = null;
            symbolLengthTable = null;
            symbolCodeTable = null;
            offsetLengthTable = null;
            offsetCodeTable = null;
        }
    }

}
