Source: webai.mjs

import cv from './opencv.mjs'
import ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@dev/dist/ort.all.min.mjs'


/**
 * Model Class.
 * @class
 */
class Model {
    /**
     * create a base model.
     * @param modelURL model URL
     * @param sessionOption onnxruntime session options
     * @param init init function
     * @param preProcess preprocess function
     * @param postProcess postprocess function
     * @returns base model object
     */
    constructor(
        modelURL,
        sessionOption,
        init,
        preProcess,
        postProcess
    ) {
        this.promises = Promise.all([
            ort.InferenceSession.create(modelURL, sessionOption)
                .then(session => this.session = session)
        ])
        if (typeof init != 'undefined') {
            init(this)
        }
        if (typeof preProcess != 'undefined') {
            this.preProcess = preProcess
        }
        if (typeof postProcess != 'undefined') {
            this.postProcess = postProcess
        }
    }

    /**
     * base model infer function.
     * @param args model infer paramters
     * @returns model infer results
     */
    async infer(...args) {
        await this.promises;
        console.time('Infer');

        console.time('Infer.Preprocess');
        let feeds = this.preProcess(...args);

        console.timeEnd('Infer.Preprocess');

        console.time('Infer.Run');

        let resultsTensors = await this.session.run(feeds);

        console.timeEnd('Infer.Run');

        console.time('Infer.Postprocess');

        let results = this.postProcess(resultsTensors, ...args);

        console.timeEnd('Infer.Postprocess');

        console.timeEnd('Infer');

        return results
    }
}

/**
 * CV Model Class.
 * @class
 * @extends Model
 */
class CV extends Model {
    /**
     * create a base CV model.
     * @param modelURL model URL
     * @param inferConfig model infer config URL
     * @param sessionOption onnxruntime session options
     * @param getFeeds get infer session feeds function
     * @param postProcess postprocess function
     * @returns base CV model object
     */
    constructor(
        modelURL,
        inferConfig,
        sessionOption,
        getFeeds,
        postProcess
    ) {
        super(modelURL, sessionOption, undefined, undefined, postProcess)
        this.loadConfigs(inferConfig)
        if (typeof getFeeds != 'undefined') {
            this.getFeeds = getFeeds
        }
    }

    /**
     * load infer configs
     * @param inferConfig model infer config URL
     */
    loadConfigs(inferConfig) {
        let inferConfigs = JSON.parse(WebAI.loadText(inferConfig));
        let preProcess = inferConfigs.Preprocess;
        this.isPermute = false
        this.isCrop = false
        this.isResize = false
        for (let i = 0; i < preProcess.length; i++) {
            let OP = preProcess[i]
            if (OP.type == 'Decode') {
                this.mode = OP.mode
                if (!(this.mode == 'RGB' || this.mode == 'BGR')) {
                    throw `Not support ${OP.mode} mode.`
                }
            }
            else if (OP.type == 'Resize') {
                this.isResize = true;
                this.interp = OP.interp;
                this.keepRatio = OP.keep_ratio;
                this.targetSize = OP.target_size;
                this.limitMax = OP.limit_max;
            }
            else if (OP.type == 'Normalize') {
                this.isScale = OP.is_scale
                if (this.isScale) {
                    this.scale = new cv.Scalar(255.0, 255.0, 255.0)
                }
                this.mean = new cv.Scalar(...OP.mean)
                this.std = new cv.Scalar(...OP.std)
            }
            else if (OP.type == 'Crop') {
                this.isCrop = true;
                this.cropSize = OP.crop_size
            }
            else if (OP.type == 'Permute') {
                this.isPermute = true
            }
            else {
                throw `Not support ${OP.type} OP.`
            }
        }

        if (inferConfigs.hasOwnProperty('label_list')) {
            this.labelList = inferConfigs.label_list;
            this.colorMap = WebAI.getColorMap(this.labelList);
        }

        console.info('model info: ', {
            mode: this.mode,
            isResize: this.isResize,
            interp: this.interp,
            keepRatio: this.keepRatio,
            targetSize: this.targetSize,
            isScale: this.isScale,
            limitMax: this.limitMax,
            mean: this.mean,
            std: this.std,
            isCrop: this.isCrop,
            cropSize: this.cropSize,
            isPermute: this.isPermute,
            labelList: this.labelList
        })
    }

    /**
     * model preprocess function. 
     * @param args preprocess args
     * @returns session infer feeds
     */
    preProcess(...args) {
        let [imgRGBA, height, width] = args.slice(0, 3)
        let imgResize, imScaleX, imScaleY
        if (this.isResize) {
            [imgResize, imScaleX, imScaleY] = WebAI.resize(imgRGBA, height, width, this.targetSize, this.keepRatio, this.limitMax, this.interp);
        }
        else {
            imgResize = imgRGBA.clone();
        }

        let imgCvt;
        if (this.isCrop) {
            let imgCrop = WebAI.crop(imgResize, this.cropSize);
            if (this.mode == 'RGB') {
                imgCvt = WebAI.rgba2rgb(imgCrop);
            }
            else if (this.mode == 'BGR') {
                imgCvt = WebAI.rgba2bgr(imgCrop);
            }
            imgCrop.delete();
        }
        else {
            if (this.mode == 'RGB') {
                imgCvt = WebAI.rgba2rgb(imgResize);
            }
            else if (this.mode == 'BGR') {
                imgCvt = WebAI.rgba2bgr(imgResize);
            }
            imgResize.delete();
        }
        let imgNorm = WebAI.normalize(imgCvt, this.scale, this.mean, this.std, this.isScale);

        let imgTensor;
        let [h, w] = [imgNorm.rows, imgNorm.cols];
        if (this.isPermute) {
            imgTensor = new ort.Tensor('float32', WebAI.permute(imgNorm), [1, 3, h, w]);
        }
        else {
            imgTensor = new ort.Tensor('float32', imgNorm.data32F, [1, h, w, 3]);
            imgNorm.delete()
        }

        return this.getFeeds(imgTensor, imScaleX, imScaleY)
    }
}


/**
 * Detection Model Class.
 * @class
 * @extends CV
 */
class Det extends CV {
    /**
     * get session infer feeds.
     * @param imgTensor image tensor
     * @param imScaleX image scale factor of x axis
     * @param imScaleY image scale factor of y axis
     * @returns session infer feeds
     */
    getFeeds(imgTensor, imScaleX, imScaleY) {
        let inputNames = this.session.inputNames;
        let _feeds = {
            im_shape: new ort.Tensor('float32', Float32Array.from(imgTensor.dims.slice(2, 4)), [1, 2]),
            image: imgTensor,
            scale_factor: new ort.Tensor('float32', Float32Array.from([imScaleY, imScaleX]), [1, 2])
        }
        let feeds = {}
        inputNames.forEach(name => {
            feeds[name] = _feeds[name]
        })
        return feeds
    }

    /**
     * detection postprocess.
     * @param resultsTensors result tensors
     * @param args postprocess args
     * @returns bboxes of the detection
     */
    postProcess(
        resultsTensors,
        ...args
    ) {
        let [height, width, drawThreshold] = args.slice(1, 4)
        let bboxesTensor = Object.values(resultsTensors)[0];
        let bboxes = [];
        let bboxesNum = bboxesTensor.dims[0];
        let bboxesDatas = bboxesTensor.data;

        for (let i = 0; i < bboxesNum; i++) {
            let classID = bboxesDatas[i * 6 + 0];
            let score = bboxesDatas[i * 6 + 1];

            let x1 = Math.max(0, Math.round(bboxesDatas[i * 6 + 2]));
            let y1 = Math.max(0, Math.round(bboxesDatas[i * 6 + 3]));
            let x2 = Math.min(width, Math.round(bboxesDatas[i * 6 + 4]));
            let y2 = Math.min(height, Math.round(bboxesDatas[i * 6 + 5]));

            let label = this.labelList[classID];
            let color = this.colorMap[classID].color;
            if (score > drawThreshold) {
                let bbox = {
                    label: label,
                    color: color,
                    score: score,
                    x1: x1,
                    y1: y1,
                    x2: x2,
                    y2: y2
                };

                bboxes.push(bbox);
            }
        }
        return bboxes
    }

    /**
     * detection infer.
     * @param imgRGBA RGBA image
     * @param drawThreshold threshold of detection
     * @returns bboxes of the detection
     */
    infer(
        imgRGBA,
        drawThreshold = 0.5
    ) {
        return super.infer(imgRGBA, imgRGBA.rows, imgRGBA.cols, drawThreshold)
    }
}


/**
 * Classification Model Class.
 * @class
 * @extends CV
 */
class Cls extends CV {
    /**
     * get the feeds of the infer session.
     * @param imgTensor image tensor
     * @returns feeds of the infer session
     */
    getFeeds(imgTensor) {
        return { x: imgTensor }
    }

    /**
     * classification postprocess.
     * @param resultsTensors result tensors
     * @param args postprocess args
     * @returns probs of the classification
     */
    postProcess(resultsTensors, ...args) {
        let topK = args[3];
        let probsTensor = Object.values(resultsTensors)[0];
        let data = probsTensor.data
        let probs = []

        for (let i = 0; i < this.labelList.length; i++) {
            probs.push({
                label: this.labelList[i],
                prob: data[i]
            })
        }
        if (topK > 0) {
            return probs.sort((a, b) => b.prob - a.prob).slice(0, topK)
        }
        else {
            return probs.sort((a, b) => b.prob - a.prob)
        }
    }

    /**
     * classification infer.
     * @param imgRGBA RGBA image
     * @param topK probs top K
     * @returns probs of the classification
     */
    infer(
        imgRGBA,
        topK = 5
    ) {
        return super.infer(imgRGBA, imgRGBA.rows, imgRGBA.cols, topK)
    }
}

/**
 * Segmentation Model Class.
 * @class
 * @extends CV
 */
class Seg extends CV {
    /**
     * get the feeds of the infer session.
     * @param imgTensor image tensor
     * @returns feeds of the infer session
     */
    getFeeds(imgTensor) {
        return { x: imgTensor }
    }

    /**
     * segmentation postprocess.
     * @param resultsTensors result tensors
     * @returns segmentation results
     */
    postProcess(resultsTensors) {
        let segTensor = Object.values(resultsTensors)[0];
        let data = segTensor.data
        let [N, C, H, W] = segTensor.dims;
        let numPixel = H * W

        let pixelArrs = []
        for (let i = 0; i < C; i++) {
            pixelArrs.push(data.slice(i * numPixel, (i + 1) * numPixel))
        }
        let colorRGBA = [];
        let gray = [];
        let tmp, index;
        for (let i = 0; i < numPixel; i++) {
            let tmp = []
            for (let j = 0; j < C; j++) {
                tmp.push(pixelArrs[j][i])
            }
            index = WebAI.argmax(tmp)
            gray.push(index)
            colorRGBA.push(...this.colorMap[index].color)
        }

        return {
            gray: cv.matFromArray(H, W, cv.CV_8UC1, gray),
            colorRGBA: cv.matFromArray(H, W, cv.CV_8UC4, colorRGBA),
            colorMap: this.colorMap,
            delete: function () {
                if (!this.gray.isDeleted()) {
                    this.gray.delete()
                }
                if (!this.colorRGBA.isDeleted()) {
                    this.colorRGBA.delete()
                }
            }
        }
    }

    /**
     * segmentation infer.
     * @param imgRGBA RGBA image
     * @returns segmentation results
     */
    infer(
        imgRGBA
    ) {
        return super.infer(imgRGBA, imgRGBA.rows, imgRGBA.cols)
    }
}

/**
 * Namespace of WebAI.
 * @namespace
 */
const WebAI = {
    /**
     * get the index of the max value of the array.
     * @param arr array
     * @returns the index of the max value of the array
     */
    argmax(arr) {
        let max = Math.max.apply(null, arr);
        let index = arr.findIndex(
            function (value) {
                if (value == max) {
                    return true
                }
                else {
                    return false
                }
            }
        )
        return index
    },

    /**
     * get image scale.
     * @param height image height
     * @param width image width
     * @param targetSize target size [h, w]
     * @param keepRatio is keep the ratio of image size
     * @param limitMax is limit max size of image
     * @returns [scale factor of x axis, , scale factor of y axis]
     */
    getIMScale(height, width, targetSize, keepRatio, limitMax) {
        let imScaleX, imScaleY;
        if (keepRatio) {
            let imSizeMin = Math.min(height, width);
            let targetSizeMin = Math.min(targetSize[0], targetSize[1]);
            let imScale = targetSizeMin / imSizeMin;

            if (limitMax) {
                let imSizeMax = Math.max(height, width);
                let targetSizeMax = Math.max(targetSize[0], targetSize[1]);
                if (Math.round(imScale * imSizeMax) > targetSizeMax) {
                    imScale = targetSizeMax / imSizeMax;
                }
            }

            imScaleX = imScale;
            imScaleY = imScale;
        }
        else {
            imScaleY = targetSize[0] / height;
            imScaleX = targetSize[1] / width;
        }
        return [imScaleX, imScaleY]
    },

    /**
     * RGBA -> RGB image.
     * @param imgRGBA RGBA image
     * @returns RGB image
     */
    rgba2rgb(imgRGBA) {
        let imgRGB = new cv.Mat();
        cv.cvtColor(imgRGBA, imgRGB, cv.COLOR_RGBA2RGB);
        return imgRGB
    },

    /**
     * RGBA -> BGR image.
     * @param imgRGBA RGBA image
     * @returns BGR image
     */
    rgba2bgr(imgRGBA) {
        let imgBGR = new cv.Mat();
        cv.cvtColor(imgRGBA, imgBGR, cv.COLOR_RGBA2BGR);
        return imgBGR
    },

    /**
     * image resize.
     * @param img image mat
     * @param height image height
     * @param width image width
     * @param targetSize target size [h, w]
     * @param keepRatio is keep the ratio of image size
     * @param limitMax is limit max size of image
     * @param interp interpolation method
     * @returns [image resized, scale factor of x axis, , scale factor of y axis]
     */
    resize(img, height, width, targetSize, keepRatio, limitMax, interp) {
        let [imScaleX, imScaleY] = WebAI.getIMScale(height, width, targetSize, keepRatio, limitMax);
        let imgResize = new cv.Mat();
        cv.resize(img, imgResize, new cv.Size(0, 0), imScaleX, imScaleY, interp);
        return [imgResize, imScaleX, imScaleY]
    },

    /**
     * image center crop.
     * @param img image mat
     * @param cropSize crop size [h, w]
     * @returns cropped image
     */
    crop(img, cropSize) {
        let imgCrop = img.roi(
            new cv.Rect(
                Math.ceil((img.cols - cropSize[1]) / 2),
                Math.ceil((img.rows - cropSize[0]) / 2),
                cropSize[1],
                cropSize[0]
            )
        )
        img.delete()
        return imgCrop
    },

    /**
     * image normalize.
     * @param img image mat
     * @param scale normalize scale
     * @param mean normalize mean
     * @param std normalize std
     * @param isScale is scale the image
     * @returns normalized image
     */
    normalize(img, scale, mean, std, isScale) {
        img.convertTo(img, cv.CV_32F);

        if (isScale) {
            let imgScale = new cv.Mat(img.rows, img.cols, cv.CV_32FC3, scale);
            cv.divide(img, imgScale, img);
            imgScale.delete();
        }

        let imgMean = new cv.Mat(img.rows, img.cols, cv.CV_32FC3, mean);
        cv.subtract(img, imgMean, img);
        imgMean.delete();

        let imgStd = new cv.Mat(img.rows, img.cols, cv.CV_32FC3, std);
        cv.divide(img, imgStd, img);
        imgStd.delete();

        return img
    },

    /**
     * permute hwc -> chw.
     * @param img image mat
     * @returns image data
     */
    permute(img) {
        let rgbPlanes = new cv.MatVector();
        cv.split(img, rgbPlanes);
        let R = rgbPlanes.get(0);
        let G = rgbPlanes.get(1);
        let B = rgbPlanes.get(2);
        rgbPlanes.delete();

        let imgData = new Float32Array(R.data32F.length * 3)
        imgData.set(R.data32F, 0)
        imgData.set(G.data32F, R.data32F.length)
        imgData.set(B.data32F, R.data32F.length * 2)
        R.delete();
        G.delete();
        B.delete();
        img.delete();
        return imgData
    },

    /**
     * load text content.
     * @param textURL text URL
     * @returns content of the text
     */
    loadText(textURL) {
        let xhr = new XMLHttpRequest();
        xhr.open('get', textURL, false);
        xhr.send(null);
        return xhr.responseText
    },

    /**
     * get color map of label list.
     * @param labelList label list
     * @returns color map of label list
     */
    getColorMap(labelList) {
        let classNum = labelList.length
        let colorMap = []
        let colorSlice = Math.ceil((256 * 256 * 256) / classNum)
        for (let i = 0; i < classNum; i++) {
            let color = (colorSlice * i).toString(16)
            let colorRGBA = []
            for (let j = 0; j < 6; j += 2) {
                let tmp = color.slice(j, j + 2)
                if (tmp == '') {
                    colorRGBA.push(0);
                }
                else {
                    colorRGBA.push(parseInt("0x" + tmp));
                }
            }
            colorRGBA.push(255)
            colorMap.push({
                label: labelList[i],
                color: colorRGBA
            })
        }
        return colorMap
    },

    /**
     * draw bboxes onto the image.
     * @param img image mat
     * @param bboxes bboxes of detection
     * @param withLabel draw with label 
     * @param withScore draw with score
     * @param thickness line thickness
     * @param lineType line type
     * @param fontFace font face
     * @param fontScale font scale
     * @returns drawed image
     */
    drawBBoxes(
        img,
        bboxes,
        withLabel = true,
        withScore = true,
        thickness = 2.0,
        lineType = 8,
        fontFace = 0,
        fontScale = 0.7
    ) {
        let imgShow = img.clone()
        for (let i = 0; i < bboxes.length; i++) {
            let bbox = bboxes[i];
            cv.rectangle(imgShow, new cv.Point(bbox.x1, bbox.y1), new cv.Point(bbox.x2, bbox.y2), bbox.color, thickness, lineType);
            if (withLabel && withScore) {
                cv.putText(imgShow, `${bbox.label} ${(bbox.score * 100).toFixed(2)}%`, new cv.Point(bbox.x1, bbox.y2), fontFace, fontScale, bbox.color, thickness, lineType);
            }
            else if (withLabel) {
                cv.putText(imgShow, `${bbox.label}`, new cv.Point(bbox.x1, bbox.y2), fontFace, fontScale, bbox.color, thickness, lineType);
            }
            else if (withScore) {
                cv.putText(imgShow, `${(bbox.score * 100).toFixed(2)}%`, new cv.Point(bbox.x1, bbox.y2), fontFace, fontScale, bbox.color, thickness, lineType);
            }
        }
        return imgShow
    },

    Model: Model,
    CV: CV,
    Det: Det,
    Cls: Cls,
    Seg: Seg,
}

export { WebAI as default, WebAI, cv, ort }