package com.alibaba.alink.operator.common.recommendation;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.MTableUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.MapperChain;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.common.recommendation.RecommUtils;
import com.alibaba.alink.params.recommendation.RecommendationRankingParams;
import com.alibaba.alink.pipeline.ModelExporterUtils;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/RecommendationRankingMapper.class */
public class RecommendationRankingMapper extends ModelMapper {
    private static final long serialVersionUID = -3353498411027168031L;
    private MapperChain mapperList;
    private List<Row> modelRows;
    private int itemListIdx;
    private int schemaLen;
    private String[] recallNames;
    private String scoreCol;
    private int scoreIndex;
    private int topN;
    private String mTableSchemaStr;

    public RecommendationRankingMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelRows = list;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        this.itemListIdx = TableUtil.findColIndex(tableSchema2.getFieldNames(), (String) params.get(RecommendationRankingParams.M_TABLE_COL));
        String str = (String) params.get(RecommendationRankingParams.OUTPUT_COL);
        String[] fieldNames = params.contains(RecommendationRankingParams.RESERVED_COLS) ? (String[]) params.get(RecommendationRankingParams.RESERVED_COLS) : getDataSchema().getFieldNames();
        this.scoreCol = params.contains(RecommendationRankingParams.RANKING_COL) ? (String) params.get(RecommendationRankingParams.RANKING_COL) : KObjectUtil.SCORE_NAME;
        this.topN = ((Integer) params.get(RecommendationRankingParams.TOP_N)).intValue();
        if (str == null) {
            str = (String) params.get(RecommendationRankingParams.M_TABLE_COL);
        }
        return Tuple4.of(tableSchema2.getFieldNames(), new String[]{str}, new TypeInformation[]{AlinkTypes.M_TABLE}, fieldNames);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        MTable mTable = MTableUtil.getMTable(slicedSelectedSample.get(this.itemListIdx));
        if (this.mapperList == null) {
            TableSchema schema = mTable.getSchema();
            String[] fieldNames = getDataSchema().getFieldNames();
            this.recallNames = schema.getFieldNames();
            TypeInformation[] fieldTypes = getDataSchema().getFieldTypes();
            TypeInformation[] fieldTypes2 = schema.getFieldTypes();
            this.schemaLen = (this.recallNames.length + fieldNames.length) - 1;
            String[] strArr = new String[this.schemaLen];
            TypeInformation[] typeInformationArr = new TypeInformation[this.schemaLen];
            int i = 0;
            for (int i2 = 0; i2 < fieldNames.length; i2++) {
                if (i2 != this.itemListIdx) {
                    strArr[i] = fieldNames[i2];
                    int i3 = i;
                    i++;
                    typeInformationArr[i3] = fieldTypes[i2];
                } else {
                    for (int i4 = 0; i4 < this.recallNames.length; i4++) {
                        strArr[i] = this.recallNames[i4];
                        int i5 = i;
                        i++;
                        typeInformationArr[i5] = fieldTypes2[i4];
                    }
                }
            }
            this.mapperList = ModelExporterUtils.loadMapperListFromStages(this.modelRows, getModelSchema(), new TableSchema(strArr, typeInformationArr));
            this.mapperList.open();
            String[] fieldNames2 = this.mapperList.getOutTableSchema().getFieldNames();
            this.scoreIndex = this.scoreCol != null ? TableUtil.findColIndex(fieldNames2, this.scoreCol) : fieldNames2.length - 1;
            StringBuilder sb = new StringBuilder();
            for (int i6 = 0; i6 < this.recallNames.length; i6++) {
                sb.append(this.recallNames[i6]).append(" ").append(FlinkTypeConverter.getTypeString((TypeInformation<?>) fieldTypes2[i6]));
                sb.append(", ");
            }
            sb.append(this.scoreCol).append(" DOUBLE");
            this.mTableSchemaStr = sb.toString();
        }
        List<Row> rows = mTable.getRows();
        Object[][] objArr = new Object[this.recallNames.length][rows.size()];
        int[] findColIndices = TableUtil.findColIndices(mTable.getColNames(), this.recallNames);
        for (int i7 = 0; i7 < this.recallNames.length; i7++) {
            for (int i8 = 0; i8 < rows.size(); i8++) {
                objArr[i7][i8] = rows.get(i8).getField(findColIndices[i7]);
            }
        }
        int length = objArr[0].length;
        double[] dArr = new double[length];
        Row row = new Row(this.schemaLen);
        for (int i9 = 0; i9 < length; i9++) {
            int i10 = 0;
            for (int i11 = 0; i11 < slicedSelectedSample.length(); i11++) {
                if (i11 != this.itemListIdx) {
                    int i12 = i10;
                    i10++;
                    row.setField(i12, slicedSelectedSample.get(i11));
                } else {
                    for (int i13 = 0; i13 < this.recallNames.length; i13++) {
                        int i14 = i10;
                        i10++;
                        row.setField(i14, objArr[i13][i9]);
                    }
                }
            }
            dArr[i9] = Double.parseDouble(this.mapperList.map(row).getField(this.scoreIndex).toString());
        }
        int min = Math.min(length, this.topN);
        RecommUtils.RecommPriorityQueue recommPriorityQueue = new RecommUtils.RecommPriorityQueue(min);
        for (int i15 = 0; i15 < length; i15++) {
            recommPriorityQueue.addOrReplace(Integer.valueOf(i15), dArr[i15]);
        }
        Tuple2<List<Object>, List<Double>> orderedObjects = recommPriorityQueue.getOrderedObjects();
        ArrayList arrayList = new ArrayList(rows.size());
        for (int i16 = 0; i16 < min; i16++) {
            Row row2 = new Row(this.recallNames.length + 1);
            for (int i17 = 0; i17 < this.recallNames.length; i17++) {
                row2.setField(i17, objArr[i17][((Integer) ((List) orderedObjects.f0).get(i16)).intValue()]);
            }
            row2.setField(this.recallNames.length, ((List) orderedObjects.f1).get(i16));
            arrayList.add(row2);
        }
        slicedResult.set(0, new MTable(arrayList, this.mTableSchemaStr));
    }
}
