package com.alibaba.alink.operator.common.associationrule;

import com.alibaba.alink.common.mapper.SISOModelMapper;
import com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/associationrule/ApplyAssociationRuleModelMapper.class */
public class ApplyAssociationRuleModelMapper extends SISOModelMapper {
    private static final long serialVersionUID = 3709131767975976366L;
    private final String sep = ",";
    private transient List<Set<String>> antecedents;
    private transient List<String> consequences;

    public ApplyAssociationRuleModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.sep = ",";
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected TypeInformation initPredResultColType() {
        return Types.STRING;
    }

    @Override // com.alibaba.alink.common.mapper.SISOModelMapper
    protected Object predictResult(Object obj) throws Exception {
        HashSet hashSet = new HashSet(Arrays.asList(((String) obj).split(",")));
        HashSet<String> hashSet2 = new HashSet();
        for (int i = 0; i < this.antecedents.size(); i++) {
            if (hashSet.containsAll(this.antecedents.get(i))) {
                String str = this.consequences.get(i);
                if (!hashSet.contains(str)) {
                    hashSet2.add(str);
                }
            }
        }
        StringBuilder sb = new StringBuilder();
        int i2 = 0;
        for (String str2 : hashSet2) {
            if (i2 > 0) {
                sb.append(",");
            }
            sb.append(str2);
            i2++;
        }
        return sb.toString();
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        int size = list.size();
        this.antecedents = new ArrayList(size);
        this.consequences = new ArrayList(size);
        list.forEach(row -> {
            String[] split = ((String) row.getField(0)).split(PrefixSpanBatchOp.RULE_SEPARATOR);
            this.consequences.add(split[1]);
            this.antecedents.add(new HashSet(Arrays.asList(split[0].split(","))));
        });
    }
}
