Was ist JAX (Python)?

JAX (Python) steht für „Just After eXecution“ und ist eine neu entwickelte Bibliothek von DeepMind, bei der das tiefe Lernen sowie das maschinelle Lernen im Fokus steht. JAX ist im Gegensatz zum bekannten Tensorflow kein offizielles Produkt von Google und erfreut sich aufgrund der NumPy-ähnlichen Syntax immer mehr Beliebtheit, insbesondere im Forschungsbereich.

JAX nutzt die maximale Anzahl von FLOPs, um daraus einen optimalen Code zu generieren. Genau betrachtet ist diese Bibliothek also eigentlich ein Just-In-Time-Compiler oder auch kurz JIT-Compiler genannt. Durch das Prinzip Just-in-Time kann JAX nicht nur auf der CPU, sondern eben auch auf GPU oder TPU eingesetzt werden, woraus sich ein weiterer wesentlicher Vorteil und zugleich das Hauptmerkmal von JAX bildet. JAX kann zwischen nativem Python- und NumPy-Code automatisch unterscheiden, sowie Teilmengen differenzieren als auch mehrfach Ableitungen nehmen.

In welchem Bereich wird JAX eingesetzt?

Der Einsatz einer solchen Software findet sich verstärkt in der Forschung und Entwicklung wieder. Vorwiegend durch die komplexe Arbeitsstruktur dieser Software können andere und teils neue Lösungswege errechnet werden. Die Umstellung auf dieses neue Prinzip ist jedoch nicht ganz einfach, was auch später im Vergleich zu PyTorch nochmal deutlich wird.

Prinzipiell kann man hier von einem ähnlichen Einsatzgebiet, wie bei PyTorch sprechen, aufgrund der komplexen Form wird jedoch nicht in jedem Unternehmen auf JAX zurückgegriffen. Zumindest momentan noch nicht.

Wenig überraschend ist es daher, dass der größte Einsatz aus dem Bereich der selbstständigen IT-Techniker kommt. Insbesondere für die Programmierung und Strukturierung oder die Berechnung von Grafikkarten kann JAX den Unterschied machen. Warum gerade in dieser Branche das Aufkommen so hoch ist, ist einfach zu erklären. Neue Unternehmen beginnen direkt mit dieser Software und nicht mit den Vorgängern, was eine spätere Umschulung oder Umgewöhnung entfallen lässt. Hier wird also direkt auf die nicht offiziell von Google unterstützte Software gesetzt. Mit Blick auf die Zukunft lässt sich hier also ableiten, dass die Nutzung und somit die Verbreitung dieser Software in Zukunft noch weiter zunehmen wird.

Was unterscheidet JAX von PyTorch?

Der wesentliche Unterschied zwischen beiden Programmen besteht in dem grundlegenden Aufbau von Berechnungen, die für ein Projekt genutzt werden sollen. Hieraus ergibt sich eine komplett andere Struktur, die somit auch Auswirkungen auf die Zurückverfolgung der errechneten Daten hat. Dies lässt sich wie folgt erklären:

PyTorch baut während des Durchlaufs einen Graphen auf. Hierbei kann ein Aufruf von backward() auf ein Ergebnis-Knoten durchgeführt werden. Daraus folgt eine Erweiterung auf jeden Zwischenknoten im Graphen, um den Gradienten des Ergebnisknotens, der bezogen ist auf diesen Zwischenknoten. Somit bietet sich hier die Möglichkeit, auf verschiedene Einzelelemente direkt Einfluss zu nehmen, die während des Durchlaufs entstehen / berechnet werden.

Bei Jax hingegen wird die Berechnung als Python-Funktion ausgedrückt, die durch die Umwandlung grad() zu einer Gradientenfunktion wird, die dann, wie die Berechnungsfunktion ausgewertet werden kann. Zudem wird die Ausgabe nicht als typische Ausgabe, sondern als Gradienten der Ausgabe zur Verfügung gestellt, der sich auf den ersten Parameter der Funktion bezieht.

Speziell dieser Unterschied macht deutlich, dass in beiden Varianten völlig unterschiedliche Codes geschrieben und Modelle aufgebaut werden müssen. Daher spricht man auch von einer großen Umstellung, wenn man von einem auf das andere Programm wechselt.

Die größte Gemeinsamkeit der beiden Softwaretypen ist ihre Flexibilität und das damit verbundene große Einsatzgebiet, wo sie sich wiederfinden können. Zudem arbeiten beide im Low- und High-Level API. Beide Typen sind Researchers, auch das kann man natürlich als Gemeinsamkeit ansehen.