package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.params.dataproc.HasTargetType;
import com.alibaba.alink.params.dataproc.NumericalTypeCastParams;
import java.util.Arrays;
import java.util.function.Function;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/NumericalTypeCastMapper.class */
public class NumericalTypeCastMapper extends Mapper {
    private static final long serialVersionUID = 767160752523041431L;
    private final TypeInformation<?> targetType;
    private transient Function<Object, Object> castFunc;

    public NumericalTypeCastMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.targetType = FlinkTypeConverter.getFlinkType(((HasTargetType.TargetType) params.get(NumericalTypeCastParams.TARGET_TYPE)).toString());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        if (this.castFunc == null) {
            initCastFunc();
        }
        for (int i = 0; i < slicedSelectedSample.length(); i++) {
            slicedResult.set(i, this.castFunc.apply(slicedSelectedSample.get(i)));
        }
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        String[] strArr = (String[]) this.params.get(NumericalTypeCastParams.SELECTED_COLS);
        String[] strArr2 = (String[]) params.get(NumericalTypeCastParams.OUTPUT_COLS);
        if (strArr2 == null || strArr2.length == 0) {
            strArr2 = strArr;
        }
        TypeInformation<?> flinkType = FlinkTypeConverter.getFlinkType(((HasTargetType.TargetType) params.get(NumericalTypeCastParams.TARGET_TYPE)).toString());
        TypeInformation[] typeInformationArr = (TypeInformation[]) Arrays.stream(strArr2).map(str -> {
            return flinkType;
        }).toArray(i -> {
            return new TypeInformation[i];
        });
        String[] strArr3 = (String[]) params.get(NumericalTypeCastParams.RESERVED_COLS);
        if (strArr3 == null || strArr3.length == 0) {
            strArr3 = tableSchema.getFieldNames();
        }
        return Tuple4.of(strArr, strArr2, typeInformationArr, strArr3);
    }

    private void initCastFunc() {
        if (this.targetType.equals(Types.DOUBLE)) {
            this.castFunc = obj -> {
                if (obj == null) {
                    return null;
                }
                return obj instanceof String ? Double.valueOf(Double.parseDouble((String) obj)) : Double.valueOf(((Number) obj).doubleValue());
            };
            return;
        }
        if (this.targetType.equals(Types.LONG)) {
            this.castFunc = obj2 -> {
                if (obj2 == null) {
                    return null;
                }
                return obj2 instanceof String ? Long.valueOf(Long.parseLong((String) obj2)) : Long.valueOf(((Number) obj2).longValue());
            };
        } else if (this.targetType.equals(Types.INT)) {
            this.castFunc = obj3 -> {
                if (obj3 == null) {
                    return null;
                }
                return obj3 instanceof String ? Integer.valueOf(Integer.parseInt((String) obj3)) : Integer.valueOf(((Number) obj3).intValue());
            };
        } else {
            if (!this.targetType.equals(Types.FLOAT)) {
                throw new AkIllegalDataException("Unsupported target type:" + this.targetType.getTypeClass().getCanonicalName());
            }
            this.castFunc = obj4 -> {
                if (obj4 == null) {
                    return null;
                }
                return obj4 instanceof String ? Float.valueOf(Float.parseFloat((String) obj4)) : Float.valueOf(((Number) obj4).floatValue());
            };
        }
    }
}
