package com.alibaba.alink.operator.common.image;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.io.filesystem.FilePath;
import com.alibaba.alink.common.linalg.tensor.FloatTensor;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.image.HasImageType;
import com.alibaba.alink.params.image.WriteTensorToImageParams;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;
import java.io.IOException;
import javax.imageio.ImageIO;
import javax.imageio.stream.ImageOutputStream;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/image/WriteTensorToImageMapper.class */
public class WriteTensorToImageMapper extends Mapper {
    private final FilePath rootPath;
    private final String imageType;

    /* loaded from: input_file:com/alibaba/alink/operator/common/image/WriteTensorToImageMapper$FloatTensorToImage.class */
    public static class FloatTensorToImage {
        private static String getFormat(FilePath filePath, String str) {
            String name = filePath.getPath().getName();
            String str2 = str;
            int lastIndexOf = name.lastIndexOf(".");
            if (lastIndexOf > 0) {
                str2 = name.substring(lastIndexOf + 1);
            }
            if (ImageIO.getImageWritersBySuffix(str2).hasNext()) {
                return str2;
            }
            throw new IllegalArgumentException(String.format("Could not write the image with suffix: %s", str2));
        }

        public static void write(FloatTensor floatTensor, FilePath filePath, String str) throws IOException {
            writeToFile(writeToImage(floatTensor), filePath, str);
        }

        public static void writeToFile(BufferedImage bufferedImage, FilePath filePath, String str) throws IOException {
            filePath.getFileSystem().mkdirs(filePath.getPath().getParent());
            ImageOutputStream createImageOutputStream = ImageIO.createImageOutputStream(filePath.getFileSystem().create(filePath.getPath(), FileSystem.WriteMode.OVERWRITE));
            Throwable th = null;
            try {
                try {
                    ImageIO.write(bufferedImage, getFormat(filePath, str), createImageOutputStream);
                    if (createImageOutputStream != null) {
                        if (0 == 0) {
                            createImageOutputStream.close();
                            return;
                        }
                        try {
                            createImageOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } catch (Throwable th4) {
                if (createImageOutputStream != null) {
                    if (th != null) {
                        try {
                            createImageOutputStream.close();
                        } catch (Throwable th5) {
                            th.addSuppressed(th5);
                        }
                    } else {
                        createImageOutputStream.close();
                    }
                }
                throw th4;
            }
        }

        public static BufferedImage writeToImage(FloatTensor floatTensor) {
            int i;
            long[] shape = floatTensor.shape();
            AkPreconditions.checkArgument(shape != null && shape.length == 3);
            int i2 = (int) shape[2];
            int i3 = (int) shape[1];
            int i4 = (int) shape[0];
            if (i2 == 3) {
                i = 1;
            } else {
                if (i2 != 4) {
                    throw new IllegalArgumentException(String.format("Unsupported tensor to image. num bands: %d", Integer.valueOf(i2)));
                }
                i = 2;
            }
            floatTensor.scale(255.0f);
            BufferedImage bufferedImage = new BufferedImage(i3, i4, i);
            WritableRaster raster = bufferedImage.getRaster();
            int minX = raster.getMinX();
            int minY = raster.getMinY();
            for (int i5 = 0; i5 < shape[0]; i5++) {
                for (int i6 = 0; i6 < shape[1]; i6++) {
                    for (int i7 = 0; i7 < shape[2]; i7++) {
                        raster.setSample(minX + i6, minY + i5, i7, floatTensor.getFloat(i5, i6, i7));
                    }
                }
            }
            return bufferedImage;
        }
    }

    public WriteTensorToImageMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.rootPath = FilePath.deserialize((String) params.get(WriteTensorToImageParams.ROOT_FILE_PATH));
        this.imageType = ((HasImageType.ImageType) params.get(WriteTensorToImageParams.IMAGE_TYPE)).toString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        FloatTensorToImage.write(FloatTensor.of(TensorUtil.getTensor(slicedSelectedSample.get(0))), new FilePath(new Path(this.rootPath.getPath(), (String) slicedSelectedSample.get(1)), this.rootPath.getFileSystem()), this.imageType);
        slicedResult.set(0, slicedSelectedSample.get(0));
        slicedResult.set(1, slicedSelectedSample.get(1));
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        String str = (String) params.get(WriteTensorToImageParams.TENSOR_COL);
        String str2 = (String) params.get(WriteTensorToImageParams.RELATIVE_FILE_PATH_COL);
        return Tuple4.of(new String[]{str, str2}, new String[]{str, str2}, new TypeInformation[]{TableUtil.findColTypeWithAssertAndHint(tableSchema, str), TableUtil.findColTypeWithAssertAndHint(tableSchema, str2)}, params.get(WriteTensorToImageParams.RESERVED_COLS));
    }
}
