import { fetchWithRetry } from "@/lib/fetchWithRetry";
import { noop } from "ts-essentials";
import WebPXMux from "webpxmux";

let xMux: ReturnType<typeof WebPXMux> | null = null;

const setXMuxOnce = () => {
    if (xMux === null) {
        xMux = WebPXMux("/webpmux/webpxmux.wasm");
    }
}

let initializingResolve = noop;
let initializingPromise: Promise<null> | null = null

function convertUint32ArrayToUint8ClampedArray(uint32Array: Uint32Array) {
    // Create a Uint8ClampedArray to hold the separated RGBA values
    const length = uint32Array.length * 4;
    const uint8ClampedArray = new Uint8ClampedArray(length);

    // Iterate through each 32-bit integer
    for (let i = 0; i < uint32Array.length; i++) {
        const rgba = uint32Array[i];

        // Extract the individual RGBA components
        uint8ClampedArray[i * 4]     = (rgba >> 24) & 0xFF; // Red
        uint8ClampedArray[i * 4 + 1] = (rgba >> 16) & 0xFF; // Green
        uint8ClampedArray[i * 4 + 2] = (rgba >> 8) & 0xFF;  // Blue
        uint8ClampedArray[i * 4 + 3] = rgba & 0xFF;         // Alpha
    }

    return uint8ClampedArray;
}

export async function decodeWebp(imageUrl: string) {

    setXMuxOnce();

    if (initializingPromise) await initializingPromise;
    if (initializingPromise == null) {
        initializingPromise = new Promise((resolve) => {
            initializingResolve = () => resolve(null);
        })
        await xMux!.waitRuntime();
        initializingResolve();
    }

    const buffer = await (await fetchWithRetry(imageUrl)).arrayBuffer();

    const frames = await xMux!.decodeFrames(new Uint8Array(buffer));
    const bitmapFrames = [];
    let duration = 0;

    let timestamp = 0;
    for (const frame of frames.frames) {
        const bitmap = await createImageBitmap(new ImageData(convertUint32ArrayToUint8ClampedArray(frame.rgba), frames.width, frames.height))

        bitmapFrames.push({
            timestamp: timestamp * 1e3,
            duration: frame.duration * 1e3,
            bitmap
        })
        timestamp += frame.duration;

        duration = timestamp * 1e3
    }

    return { duration, bitmapFrames };
}