package com.alibaba.alink.operator.common.nlp.bert;

import com.alibaba.alink.common.linalg.tensor.FloatTensor;
import com.alibaba.alink.common.linalg.tensor.IntTensor;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.params.shared.colname.HasOutputCol;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.params.tensorflow.bert.HasHiddenStatesCol;
import com.alibaba.alink.params.tensorflow.bert.HasLayer;
import com.alibaba.alink.params.tensorflow.bert.HasLengthCol;
import com.google.common.primitives.Floats;
import java.util.Arrays;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/nlp/bert/BertEmbeddingExtractorMapper.class */
public class BertEmbeddingExtractorMapper extends Mapper {
    public static final String SEP_CHAR = " ";
    int layer;

    public BertEmbeddingExtractorMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.layer = ((Integer) params.get(HasLayer.LAYER)).intValue();
    }

    protected static int calcIndex(int[] iArr, long[] jArr) {
        int i = 0;
        for (int i2 = 0; i2 < jArr.length; i2++) {
            if (i2 > 0) {
                i = (int) (i * jArr[i2]);
            }
            i += iArr[i2];
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        int i = ((IntTensor) slicedSelectedSample.get(0)).getInt(0);
        FloatTensor floatTensor = (FloatTensor) slicedSelectedSample.get(1);
        long[] shape = floatTensor.shape();
        int i2 = (int) shape[shape.length - 1];
        float[] fArr = new float[i2];
        Arrays.fill(fArr, 0.0f);
        long[] jArr = {shape[0] + this.layer, 0, 0};
        jArr[1] = 0;
        while (jArr[1] < i) {
            jArr[2] = 0;
            while (jArr[2] < i2) {
                int i3 = (int) jArr[2];
                fArr[i3] = fArr[i3] + (floatTensor.getFloat(jArr) / i);
                jArr[2] = jArr[2] + 1;
            }
            jArr[1] = jArr[1] + 1;
        }
        slicedResult.set(0, Floats.join(" ", fArr));
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        String str = (String) params.get(HasLengthCol.LENGTH_COL);
        String str2 = (String) params.get(HasHiddenStatesCol.HIDDEN_STATES_COL);
        String str3 = (String) params.get(HasOutputCol.OUTPUT_COL);
        return Tuple4.of(new String[]{str, str2}, new String[]{str3}, new TypeInformation[]{Types.STRING}, (String[]) params.get(HasReservedColsDefaultAsNull.RESERVED_COLS));
    }
}
