federated-learning-client
v0.0.5
Published
Experimental client library for federated learning in Tensorflow.js
Downloads
6
Readme
Tensorflow.js Federated Learning Client
This library sets up a simple websocket-based client for transmitting and receiving Tensorflow.js model weights.
Usage
Basic
import * as federated from 'federated-learning-client';
const SERVER_URL = 'https://federated.learning.server';
const INIT_MODEL = 'https://initial.com/tf-model.json';
const client = new federated.Client(SERVER_URL, INIT_MODEL);
client.setup().then(() => {
const yhat = client.predict(x); // make predictions!
client.federatedUpdate(x, y); // train (and asynchronously update the server)!
});
After connecting to a server and downloading a model, the federated client
object exposes a model
that allows you to call predict
, evaluate
, and
federatedUpdate
on new data. Calling federatedUpdate
will asynchronously
send updated weights to the server. If the server receives enough updates
from different clients, it will update its canonical copy of the model and
send a new version down, which will transparently replace the current one.
Setting the Initial Model
During federated updates, the client and server will only communicate changed weights, which for applications like transfer learning may be fairly small in size compared to the full model. However, to begin the process, the client must first obtain a copy of the full model.
This can be done in one of the following ways:
new federated.Client(SERVER_URL, tfModel); // tf.Model instance
new federated.Client(SERVER_URL, 'https://initial.com/model.json'); // url pointing to one
new federated.Client(SERVER_URL, async () => { // async function returning one
const model = await tf.loadModel('https://initial.com/model.json')
return someCustomTransformationOfThe(model);
});
new federated.Client(SERVER_URL, federatedClientModel); // if you need custom behavior
You can pass a FederatedClientModel
if you need custom behavior not supported by tf.Model
s. See its documentation for the methods you will need to implement.
Currently, we do not support loading the model automatically from the server (e.g. just doing federated.Client(SERVER_URL)
). If the model you want to load client-side is only hosted on your federated learning server, but is a tf.Model
, you can implement this fairly simply by statically serving the model files, which are conveniently symlinked. Here is an example of how to do this with express
:
// server
const app = express();
const dir = '/models';
app.use(express.static(dir));
const webServer = http.createServer(app);
const fedServer = new federated.Server(webServer, initModel, {
modelDir: dir // current model will be symlinked to ${dir}/current
});
await fedServer.setup();
webServer.listen(80);
// client
const url = 'http://my.server';
const fedClient = new federated.Client(url, `${url}/current/model.json`);
await fedClient.setup();
The reason we recommend hosting the model elsewhere is to reduce unnecessary load on the server, which is not optimized to serve large static files.
Federated Learning
Simply call:
await client.federatedUpdate(inputTensor, targetTensor);
The federatedUpdate
method will:
- verify
inputTensor
andtargetTensor
have the correct shapes (can pass inputs either with or without the batch dimension) - check if it has enough examples to train (this is a hyperparameter sent by the server, which can be inspected on the client using
client.numExamplesPerUpdate()
) - if there are enough examples, it will fit the model, send updated weights to the server, and then revert back to its previous weights
- if there aren't, it will store the examples in memory (copying the tensors) and immediately resolve.
Setting the Loss Function
For tf.Model
-based federated learning, we assume by default that we are minimizing
the categoricalCrossEntropy
when users call client.fit(x, y)
. If this assumption is incorrect, you can do the following:
federated.Client(SERVER_URL, 'https://initial.com/model.json', {
modelCompileConfig: {
loss: 'meanSquaredError'
}
});
This dictionary will be passed to tf.Model.compile
, so you can also pass custom metrics
.
However, we will always set the optimizer
to tf.SGDOptimizer
to properly adopt a learning rate based on hyperparameters sent by the server. If you need more flexibility than this, you can define
a custom FederatedClientModel
.
Client IDs
By default, each federated learning client is assigned a random, anonymous identifier which is persisted in cookie storage. This ID is sent to the server alongside weight updates to enable features that prevent individual clients from dominating the aggregation process. It can also be helpful for anonymized performance metric analysis, e.g. determining if the gains of federated learning are going to one group of users at the expense of a minority.
If you would like to assign your own custom client identifier not stored in a cookie, you can do so as follows:
new federated.Client(SERVER_URL, MODEL_URL, {
clientId: 'custom-id'
});
In the future, we hope to support client identifiers tied to rate-limited authentication strategies to help prevent bot users from flooding the server with spurious updates.
Getting Stats
For printing and debugging purposes, you can get stats about the client using:
client.inputShape // shape of model inputs (without batch dimension)
client.outputShape // shape of model outputs (without batch dimension)
client.evaluate(x, y) // metrics for the current model on the data
client.predict(x) // predictions of current model on the data
client.modelVersion() // identifier of the current model
client.numUpdates() // how many updates client has contributed
client.numVersions() // how many versions client has downloaded
client.numExamples() // how many examples currently stored in client's buffer
client.numExamplesPerUpdate() // how many examples used in each update
client.numExamplesRemaining() // how many more examples needed before updating
You can also listen for changes in the model version (or successful uploads) by doing:
client.onNewVersion((oldVersion, newVersion) => {
console.log(`you've seen ${client.numVersions()} versions now!`);
$('#model-version').text(`Model Version #${newVersion}`);
});
client.onUpload(uploadMessage => {
console.log(`uploaded another version to the server!`);
});
Sending Metrics
By default, federated clients will only send updated weights to the server. However, you can configure them to also send the results of client.evaluate(x, y)
by setting the sendMetrics
config flag to true
:
new federated.Client(SERVER_URL, MODEL_URL, {
sendMetrics: true
})
If evaluate
returns the exact training loss, then sending this information could make the federated learning process slightly less private, but it is helpful for server-side performance monitoring.
TODO
- Support initialization of clients without initial models (i.e. autoloading them from the server)
- Support customizable client authentication (e.g. google oauth, captchas, etc)
- Collect a local dataset of examples we store even after they're used in
federatedUpdate
. Change the behavior offederatedUpdate
to compute weight updates relative to the original model but then retrain on the full local dataset -- so clients can see the benefits of training even before the server model adapts.