package com.alibaba.alink.operator.common.tensorflow;

import com.alibaba.alink.common.dl.utils.PythonFileUtils;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelData;
import com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelData;
import com.google.common.collect.Iterables;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/TFModelDataConverterUtils.class */
public class TFModelDataConverterUtils {
    private static final ParamInfo<Long> TF_MODEL_PARTITION_START = ParamInfoFactory.createParamInfo("tfModelPartitionStart", Long.class).setDescription("tfModelPartitionStart").build();
    private static final ParamInfo<Long> TF_MODEL_PARTITION_SIZE = ParamInfoFactory.createParamInfo("tfModelPartitionSize", Long.class).setDescription("tfModelPartitionSize").build();
    private static final ParamInfo<String> PREPROCESS_PIPELINE_MODEL_SCHEMA_STR = ParamInfoFactory.createParamInfo("preprocessPipelineModelSchemaStr", String.class).setDescription("preprocessPipelineModelSchemaStr").setHasDefaultValue(null).build();
    private static final ParamInfo<Long> PREPROCESS_PIPELINE_MODEL_PARTITION_START = ParamInfoFactory.createParamInfo("preprocessPipelineModelPartitionStart", Long.class).setDescription("preprocessPipelineModelPartitionStart").build();
    private static final ParamInfo<Long> PREPROCESS_PIPELINE_MODEL_PARTITION_SIZE = ParamInfoFactory.createParamInfo("preprocessPipelineModelPartitionSize", Long.class).setDescription("preprocessPipelineModelPartitionSize").build();
    public static ParamInfo<String[]> TF_INPUT_COLS = ParamInfoFactory.createParamInfo("tfInputCols", String[].class).setDescription("tfInputCols").build();
    public static ParamInfo<String> TF_OUTPUT_SIGNATURE_DEF = ParamInfoFactory.createParamInfo("tfOutputSignatureDef", String.class).setDescription("tfOutputSignatureDef").build();
    public static ParamInfo<Boolean> IS_OUTPUT_LOGITS = ParamInfoFactory.createParamInfo("isOutputLogits", Boolean.class).setDescription("isOutputLogits").setHasDefaultValue(false).build();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/TFModelDataConverterUtils$ModelRowsIterable.class */
    public static class ModelRowsIterable implements Iterable<String> {
        private final Iterable<Row> concat;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/TFModelDataConverterUtils$ModelRowsIterable$ModelRowsIterator.class */
        public static class ModelRowsIterator implements Iterator<String> {
            private final Iterator<Row> iter;

            public ModelRowsIterator(Iterator<Row> it) {
                this.iter = it;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.iter.hasNext();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public String next() {
                Row next = this.iter.next();
                Object[] objArr = new Object[next.getArity()];
                for (int i = 0; i < next.getArity(); i++) {
                    objArr[i] = next.getField(i);
                }
                return JsonConverter.toJson(objArr);
            }
        }

        @SafeVarargs
        public ModelRowsIterable(Iterable<Row>... iterableArr) {
            this.concat = Iterables.concat(iterableArr);
        }

        @Override // java.lang.Iterable
        /* renamed from: iterator, reason: merged with bridge method [inline-methods] */
        public Iterator<String> iterator2() {
            return new ModelRowsIterator(this.concat.iterator());
        }
    }

    public static Tuple3<Params, Iterable<String>, Iterable<Object>> serializeRegressionModel(TFTableModelRegressionModelData tFTableModelRegressionModelData) {
        Params m1495clone = tFTableModelRegressionModelData.getMeta().m1495clone();
        m1495clone.set((ParamInfo<ParamInfo<String[]>>) TF_INPUT_COLS, (ParamInfo<String[]>) tFTableModelRegressionModelData.getTfInputCols());
        m1495clone.set((ParamInfo<ParamInfo<String>>) TF_OUTPUT_SIGNATURE_DEF, (ParamInfo<String>) tFTableModelRegressionModelData.getTfOutputSignatureDef());
        m1495clone.set((ParamInfo<ParamInfo<String>>) PREPROCESS_PIPELINE_MODEL_SCHEMA_STR, (ParamInfo<String>) tFTableModelRegressionModelData.getPreprocessPipelineModelSchemaStr());
        Iterable<Row> tfModelRows = tFTableModelRegressionModelData.getTfModelRows();
        List<Row> preprocessPipelineModelRows = tFTableModelRegressionModelData.getPreprocessPipelineModelRows();
        ModelRowsIterable modelRowsIterable = new ModelRowsIterable(preprocessPipelineModelRows, tfModelRows);
        long size = preprocessPipelineModelRows.size();
        m1495clone.set((ParamInfo<ParamInfo<Long>>) PREPROCESS_PIPELINE_MODEL_PARTITION_START, (ParamInfo<Long>) 0L);
        m1495clone.set((ParamInfo<ParamInfo<Long>>) PREPROCESS_PIPELINE_MODEL_PARTITION_SIZE, (ParamInfo<Long>) Long.valueOf(size));
        m1495clone.set((ParamInfo<ParamInfo<Long>>) TF_MODEL_PARTITION_START, (ParamInfo<Long>) Long.valueOf(0 + size));
        return Tuple3.of(m1495clone, modelRowsIterable, new ArrayList());
    }

    public static Tuple3<Params, Iterable<String>, Iterable<Object>> serializeClassificationModel(TFTableModelClassificationModelData tFTableModelClassificationModelData) {
        Tuple3<Params, Iterable<String>, Iterable<Object>> serializeRegressionModel = serializeRegressionModel(tFTableModelClassificationModelData);
        Params params = (Params) serializeRegressionModel.f0;
        Iterable iterable = (Iterable) serializeRegressionModel.f1;
        params.set((ParamInfo<ParamInfo<Boolean>>) IS_OUTPUT_LOGITS, (ParamInfo<Boolean>) Boolean.valueOf(tFTableModelClassificationModelData.getIsLogits()));
        return Tuple3.of(params, iterable, tFTableModelClassificationModelData.getSortedLabels());
    }

    private static List<Row> extractModelRows(Iterator<String> it, long j) {
        ArrayList arrayList = new ArrayList();
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return arrayList;
            }
            Object[] objArr = (Object[]) JsonConverter.fromJson(it.next(), Object[].class);
            int length = objArr.length;
            Row row = new Row(length);
            for (int i = 0; i < length; i++) {
                row.setField(i, objArr[i]);
            }
            row.setField(0, Long.valueOf(((Integer) row.getField(0)).longValue()));
            arrayList.add(row);
            j2 = j3 + 1;
        }
    }

    /* JADX WARN: Failed to calculate best type for var: r14v1 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r15v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 14, insn: 0x00b4: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r14 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:30:0x00b4 */
    /* JADX WARN: Not initialized variable reg: 15, insn: 0x00b9: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r15 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:32:0x00b9 */
    /* JADX WARN: Type inference failed for: r14v1, types: [java.io.FileOutputStream] */
    /* JADX WARN: Type inference failed for: r15v0, types: [java.lang.Throwable] */
    private static String writeModelRowsToFile(Iterator<String> it, long j) {
        Path resolve = PythonFileUtils.createTempDir("saved_model_").resolve((String) ((Object[]) JsonConverter.fromJson(it.next(), Object[].class))[1]);
        Base64.Decoder decoder = Base64.getDecoder();
        try {
            try {
                FileOutputStream fileOutputStream = new FileOutputStream(resolve.toFile());
                Throwable th = null;
                for (long j2 = 1; j2 < j; j2++) {
                    fileOutputStream.write(decoder.decode((String) ((Object[]) JsonConverter.fromJson(it.next(), Object[].class))[1]));
                }
                if (fileOutputStream != null) {
                    if (0 != 0) {
                        try {
                            fileOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileOutputStream.close();
                    }
                }
                return resolve.toFile().getAbsolutePath();
            } finally {
            }
        } catch (IOException e) {
            throw new AkUnclassifiedErrorException("Failed to write model rows to file.", e);
        }
    }

    private static String writeModelRowsToFile(Iterator<String> it) {
        Path resolve = PythonFileUtils.createTempDir("saved_model_").resolve((String) ((Object[]) JsonConverter.fromJson(it.next(), Object[].class))[1]);
        Base64.Decoder decoder = Base64.getDecoder();
        try {
            FileOutputStream fileOutputStream = new FileOutputStream(resolve.toFile());
            Throwable th = null;
            while (it.hasNext()) {
                try {
                    try {
                        fileOutputStream.write(decoder.decode((String) ((Object[]) JsonConverter.fromJson(it.next(), Object[].class))[1]));
                    } finally {
                    }
                } finally {
                }
            }
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            return resolve.toFile().getAbsolutePath();
        } catch (IOException e) {
            throw new AkUnclassifiedErrorException("Failed to write model rows to file.", e);
        }
    }

    public static void deserializeRegressionModel(TFTableModelRegressionModelData tFTableModelRegressionModelData, Params params, Iterable<String> iterable) {
        tFTableModelRegressionModelData.setMeta(params);
        tFTableModelRegressionModelData.setTfInputCols((String[]) params.get(TF_INPUT_COLS));
        tFTableModelRegressionModelData.setPreprocessPipelineModelSchemaStr((String) params.get(PREPROCESS_PIPELINE_MODEL_SCHEMA_STR));
        boolean contains = params.contains(TF_MODEL_PARTITION_SIZE);
        Iterator<String> it = iterable.iterator();
        if (!contains) {
            if (params.contains(PREPROCESS_PIPELINE_MODEL_PARTITION_SIZE)) {
                tFTableModelRegressionModelData.setPreprocessPipelineModelRows(extractModelRows(it, ((Long) params.get(PREPROCESS_PIPELINE_MODEL_PARTITION_SIZE)).longValue()));
            }
            tFTableModelRegressionModelData.setTfModelZipPath(writeModelRowsToFile(it));
        } else {
            tFTableModelRegressionModelData.setTfModelZipPath(writeModelRowsToFile(it, ((Long) params.get(TF_MODEL_PARTITION_SIZE)).longValue()));
            if (params.contains(PREPROCESS_PIPELINE_MODEL_PARTITION_SIZE)) {
                tFTableModelRegressionModelData.setPreprocessPipelineModelRows(extractModelRows(it, ((Long) params.get(PREPROCESS_PIPELINE_MODEL_PARTITION_SIZE)).longValue()));
            }
        }
    }

    public static void deserializeClassificationModel(TFTableModelClassificationModelData tFTableModelClassificationModelData, Params params, Iterable<String> iterable, Iterable<Object> iterable2) {
        deserializeRegressionModel(tFTableModelClassificationModelData, params, iterable);
        tFTableModelClassificationModelData.setSortedLabels(iterable2);
    }
}
