package com.alibaba.alink.common.mapper;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.io.Serializable;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper.class */
public abstract class Mapper implements Serializable {
    private static final long serialVersionUID = -3634328096241559957L;
    private final String[] dataFieldNames;
    private final DataType[] dataFieldTypes;
    protected final Params params;
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> ioSchema;
    private MemoryTransformer transformer;
    private SlicedSelectedSampleThreadLocal selection;
    private SlicedSlicedResultThreadLocal result;
    private SlicedSelectedSampleArrayThreadLocal selections;
    private SlicedSlicedResultArrayThreadLocal results;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper$MemoryTransformer.class */
    public static class MemoryTransformer implements Serializable {
        private final int[] reservedSelectedIndices;
        private final int[] reservedResultIndices;

        public MemoryTransformer(int[] iArr, int[] iArr2) {
            this.reservedSelectedIndices = iArr;
            this.reservedResultIndices = iArr2;
        }

        public void transform(Row row, Row row2) {
            for (int i = 0; i < this.reservedSelectedIndices.length; i++) {
                row2.setField(this.reservedResultIndices[i], row.getField(this.reservedSelectedIndices[i]));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper$SlicedResult.class */
    public static final class SlicedResult implements Serializable {
        private static final long serialVersionUID = -702606581670930534L;
        private int[] columnIndices;
        private transient Row instance;

        private SlicedResult(int[] iArr) {
            this.columnIndices = iArr;
        }

        public int length() {
            return this.columnIndices.length;
        }

        public Object get(int i) {
            return this.instance.getField(this.columnIndices[i]);
        }

        public void set(int i, Object obj) {
            this.instance.setField(this.columnIndices[i], obj);
        }

        void resetColumnIndices(int[] iArr) {
            this.columnIndices = iArr;
        }

        void resetInstance(Row row) {
            this.instance = row;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper$SlicedSelectedSample.class */
    public static final class SlicedSelectedSample implements Serializable {
        private static final long serialVersionUID = 2774288987465829372L;
        private int[] columnIndices;
        private transient Row instance;

        private SlicedSelectedSample(int[] iArr) {
            this.columnIndices = iArr;
        }

        public int length() {
            return this.columnIndices.length;
        }

        public Object get(int i) {
            return this.instance.getField(this.columnIndices[i]);
        }

        public void fillDenseVector(DenseVector denseVector, boolean z) {
            fillDenseVector(denseVector, z, null);
        }

        public void fillDenseVector(DenseVector denseVector, boolean z, int[] iArr) {
            if (z && iArr != null) {
                denseVector.set(0, 1.0d);
                int i = 0;
                int i2 = 1;
                while (i < iArr.length) {
                    if (get(iArr[i]) instanceof Number) {
                        denseVector.set(i2, ((Number) get(iArr[i])).doubleValue());
                    } else {
                        denseVector.set(i2, Criteria.INVALID_GAIN);
                    }
                    i++;
                    i2++;
                }
                return;
            }
            if (z) {
                denseVector.set(0, 1.0d);
                int i3 = 0;
                int i4 = 1;
                while (i3 < length()) {
                    if (get(i3) instanceof Number) {
                        denseVector.set(i4, ((Number) get(i3)).doubleValue());
                    } else {
                        denseVector.set(i4, Criteria.INVALID_GAIN);
                    }
                    i3++;
                    i4++;
                }
                return;
            }
            if (iArr != null) {
                for (int i5 = 0; i5 < iArr.length; i5++) {
                    if (get(iArr[i5]) instanceof Number) {
                        denseVector.set(i5, ((Number) get(iArr[i5])).doubleValue());
                    } else {
                        denseVector.set(i5, Criteria.INVALID_GAIN);
                    }
                }
                return;
            }
            for (int i6 = 0; i6 < length(); i6++) {
                if (get(i6) instanceof Number) {
                    denseVector.set(i6, ((Number) get(i6)).doubleValue());
                } else {
                    denseVector.set(i6, Criteria.INVALID_GAIN);
                }
            }
        }

        public void fillRow(Row row) {
            for (int i = 0; i < length(); i++) {
                row.setField(i, get(i));
            }
        }

        void resetColumnIndices(int[] iArr) {
            this.columnIndices = iArr;
        }

        void resetInstance(Row row) {
            this.instance = row;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper$SlicedSelectedSampleArrayThreadLocal.class */
    public static class SlicedSelectedSampleArrayThreadLocal extends ThreadLocal<SlicedSelectedSample[]> implements Serializable {
        private static final long serialVersionUID = -880790266294357596L;
        private final int[] columnIndices;
        private final int size;

        public SlicedSelectedSampleArrayThreadLocal(int[] iArr, int i) {
            this.columnIndices = iArr;
            this.size = i;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public SlicedSelectedSample[] initialValue() {
            SlicedSelectedSample[] slicedSelectedSampleArr = new SlicedSelectedSample[this.size];
            for (int i = 0; i < this.size; i++) {
                slicedSelectedSampleArr[i] = new SlicedSelectedSample(this.columnIndices);
            }
            return slicedSelectedSampleArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper$SlicedSelectedSampleThreadLocal.class */
    public static class SlicedSelectedSampleThreadLocal extends ThreadLocal<SlicedSelectedSample> implements Serializable {
        private static final long serialVersionUID = -880790266294357596L;
        private final int[] columnIndices;

        public SlicedSelectedSampleThreadLocal(int[] iArr) {
            this.columnIndices = iArr;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public SlicedSelectedSample initialValue() {
            return new SlicedSelectedSample(this.columnIndices);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper$SlicedSlicedResultArrayThreadLocal.class */
    public static class SlicedSlicedResultArrayThreadLocal extends ThreadLocal<SlicedResult[]> implements Serializable {
        private static final long serialVersionUID = 3929812061012978137L;
        private final int[] columnIndices;
        private final int size;

        public SlicedSlicedResultArrayThreadLocal(int[] iArr, int i) {
            this.columnIndices = iArr;
            this.size = i;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public SlicedResult[] initialValue() {
            SlicedResult[] slicedResultArr = new SlicedResult[this.size];
            for (int i = 0; i < this.size; i++) {
                slicedResultArr[i] = new SlicedResult(this.columnIndices);
            }
            return slicedResultArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/alibaba/alink/common/mapper/Mapper$SlicedSlicedResultThreadLocal.class */
    public static class SlicedSlicedResultThreadLocal extends ThreadLocal<SlicedResult> implements Serializable {
        private static final long serialVersionUID = 3929812061012978137L;
        private final int[] columnIndices;

        public SlicedSlicedResultThreadLocal(int[] iArr) {
            this.columnIndices = iArr;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public SlicedResult initialValue() {
            return new SlicedResult(this.columnIndices);
        }
    }

    public Mapper(TableSchema tableSchema, Params params) {
        this.dataFieldNames = tableSchema.getFieldNames();
        this.dataFieldTypes = tableSchema.getFieldDataTypes();
        this.params = null == params ? new Params() : params.m1495clone();
        this.ioSchema = prepareIoSchema(tableSchema, params);
        checkIoSchema();
        initializeSliced();
    }

    public void open() {
    }

    public void close() {
    }

    public Row map(Row row) throws Exception {
        Row row2 = new Row(getOutputSchema().getFieldNames().length);
        this.transformer.transform(row, row2);
        SlicedSelectedSample slicedSelectedSample = this.selection.get();
        SlicedResult slicedResult = this.result.get();
        slicedSelectedSample.resetInstance(row);
        slicedResult.resetInstance(row2);
        map(slicedSelectedSample, slicedResult);
        return row2;
    }

    public Row[] bunchMap(Row[] rowArr) throws Exception {
        Row[] rowArr2 = new Row[rowArr.length];
        for (int i = 0; i < rowArr.length; i++) {
            rowArr2[i] = new Row(getOutputSchema().getFieldNames().length);
        }
        map(rowArr, rowArr2, rowArr.length);
        return rowArr2;
    }

    public void map(Row[] rowArr, Row[] rowArr2, int i) throws Exception {
        if (rowArr2 == null) {
            rowArr2 = new Row[i];
            for (int i2 = 0; i2 < i; i2++) {
                rowArr2[i2] = new Row(getOutputSchema().getFieldNames().length);
            }
        }
        if (rowArr2[0] == null) {
            for (int i3 = 0; i3 < i; i3++) {
                rowArr2[i3] = new Row(getOutputSchema().getFieldNames().length);
            }
        }
        if (this.selections == null) {
            this.selections = new SlicedSelectedSampleArrayThreadLocal(TableUtil.findColIndicesWithAssertAndHint(getDataSchema(), (String[]) this.ioSchema.f0), i);
        }
        if (this.results == null) {
            this.results = new SlicedSlicedResultArrayThreadLocal(TableUtil.findColIndicesWithAssertAndHint(getOutputSchema(), (String[]) this.ioSchema.f1), i);
        }
        for (int i4 = 0; i4 < i; i4++) {
            this.transformer.transform(rowArr[i4], rowArr2[i4]);
        }
        SlicedSelectedSample[] slicedSelectedSampleArr = this.selections.get();
        SlicedResult[] slicedResultArr = this.results.get();
        for (int i5 = 0; i5 < i; i5++) {
            slicedSelectedSampleArr[i5].resetInstance(rowArr[i5]);
            slicedResultArr[i5].resetInstance(rowArr2[i5]);
        }
        bunchMap(slicedSelectedSampleArr, slicedResultArr, i);
    }

    public void bufferMap(Row row, int[] iArr, int[] iArr2) throws Exception {
        SlicedSelectedSample slicedSelectedSample = this.selection.get();
        SlicedResult slicedResult = this.result.get();
        slicedSelectedSample.resetColumnIndices(iArr);
        slicedSelectedSample.resetInstance(row);
        slicedResult.resetColumnIndices(iArr2);
        slicedResult.resetInstance(row);
        map(slicedSelectedSample, slicedResult);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract void map(SlicedSelectedSample slicedSelectedSample, SlicedResult slicedResult) throws Exception;

    protected void bunchMap(SlicedSelectedSample[] slicedSelectedSampleArr, SlicedResult[] slicedResultArr, int i) throws Exception {
        for (int i2 = 0; i2 < i; i2++) {
            map(slicedSelectedSampleArr[i2], slicedResultArr[i2]);
        }
    }

    public final TableSchema getDataSchema() {
        return TableSchema.builder().fields(this.dataFieldNames, this.dataFieldTypes).build();
    }

    public final String[] getSelectedCols() {
        return (String[]) this.ioSchema.f0;
    }

    public final String[] getResultCols() {
        return (String[]) this.ioSchema.f1;
    }

    public TableSchema getOutputSchema() {
        return new OutputColsHelper(getDataSchema(), (String[]) this.ioSchema.f1, (TypeInformation<?>[]) this.ioSchema.f2, (String[]) this.ioSchema.f3).getResultSchema();
    }

    protected abstract Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params);

    /* JADX INFO: Access modifiers changed from: protected */
    public final void initializeSliced() {
        if (null != this.ioSchema) {
            OutputColsHelper outputColsHelper = new OutputColsHelper(getDataSchema(), (String[]) this.ioSchema.f1, (TypeInformation<?>[]) this.ioSchema.f2, (String[]) this.ioSchema.f3);
            this.transformer = new MemoryTransformer(TableUtil.findColIndicesWithAssertAndHint(getDataSchema(), outputColsHelper.getReservedColumns()), TableUtil.findColIndicesWithAssertAndHint(getOutputSchema(), outputColsHelper.getReservedColumns()));
            this.selection = new SlicedSelectedSampleThreadLocal(TableUtil.findColIndicesWithAssertAndHint(getDataSchema(), (String[]) this.ioSchema.f0));
            this.result = new SlicedSlicedResultThreadLocal(TableUtil.findColIndicesWithAssertAndHint(getOutputSchema(), (String[]) this.ioSchema.f1));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void checkIoSchema() {
        if (null != this.ioSchema) {
            AkPreconditions.checkState(this.ioSchema.f0 != null, "Selected columns in mapper should not be null.");
            AkPreconditions.checkState(this.ioSchema.f1 != null, "Output columns in mapper should not be null.");
            AkPreconditions.checkState(this.ioSchema.f2 != null, "Output types in mapper should not be null.");
        }
    }
}
