package com.ss.bytenn;

import com.bytedance.frameworks.apm.trace.MethodCollector;
import com.ss.bytenn.Tensor;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;

/* loaded from: classes4.dex */
public class API {
    public static String TAG = "BYTENN.API";
    private long engineHandle;
    private int modelBufferSize;

    /* loaded from: classes4.dex */
    public enum a {
        NO_ERROR,
        ERR_MEMORY_ALLOC,
        NOT_IMPLEMENTED,
        ERR_UNEXPECTED,
        ERR_DATANOMATCH,
        INPUT_DATA_ERROR,
        CALL_BACK_STOP,
        BACKEND_FALLBACK,
        NULL_POINTER,
        INVALID_POINTER,
        INVALID_MODEL,
        INFER_SIZE_ERROR,
        NOT_SUPPORT,
        DESTROYED_ERROR,
        WRONG_LICENSE,
        BROKEN_MODEL,
        EARLY_STOP;

        public static a valueOf(String str) {
            MethodCollector.i(3321);
            a aVar = (a) Enum.valueOf(a.class, str);
            MethodCollector.o(3321);
            return aVar;
        }

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static a[] valuesCustom() {
            MethodCollector.i(3320);
            a[] aVarArr = (a[]) values().clone();
            MethodCollector.o(3320);
            return aVarArr;
        }
    }

    /* loaded from: classes4.dex */
    public enum b {
        SGD,
        RMSProp,
        ADAM;

        public static b valueOf(String str) {
            MethodCollector.i(3319);
            b bVar = (b) Enum.valueOf(b.class, str);
            MethodCollector.o(3319);
            return bVar;
        }

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static b[] valuesCustom() {
            MethodCollector.i(3318);
            b[] bVarArr = (b[]) values().clone();
            MethodCollector.o(3318);
            return bVarArr;
        }
    }

    static {
        System.loadLibrary("bytenn");
        System.loadLibrary("bytennwrapper");
    }

    public static native long nativeCreateEngineInstance();

    public static native int nativeDestroyEngine(long j);

    public static native Tensor nativeGetEngineGradient(long j, String str);

    public static native Tensor[] nativeGetEngineOutputs(long j);

    public static native Tensor nativeGetEngineWeight(long j, String str);

    public static native Tensor[] nativeGetInputConfig(long j);

    public static native int nativeInference(long j);

    public static native int nativeInitEngine(long j, ByteNNConfig byteNNConfig);

    public static native int nativeReInferShape(long j, int i, int i2);

    public static native int nativeReInferShapeTensors(long j, Tensor[] tensorArr);

    public static native int nativeReleaseEngine(long j);

    public static native int nativeSaveModel(long j, ByteBuffer byteBuffer);

    public static native int nativeSetEngineInputs(long j, Tensor[] tensorArr);

    public static native int nativeSetEngineLabel(long j, Tensor[] tensorArr);

    public static native int nativeSetEngineLossLayer(long j, LossInfo[] lossInfoArr);

    public static native int nativeSetEngineWeight(long j, Tensor tensor);

    public static native int nativeSetOptimizer(long j, int i, float f, String[] strArr);

    public static native float nativeStep(long j);

    public a CreateEngine() {
        long nativeCreateEngineInstance = nativeCreateEngineInstance();
        this.engineHandle = nativeCreateEngineInstance;
        return nativeCreateEngineInstance == 0 ? a.ERR_MEMORY_ALLOC : a.NO_ERROR;
    }

    public a DestroyEngine() {
        a aVar = a.valuesCustom()[nativeDestroyEngine(this.engineHandle)];
        this.engineHandle = 0L;
        return aVar;
    }

    public a GetEngineGradient(String str, Tensor tensor) throws Exception {
        if (tensor == null) {
            return a.INPUT_DATA_ERROR;
        }
        Tensor nativeGetEngineGradient = nativeGetEngineGradient(this.engineHandle, str);
        tensor.setBatch(nativeGetEngineGradient.getBatch());
        tensor.setChannel(nativeGetEngineGradient.getChannel());
        tensor.setHeight(nativeGetEngineGradient.getHeight());
        tensor.setWidth(nativeGetEngineGradient.getWidth());
        tensor.setData(nativeGetEngineGradient.getData());
        tensor.setDataFormat(Tensor.DataFormat.valuesCustom()[nativeGetEngineGradient.getOrdinalOfDataFormat()]);
        tensor.setDataType(Tensor.DataType.valuesCustom()[nativeGetEngineGradient.getOrdinalOfDataType()]);
        tensor.setName(nativeGetEngineGradient.getName());
        return tensor == null ? a.INPUT_DATA_ERROR : a.NO_ERROR;
    }

    public a GetEngineInputConfig(ArrayList<Tensor> arrayList) {
        long j = this.engineHandle;
        if (j == 0) {
            return a.NULL_POINTER;
        }
        Tensor[] nativeGetInputConfig = nativeGetInputConfig(j);
        if (nativeGetInputConfig.length <= 0) {
            return a.ERR_UNEXPECTED;
        }
        for (Tensor tensor : nativeGetInputConfig) {
            tensor.getData().order(ByteOrder.nativeOrder());
            arrayList.add(tensor);
        }
        return a.NO_ERROR;
    }

    public a GetEngineOutputs(ArrayList<Tensor> arrayList) {
        Tensor[] nativeGetEngineOutputs = nativeGetEngineOutputs(this.engineHandle);
        if (nativeGetEngineOutputs.length <= 0) {
            return a.ERR_UNEXPECTED;
        }
        for (Tensor tensor : nativeGetEngineOutputs) {
            arrayList.add(tensor);
        }
        return a.NO_ERROR;
    }

    public a GetEngineWeight(String str, Tensor tensor) throws Exception {
        if (tensor == null) {
            return a.INPUT_DATA_ERROR;
        }
        Tensor nativeGetEngineWeight = nativeGetEngineWeight(this.engineHandle, str);
        tensor.setBatch(nativeGetEngineWeight.getBatch());
        tensor.setChannel(nativeGetEngineWeight.getChannel());
        tensor.setHeight(nativeGetEngineWeight.getHeight());
        tensor.setWidth(nativeGetEngineWeight.getWidth());
        tensor.setData(nativeGetEngineWeight.getData());
        tensor.setDataFormat(Tensor.DataFormat.valuesCustom()[nativeGetEngineWeight.getOrdinalOfDataFormat()]);
        tensor.setDataType(Tensor.DataType.valuesCustom()[nativeGetEngineWeight.getOrdinalOfDataType()]);
        tensor.setName(nativeGetEngineWeight.getName());
        return a.NO_ERROR;
    }

    public a Inference() {
        return a.valuesCustom()[nativeInference(this.engineHandle)];
    }

    public a InitEngine(ByteNNConfig byteNNConfig) {
        if (byteNNConfig == null) {
            return a.INPUT_DATA_ERROR;
        }
        int nativeInitEngine = nativeInitEngine(this.engineHandle, byteNNConfig);
        if (a.NO_ERROR != a.valuesCustom()[nativeInitEngine]) {
            return a.valuesCustom()[nativeInitEngine];
        }
        this.modelBufferSize = byteNNConfig.getModelBufferSize();
        return a.NO_ERROR;
    }

    public a ReInferShape(int i, int i2) {
        return a.valuesCustom()[nativeReInferShape(this.engineHandle, i2, i)];
    }

    public a ReInferShapeTensors(ArrayList<Tensor> arrayList) {
        return a.valuesCustom()[nativeReInferShapeTensors(this.engineHandle, (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]))];
    }

    public a ReleaseEngine() {
        return a.valuesCustom()[nativeReleaseEngine(this.engineHandle)];
    }

    public a SaveModel(ByteBuffer byteBuffer) {
        return (byteBuffer == null || !byteBuffer.isDirect() || byteBuffer.capacity() < this.modelBufferSize) ? a.INPUT_DATA_ERROR : a.valuesCustom()[nativeSaveModel(this.engineHandle, byteBuffer)];
    }

    public a SetEngineInputs(ArrayList<Tensor> arrayList) {
        return a.valuesCustom()[nativeSetEngineInputs(this.engineHandle, (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]))];
    }

    public a SetEngineLabel(ArrayList<Tensor> arrayList) {
        return a.valuesCustom()[nativeSetEngineLabel(this.engineHandle, (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]))];
    }

    public a SetEngineWeight(Tensor tensor) {
        return (tensor == null || !tensor.getData().isDirect()) ? a.INPUT_DATA_ERROR : a.valuesCustom()[nativeSetEngineWeight(this.engineHandle, tensor)];
    }

    public a SetLossLayer(ArrayList<LossInfo> arrayList) {
        return a.valuesCustom()[nativeSetEngineLossLayer(this.engineHandle, (LossInfo[]) arrayList.toArray(new LossInfo[arrayList.size()]))];
    }

    public a SetOptimizer(b bVar, float f, ArrayList<String> arrayList) {
        return a.valuesCustom()[nativeSetOptimizer(this.engineHandle, bVar.ordinal(), f, (String[]) arrayList.toArray(new String[arrayList.size()]))];
    }

    public float Step() {
        return nativeStep(this.engineHandle);
    }
}
