Autograd for ndarray
A collection of automatic differentiation tools for operations on ndarrays.
First, install the library using npm:
npm install autograd
Then you can import the library by doing:
const autograd = require("autograd");
Then you can use the functions as in the following example:
const autograd = require("autograd");
const ndarray = require("ndarray");
const ops = require("ndarray-ops");
// re-create multiplication as differentiable function.
const mul = autograd.op({
name: '*',
forward(x1, x2) {
// clone x1 to create a new `autograd.variable` as a result of the operation `*`.
// if x1 is not a `autograd.variable` but an `ndarray`, this is the same as
// calling `require('ndarray-scratch').clone(x1)`.
const y = autograd.clone(x1);
ops.muleq(y, x2);
return y;
backward(y, x1, x2) {
// only differentiate if the input `x1` or `x2` is a `autograd.variable`.
if (autograd.isvariable(x1)) {
ops.mul(x1.grad, y.grad, x2);
if (autograd.isvariable(x2)) {
ops.mul(x2.grad, y.grad, x1);
// call the op with constants
const x1 = ndarray(new Float32Array(128*128).fill(2));
const x2 = ndarray(new Float32Array(128*128).fill(4));
let y = mul(x1, x2);
// y = [8 8 8 ...]
// ...
// call the op with variables!
// create differentiable variables to track gradients.
const varx1 = autograd.variable(x1, "x1");
const varx2 = autograd.variable(x2, "x2");
y = mul(varx1, varx2);
// y = [8 8 8 ...]
// with variables, access to gradients via `<variable>.grad` is available.
// call backward to compute partial derivatives of `y` w.r.t each of it's inputs.
// backward also accepts input gradient information from external sources.
// this by default makes `y.grad` === the gradient input.
// Without a gradient input, `y.grad` === dy/dy === [1 1 1 ...].
autograd.backward(y, ndarray(new Float32Array(128*128).fill(1)));
// ** pulls out Calculus book **
// At a glance, `autograd.backward(y)` does the following:
// y = x1 * x2
// dy/dx1 = d(y) / dx1
// dy/dx1 = d(x1 * x2) / dx1 (x2 treated as constant)
// dy/dx1 = (dx1 * x2) / dx1 (take the derivative w.r.t x1)
// dy/dx1 = x2 (dx1/dx1 = 1)
// asserting the identity of our partial derivatives,
// (dy/dy) * (dy/dx1) = (dy/dx1)