package com.android.tools.mlkit;

import com.android.testutils.TestUtils;
import com.android.tools.mlkit.TensorInfo;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:com/android/tools/mlkit/ModelInfoTest.class */
public class ModelInfoTest {
    private static final double DELTA = 1.0E-5d;

    @Test
    public void testImageClassificationQuantModelMetadataExtractedCorrectly() throws TfliteModelException, IOException {
        ModelInfo buildFrom = ModelInfo.buildFrom(extractByteBufferFromModel("prebuilts/tools/common/mlkit/testData/models/mobilenet_quant_metadata.tflite"));
        Assert.assertEquals(1L, buildFrom.getInputs().size());
        Assert.assertEquals(1L, buildFrom.getOutputs().size());
        Assert.assertTrue(buildFrom.isMetadataExisted());
        TensorInfo tensorInfo = (TensorInfo) buildFrom.getInputs().get(0);
        Assert.assertEquals(TensorInfo.ContentType.IMAGE, tensorInfo.getContentType());
        Assert.assertEquals("image", tensorInfo.getIdentifierName());
        Assert.assertEquals(new TensorInfo.ImageProperties(TensorInfo.ImageProperties.ColorSpaceType.RGB), tensorInfo.getImageProperties());
        TensorInfo.NormalizationParams normalizationParams = tensorInfo.getNormalizationParams();
        Assert.assertEquals(127.5d, normalizationParams.getMean()[0], DELTA);
        Assert.assertEquals(127.5d, normalizationParams.getStd()[0], DELTA);
        Assert.assertEquals(0.0d, normalizationParams.getMin()[0], DELTA);
        Assert.assertEquals(255.0d, normalizationParams.getMax()[0], DELTA);
        TensorInfo.QuantizationParams quantizationParams = tensorInfo.getQuantizationParams();
        Assert.assertEquals(128.0d, quantizationParams.getZeroPoint(), DELTA);
        Assert.assertEquals(0.0078125d, quantizationParams.getScale(), DELTA);
        TensorInfo tensorInfo2 = (TensorInfo) buildFrom.getOutputs().get(0);
        Assert.assertEquals("probability", tensorInfo2.getIdentifierName());
        Assert.assertEquals(TensorInfo.FileType.TENSOR_AXIS_LABELS, tensorInfo2.getFileType());
        TensorInfo.QuantizationParams quantizationParams2 = tensorInfo2.getQuantizationParams();
        Assert.assertEquals(0.0d, quantizationParams2.getZeroPoint(), DELTA);
        Assert.assertEquals(0.00390625d, quantizationParams2.getScale(), DELTA);
    }

    @Test
    public void testImageClassificationQuantModelExtractedCorrectly() throws TfliteModelException, IOException {
        ModelInfo buildFrom = ModelInfo.buildFrom(extractByteBufferFromModel("prebuilts/tools/common/mlkit/testData/models/mobilenet_quant_no_metadata.tflite"));
        Assert.assertEquals(1L, buildFrom.getInputs().size());
        Assert.assertEquals(1L, buildFrom.getOutputs().size());
        Assert.assertFalse(buildFrom.isMetadataExisted());
        Assert.assertEquals("inputFeature0", ((TensorInfo) buildFrom.getInputs().get(0)).getIdentifierName());
        Assert.assertEquals("outputFeature0", ((TensorInfo) buildFrom.getOutputs().get(0)).getIdentifierName());
    }

    @Test
    public void testImageClassificationModelWithMultipleLabelFiles() throws TfliteModelException, IOException {
        ModelInfo buildFrom = ModelInfo.buildFrom(extractByteBufferFromModel("prebuilts/tools/common/mlkit/testData/models/cropnet_classifier_multi_labels.tflite"));
        Assert.assertTrue(buildFrom.isMetadataExisted());
        Assert.assertEquals(1L, buildFrom.getOutputs().size());
        Assert.assertEquals("probability-labels-en.txt", ((TensorInfo) buildFrom.getOutputs().get(0)).getFileName());
    }

    @Test
    public void testObjectDetectionModelMetadataExtractedCorrectly() throws TfliteModelException, IOException {
        ModelInfo buildFrom = ModelInfo.buildFrom(extractByteBufferFromModel("prebuilts/tools/common/mlkit/testData/models/ssd_mobilenet_odt_metadata.tflite"));
        Assert.assertEquals(1L, buildFrom.getInputs().size());
        Assert.assertEquals(4L, buildFrom.getOutputs().size());
        Assert.assertTrue(buildFrom.isMetadataExisted());
        TensorInfo tensorInfo = (TensorInfo) buildFrom.getInputs().get(0);
        Assert.assertEquals(TensorInfo.ContentType.IMAGE, tensorInfo.getContentType());
        Assert.assertEquals("image", tensorInfo.getIdentifierName());
        Assert.assertEquals(new TensorInfo.ImageProperties(TensorInfo.ImageProperties.ColorSpaceType.RGB), tensorInfo.getImageProperties());
        TensorInfo.NormalizationParams normalizationParams = tensorInfo.getNormalizationParams();
        Assert.assertEquals(127.5d, normalizationParams.getMean()[0], DELTA);
        Assert.assertEquals(127.5d, normalizationParams.getStd()[0], DELTA);
        Assert.assertEquals(0.0d, normalizationParams.getMin()[0], DELTA);
        Assert.assertEquals(255.0d, normalizationParams.getMax()[0], DELTA);
        TensorInfo.QuantizationParams quantizationParams = tensorInfo.getQuantizationParams();
        Assert.assertEquals(128.0d, quantizationParams.getZeroPoint(), DELTA);
        Assert.assertEquals(0.0078125d, quantizationParams.getScale(), DELTA);
        TensorInfo tensorInfo2 = (TensorInfo) buildFrom.getOutputs().get(0);
        Assert.assertEquals("locations", tensorInfo2.getName());
        Assert.assertEquals(TensorInfo.ContentType.BOUNDING_BOX, tensorInfo2.getContentType());
        TensorInfo.BoundingBoxProperties boundingBoxProperties = tensorInfo2.getBoundingBoxProperties();
        Assert.assertEquals(TensorInfo.BoundingBoxProperties.Type.BOUNDARIES, boundingBoxProperties.type);
        Assert.assertEquals(TensorInfo.BoundingBoxProperties.CoordinateType.RATIO, boundingBoxProperties.coordinateType);
        Assert.assertArrayEquals(new int[]{1, 0, 3, 2}, boundingBoxProperties.index);
        TensorInfo.ContentRange contentRange = tensorInfo2.getContentRange();
        Assert.assertEquals(2L, contentRange.min);
        Assert.assertEquals(2L, contentRange.max);
        TensorInfo tensorInfo3 = (TensorInfo) buildFrom.getOutputs().get(1);
        Assert.assertEquals("classes", tensorInfo3.getName());
        Assert.assertEquals(TensorInfo.FileType.TENSOR_VALUE_LABELS, tensorInfo3.getFileType());
        Assert.assertEquals("labelmap.txt", tensorInfo3.getFileName());
        Assert.assertEquals("scores", ((TensorInfo) buildFrom.getOutputs().get(2)).getName());
        Assert.assertEquals("number of detections", ((TensorInfo) buildFrom.getOutputs().get(3)).getName());
    }

    @Test
    public void testV2ObjectDetectionModelGroupMetadataExtractedCorrectly() throws TfliteModelException, IOException {
        ModelInfo buildFrom = ModelInfo.buildFrom(extractByteBufferFromModel("prebuilts/tools/common/mlkit/testData/models/ssd_mobilenet_odt_metadata_v1.2.tflite"));
        Assert.assertEquals(1L, buildFrom.getInputs().size());
        Assert.assertEquals(4L, buildFrom.getOutputs().size());
        Assert.assertTrue(buildFrom.isMetadataExisted());
        Assert.assertEquals(0L, buildFrom.getInputTensorGroups().size());
        Assert.assertEquals(1L, buildFrom.getOutputTensorGroups().size());
        TensorGroupInfo tensorGroupInfo = (TensorGroupInfo) buildFrom.getOutputTensorGroups().get(0);
        Assert.assertEquals("detection result", tensorGroupInfo.getName());
        Assert.assertEquals(Arrays.asList("locations", "classes", "scores"), tensorGroupInfo.getTensorNames());
    }

    @Test
    public void testModelInfoSerialization() throws TfliteModelException, IOException {
        testModelInfoSerialization("prebuilts/tools/common/mlkit/testData/models/mobilenet_quant_metadata.tflite");
        testModelInfoSerialization("prebuilts/tools/common/mlkit/testData/models/ssd_mobilenet_odt_metadata.tflite");
        testModelInfoSerialization("prebuilts/tools/common/mlkit/testData/models/ssd_mobilenet_odt_metadata_v1.2.tflite");
    }

    private static void testModelInfoSerialization(String str) throws TfliteModelException, IOException {
        ModelInfo buildFrom = ModelInfo.buildFrom(extractByteBufferFromModel(str));
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            buildFrom.save(new DataOutputStream(byteArrayOutputStream));
            byte[] byteArray = byteArrayOutputStream.toByteArray();
            byteArrayOutputStream.close();
            Assert.assertEquals(buildFrom, new ModelInfo(new DataInputStream(new ByteArrayInputStream(byteArray))));
        } catch (Throwable th) {
            try {
                byteArrayOutputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static ByteBuffer extractByteBufferFromModel(String str) throws IOException {
        RandomAccessFile randomAccessFile = new RandomAccessFile(TestUtils.resolveWorkspacePath(str).toFile(), "r");
        try {
            byte[] bArr = new byte[(int) randomAccessFile.length()];
            randomAccessFile.readFully(bArr);
            ByteBuffer wrap = ByteBuffer.wrap(bArr);
            randomAccessFile.close();
            return wrap;
        } catch (Throwable th) {
            try {
                randomAccessFile.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
