/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.common.config.keys.loader.openssh.kdf;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.StreamCorruptedException;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.apache.sshd.common.NamedResource;
import org.apache.sshd.common.RuntimeSshException;
import org.apache.sshd.common.cipher.BuiltinCiphers;
import org.apache.sshd.common.cipher.CipherFactory;
import org.apache.sshd.common.config.keys.KeyEntryResolver;
import org.apache.sshd.common.config.keys.loader.openssh.OpenSSHKdfOptions;
import org.apache.sshd.common.config.keys.loader.openssh.kdf.BCrypt;
import org.apache.sshd.common.session.SessionContext;
import org.apache.sshd.common.util.ExceptionUtils;
import org.apache.sshd.common.util.NumberUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.security.SecurityUtils;

public class BCryptKdfOptions
implements OpenSSHKdfOptions {
    public static final String NAME = "bcrypt";
    public static final int DEFAULT_MAX_ROUNDS = 255;
    private static final AtomicInteger MAX_ROUNDS_HOLDER = new AtomicInteger(255);
    private byte[] salt;
    private int numRounds;

    @Override
    public void initialize(String name, byte[] kdfOptions) throws IOException {
        if (!NAME.equalsIgnoreCase(name)) {
            throw new StreamCorruptedException("Mismatched KDF name: " + name);
        }
        if (NumberUtils.isEmpty(kdfOptions)) {
            throw new StreamCorruptedException("Missing KDF options for " + name);
        }
        int expectedSaltLength = kdfOptions.length - 8;
        try (ByteArrayInputStream stream = new ByteArrayInputStream(kdfOptions);){
            this.initialize(stream, expectedSaltLength);
        }
        byte[] saltValue = this.getSalt();
        int actualSaltLength = NumberUtils.length(saltValue);
        if (actualSaltLength != expectedSaltLength) {
            throw new StreamCorruptedException("Mismatched salt data length: expected=" + expectedSaltLength + ", actual=" + actualSaltLength);
        }
    }

    protected void initialize(InputStream stream, int maxSaltSize) throws IOException {
        this.setSalt(KeyEntryResolver.readRLEBytes(stream, maxSaltSize));
        this.setNumRounds(KeyEntryResolver.decodeInt(stream));
    }

    @Override
    public boolean isEncrypted() {
        return true;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public byte[] decodePrivateKeyBytes(SessionContext session, NamedResource resourceKey, String cipherName, byte[] privateDataBytes, String password) throws IOException, GeneralSecurityException {
        if (NumberUtils.isEmpty(privateDataBytes)) {
            return privateDataBytes;
        }
        CipherFactory cipherSpec = BuiltinCiphers.resolveFactory(cipherName);
        if (cipherSpec == null || !cipherSpec.isSupported()) {
            throw new NoSuchAlgorithmException("Unsupported cipher: " + cipherName);
        }
        int blockSize = cipherSpec.getCipherBlockSize();
        if (privateDataBytes.length % blockSize != 0) {
            throw new StreamCorruptedException("Encrypted data size (" + privateDataBytes.length + ") is not aligned to  " + cipherName + " block size (" + blockSize + ")");
        }
        byte[] pwd = password.getBytes(StandardCharsets.UTF_8);
        int keySize = cipherSpec.getKdfSize();
        int ivSize = cipherSpec.getIVSize();
        byte[] cipherInput = new byte[keySize + ivSize];
        try {
            byte[] byArray;
            this.bcryptKdf(pwd, cipherInput);
            byte[] kv = Arrays.copyOfRange(cipherInput, 0, keySize);
            byte[] iv = Arrays.copyOfRange(cipherInput, keySize, cipherInput.length);
            try {
                Cipher cipher = SecurityUtils.getCipher(cipherSpec.getTransformation());
                SecretKeySpec keySpec = new SecretKeySpec(kv, cipherSpec.getAlgorithm());
                IvParameterSpec ivSpec = new IvParameterSpec(iv);
                cipher.init(2, (Key)keySpec, ivSpec);
                byArray = cipher.doFinal(privateDataBytes);
            }
            catch (Throwable throwable) {
                try {
                    Arrays.fill(kv, (byte)0);
                    Arrays.fill(iv, (byte)0);
                    throw throwable;
                }
                catch (RuntimeException e) {
                    Throwable t = ExceptionUtils.peelException(e);
                    Throwable err = null;
                    if (t instanceof IOException || t instanceof GeneralSecurityException) {
                        err = t;
                    } else {
                        t = ExceptionUtils.resolveExceptionCause(e);
                        if (t instanceof IOException || t instanceof GeneralSecurityException) {
                            err = t;
                        }
                    }
                    if (err instanceof IOException) {
                        throw (IOException)err;
                    }
                    if (err instanceof GeneralSecurityException) {
                        throw (GeneralSecurityException)err;
                    }
                    throw e;
                }
            }
            Arrays.fill(kv, (byte)0);
            Arrays.fill(iv, (byte)0);
            return byArray;
        }
        finally {
            Arrays.fill(pwd, (byte)0);
            Arrays.fill(cipherInput, (byte)0);
        }
    }

    protected void bcryptKdf(byte[] password, byte[] output) throws IOException, GeneralSecurityException {
        BCrypt bcrypt = new BCrypt();
        bcrypt.pbkdf(password, this.getSalt(), this.getNumRounds(), output);
    }

    @Override
    public final String getName() {
        return NAME;
    }

    public byte[] getSalt() {
        return NumberUtils.emptyIfNull(this.salt);
    }

    public void setSalt(byte[] salt) {
        this.salt = NumberUtils.emptyIfNull(salt);
    }

    public int getNumRounds() {
        return this.numRounds;
    }

    public void setNumRounds(int numRounds) {
        int maxAllowed = BCryptKdfOptions.getMaxAllowedRounds();
        if (numRounds <= 0 || numRounds > maxAllowed) {
            throw new BCryptBadRoundsException(numRounds, "Bad rounds value (" + numRounds + ") - max. allowed " + maxAllowed);
        }
        this.numRounds = numRounds;
    }

    public int hashCode() {
        return 31 * this.getNumRounds() + Arrays.hashCode(this.getSalt());
    }

    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }
        if (this == obj) {
            return true;
        }
        if (this.getClass() != obj.getClass()) {
            return false;
        }
        BCryptKdfOptions other = (BCryptKdfOptions)obj;
        return this.getNumRounds() == other.getNumRounds() && Arrays.equals(this.getSalt(), other.getSalt());
    }

    public String toString() {
        return this.getName() + ": rounds=" + this.getNumRounds() + ", salt=" + BufferUtils.toHex(':', this.getSalt());
    }

    public static int getMaxAllowedRounds() {
        return MAX_ROUNDS_HOLDER.get();
    }

    public static void setMaxAllowedRounds(int value) {
        ValidateUtils.checkTrue(value > 0, "Invalid max. rounds value: %d", value);
        MAX_ROUNDS_HOLDER.set(value);
    }

    public static class BCryptBadRoundsException
    extends RuntimeSshException {
        private static final long serialVersionUID = 1724985268892193553L;
        private final int rounds;

        public BCryptBadRoundsException(int rounds) {
            this(rounds, "Bad rounds value: " + rounds);
        }

        public BCryptBadRoundsException(int rounds, String message) {
            this(rounds, message, null);
        }

        public BCryptBadRoundsException(int rounds, String message, Throwable reason) {
            super(message, reason);
            this.rounds = rounds;
        }

        public int getRounds() {
            return this.rounds;
        }
    }
}

