package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.MatVecOp;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.feature.featurebuilder.FeatureClause;
import com.alibaba.alink.operator.common.feature.featurebuilder.FeatureClauseOperator;
import com.alibaba.alink.operator.common.feature.featurebuilder.FeatureClauseUtil;
import com.alibaba.alink.operator.common.nlp.Word2VecModelDataConverter;
import com.alibaba.alink.params.dataproc.AggLookupParams;
import java.util.HashMap;
import java.util.List;
import java.util.function.BiFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/AggLookupModelMapper.class */
public class AggLookupModelMapper extends ModelMapper {
    private FeatureClauseOperator[] operators;
    private int[] sequenceLens;
    private int numAgg;
    private HashMap<String, DenseVector> embed;
    private String delimiter;
    private int vecSize;

    public AggLookupModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.embed = new HashMap<>();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        for (int i = 0; i < this.numAgg; i++) {
            String str = (String) slicedSelectedSample.get(i);
            FeatureClauseOperator featureClauseOperator = this.operators[i];
            int i2 = this.sequenceLens[i] == -1 ? 0 : this.sequenceLens[i];
            if (null != slicedSelectedSample.get(i)) {
                String[] split = str.split(this.delimiter);
                DenseVector denseVector = null;
                double d = 0.0d;
                if (featureClauseOperator.equals(FeatureClauseOperator.CONCAT)) {
                    denseVector = new DenseVector(this.vecSize * (i2 == 0 ? split.length : i2));
                    int length = i2 == 0 ? split.length : Math.min(i2, split.length);
                    for (int i3 = 0; i3 < length; i3++) {
                        DenseVector denseVector2 = this.embed.get(split[i3]);
                        if (denseVector2 != null) {
                            for (int i4 = 0; i4 < this.vecSize; i4++) {
                                denseVector.set((i3 * this.vecSize) + i4, denseVector2.get(i4));
                            }
                        }
                    }
                } else {
                    for (String str2 : split) {
                        DenseVector denseVector3 = this.embed.get(str2);
                        if (null != denseVector3) {
                            if (null != denseVector) {
                                switch (featureClauseOperator) {
                                    case MAX:
                                        MatVecOp.apply(denseVector, denseVector3, denseVector, new BiFunction<Double, Double, Double>() { // from class: com.alibaba.alink.operator.common.dataproc.AggLookupModelMapper.1
                                            @Override // java.util.function.BiFunction
                                            public Double apply(Double d2, Double d3) {
                                                return Double.valueOf(Math.max(d2.doubleValue(), d3.doubleValue()));
                                            }
                                        });
                                        break;
                                    case MIN:
                                        MatVecOp.apply(denseVector, denseVector3, denseVector, new BiFunction<Double, Double, Double>() { // from class: com.alibaba.alink.operator.common.dataproc.AggLookupModelMapper.2
                                            @Override // java.util.function.BiFunction
                                            public Double apply(Double d2, Double d3) {
                                                return Double.valueOf(Math.min(d2.doubleValue(), d3.doubleValue()));
                                            }
                                        });
                                        break;
                                    case AVG:
                                    case SUM:
                                        denseVector.plusScaleEqual(denseVector3, 1.0d);
                                        break;
                                    default:
                                        throw new AkUnsupportedOperationException("not support yet.");
                                }
                            } else {
                                denseVector = denseVector3.mo136clone();
                            }
                            d += 1.0d;
                        }
                    }
                    if (featureClauseOperator.equals(FeatureClauseOperator.AVG) && denseVector != null) {
                        denseVector.scaleEqual(1.0d / d);
                    }
                }
                slicedResult.set(i, denseVector);
            } else if (featureClauseOperator.equals(FeatureClauseOperator.CONCAT)) {
                slicedResult.set(i, new DenseVector(i2));
            } else {
                slicedResult.set(i, null);
            }
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        FeatureClause[] extractFeatureClauses = FeatureClauseUtil.extractFeatureClauses((String) params.get(AggLookupParams.CLAUSE));
        this.numAgg = extractFeatureClauses.length;
        this.sequenceLens = new int[this.numAgg];
        this.operators = new FeatureClauseOperator[this.numAgg];
        String[] strArr = new String[this.numAgg];
        String[] strArr2 = new String[this.numAgg];
        String[] strArr3 = (String[]) params.get(AggLookupParams.RESERVED_COLS);
        if (strArr3 == null) {
            strArr3 = tableSchema2.getFieldNames();
        }
        TypeInformation[] typeInformationArr = new TypeInformation[this.numAgg];
        for (int i = 0; i < this.numAgg; i++) {
            this.operators[i] = extractFeatureClauses[i].op;
            strArr[i] = extractFeatureClauses[i].inColName;
            strArr2[i] = extractFeatureClauses[i].outColName;
            typeInformationArr[i] = AlinkTypes.DENSE_VECTOR;
            this.sequenceLens[i] = extractFeatureClauses[i].inputParams.length == 1 ? Integer.parseInt(extractFeatureClauses[i].inputParams[0].toString()) : -1;
        }
        return Tuple4.of(strArr, strArr2, typeInformationArr, strArr3);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        Word2VecModelDataConverter word2VecModelDataConverter = new Word2VecModelDataConverter();
        word2VecModelDataConverter.load(list);
        this.delimiter = (String) this.params.get(AggLookupParams.DELIMITER);
        this.vecSize = VectorUtil.getVector(VectorUtil.getVector(word2VecModelDataConverter.modelRows.get(0).getField(1))).size();
        for (Row row : word2VecModelDataConverter.modelRows) {
            this.embed.put(row.getField(0).toString(), (DenseVector) VectorUtil.getVector(VectorUtil.getVector(row.getField(1))));
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public ModelMapper createNew(List<Row> list) {
        AggLookupModelMapper aggLookupModelMapper = new AggLookupModelMapper(getModelSchema(), getDataSchema(), this.params);
        aggLookupModelMapper.embed = this.embed;
        aggLookupModelMapper.loadModel(list);
        return aggLookupModelMapper;
    }
}
