package com.alibaba.alink.operator.common.similarity.dataConverter;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.distance.FastCategoricalDistance;
import com.alibaba.alink.operator.common.distance.LevenshteinDistance;
import com.alibaba.alink.operator.common.similarity.Sample;
import com.alibaba.alink.operator.common.similarity.modeldata.StringModelData;
import com.alibaba.alink.operator.common.similarity.similarity.Cosine;
import com.alibaba.alink.operator.common.similarity.similarity.LevenshteinSimilarity;
import com.alibaba.alink.operator.common.similarity.similarity.LongestCommonSubsequence;
import com.alibaba.alink.operator.common.similarity.similarity.LongestCommonSubsequenceSimilarity;
import com.alibaba.alink.operator.common.similarity.similarity.SubsequenceKernelSimilarity;
import com.alibaba.alink.params.similarity.StringTextNearestNeighborTrainParams;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/similarity/dataConverter/StringModelDataConverter.class */
public class StringModelDataConverter extends NearestNeighborDataConverter<StringModelData> {
    private static final long serialVersionUID = 8761170480003926433L;
    public static ParamInfo<Boolean> TEXT = ParamInfoFactory.createParamInfo("text", Boolean.class).setDescription("text").setHasDefaultValue(false).build();
    private static int ROW_SIZE = 3;
    private static int ID_INDEX = 0;
    private static int DATA_INDEX = 1;
    private static int LABEL_INDEX = 2;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.alibaba.alink.operator.common.similarity.dataConverter.StringModelDataConverter$3, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/operator/common/similarity/dataConverter/StringModelDataConverter$3.class */
    public static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric = new int[StringTextNearestNeighborTrainParams.Metric.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric[StringTextNearestNeighborTrainParams.Metric.LEVENSHTEIN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric[StringTextNearestNeighborTrainParams.Metric.LEVENSHTEIN_SIM.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric[StringTextNearestNeighborTrainParams.Metric.LCS.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric[StringTextNearestNeighborTrainParams.Metric.LCS_SIM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric[StringTextNearestNeighborTrainParams.Metric.SSK.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric[StringTextNearestNeighborTrainParams.Metric.COSINE.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    public StringModelDataConverter() {
        this.rowSize = ROW_SIZE;
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public TableSchema getModelDataSchema() {
        return new TableSchema(new String[]{"ID", "DATA", "LABEL"}, new TypeInformation[]{Types.STRING, Types.STRING, Types.STRING});
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public StringModelData loadModelData(List<Row> list) {
        Sample[] sampleArr = new Sample[list.size() - 1];
        int i = 0;
        for (Row row : list) {
            if (row.getField(ID_INDEX) != null) {
                Object field = row.getField(ID_INDEX);
                String str = (String) row.getField(DATA_INDEX);
                String str2 = (String) row.getField(LABEL_INDEX);
                int i2 = i;
                i++;
                sampleArr[i2] = new Sample(str, Row.of(new Object[]{field}), null == str2 ? null : Double.valueOf(str2));
            }
        }
        return new StringModelData(sampleArr, initSimilarity(this.meta), ((Boolean) this.meta.get(TEXT)).booleanValue());
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public DataSet<Row> buildIndex(BatchOperator batchOperator, final Params params) {
        DataSet<Row> dataSet = batchOperator.getDataSet();
        final FastCategoricalDistance initSimilarity = initSimilarity(params);
        return dataSet.map(new MapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.StringModelDataConverter.1
            private static final long serialVersionUID = -600268964767461036L;

            public Row map(Row row) throws Exception {
                Sample prepareSample = initSimilarity.prepareSample((String) row.getField(1), ((Boolean) params.get(StringModelDataConverter.TEXT)).booleanValue());
                Row row2 = new Row(StringModelDataConverter.ROW_SIZE);
                row2.setField(StringModelDataConverter.ID_INDEX, row.getField(0).toString());
                if (prepareSample.getLabel() != null) {
                    row2.setField(StringModelDataConverter.LABEL_INDEX, prepareSample.getLabel().toString());
                }
                row2.setField(StringModelDataConverter.DATA_INDEX, prepareSample.getStr());
                return row2;
            }
        }).mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.StringModelDataConverter.2
            private static final long serialVersionUID = -1078356373351365760L;

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                Params params2 = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    params2 = params;
                }
                new StringModelDataConverter().save2(Tuple2.of(params2, iterable), collector);
            }
        }).name("build_model");
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public List<Row> buildIndex(MTable mTable, Params params) {
        FastCategoricalDistance initSimilarity = initSimilarity(params);
        ArrayList arrayList = new ArrayList();
        for (Row row : mTable.getRows()) {
            Sample prepareSample = initSimilarity.prepareSample((String) row.getField(1), ((Boolean) params.get(TEXT)).booleanValue());
            Row row2 = new Row(ROW_SIZE);
            row2.setField(ID_INDEX, row.getField(0).toString());
            if (prepareSample.getLabel() != null) {
                row2.setField(LABEL_INDEX, prepareSample.getLabel().toString());
            }
            row2.setField(DATA_INDEX, prepareSample.getStr());
            arrayList.add(row2);
        }
        RowCollector rowCollector = new RowCollector();
        save2(Tuple2.of(params, arrayList), (Collector<Row>) rowCollector);
        return rowCollector.getRows();
    }

    private static FastCategoricalDistance initSimilarity(Params params) {
        switch (AnonymousClass3.$SwitchMap$com$alibaba$alink$params$similarity$StringTextNearestNeighborTrainParams$Metric[((StringTextNearestNeighborTrainParams.Metric) params.get(StringTextNearestNeighborTrainParams.METRIC)).ordinal()]) {
            case 1:
                return new LevenshteinDistance();
            case 2:
                return new LevenshteinSimilarity();
            case 3:
                return new LongestCommonSubsequence();
            case 4:
                return new LongestCommonSubsequenceSimilarity();
            case 5:
                return new SubsequenceKernelSimilarity(((Integer) params.get(StringTextNearestNeighborTrainParams.WINDOW_SIZE)).intValue(), ((Double) params.get(StringTextNearestNeighborTrainParams.LAMBDA)).doubleValue());
            case TableUtil.DISPLAY_SIZE /* 6 */:
                return new Cosine(((Integer) params.get(StringTextNearestNeighborTrainParams.WINDOW_SIZE)).intValue());
            default:
                throw new AkUnsupportedOperationException("unsupported distance type:" + ((StringTextNearestNeighborTrainParams.Metric) params.get(StringTextNearestNeighborTrainParams.METRIC)).toString());
        }
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public /* bridge */ /* synthetic */ StringModelData loadModelData(List list) {
        return loadModelData((List<Row>) list);
    }
}
