Utilisation de JAX pour accélérer notre recherche

Les ingénieurs de DeepMind accélèrent nos recherches en créant des outils, en étendant des algorithmes et en créant des mondes virtuels et physiques stimulants pour entraîner et tester des systèmes d’intelligence artificielle (IA). Dans le cadre de ce travail, nous évaluons constamment de nouveaux frameworks de bibliothèque et d’apprentissage automatique.

Nous avons récemment constaté qu’un nombre croissant de projets sont bien rendus par JAX, un cadre d’apprentissage automatique développé par les équipes de recherche de Google. JAX s’aligne bien avec notre philosophie d’ingénierie et a été largement adopté par notre communauté de recherche au cours de l’année écoulée. Nous partageons ici notre expérience de travail avec JAX, expliquons pourquoi nous le trouvons bénéfique pour notre recherche en IA et donnons un aperçu de l’écosystème que nous construisons pour soutenir les chercheurs du monde entier.

Pourquoi JAX ?

JAX est une bibliothèque Python conçue pour le calcul numérique haute performance, en particulier la recherche en apprentissage automatique. Son API pour les fonctions numériques est basée sur NumPy, un ensemble de fonctions utilisées en calcul scientifique. Python et NumPy sont à la fois largement utilisés et familiers, ce qui rend JAX simple, flexible et facile à utiliser.

En plus de l’API NumPy, JAX inclut un système extensible de Conversions de tâches configurables qui aident à soutenir la recherche sur l’apprentissage automatique, notamment :

  • différenciation: L’optimisation basée sur les gradients est au cœur du ML. JAX prend en charge à l’origine la différenciation automatique vers l’avant et vers l’arrière des fonctions scalaires arbitraires, via des conversions de fonctions telles que grad, hessian, jacfwd et jacrev.
  • Orientation: Dans la recherche ML, nous appliquons souvent une seule fonction à un grand nombre de données, par exemple en calculant la perte dans un groupe ou en évaluant les gradients pour chaque exemple d’apprentissage privé différentiel. JAX fournit un routage automatique via la transformation vmap qui simplifie cette forme de programmation. Par exemple, les chercheurs n’ont pas besoin de considérer le clustering lors de la mise en œuvre de nouveaux algorithmes. JAX prend également en charge le parallélisme des données à grande échelle via la transformation pmap associée, analysant élégamment les données trop volumineuses pour une seule mémoire d’accélérateur.
  • Compilation juste à temps : XLA est utilisé pour compiler et exécuter des programmes JAX juste-à-temps (JIT) sur des accélérateurs GPU et Cloud TPU. La compilation JIT, combinée à l’API compatible JAX de NumPy, permet aux chercheurs sans expérience HPC préalable d’étendre facilement à un ou plusieurs accélérateurs.

Nous avons constaté que JAX a permis une expérimentation rapide avec de nouveaux algorithmes et architectures et prend désormais en charge bon nombre de nos publications récentes. Pour en savoir plus, veuillez envisager de vous joindre à la table ronde JAX, le mercredi 9 décembre à 19 h 00 GMT, pour la conférence virtuelle NeurIPS.

JAX chez DeepMind

Soutenir la recherche moderne sur l’IA signifie équilibrer le prototypage rapide et l’itération rapide avec la capacité de déployer des expériences à l’échelle traditionnellement associée aux systèmes de production. Ce qui rend ces types de projets particulièrement difficiles, c’est que le paysage de la recherche évolue rapidement et est difficile à prévoir. A tout moment, une nouvelle avancée de la recherche pourrait, et régulièrement, modifier la trajectoire et les besoins d’équipes entières. Dans ce paysage en constante évolution, la principale responsabilité de notre équipe d’ingénieurs est de s’assurer que les leçons apprises et le code écrit pour un projet de recherche sont effectivement réutilisés dans le suivant.

La seule méthode éprouvée est le système modulaire : nous extrayons les blocs de construction les plus importants et les plus critiques développés dans chaque projet de recherche en modules bien testés et efficaces articles. Cela permet aux chercheurs de se concentrer sur leurs recherches tout en bénéficiant de la réutilisation du code, des corrections de bogues et des améliorations des performances des composants algorithmiques mis en œuvre par nos bibliothèques principales. Nous avons également trouvé important de s’assurer que chaque bibliothèque a une portée clairement définie et de s’assurer qu’elles sont interopérables mais indépendantes. Achat progressifLa capacité de sélectionner et de choisir des fonctionnalités sans se refermer sur les autres est essentielle pour offrir aux chercheurs une flexibilité maximale et toujours les aider à choisir le bon outil pour le travail.

D’autres considérations qui ont été prises en compte dans le développement de l’écosystème JAX incluent la garantie qu’il reste cohérent (dans la mesure du possible) avec la conception des bibliothèques TensorFlow existantes (telles que Sonnet et TRFL). Nous avons également cherché à créer des composants (le cas échéant) qui correspondent le plus possible à leurs mathématiques sous-jacentes, à être auto-descriptifs et à réduire les sauts mentaux « du papier au code ». Enfin, nous avons choisi d’ouvrir nos bibliothèques en open source pour faciliter le partage des résultats de la recherche et encourager la communauté au sens large à explorer l’écosystème JAX.

notre écosystème aujourd’hui

haïku

Le paradigme de programmation JAX des transformations de fonctions configurables peut rendre la gestion d’objets complexes avec état, tels que les réseaux de neurones avec des paramètres entraînables. Haiku est une bibliothèque de réseaux de neurones qui permet aux utilisateurs d’utiliser des modèles de programmation orientés objet familiers tout en exploitant la puissance et la simplicité du modèle fonctionnel pur de JAX.

Haiku est activement utilisé par des centaines de chercheurs de DeepMind et de Google, et a déjà été adopté dans plusieurs projets tiers (tels que Coax, DeepChem et NumPyro). Il s’appuie sur l’API de Sonnet, notre modèle de programmation basé sur des modules pour les réseaux de neurones dans TensorFlow, et nous visons à rendre le portage de Sonnet vers Haiku aussi simple que possible.

En savoir plus sur github

optaxe

L’optimisation basée sur les gradients est au cœur du ML. Optax fournit une bibliothèque de transformations de gradient, ainsi que des opérateurs de composition (tels que String) qui permettent d’exécuter plusieurs optimiseurs standard (tels que RMSProp ou Adam) en une seule ligne de code.

La nature synthétique d’Optax prend naturellement en charge la réincorporation des mêmes ingrédients de base dans des exhausteurs personnalisés. En outre, il fournit un certain nombre d’utilitaires pour l’estimation de gradient stochastique et l’optimisation de second ordre.

De nombreux utilisateurs d’Optax ont adopté Haiku, mais conformément à notre philosophie d’achat incrémentiel, toute bibliothèque représentant des paramètres tels que les arborescences JAX (par exemple Elegy, Flax et Stax) est prise en charge. Veuillez cliquer ici pour plus d’informations sur ce riche écosystème de bibliothèques JAX.

En savoir plus sur github

RLax

Bon nombre de nos projets les plus réussis se situent à l’intersection de l’apprentissage en profondeur et de l’apprentissage par renforcement (RL), également connu sous le nom d’apprentissage par renforcement en profondeur. RLax est une bibliothèque qui fournit des blocs de construction utiles pour la création d’agents RL.

Les composants de RLax couvrent un large éventail d’algorithmes et d’idées : apprentissage TD, gradients de politique, critiques d’acteurs, MAP, optimisation proche de la politique, transformation de valeur non linéaire, fonctions de valeur génériques et un certain nombre de méthodes d’exploration.

Bien que quelques exemples d’introduction d’agents soient fournis, RLax n’est pas conçu comme un cadre pour la construction et le déploiement de systèmes d’agents RL complets. Acme est un exemple de cadre de proxy complet qui s’appuie sur des composants RLax.

En savoir plus sur github

Chex

Les tests sont essentiels à la fiabilité des logiciels et Research Code ne fait pas exception. Pour tirer des conclusions scientifiques à partir d’expériences de recherche, vous devez être sûr que votre code est correct. Chex est une collection d’utilitaires de test utilisés par les auteurs de bibliothèques pour vérifier que les blocs de construction courants sont corrects et robustes et par les utilisateurs finaux pour vérifier leur propre code expérimental.

Chex fournit une variété d’utilitaires, notamment des tests unitaires compatibles JAX, des assertions sur les propriétés des types de données JAX, le gauchissement et le gauchissement, et des environnements de test multi-machines. Chex est utilisé dans l’écosystème JAX de DeepMind et dans des projets externes tels que Coax et MineRL.

En savoir plus sur github

bulldozer

Les réseaux de neurones graphiques (GNN) sont un domaine de recherche passionnant avec de nombreuses applications prometteuses. Voir, par exemple, nos travaux récents sur la prévision du trafic dans Google Maps et nos travaux sur les simulations physiques. Graff (prononcé « girafe ») est une bibliothèque légère pour prendre en charge le travail avec les GNN dans JAX.

Jraph fournit une structure de données unifiée pour les graphes, un ensemble d’utilitaires pour travailler avec des graphes et un “zoo” de modèles de réseaux neuronaux de graphes qui sont facilement bifurquables et extensibles. Les autres fonctionnalités clés incluent : la compilation de GraphTuples qui utilisent efficacement les accélérateurs matériels, la prise en charge de la compilation JIT de graphiques de forme variable via le remplissage et le masquage, et la définition des pertes sur les partitions d’entrée. Comme Optax et nos autres bibliothèques, Jraph n’impose aucune restriction sur le choix de l’utilisateur de la bibliothèque de réseau neuronal.

En savoir plus sur l’utilisation de la bibliothèque avec notre riche collection d’exemples.

En savoir plus sur github

L’écosystème JAX est en constante évolution et nous encourageons la communauté de recherche ML à explorer nos bibliothèques et le potentiel de JAX pour accélérer leurs recherches.

Enregistrer un commentaire

Plus récente Plus ancienne

نموذج الاتصال