package org.deeplearning4j.nn.modelimport.keras.layers.recurrent;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLstm.class */
public class KerasLstm extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLstm.class);
    private final String LSTM_FORGET_BIAS_INIT_ZERO = "zero";
    private final String LSTM_FORGET_BIAS_INIT_ONE = "one";
    private final int NUM_TRAINABLE_PARAMS_KERAS_2 = 3;
    private final int NUM_TRAINABLE_PARAMS = 12;
    private final String KERAS_PARAM_NAME_W_C = "W_c";
    private final String KERAS_PARAM_NAME_W_F = "W_f";
    private final String KERAS_PARAM_NAME_W_I = "W_i";
    private final String KERAS_PARAM_NAME_W_O = "W_o";
    private final String KERAS_PARAM_NAME_U_C = "U_c";
    private final String KERAS_PARAM_NAME_U_F = "U_f";
    private final String KERAS_PARAM_NAME_U_I = "U_i";
    private final String KERAS_PARAM_NAME_U_O = "U_o";
    private final String KERAS_PARAM_NAME_B_C = "b_c";
    private final String KERAS_PARAM_NAME_B_F = "b_f";
    private final String KERAS_PARAM_NAME_B_I = "b_i";
    private final String KERAS_PARAM_NAME_B_O = "b_o";
    private final int NUM_WEIGHTS_IN_KERAS_LSTM = 12;
    protected boolean unroll;
    protected boolean returnSequences;

    public KerasLstm(Integer num) throws UnsupportedKerasConfigurationException {
        super(num);
        this.LSTM_FORGET_BIAS_INIT_ZERO = "zero";
        this.LSTM_FORGET_BIAS_INIT_ONE = "one";
        this.NUM_TRAINABLE_PARAMS_KERAS_2 = 3;
        this.NUM_TRAINABLE_PARAMS = 12;
        this.KERAS_PARAM_NAME_W_C = "W_c";
        this.KERAS_PARAM_NAME_W_F = "W_f";
        this.KERAS_PARAM_NAME_W_I = "W_i";
        this.KERAS_PARAM_NAME_W_O = "W_o";
        this.KERAS_PARAM_NAME_U_C = "U_c";
        this.KERAS_PARAM_NAME_U_F = "U_f";
        this.KERAS_PARAM_NAME_U_I = "U_i";
        this.KERAS_PARAM_NAME_U_O = "U_o";
        this.KERAS_PARAM_NAME_B_C = "b_c";
        this.KERAS_PARAM_NAME_B_F = "b_f";
        this.KERAS_PARAM_NAME_B_I = "b_i";
        this.KERAS_PARAM_NAME_B_O = "b_o";
        this.NUM_WEIGHTS_IN_KERAS_LSTM = 12;
        this.unroll = false;
    }

    public KerasLstm(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true);
    }

    public KerasLstm(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, z, Collections.emptyMap());
    }

    public KerasLstm(Map<String, Object> map, Map<String, ? extends KerasLayer> map2) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true, map2);
    }

    public KerasLstm(Map<String, Object> map, boolean z, Map<String, ? extends KerasLayer> map2) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
        this.LSTM_FORGET_BIAS_INIT_ZERO = "zero";
        this.LSTM_FORGET_BIAS_INIT_ONE = "one";
        this.NUM_TRAINABLE_PARAMS_KERAS_2 = 3;
        this.NUM_TRAINABLE_PARAMS = 12;
        this.KERAS_PARAM_NAME_W_C = "W_c";
        this.KERAS_PARAM_NAME_W_F = "W_f";
        this.KERAS_PARAM_NAME_W_I = "W_i";
        this.KERAS_PARAM_NAME_W_O = "W_o";
        this.KERAS_PARAM_NAME_U_C = "U_c";
        this.KERAS_PARAM_NAME_U_F = "U_f";
        this.KERAS_PARAM_NAME_U_I = "U_i";
        this.KERAS_PARAM_NAME_U_O = "U_o";
        this.KERAS_PARAM_NAME_B_C = "b_c";
        this.KERAS_PARAM_NAME_B_F = "b_f";
        this.KERAS_PARAM_NAME_B_I = "b_i";
        this.KERAS_PARAM_NAME_B_O = "b_o";
        this.NUM_WEIGHTS_IN_KERAS_LSTM = 12;
        this.unroll = false;
        Pair<WeightInit, Distribution> weightInitFromConfig = KerasInitilizationUtils.getWeightInitFromConfig(map, this.conf.getLAYER_FIELD_INIT(), z, this.conf, this.kerasMajorVersion.intValue());
        WeightInit weightInit = (WeightInit) weightInitFromConfig.getFirst();
        Distribution distribution = (Distribution) weightInitFromConfig.getSecond();
        Pair<WeightInit, Distribution> weightInitFromConfig2 = KerasInitilizationUtils.getWeightInitFromConfig(map, this.conf.getLAYER_FIELD_INNER_INIT(), z, this.conf, this.kerasMajorVersion.intValue());
        WeightInit weightInit2 = (WeightInit) weightInitFromConfig2.getFirst();
        Distribution distribution2 = (Distribution) weightInitFromConfig2.getSecond();
        KerasLayerUtils.getHasBiasFromConfig(map, this.conf);
        this.returnSequences = ((Boolean) KerasLayerUtils.getInnerLayerConfigFromConfig(map, this.conf).get(this.conf.getLAYER_FIELD_RETURN_SEQUENCES())).booleanValue();
        this.unroll = KerasRnnUtils.getUnrollRecurrentLayer(this.conf, map);
        LayerConstraint constraintsFromConfig = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_B_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        LayerConstraint constraintsFromConfig2 = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_W_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        LayerConstraint constraintsFromConfig3 = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_RECURRENT_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        boolean z2 = false;
        for (String str : this.inboundLayerNames) {
            if (map2.containsKey(str)) {
                KerasLayer kerasLayer = map2.get(str);
                if ((kerasLayer instanceof KerasEmbedding) && ((KerasEmbedding) kerasLayer).isHasZeroMasking()) {
                    z2 = true;
                }
            }
        }
        LSTM.Builder l2 = new LSTM.Builder().gateActivationFunction(getGateActivationFromConfig(map)).forgetGateBiasInit(getForgetBiasInitFromConfig(map, z)).name(this.layerName).nOut(KerasLayerUtils.getNOutFromConfig(map, this.conf)).dropOut(this.dropout).activation(KerasActivationUtils.getActivationFromConfig(map, this.conf)).weightInit(weightInit).weightInitRecurrent(weightInit2).biasInit(0.0d).l1(this.weightL1Regularization).l2(this.weightL2Regularization);
        if (distribution != null) {
            l2.dist(distribution);
        }
        if (distribution2 != null) {
            l2.weightInitRecurrent(distribution2);
        }
        if (constraintsFromConfig != null) {
            l2.constrainBias(new LayerConstraint[]{constraintsFromConfig});
        }
        if (constraintsFromConfig2 != null) {
            l2.constrainInputWeights(new LayerConstraint[]{constraintsFromConfig2});
        }
        if (constraintsFromConfig3 != null) {
            l2.constrainRecurrent(new LayerConstraint[]{constraintsFromConfig3});
        }
        this.layer = l2.build();
        if (z2) {
            this.layer = new MaskZeroLayer(this.layer);
        }
        if (this.returnSequences) {
            return;
        }
        this.layer = new LastTimeStep(this.layer);
    }

    public Layer getLSTMLayer() {
        return this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        InputPreProcessor inputPreprocessor = getInputPreprocessor(inputTypeArr);
        return inputPreprocessor != null ? this.returnSequences ? inputPreprocessor.getOutputType(inputTypeArr[0]) : getLSTMLayer().getOutputType(-1, inputPreprocessor.getOutputType(inputTypeArr[0])) : getLSTMLayer().getOutputType(-1, inputTypeArr[0]);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public int getNumParams() {
        return this.kerasMajorVersion.intValue() == 2 ? 3 : 12;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputPreProcessor getInputPreprocessor(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputTypeArr[0], this.layerName);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
        INDArray iNDArray;
        INDArray iNDArray2;
        INDArray iNDArray3;
        INDArray iNDArray4;
        INDArray iNDArray5;
        INDArray iNDArray6;
        INDArray iNDArray7;
        INDArray iNDArray8;
        INDArray iNDArray9;
        INDArray iNDArray10;
        INDArray iNDArray11;
        INDArray iNDArray12;
        this.weights = new HashMap();
        if (this.kerasMajorVersion.intValue() == 2) {
            if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_W())) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + this.conf.getKERAS_PARAM_NAME_W());
            }
            INDArray iNDArray13 = map.get(this.conf.getKERAS_PARAM_NAME_W());
            if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_RW())) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + this.conf.getKERAS_PARAM_NAME_RW());
            }
            INDArray iNDArray14 = map.get(this.conf.getKERAS_PARAM_NAME_RW());
            if (!map.containsKey(this.conf.getKERAS_PARAM_NAME_B())) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + this.conf.getKERAS_PARAM_NAME_B());
            }
            INDArray iNDArray15 = map.get(this.conf.getKERAS_PARAM_NAME_B());
            int length = iNDArray15.length() / 4;
            iNDArray4 = iNDArray13.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, length)});
            iNDArray2 = iNDArray13.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(length, 2 * length)});
            iNDArray = iNDArray13.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * length, 3 * length)});
            iNDArray3 = iNDArray13.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * length, 4 * length)});
            iNDArray8 = iNDArray14.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, length)});
            iNDArray6 = iNDArray14.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(length, 2 * length)});
            iNDArray5 = iNDArray14.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * length, 3 * length)});
            iNDArray7 = iNDArray14.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * length, 4 * length)});
            iNDArray12 = iNDArray15.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, length)});
            iNDArray10 = iNDArray15.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(length, 2 * length)});
            iNDArray9 = iNDArray15.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * length, 3 * length)});
            iNDArray11 = iNDArray15.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * length, 4 * length)});
        } else {
            if (!map.containsKey("W_c")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_c");
            }
            iNDArray = map.get("W_c");
            if (!map.containsKey("W_f")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_f");
            }
            iNDArray2 = map.get("W_f");
            if (!map.containsKey("W_o")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_o");
            }
            iNDArray3 = map.get("W_o");
            if (!map.containsKey("W_i")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_i");
            }
            iNDArray4 = map.get("W_i");
            if (!map.containsKey("U_c")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_c");
            }
            iNDArray5 = map.get("U_c");
            if (!map.containsKey("U_f")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_f");
            }
            iNDArray6 = map.get("U_f");
            if (!map.containsKey("U_o")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_o");
            }
            iNDArray7 = map.get("U_o");
            if (!map.containsKey("U_i")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_i");
            }
            iNDArray8 = map.get("U_i");
            if (!map.containsKey("b_c")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_c");
            }
            iNDArray9 = map.get("b_c");
            if (!map.containsKey("b_f")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_f");
            }
            iNDArray10 = map.get("b_f");
            if (!map.containsKey("b_o")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_o");
            }
            iNDArray11 = map.get("b_o");
            if (!map.containsKey("b_i")) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_i");
            }
            iNDArray12 = map.get("b_i");
        }
        int columns = iNDArray.columns();
        int rows = iNDArray.rows();
        INDArray zeros = Nd4j.zeros(rows, 4 * columns);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, rows), NDArrayIndex.interval(0, columns)}, iNDArray);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, rows), NDArrayIndex.interval(columns, 2 * columns)}, iNDArray2);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, rows), NDArrayIndex.interval(2 * columns, 3 * columns)}, iNDArray3);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, rows), NDArrayIndex.interval(3 * columns, 4 * columns)}, iNDArray4);
        this.weights.put("W", zeros);
        int columns2 = iNDArray5.columns();
        INDArray zeros2 = Nd4j.zeros(iNDArray5.rows(), 4 * columns2);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(0, columns2)}, iNDArray5);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(columns2, 2 * columns2)}, iNDArray6);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(2 * columns2, 3 * columns2)}, iNDArray7);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(3 * columns2, 4 * columns2)}, iNDArray8);
        this.weights.put("RW", zeros2);
        int columns3 = iNDArray9.columns();
        INDArray zeros3 = Nd4j.zeros(iNDArray9.rows(), 4 * columns3);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(0, columns3)}, iNDArray9);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(columns3, 2 * columns3)}, iNDArray10);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(2 * columns3, 3 * columns3)}, iNDArray11);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(3 * columns3, 4 * columns3)}, iNDArray12);
        this.weights.put("b", zeros3);
        if (map.size() > 12) {
            Set<String> keySet = map.keySet();
            keySet.remove("W_c");
            keySet.remove("W_f");
            keySet.remove("W_i");
            keySet.remove("W_o");
            keySet.remove("U_c");
            keySet.remove("U_f");
            keySet.remove("U_i");
            keySet.remove("U_o");
            keySet.remove("b_c");
            keySet.remove("b_f");
            keySet.remove("b_i");
            keySet.remove("b_o");
            String obj = keySet.toString();
            log.warn("Attemping to set weights for unknown parameters: " + obj.substring(1, obj.length() - 1));
        }
    }

    public boolean getUnroll() {
        return this.unroll;
    }

    public IActivation getGateActivationFromConfig(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(map, this.conf);
        if (innerLayerConfigFromConfig.containsKey(this.conf.getLAYER_FIELD_INNER_ACTIVATION())) {
            return KerasActivationUtils.mapActivation((String) innerLayerConfigFromConfig.get(this.conf.getLAYER_FIELD_INNER_ACTIVATION()), this.conf);
        }
        throw new InvalidKerasConfigurationException("Keras LSTM layer config missing " + this.conf.getLAYER_FIELD_INNER_ACTIVATION() + " field");
    }

    public double getForgetBiasInitFromConfig(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        String str;
        double d;
        Map<String, Object> innerLayerConfigFromConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(map, this.conf);
        if (innerLayerConfigFromConfig.containsKey(this.conf.getLAYER_FIELD_UNIT_FORGET_BIAS())) {
            str = "one";
        } else {
            if (!innerLayerConfigFromConfig.containsKey(this.conf.getLAYER_FIELD_FORGET_BIAS_INIT())) {
                throw new InvalidKerasConfigurationException("Keras LSTM layer config missing " + this.conf.getLAYER_FIELD_FORGET_BIAS_INIT() + " field");
            }
            str = (String) innerLayerConfigFromConfig.get(this.conf.getLAYER_FIELD_FORGET_BIAS_INIT());
        }
        String str2 = str;
        boolean z2 = -1;
        switch (str2.hashCode()) {
            case 110182:
                if (str2.equals("one")) {
                    z2 = true;
                    break;
                }
                break;
            case 3735208:
                if (str2.equals("zero")) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                d = 0.0d;
                break;
            case true:
                d = 1.0d;
                break;
            default:
                if (!z) {
                    d = 1.0d;
                    log.warn("Unsupported LSTM forget gate bias initialization: " + str + " (using 1 instead)");
                    break;
                } else {
                    throw new UnsupportedKerasConfigurationException("Unsupported LSTM forget gate bias initialization: " + str);
                }
        }
        return d;
    }

    public String getLSTM_FORGET_BIAS_INIT_ZERO() {
        getClass();
        return "zero";
    }

    public String getLSTM_FORGET_BIAS_INIT_ONE() {
        getClass();
        return "one";
    }

    public int getNUM_TRAINABLE_PARAMS_KERAS_2() {
        getClass();
        return 3;
    }

    public int getNUM_TRAINABLE_PARAMS() {
        getClass();
        return 12;
    }

    public String getKERAS_PARAM_NAME_W_C() {
        getClass();
        return "W_c";
    }

    public String getKERAS_PARAM_NAME_W_F() {
        getClass();
        return "W_f";
    }

    public String getKERAS_PARAM_NAME_W_I() {
        getClass();
        return "W_i";
    }

    public String getKERAS_PARAM_NAME_W_O() {
        getClass();
        return "W_o";
    }

    public String getKERAS_PARAM_NAME_U_C() {
        getClass();
        return "U_c";
    }

    public String getKERAS_PARAM_NAME_U_F() {
        getClass();
        return "U_f";
    }

    public String getKERAS_PARAM_NAME_U_I() {
        getClass();
        return "U_i";
    }

    public String getKERAS_PARAM_NAME_U_O() {
        getClass();
        return "U_o";
    }

    public String getKERAS_PARAM_NAME_B_C() {
        getClass();
        return "b_c";
    }

    public String getKERAS_PARAM_NAME_B_F() {
        getClass();
        return "b_f";
    }

    public String getKERAS_PARAM_NAME_B_I() {
        getClass();
        return "b_i";
    }

    public String getKERAS_PARAM_NAME_B_O() {
        getClass();
        return "b_o";
    }

    public int getNUM_WEIGHTS_IN_KERAS_LSTM() {
        getClass();
        return 12;
    }

    public boolean isReturnSequences() {
        return this.returnSequences;
    }

    public void setUnroll(boolean z) {
        this.unroll = z;
    }

    public void setReturnSequences(boolean z) {
        this.returnSequences = z;
    }

    public String toString() {
        return "KerasLstm(LSTM_FORGET_BIAS_INIT_ZERO=" + getLSTM_FORGET_BIAS_INIT_ZERO() + ", LSTM_FORGET_BIAS_INIT_ONE=" + getLSTM_FORGET_BIAS_INIT_ONE() + ", NUM_TRAINABLE_PARAMS_KERAS_2=" + getNUM_TRAINABLE_PARAMS_KERAS_2() + ", NUM_TRAINABLE_PARAMS=" + getNUM_TRAINABLE_PARAMS() + ", KERAS_PARAM_NAME_W_C=" + getKERAS_PARAM_NAME_W_C() + ", KERAS_PARAM_NAME_W_F=" + getKERAS_PARAM_NAME_W_F() + ", KERAS_PARAM_NAME_W_I=" + getKERAS_PARAM_NAME_W_I() + ", KERAS_PARAM_NAME_W_O=" + getKERAS_PARAM_NAME_W_O() + ", KERAS_PARAM_NAME_U_C=" + getKERAS_PARAM_NAME_U_C() + ", KERAS_PARAM_NAME_U_F=" + getKERAS_PARAM_NAME_U_F() + ", KERAS_PARAM_NAME_U_I=" + getKERAS_PARAM_NAME_U_I() + ", KERAS_PARAM_NAME_U_O=" + getKERAS_PARAM_NAME_U_O() + ", KERAS_PARAM_NAME_B_C=" + getKERAS_PARAM_NAME_B_C() + ", KERAS_PARAM_NAME_B_F=" + getKERAS_PARAM_NAME_B_F() + ", KERAS_PARAM_NAME_B_I=" + getKERAS_PARAM_NAME_B_I() + ", KERAS_PARAM_NAME_B_O=" + getKERAS_PARAM_NAME_B_O() + ", NUM_WEIGHTS_IN_KERAS_LSTM=" + getNUM_WEIGHTS_IN_KERAS_LSTM() + ", unroll=" + getUnroll() + ", returnSequences=" + isReturnSequences() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof KerasLstm)) {
            return false;
        }
        KerasLstm kerasLstm = (KerasLstm) obj;
        if (!kerasLstm.canEqual(this)) {
            return false;
        }
        String lstm_forget_bias_init_zero = getLSTM_FORGET_BIAS_INIT_ZERO();
        String lstm_forget_bias_init_zero2 = kerasLstm.getLSTM_FORGET_BIAS_INIT_ZERO();
        if (lstm_forget_bias_init_zero == null) {
            if (lstm_forget_bias_init_zero2 != null) {
                return false;
            }
        } else if (!lstm_forget_bias_init_zero.equals(lstm_forget_bias_init_zero2)) {
            return false;
        }
        String lstm_forget_bias_init_one = getLSTM_FORGET_BIAS_INIT_ONE();
        String lstm_forget_bias_init_one2 = kerasLstm.getLSTM_FORGET_BIAS_INIT_ONE();
        if (lstm_forget_bias_init_one == null) {
            if (lstm_forget_bias_init_one2 != null) {
                return false;
            }
        } else if (!lstm_forget_bias_init_one.equals(lstm_forget_bias_init_one2)) {
            return false;
        }
        if (getNUM_TRAINABLE_PARAMS_KERAS_2() != kerasLstm.getNUM_TRAINABLE_PARAMS_KERAS_2() || getNUM_TRAINABLE_PARAMS() != kerasLstm.getNUM_TRAINABLE_PARAMS()) {
            return false;
        }
        String keras_param_name_w_c = getKERAS_PARAM_NAME_W_C();
        String keras_param_name_w_c2 = kerasLstm.getKERAS_PARAM_NAME_W_C();
        if (keras_param_name_w_c == null) {
            if (keras_param_name_w_c2 != null) {
                return false;
            }
        } else if (!keras_param_name_w_c.equals(keras_param_name_w_c2)) {
            return false;
        }
        String keras_param_name_w_f = getKERAS_PARAM_NAME_W_F();
        String keras_param_name_w_f2 = kerasLstm.getKERAS_PARAM_NAME_W_F();
        if (keras_param_name_w_f == null) {
            if (keras_param_name_w_f2 != null) {
                return false;
            }
        } else if (!keras_param_name_w_f.equals(keras_param_name_w_f2)) {
            return false;
        }
        String keras_param_name_w_i = getKERAS_PARAM_NAME_W_I();
        String keras_param_name_w_i2 = kerasLstm.getKERAS_PARAM_NAME_W_I();
        if (keras_param_name_w_i == null) {
            if (keras_param_name_w_i2 != null) {
                return false;
            }
        } else if (!keras_param_name_w_i.equals(keras_param_name_w_i2)) {
            return false;
        }
        String keras_param_name_w_o = getKERAS_PARAM_NAME_W_O();
        String keras_param_name_w_o2 = kerasLstm.getKERAS_PARAM_NAME_W_O();
        if (keras_param_name_w_o == null) {
            if (keras_param_name_w_o2 != null) {
                return false;
            }
        } else if (!keras_param_name_w_o.equals(keras_param_name_w_o2)) {
            return false;
        }
        String keras_param_name_u_c = getKERAS_PARAM_NAME_U_C();
        String keras_param_name_u_c2 = kerasLstm.getKERAS_PARAM_NAME_U_C();
        if (keras_param_name_u_c == null) {
            if (keras_param_name_u_c2 != null) {
                return false;
            }
        } else if (!keras_param_name_u_c.equals(keras_param_name_u_c2)) {
            return false;
        }
        String keras_param_name_u_f = getKERAS_PARAM_NAME_U_F();
        String keras_param_name_u_f2 = kerasLstm.getKERAS_PARAM_NAME_U_F();
        if (keras_param_name_u_f == null) {
            if (keras_param_name_u_f2 != null) {
                return false;
            }
        } else if (!keras_param_name_u_f.equals(keras_param_name_u_f2)) {
            return false;
        }
        String keras_param_name_u_i = getKERAS_PARAM_NAME_U_I();
        String keras_param_name_u_i2 = kerasLstm.getKERAS_PARAM_NAME_U_I();
        if (keras_param_name_u_i == null) {
            if (keras_param_name_u_i2 != null) {
                return false;
            }
        } else if (!keras_param_name_u_i.equals(keras_param_name_u_i2)) {
            return false;
        }
        String keras_param_name_u_o = getKERAS_PARAM_NAME_U_O();
        String keras_param_name_u_o2 = kerasLstm.getKERAS_PARAM_NAME_U_O();
        if (keras_param_name_u_o == null) {
            if (keras_param_name_u_o2 != null) {
                return false;
            }
        } else if (!keras_param_name_u_o.equals(keras_param_name_u_o2)) {
            return false;
        }
        String keras_param_name_b_c = getKERAS_PARAM_NAME_B_C();
        String keras_param_name_b_c2 = kerasLstm.getKERAS_PARAM_NAME_B_C();
        if (keras_param_name_b_c == null) {
            if (keras_param_name_b_c2 != null) {
                return false;
            }
        } else if (!keras_param_name_b_c.equals(keras_param_name_b_c2)) {
            return false;
        }
        String keras_param_name_b_f = getKERAS_PARAM_NAME_B_F();
        String keras_param_name_b_f2 = kerasLstm.getKERAS_PARAM_NAME_B_F();
        if (keras_param_name_b_f == null) {
            if (keras_param_name_b_f2 != null) {
                return false;
            }
        } else if (!keras_param_name_b_f.equals(keras_param_name_b_f2)) {
            return false;
        }
        String keras_param_name_b_i = getKERAS_PARAM_NAME_B_I();
        String keras_param_name_b_i2 = kerasLstm.getKERAS_PARAM_NAME_B_I();
        if (keras_param_name_b_i == null) {
            if (keras_param_name_b_i2 != null) {
                return false;
            }
        } else if (!keras_param_name_b_i.equals(keras_param_name_b_i2)) {
            return false;
        }
        String keras_param_name_b_o = getKERAS_PARAM_NAME_B_O();
        String keras_param_name_b_o2 = kerasLstm.getKERAS_PARAM_NAME_B_O();
        if (keras_param_name_b_o == null) {
            if (keras_param_name_b_o2 != null) {
                return false;
            }
        } else if (!keras_param_name_b_o.equals(keras_param_name_b_o2)) {
            return false;
        }
        return getNUM_WEIGHTS_IN_KERAS_LSTM() == kerasLstm.getNUM_WEIGHTS_IN_KERAS_LSTM() && getUnroll() == kerasLstm.getUnroll() && isReturnSequences() == kerasLstm.isReturnSequences();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof KerasLstm;
    }

    public int hashCode() {
        String lstm_forget_bias_init_zero = getLSTM_FORGET_BIAS_INIT_ZERO();
        int hashCode = (1 * 59) + (lstm_forget_bias_init_zero == null ? 43 : lstm_forget_bias_init_zero.hashCode());
        String lstm_forget_bias_init_one = getLSTM_FORGET_BIAS_INIT_ONE();
        int hashCode2 = (((((hashCode * 59) + (lstm_forget_bias_init_one == null ? 43 : lstm_forget_bias_init_one.hashCode())) * 59) + getNUM_TRAINABLE_PARAMS_KERAS_2()) * 59) + getNUM_TRAINABLE_PARAMS();
        String keras_param_name_w_c = getKERAS_PARAM_NAME_W_C();
        int hashCode3 = (hashCode2 * 59) + (keras_param_name_w_c == null ? 43 : keras_param_name_w_c.hashCode());
        String keras_param_name_w_f = getKERAS_PARAM_NAME_W_F();
        int hashCode4 = (hashCode3 * 59) + (keras_param_name_w_f == null ? 43 : keras_param_name_w_f.hashCode());
        String keras_param_name_w_i = getKERAS_PARAM_NAME_W_I();
        int hashCode5 = (hashCode4 * 59) + (keras_param_name_w_i == null ? 43 : keras_param_name_w_i.hashCode());
        String keras_param_name_w_o = getKERAS_PARAM_NAME_W_O();
        int hashCode6 = (hashCode5 * 59) + (keras_param_name_w_o == null ? 43 : keras_param_name_w_o.hashCode());
        String keras_param_name_u_c = getKERAS_PARAM_NAME_U_C();
        int hashCode7 = (hashCode6 * 59) + (keras_param_name_u_c == null ? 43 : keras_param_name_u_c.hashCode());
        String keras_param_name_u_f = getKERAS_PARAM_NAME_U_F();
        int hashCode8 = (hashCode7 * 59) + (keras_param_name_u_f == null ? 43 : keras_param_name_u_f.hashCode());
        String keras_param_name_u_i = getKERAS_PARAM_NAME_U_I();
        int hashCode9 = (hashCode8 * 59) + (keras_param_name_u_i == null ? 43 : keras_param_name_u_i.hashCode());
        String keras_param_name_u_o = getKERAS_PARAM_NAME_U_O();
        int hashCode10 = (hashCode9 * 59) + (keras_param_name_u_o == null ? 43 : keras_param_name_u_o.hashCode());
        String keras_param_name_b_c = getKERAS_PARAM_NAME_B_C();
        int hashCode11 = (hashCode10 * 59) + (keras_param_name_b_c == null ? 43 : keras_param_name_b_c.hashCode());
        String keras_param_name_b_f = getKERAS_PARAM_NAME_B_F();
        int hashCode12 = (hashCode11 * 59) + (keras_param_name_b_f == null ? 43 : keras_param_name_b_f.hashCode());
        String keras_param_name_b_i = getKERAS_PARAM_NAME_B_I();
        int hashCode13 = (hashCode12 * 59) + (keras_param_name_b_i == null ? 43 : keras_param_name_b_i.hashCode());
        String keras_param_name_b_o = getKERAS_PARAM_NAME_B_O();
        return (((((((hashCode13 * 59) + (keras_param_name_b_o == null ? 43 : keras_param_name_b_o.hashCode())) * 59) + getNUM_WEIGHTS_IN_KERAS_LSTM()) * 59) + (getUnroll() ? 79 : 97)) * 59) + (isReturnSequences() ? 79 : 97);
    }
}
