KotlinJS, ONNX and Deep Learning in the browser
One day I had the crazy idea to try two non-mainstream things out at the same time. On top of that I figured I’d combine them in the same project, imagine that!
Preview of final result running model inference in the browser using KotlinJS, ONNX & fritz2:
Quick Kotlin JS
KotlinJS resembles TypeScript (TS) in the sense that it’s typed and transpiles into JavaScript (JS) at the end of the day. The final JS code runs directly in the browser or through Node.js.
What makes KotlinJS stand out? In my optinion it picks up where TS leaves. By adding (almost) all of the Kotlin ecosystem we get a really superb toolbox out-of-the-box, which is more than simply types. Some of the awesome perks are coroutines and collections.
As someone who has done a lot of backend development in Scala, with some Java, it feels like home because of the familiar apperance and interaction.
Having sweet syntax, superb typing I feel a great preference toward KotlinJS even if TS is closer to JS making transpiled code easier to reason about.
Quick ONNX
Open Neural Network Exchange (ONNX) Runtime is a open format created by Facebook, Microsoft & others, and is part of Linux Foundation AI.
ONNX is an open polyglot format, meaning that you can run Neural Networks from multiple coding languages. This in turn promotes innovation and collaboration, especially through the fact that you can run your State-of-the-Art model almost everywhere, including C# and Java.
ONNX is a impressive feat that allows companies to reduce their inference time by magnitudes, cherry on top it reduces code complexity when models are deployed directly in the original backend.
Recently ONNX added a new runtime, ONNX-webruntime, which enables ONNX models to run directly inside the browser. Simply take your PyTorch/Tensorflow model, convert to ONNX and then run! Incredible! 🎉.
ONNX-webruntime leverages WebGL as GPU and WASM with SIMD as CPU.
Simple edge deployment is here!
The Set Up
The set up is simple, - Kotlin JS project - fritz2
as web framework - onnx-webruntime
as deep learning inference tool
For this demo I could’ve used raw html elements in the Kotlin JS code, but it’s more fun to use something enjoyable, as such I chose fritz2
that I introduce below 👇.
Introducing fritz2
Introducing fritz2, a small but impressive framework.
Because of the size you can understand the full idea and implementation, which is something you cannot say about React. Through simple DSLs, superb usage of Flow<T>
you end up with a simple yet powerful model that maps perfectly to my own mind.
In my opinion fritz2 feels less magic while very powerful and simple. Everything works with full typing and no hacks. Cherry on the top? No virtual dom!
Fritz2 has a extra components library which you can additionally install. This library contains simple components to make your development much faster, with things like File input, Data Tables and much more!
Personally I even did my own wedding website using fritz2
, and it ended up pretty great! > My personal wedding site created in fritz2
ONNX typing in KotlinJS
Using dukat
(included by default in Kotlin > 1.6 or perhaps earlier) it’s possible to generate external types/bindings for any TypeScript project.
Guess what, ONNXRuntime Web is full TypeScript - awesome!
Unfortunately ONNX has some really weird structure which I’d call non-standard, this ends up not working great in dukat
-generation…
Luckily enough it is easy to make your own bindings. Keep your breath for now, I’ll share them later in this post, but for now let’s say that it’s like a .d.ts-file.
The Implementation
We need to create our project, I usually do it by scratch but if you want to keep it easy setting up the MPP project for friz2
you can make sure to use their template project. Make sure to include the fritz2 component library, as it’ll be used in the implementation.
Please note that the focus will be ONNX, as such I’ll save some fritz2
details for another post.
Basic UI
Getting the skeleton UI up
In the “main” file of the js-folder, but not as in js-code 😉, you’ll need to set up a file and image element.
fun main() {
val imgSrc = RootStore("")
{
render val srcImg = img(id = "img-from") {
(imgSrc.data)
src}
}
{
file ("image/*")
accept{ text("Single select") }
button }
.map { file -> "data:${file.type};base64,${file.content}" } handledBy imgSrc.update
}
Breaking down what’s done
- A
RootStore
is a abstraction on top of a(Mutable)StateFlow
which is aFlow
with a state.- In simple terms a
Flow
is a collection of asynchronously computed values just like you haveSequence
andList
being collections of synchronously computed values.
- In simple terms a
- A
Store
is a reactive component that contains our apps state, it can do bidirectional communication with the DOM/GUI.- We update
imgSrc
through thefile
-component, wheneverfile
is updated.
<img>
listens on changes fromimgSrc
, hence it’s updated asimgSrc
is updated- All in all we get typed and no-magic dynamical updates in our GUI. This is something I love, compared to
react
andsvelte
where it seems more magical.
- We update
The connector between file
and imgSrc
is dirty, I hoped to be able to load the b64
content directly into a UInt8ClampedArray
to have optimal performance, but because the b64
-string actually contains PNG/JPEG headers and other things the perfomance gains versus simplicity is not worth it. Hence I transform it from the b64
-string (data:image/pdf;base64,<content>
) to image and then extract ImageData
- annoying but clean.
The detail that <content>
in b64
-string is only the pixel data haunted me for a long time… I couldn’t figure why my arrays had the wrong dimensions! 😅
The next step: transfer image from this component to another, while allowing a transformation (neural network inference) in-between.
// highlight-start
fun loadImgToCanvas(img: Image, canvas: Canvas, context: CanvasRenderingContext2D) {
if (img.domNode.src.isNotEmpty()) {
.width = img.domNode.naturalWidth
canvas.height = img.domNode.naturalHeight
canvas.drawImage(img, 0.0, 0.0)
context}
}
// highlight-end
fun main() {
val imgSrc = RootStore("")
{
render val srcImg = img(id = "img-from") {
(imgSrc.data)
src}
val targetCanvas = canvas(id = "img-to") { }
val imgContext = targetCanvas.domNode.getContext("2d") as CanvasRenderingContext2D
{ /** same as before ... */
file
// highlight-next-line
.domNode.onload { loadImgToCanvas(srcImg, targetCanvas, imgContext) }
srcImg}
}
Whenever srcImg.onload
event happens we call loadImgToCanvas
which loads img
on our canvas.
Why did I choose to not have a new <img>
? Because we later need to use ImageData
and this is the way to have the minimum number of data transitions, trust me 😉.
Let’s start adding bindings for ONNX!
Binding ONNX and webruntime
Binding TS/JS is as simple as a .d.ts
-file in TS. You define the component to bind, declare the types, e.g. function name, input and outout. Simple as that!
@file:JsModule("onnxruntime-web") // npm-package
@file:JsNonModule
import kotlin.js.Promise
external abstract class InferenceSession {
fun run(feeds: FeedsType): Promise<ReturnType> // FeedsType / ReturnType separately defined the same way as InferenceSession & run.
}
Moving further we’ll add a method to extract ImageData
’s UInt8ClampedArray
from a img
-element using a canvas
-element with its CanvasRenderingContext2D
(lots of JS/web words, the most I’ll have in a sentence, peeew! 😅)
fun imgToUInt8ClampedArray(img: HTMLImageElement, ctx: CanvasRenderingContext2D): Uint8ClampedArray {
val canvas = ctx.canvas
.width = img.naturalWidth
canvas.height = img.naturalHeight
canvas.drawImage(img, 0.0, 0.0)
ctx
return ctx.getImageData(0.0, 0.0, img.naturalWidth.toDouble(), img.naturalHeight.toDouble()).data // extract data from ImageData
}
The UInt8ClampedArray
has to be transformed into a Float32Array
that the model expects.
Sounds easy? Think again!
Because JS is not a data science language it’s not surprising that the data is “incorrectly” ordered. The model expects the data to be formed as [3,width,height]
where 3 is the number of dimensions, in our case RGB, but in JS it’s the reverse way. On top of the wrong ordering JS has a fourth dimension, namely transparency. Following all that knowledge we can transform the array.
fun uInt8ClampedToFloat32Array(data: Uint8ClampedArray): Float32Array {
val floats = Float32Array(data.length / 4 * 3)
val rgb =listOf(0, data.length / 4, data.length / 4 * 2)
for (i in 0untildata.lengthstep4) {
[rgb[0] + i / 4] = data[i + 0] / 255f
floats[rgb[1] + i / 4] = data[i + 1] / 255f
floats[rgb[2] + i / 4] = data[i + 2] / 255f // Skip i+3 as that's ALPHA
floats}
return floats
}
As ONNX expects Tensor
we need to transform the Float32Array
into a Tensor
and then into FeedsInput
which is a Object
of the data, luckily that’s very easy after our binding is done.
fun tensorToInput(tensor: Tensor, inputName: String = "input"): FeedsType {
val input: dynamic = object {} // To hack JS Objects
[inputName] = tensor
input
return input.unsafeCast<FeedsType>()
}
val tensor = Tensor("float32", floats, arrayOf(1, 3, srcImg.domNode.naturalWidth, srcImg.domNode.naturalHeight))
val input = tensorToInput(tensor)
…and it’s time to run the model! 🥳
val ir = InferenceSession.create("./dce2.onnx").await()
val out = ir.run(input).await()
val outTensor = out["output"] as Tensor
val outData = outTensor.data as Float32Array
The output then needs to have the reverse transform applied to be viewable in the browser. That is, reverse axis, add fourth dimension and cast into int
.
// Calling on the output data, before converting to UInt8Clamped..
for (i in 0untiloutData.length) {
[i] = min(outData[i], 1f) * 255f // `min` to not go above 255
outData}
fun float32ToUInt8Clamped(data: Float32Array): Uint8ClampedArray {
val rgb =arrayOf(0, data.length / 3, data.length / 3 * 2)
val intOut = Uint8ClampedArray(data.length / 3 * 4)
for (i in 0untilintOut.length / 4) {
.asDynamic()[i * 4 + 0] = data[rgb[0] + i].toInt()
intOut.asDynamic()[i * 4 + 1] = data[rgb[1] + i].toInt()
intOut.asDynamic()[i * 4 + 2] = data[rgb[2] + i].toInt()
intOut.asDynamic()[i * 4 + 3] = 255 }
intOut
return intOut
}
As you might notice we cast a lot asDynamic()
, this is because of a current bug in Kotlin JS where it sends signed Byte
when it should be an unsigned Byte
.
See the current issue at youtrack.jetbrains.com.
We finally got all the pieces, how about gluing it all together? 😄
Solving the puzzle
The model I wish to use has a dynamic input/output size, e.g. the image dimensions, I need to recreate the session or else it will throw as ONNX expects the last used shape on new runs. This is not true as images are of different sizes.
One solution would be to preprocess the image to always be the same size, but I prefer to return the image in the original dimensions for this use-case.
View code!
fun imgToUInt8ClampedArray(img: HTMLImageElement, ctx: CanvasRenderingContext2D): Uint8ClampedArray {
/** same code as previously */
}
fun float32ToUInt8Clamped(data: Float32Array): Uint8ClampedArray {
/** same code as previously */
}
fun tensorToInput(tensor: Tensor, inputName: String = "input"): FeedsType {
/** same code as previously */
}
fun uInt8ClampedToFloat32Array(data: Uint8ClampedArray): Float32Array {
/** same code as previously */
}
@OptIn(ExperimentalTime::class, ExperimentalCoroutinesApi::class)
fun main() {
suspend val flow = RootStore("")
val isLoaded = RootStore("")
val webgl: dynamic = object {}
["executionProviders"] = arrayOf("webgl") // want that WebGL GPU power
webgl
{
render val srcImg = img(id = "img-from") {
(flow.data)
src.onload = { isLoaded.update(domNode.src) }
domNode}
val targetCanvas = canvas(id = "img-to") {}
val imgContext = targetCanvas.domNode.getContext("2d") as CanvasRenderingContext2D
.data
isLoaded.distinctUntilChanged()
.filter { b64 -> b64.isNotEmpty() }
.map {
val ir = runCatching { InferenceSession.create("./dce2.onnx", webgl).await() }
.onFailure { showAlertToast { alert { title("Could not load WebGL, using WASM.") } } }
.getOrDefault(InferenceSession.create("./dce2.onnx").await())
val intData = imgToUInt8ClampedArray(srcImg.domNode, imgContext)
val floats = uInt8ClampedToFloat32Array(intData)
val tensor = Tensor("float32", floats, arrayOf(1, 3, srcImg.domNode.naturalWidth, srcImg.domNode.naturalHeight))
val input = tensorToInput(tensor)
val out = ir.run(input).await()
val outTensor = out["output"] as Tensor
val outData = outTensor.data as Float32Array
for (i in 0 until outData.length) {
[i] = min(outData[i], 1f) * 255f
outData}
val intOut = float32ToUInt8Clamped(outData)
(intOut, srcImg.domNode.naturalWidth, srcImg.domNode.naturalHeight)
ImageData} handledBy { imageData -> imgContext.putImageData(imageData, 0.0, 0.0) }
{
file ("image/*")
accept{ text("Single select") }
button }.map { file -> "data:${file.type};base64,${file.content}" } handledBy flow.update
}
}
With the joining bindings for ONNX.
Thoughts
Wrapping it all together I feel like I want to leave with the sentiment that KotlinJS is a player, ONNX Webruntime certainly is capable and I’ll continue creating small MVP:s and demos using this setup!
KotlinJS vs TypeScript
Regarding KotlinJS I believe it’s still behind TypeScript in terms of compatibility. I need to do more plumbing than someone using TS would, especially as dukat
don’t solve all my problems magically. Luckily it’s very easy to make those bindings!
In terms of how usable it is I find it much better than TypeScript, the experience when working with KotlinJS-code (e.g. interfacing std-lib, pure kotlin code or bindings) is so much better than TypeScript - it’s just like when I write my good ol’ JVM applications. I’m not sure if I’m missing something, but TypeScript’s typesystem always felt a bit choppy, just like Pythons. Sometimes I don’t get the intellisense I’m expecting.
ONNX in the web
The performance when using WebGL is definitely better than I expected, but not as good as using the usual runtime. Something I did notice during my testing is that it scales badly with size, using a high-res image (3000x4000) ends up slowing my whole computer. I know I’m not really working on a separate thread or anything, but it’s too bad it doesn’t scale well. Further there’s an internal max-limit somewhere around the same dimensions, which I hit once with another image.
Personally, including these issues, I’m left impressed about how easy it is to set up a completely custom model to run inside the browser (“on the edge”), where we don’t have to care about architecture, OS or anything and that it works efficiently enough to use.
I can see this as a key tool to start-ups and larger companies to reduce costs & inference-time (as computation happens on the edge). On top of the $’s I see a win for privacy as the data will never leave the users device, which in turn simplifies GDPR compliance and much more!
Even moving inference to the edge through a common simple interface that is the browser we’ll still have plenty of need for servers, not only serving larger models for complex problems, old devices and batch inference of larger amounts of data.
The future is indeed still moving fast for Deep Learning and I can’t wait to see where we’re moving!
My own prediction: Deep Learning will simply ignore serverless computing and jump straight to edge computing in an effort to reduce costs.
The Combination
Combining ONNX & KotlinJS (perhaps testing Compose rather than fritz2) is something I’m gonna keep on doing in the future to deploy demos. Either deploying through Github Pages or my own Raspberry Pi this will be piece a cake as my devices don’t have to do the inference, keeping my costs down for fun demos.
Demo: A live demo can be found here.
And the code can be found on github.com/londogard/photo-fritz2, but be careful - it’s not that beautiful right now 😰
That’s it for now.. 🥳
~Hampus Londögård