package net.y3n20u.aeszip;

import static net.y3n20u.aeszip.CommonValues.DEFAULT_ENCRYPT_STRENGTH_MODE;
import static net.y3n20u.aeszip.CommonValues.ITERATION_COUNT;
import static net.y3n20u.aeszip.CommonValues.LENGTH_AUTHENTICATION_CODE;
import static net.y3n20u.aeszip.CommonValues.LENGTH_PASSWORD_VERIFICATION_VALUE;
import static net.y3n20u.aeszip.CommonValues.MAX_FOUR_BYTE_FIELD_LONG;
import static net.y3n20u.aeszip.CommonValues.MESSAGE_OFFSET_INVALID;
import static net.y3n20u.aeszip.CommonValues.METHOD_AES;
import static net.y3n20u.aeszip.CommonValues.METHOD_DEFLATED;
import static net.y3n20u.aeszip.CommonValues.METHOD_STORED;

import java.io.ByteArrayOutputStream;
import java.security.SecureRandom;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Calendar;
import java.util.zip.ZipEntry;

import net.y3n20u.rfc2898.Pbkdf2;
import net.y3n20u.util.ByteHelper;

/**
 * 
 * @author y3n20u@gmail.com
 */
public class AesZipEntry extends ZipEntry {
	
	// extra field header ID: 0x9901 - 2 bytes
	// data size: 0x0007 - 2 bytes
	// integer version number: 0x0002 (AE-2) - 2 bytes
	// vender ID: "AE" - 2 bytes
	// 01 99 07 00 02 00 41 45
	private static final byte[] AES_EXTRA_BYTES = ByteHelper.getBytes("01 99 07 00 02 00 41 45");

	private static final String MESSAGE_SIZE_WRONG = "size is too big or too small: {0}";
	
	private final EncryptionStrengthMode encryptionStrengthMode;
	private final byte[] saltValue;
	private byte[] encryptionKey;
	private byte[] authenticationKey;
	private byte[] passwordVerificationValue;
	private int method = METHOD_AES;
	private long relativeOffsetOfLocalFileHeader;
	
	private boolean derivedKeys;
	
	public AesZipEntry(String name) {
		this(name, DEFAULT_ENCRYPT_STRENGTH_MODE);
	}
	
	public AesZipEntry(String name, EncryptionStrengthMode mode) {
		super(name);
		derivedKeys = false;
		encryptionStrengthMode = mode;
		saltValue = AesZipEntry.generateSaltValue(encryptionStrengthMode.getSaltLength());
	}

	public AesZipEntry(String name, byte[] password) {
		this(name, DEFAULT_ENCRYPT_STRENGTH_MODE, password);
	}
	
	public AesZipEntry(String name, EncryptionStrengthMode mode, byte[] password) {
		this(name, mode);
		this.deriveKeys(password);
	}
	
	public void deriveKeys(byte[] password) {
		int derivedKeyLen = encryptionStrengthMode.getKeyLength() + encryptionStrengthMode.getKeyLength() + LENGTH_PASSWORD_VERIFICATION_VALUE;
		byte[] derivedKeyByPbkdf2 = new Pbkdf2().deriveKey(password, saltValue, ITERATION_COUNT, derivedKeyLen);
		encryptionKey = new byte[encryptionStrengthMode.getKeyLength()];
		System.arraycopy(derivedKeyByPbkdf2, 0, encryptionKey, 0, encryptionKey.length);
		authenticationKey = new byte[encryptionStrengthMode.getKeyLength()];
		System.arraycopy(derivedKeyByPbkdf2, encryptionKey.length, authenticationKey, 0, authenticationKey.length);
		passwordVerificationValue = new byte[LENGTH_PASSWORD_VERIFICATION_VALUE];
		System.arraycopy(derivedKeyByPbkdf2, encryptionKey.length + authenticationKey.length, passwordVerificationValue, 0, passwordVerificationValue.length);
		derivedKeys = true;
	}
	
	public void deriveKeys(byte[] password, byte[] saltValue, byte[] passwordVerificationValueExpected) throws PasswordNotCorrectException {
		int derivedKeyLen = encryptionStrengthMode.getKeyLength() + encryptionStrengthMode.getKeyLength() + LENGTH_PASSWORD_VERIFICATION_VALUE;
		byte[] derivedKeyByPbkdf2 = new Pbkdf2().deriveKey(password, saltValue, ITERATION_COUNT, derivedKeyLen);
		encryptionKey = new byte[encryptionStrengthMode.getKeyLength()];
		System.arraycopy(derivedKeyByPbkdf2, 0, encryptionKey, 0, encryptionKey.length);
		authenticationKey = new byte[encryptionStrengthMode.getKeyLength()];
		System.arraycopy(derivedKeyByPbkdf2, encryptionKey.length, authenticationKey, 0, authenticationKey.length);
		passwordVerificationValue = new byte[LENGTH_PASSWORD_VERIFICATION_VALUE];
		System.arraycopy(derivedKeyByPbkdf2, encryptionKey.length + authenticationKey.length, passwordVerificationValue, 0, passwordVerificationValue.length);
		if (!Arrays.equals(passwordVerificationValueExpected, passwordVerificationValue)) {
			// FIXME: fail in password verification.
			throw new PasswordNotCorrectException(this.getName());
		}
		derivedKeys = true;
	}
	
	public void setContentCompressedSize(long contentCompressedSize) {
		if (contentCompressedSize < 0) {
			// TODO
			throw new IllegalArgumentException();
		}
		long compressedSize = contentCompressedSize + encryptionStrengthMode.getSaltLength() + LENGTH_PASSWORD_VERIFICATION_VALUE + LENGTH_AUTHENTICATION_CODE;
		
		if (compressedSize > MAX_FOUR_BYTE_FIELD_LONG || compressedSize < 0) {
			throw new InvalidFieldException(MessageFormat.format(MESSAGE_SIZE_WRONG, compressedSize));
		}
		super.setCompressedSize(compressedSize);
	}
	
	public long getContentCompressedSize() {
		long originalCompressedSize = super.getCompressedSize();
		if (this.isDirectory()) {
			return 0l;
		}
		if (originalCompressedSize < 0) {
			return -1l;
		}
		long compressedSize = originalCompressedSize - encryptionStrengthMode.getSaltLength() - LENGTH_PASSWORD_VERIFICATION_VALUE - LENGTH_AUTHENTICATION_CODE;
		
		if (compressedSize > MAX_FOUR_BYTE_FIELD_LONG || compressedSize < 0) {
			throw new InvalidFieldException(MessageFormat.format(MESSAGE_SIZE_WRONG, compressedSize));
		}
		return compressedSize;
	}
	
	@Override
	public void setMethod(int method) {
		if (method != METHOD_AES && method != METHOD_STORED) {
			throw new IllegalArgumentException(new InvalidMethodException((short)method));
		}
		this.method = method;
	}
	
	@Override
	public int getMethod() {
		return method;
	}

	@Override
	public byte[] getExtra() {
		ByteArrayOutputStream baos = new ByteArrayOutputStream();
		byte[] originalExtra = super.getExtra();		
		// extra field of original entry.
		if (originalExtra != null) {
			baos.write(originalExtra, 0, originalExtra.length);
		}
		if (this.getMethod() != METHOD_AES) {
			// if the method is not 'AES', no need to append the 'AES' extra field.
			return baos.toByteArray();
		}
		
		// extra field header ID: 0x9901 - 2 bytes
		// data size: 0x0007 - 2 bytes
		// integer version number: 0x0002 (AE-2) - 2 bytes
		// vender ID: "AE" - 2 bytes
		// 01 99 07 00 02 00 41 45
		baos.write(AES_EXTRA_BYTES, 0, AES_EXTRA_BYTES.length);

		// integer mode value indicating AES encryption strength - 1 byte
		baos.write(this.getStrengthMode().getModeValue());
		
		// actual compression method used to compress the file - 2 bytes
		short value = this.getActualCompressionMethod();
		baos.write((byte) (value & 0xff));
		baos.write((byte) ((value >>> 8) & 0xff));
		
		return baos.toByteArray();
	}
	
	public void setActualCompressionMethod(int method) {
		short actualCompressionMethod = (short) method;
		if (actualCompressionMethod != METHOD_STORED && actualCompressionMethod != METHOD_DEFLATED) {
			throw new InvalidMethodException(actualCompressionMethod);
		}
		super.setMethod(method);
	}
	
	public short getActualCompressionMethod() {
		return (short) super.getMethod();
	}

	/**
	 * set the offset of this entry in the zip file.
	 * 
	 * @param offset
	 *            offset (4-byte value)
	 * @throws IllegalArgumentException
	 *             the parameter is too big or too small.
	 */
	public void setRelativeOffsetOfLocalFileHeader(long offset) {
		if (offset < 0 || offset > MAX_FOUR_BYTE_FIELD_LONG) {
			throw new IllegalArgumentException(MessageFormat.format(MESSAGE_OFFSET_INVALID, offset,
					MAX_FOUR_BYTE_FIELD_LONG));
		}
		relativeOffsetOfLocalFileHeader = offset;
	}

	public long getRelativeOffsetOfLocalFileHeader() {
		return relativeOffsetOfLocalFileHeader;
	}

	public short getLastModTime() {
		return AesZipEntry.generateTime(super.getTime());
	}

	public short getLastModDate() {
		return AesZipEntry.generateDate(super.getTime());
	}

	public EncryptionStrengthMode getStrengthMode() {
		return encryptionStrengthMode;
	}
	
	public byte[] getSaltValue() {
		return saltValue;
	}
	
	public byte[] getEncryptionKey() {
		if (derivedKeys) {
			return encryptionKey;
		}
		// FIXME
		throw new IllegalStateException();
	}
	
	public byte[] getAuthenticationKey() {
		if (derivedKeys) {
			return authenticationKey;
		}
		// FIXME
		throw new IllegalStateException();
	}
	
	public byte[] getPasswordVerificationValue() {
		if (derivedKeys) {
			return passwordVerificationValue;
		}
		// FIXME
		throw new IllegalStateException();
	}
	
	/*
	 * TODO: Is the randomness sufficiet ??
	 */
	private static byte[] generateSaltValue(int length) {
		byte[] r = new byte[length];
		new SecureRandom().nextBytes(r);
		return r;
	}
	
	private static short generateTime(long time) {
		Calendar c = Calendar.getInstance();
		c.setTimeInMillis(time);
		if (c.get(Calendar.YEAR) < 1980) {
			return 0;
		}
		int hour = c.get(Calendar.HOUR_OF_DAY);
		int minute = c.get(Calendar.MINUTE);
		int second = c.get(Calendar.SECOND);
		return (short) (hour << 11 | minute << 5 | second >> 1);
	}

	private static short generateDate(long time) {
		Calendar c = Calendar.getInstance();
		c.setTimeInMillis(time);
		int year = c.get(Calendar.YEAR);
		int month = c.get(Calendar.MONTH);
		int date = c.get(Calendar.DATE);
		if (year < 1980) {
			year = 1980;
			month = 1;
			date = 1;
		}
		return (short) ((year - 1980) << 9 | (month + 1) << 5 | date);
	}
}
