Skip to main content

KotlinJS, ONNX and Deep Learning in the browser

Β· 13 min read
Hampus LondΓΆgΓ₯rd

One day I had the crazy idea to try two non-mainstream things out at the same time. On top of that I figured I'd combine them in the same project, imagine that!

Preview of final result running model inference in the browser using KotlinJS, ONNX & fritz2: model inference in browser gif

Quick Kotlin JS​

KotlinJS resembles TypeScript (TS) in the sense that it's typed and transpiles into JavaScript (JS) at the end of the day. The final JS code runs directly in the browser or through Node.js.

What makes KotlinJS stand out? In my optinion it picks up where TS leaves. By adding (almost) all of the Kotlin ecosystem we get a really superb toolbox out-of-the-box, which is more than simply types. Some of the awesome perks are coroutines and collections.
As someone who has done a lot of backend development in Scala, with some Java, it feels like home because of the familiar apperance and interaction.
Having sweet syntax, superb typing I feel a great preference toward KotlinJS even if TS is closer to JS making transpiled code easier to reason about.

Quick ONNX​

info

Open Neural Network Exchange (ONNX) Runtime is a open format created by Facebook, Microsoft & others, and is part of Linux Foundation AI.

ONNX is an open polyglot format, meaning that you can run Neural Networks from multiple coding languages. This in turn promotes innovation and collaboration, especially through the fact that you can run your State-of-the-Art model almost everywhere, including C# and Java.
ONNX is a impressive feat that allows companies to reduce their inference time by magnitudes, cherry on top it reduces code complexity when models are deployed directly in the original backend.

Recently ONNX added a new runtime, ONNX-webruntime, which enables ONNX models to run directly inside the browser. Simply take your PyTorch/Tensorflow model, convert to ONNX and then run! Incredible! πŸŽ‰.
ONNX-webruntime leverages WebGL as GPU and WASM with SIMD as CPU.
Simple edge deployment is here!

The Set Up​

The set up is simple,

  • Kotlin JS project
  • fritz2 as web framework
  • onnx-webruntime as deep learning inference tool

For this demo I could've used raw html elements in the Kotlin JS code, but it's more fun to use something enjoyable, as such I chose fritz2 that I introduce below πŸ‘‡.

Introducing fritz2​

Introducing fritz2, a small but impressive framework.
Because of the size you can understand the full idea and implementation, which is something you cannot say about React. Through simple DSLs, superb usage of Flow<T> you end up with a simple yet powerful model that maps perfectly to my own mind.
In my opinion fritz2 feels less magic while very powerful and simple. Everything works with full typing and no hacks. Cherry on the top? No virtual dom!

Fritz2 has a extra components library which you can additionally install. This library contains simple components to make your development much faster, with things like File input, Data Tables and much more!

Personally I even did my own wedding website using fritz2, and it ended up pretty great! fritz2 gui

My personal wedding site created in fritz2

ONNX typing in KotlinJS​

Using dukat (included by default in Kotlin > 1.6 or perhaps earlier) it's possible to generate external types/bindings for any TypeScript project.
Guess what, ONNXRuntime Web is full TypeScript - awesome!

Unfortunately ONNX has some really weird structure which I'd call non-standard, this ends up not working great in dukat-generation...
Luckily enough it is easy to make your own bindings. Keep your breath for now, I'll share them later in this post, but for now let's say that it's like a .d.ts-file.

The Implementation​

We need to create our project, I usually do it by scratch but if you want to keep it easy setting up the MPP project for friz2 you can make sure to use their template project. Make sure to include the fritz2 component library, as it'll be used in the implementation.
Please note that the focus will be ONNX, as such I'll save some fritz2 details for another post.

Basic UI​

Getting the skeleton UI up
In the "main" file of the js-folder, but not as in js-code πŸ˜‰, you'll need to set up a file and image element.

fun main() {
val imgSrc = RootStore("")
render {
val srcImg = img(id = "img-from") {
src(imgSrc.data)
}
}

file {
accept("image/*")
button { text("Single select") }
}
.map { file -> "data:${file.type};base64,${file.content}" } handledBy imgSrc.update
}

Breaking down what's done

  1. A RootStore is a abstraction on top of a (Mutable)StateFlow which is a Flow with a state.
    1. In simple terms a Flow is a collection of asynchronously computed values just like you have Sequence and List being collections of synchronously computed values.
  2. A Store is a reactive component that contains our apps state, it can do bidirectional communication with the DOM/GUI.
    1. We update imgSrc through the file-component, whenever file is updated.
    2. <img> listens on changes from imgSrc, hence it's updated as imgSrc is updated
    3. All in all we get typed and no-magic dynamical updates in our GUI. This is something I love, compared to react and svelte where it seems more magical.

The connector between file and imgSrc is dirty, I hoped to be able to load the b64 content directly into a UInt8ClampedArray to have optimal performance, but because the b64-string actually contains PNG/JPEG headers and other things the perfomance gains versus simplicity is not worth it. Hence I transform it from the b64-string (data:image/pdf;base64,<content>) to image and then extract ImageData - annoying but clean.
The detail that <content> in b64-string is only the pixel data haunted me for a long time... I couldn't figure why my arrays had the wrong dimensions! πŸ˜…

The next step: transfer image from this component to another, while allowing a transformation (neural network inference) in-between.

fun loadImgToCanvas(img: Image, canvas: Canvas, context: CanvasRenderingContext2D) {
if (img.domNode.src.isNotEmpty()) {
canvas.width = img.domNode.naturalWidth
canvas.height = img.domNode.naturalHeight
context.drawImage(img, 0.0, 0.0)
}
}

fun main() {
val imgSrc = RootStore("")

render {
val srcImg = img(id = "img-from") {
src(imgSrc.data)
}
val targetCanvas = canvas(id = "img-to") { }
val imgContext = targetCanvas.domNode.getContext("2d") as CanvasRenderingContext2D

file { /** same as before ... */

srcImg.domNode.onload { loadImgToCanvas(srcImg, targetCanvas, imgContext) }
}
}

Whenever srcImg.onload event happens we call loadImgToCanvas which loads img on our canvas.
Why did I choose to not have a new <img>? Because we later need to use ImageData and this is the way to have the minimum number of data transitions, trust me πŸ˜‰.

Let's start adding bindings for ONNX!

Binding ONNX and webruntime​

Binding TS/JS is as simple as a .d.ts-file in TS. You define the component to bind, declare the types, e.g. function name, input and outout. Simple as that!

@file:JsModule("onnxruntime-web")   // npm-package
@file:JsNonModule

import kotlin.js.Promise

external abstract class InferenceSession {
fun run(feeds: FeedsType): Promise<ReturnType> // FeedsType / ReturnType separately defined the same way as InferenceSession & run.
}

Moving further we'll add a method to extract ImageData's UInt8ClampedArray from a img-element using a canvas-element with its CanvasRenderingContext2D (lots of JS/web words, the most I'll have in a sentence, peeew! πŸ˜…)

fun imgToUInt8ClampedArray(img: HTMLImageElement, ctx: CanvasRenderingContext2D): Uint8ClampedArray {
val canvas = ctx.canvas
canvas.width = img.naturalWidth
canvas.height = img.naturalHeight
ctx.drawImage(img, 0.0, 0.0)

return ctx.getImageData(0.0, 0.0, img.naturalWidth.toDouble(), img.naturalHeight.toDouble()).data // extract data from ImageData
}

The UInt8ClampedArray has to be transformed into a Float32Array that the model expects.
Sounds easy? Think again!
Because JS is not a data science language it's not surprising that the data is "incorrectly" ordered. The model expects the data to be formed as [3,width,height] where 3 is the number of dimensions, in our case RGB, but in JS it's the reverse way. On top of the wrong ordering JS has a fourth dimension, namely transparency. Following all that knowledge we can transform the array.

fun uInt8ClampedToFloat32Array(data: Uint8ClampedArray): Float32Array {
val floats = Float32Array(data.length / 4 * 3)
val rgb =listOf(0, data.length / 4, data.length / 4 * 2)

for (i in 0untildata.lengthstep4) {
floats[rgb[0] + i / 4] = data[i + 0] / 255f
floats[rgb[1] + i / 4] = data[i + 1] / 255f
floats[rgb[2] + i / 4] = data[i + 2] / 255f // Skip i+3 as that's ALPHA
}

return floats
}

As ONNX expects Tensor we need to transform the Float32Array into a Tensor and then into FeedsInput which is a Object of the data, luckily that's very easy after our binding is done.

fun tensorToInput(tensor: Tensor, inputName: String = "input"): FeedsType {
val input: dynamic = object {} // To hack JS Objects
input[inputName] = tensor

return input.unsafeCast<FeedsType>()
}

val tensor = Tensor("float32", floats, arrayOf(1, 3, srcImg.domNode.naturalWidth, srcImg.domNode.naturalHeight))
val input = tensorToInput(tensor)

...and it's time to run the model! πŸ₯³

val ir = InferenceSession.create("./dce2.onnx").await()
val out = ir.run(input).await()

val outTensor = out["output"] as Tensor
val outData = outTensor.data as Float32Array

The output then needs to have the reverse transform applied to be viewable in the browser. That is, reverse axis, add fourth dimension and cast into int.

// Calling on the output data, before converting to UInt8Clamped..
for (i in 0untiloutData.length) {
outData[i] = min(outData[i], 1f) * 255f // `min` to not go above 255
}

fun float32ToUInt8Clamped(data: Float32Array): Uint8ClampedArray {
val rgb =arrayOf(0, data.length / 3, data.length / 3 * 2)
val intOut = Uint8ClampedArray(data.length / 3 * 4)

for (i in 0untilintOut.length / 4) {
intOut.asDynamic()[i * 4 + 0] = data[rgb[0] + i].toInt()
intOut.asDynamic()[i * 4 + 1] = data[rgb[1] + i].toInt()
intOut.asDynamic()[i * 4 + 2] = data[rgb[2] + i].toInt()
intOut.asDynamic()[i * 4 + 3] = 255 }

return intOut
}

As you might notice we cast a lot asDynamic(), this is because of a current bug in Kotlin JS where it sends signed Byte when it should be an unsigned Byte.
See the current issue at youtrack.jetbrains.com.

We finally got all the pieces, how about gluing it all together? πŸ˜„

Solving the puzzle​

The model I wish to use has a dynamic input/output size, e.g. the image dimensions, I need to recreate the session or else it will throw as ONNX expects the last used shape on new runs. This is not true as images are of different sizes.
One solution would be to preprocess the image to always be the same size, but I prefer to return the image in the original dimensions for this use-case.

View code!
fun imgToUInt8ClampedArray(img: HTMLImageElement, ctx: CanvasRenderingContext2D): Uint8ClampedArray {  
/** same code as previously */
}

fun float32ToUInt8Clamped(data: Float32Array): Uint8ClampedArray {
/** same code as previously */
}

fun tensorToInput(tensor: Tensor, inputName: String = "input"): FeedsType {
/** same code as previously */
}

fun uInt8ClampedToFloat32Array(data: Uint8ClampedArray): Float32Array {
/** same code as previously */
}

@OptIn(ExperimentalTime::class, ExperimentalCoroutinesApi::class)
suspend fun main() {
val flow = RootStore("")
val isLoaded = RootStore("")
val webgl: dynamic = object {}
webgl["executionProviders"] = arrayOf("webgl") // want that WebGL GPU power

render {
val srcImg = img(id = "img-from") {
src(flow.data)
domNode.onload = { isLoaded.update(domNode.src) }
}
val targetCanvas = canvas(id = "img-to") {}
val imgContext = targetCanvas.domNode.getContext("2d") as CanvasRenderingContext2D

isLoaded.data
.distinctUntilChanged()
.filter { b64 -> b64.isNotEmpty() }
.map {
val ir = runCatching { InferenceSession.create("./dce2.onnx", webgl).await() }
.onFailure { showAlertToast { alert { title("Could not load WebGL, using WASM.") } } }
.getOrDefault(InferenceSession.create("./dce2.onnx").await())
val intData = imgToUInt8ClampedArray(srcImg.domNode, imgContext)
val floats = uInt8ClampedToFloat32Array(intData)

val tensor = Tensor("float32", floats, arrayOf(1, 3, srcImg.domNode.naturalWidth, srcImg.domNode.naturalHeight))
val input = tensorToInput(tensor)

val out = ir.run(input).await()
val outTensor = out["output"] as Tensor
val outData = outTensor.data as Float32Array

for (i in 0 until outData.length) {
outData[i] = min(outData[i], 1f) * 255f
}
val intOut = float32ToUInt8Clamped(outData)

ImageData(intOut, srcImg.domNode.naturalWidth, srcImg.domNode.naturalHeight)
} handledBy { imageData -> imgContext.putImageData(imageData, 0.0, 0.0) }

file {
accept("image/*")
button { text("Single select") }
}.map { file -> "data:${file.type};base64,${file.content}" } handledBy flow.update
}
}

With the joining bindings for ONNX.

Thoughts​

Wrapping it all together I feel like I want to leave with the sentiment that KotlinJS is a player, ONNX Webruntime certainly is capable and I'll continue creating small MVP:s and demos using this setup!

KotlinJS vs TypeScript​

Regarding KotlinJS I believe it's still behind TypeScript in terms of compatibility. I need to do more plumbing than someone using TS would, especially as dukat don't solve all my problems magically. Luckily it's very easy to make those bindings!

In terms of how usable it is I find it much better than TypeScript, the experience when working with KotlinJS-code (e.g. interfacing std-lib, pure kotlin code or bindings) is so much better than TypeScript - it's just like when I write my good ol' JVM applications. I'm not sure if I'm missing something, but TypeScript's typesystem always felt a bit choppy, just like Pythons. Sometimes I don't get the intellisense I'm expecting.

ONNX in the web​

The performance when using WebGL is definitely better than I expected, but not as good as using the usual runtime. Something I did notice during my testing is that it scales badly with size, using a high-res image (3000x4000) ends up slowing my whole computer. I know I'm not really working on a separate thread or anything, but it's too bad it doesn't scale well. Further there's an internal max-limit somewhere around the same dimensions, which I hit once with another image.

Personally, including these issues, I'm left impressed about how easy it is to set up a completely custom model to run inside the browser ("on the edge"), where we don't have to care about architecture, OS or anything and that it works efficiently enough to use.

I can see this as a key tool to start-ups and larger companies to reduce costs & inference-time (as computation happens on the edge). On top of the $'s I see a win for privacy as the data will never leave the users device, which in turn simplifies GDPR compliance and much more!

Even moving inference to the edge through a common simple interface that is the browser we'll still have plenty of need for servers, not only serving larger models for complex problems, old devices and batch inference of larger amounts of data.

The future is indeed still moving fast for Deep Learning and I can't wait to see where we're moving!
My own prediction: Deep Learning will simply ignore serverless computing and jump straight to edge computing in an effort to reduce costs.

The Combination​

Combining ONNX & KotlinJS (perhaps testing Compose rather than fritz2) is something I'm gonna keep on doing in the future to deploy demos. Either deploying through Github Pages or my own Raspberry Pi this will be piece a cake as my devices don't have to do the inference, keeping my costs down for fun demos.

Demo: A live demo can be found here.

And the code can be found on github.com/londogard/photo-fritz2, but be careful - it's not that beautiful right now 😰

That's it for now.. πŸ₯³
~Hampus LondΓΆgΓ₯rd