@teachablemachine/image
v0.8.5
Published
A support library for use with mobilenet-based models generated by Teachable Machine (https://g.co/teachablemachine)
Downloads
2,684
Keywords
Readme
Teachable Machine Library - Image
Library for using image models created with Teachable Machine.
Model checkpoints
There is one link related to your model that will be provided by Teachable Machine
https://teachablemachine.withgoogle.com/models/MODEL_ID/
Which you can use to access:
- The model topology:
https://teachablemachine.withgoogle.com/models/MODEL_ID/model.json
- The model metadata:
https://teachablemachine.withgoogle.com/models/MODEL_ID/metadata.json
Usage
There are two ways to easily use the model provided by Teachable Machine in your Javascript project: by using this library via script tags or by installing this library from NPM (and using a build tool ike Parcel, WebPack, or Rollup)
via Script Tag
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@teachablemachine/[email protected]/dist/teachablemachine-image.min.js"></script>
via NPM
npm i @tensorflow/tfjs
npm i @teachablemachine/image
import * as tf from '@tensorflow/tfjs';
import * as tmImage from '@teachablemachine/image';
Sample snippet
<div>Teachable Machine Image Model</div>
<button type='button' onclick='init()'>Start</button>
<div id='webcam-container'></div>
<div id='label-container'></div>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@teachablemachine/[email protected]/dist/teachablemachine-image.min.js"></script>
<script type="text/javascript">
// More API functions here:
// https://github.com/googlecreativelab/teachablemachine-community/tree/master/libraries/image
// the link to your model provided by Teachable Machine export panel
const URL = '{{URL}}';
let model, webcam, labelContainer, maxPredictions;
let isIos = false;
// fix when running demo in ios, video will be frozen;
if (window.navigator.userAgent.indexOf('iPhone') > -1 || window.navigator.userAgent.indexOf('iPad') > -1) {
isIos = true;
}
// Load the image model and setup the webcam
async function init() {
const modelURL = URL + 'model.json';
const metadataURL = URL + 'metadata.json';
// load the model and metadata
// Refer to tmImage.loadFromFiles() in the API to support files from a file picker
// or files from your local hard drive
model = await tmImage.load(modelURL, metadataURL);
maxPredictions = model.getTotalClasses();
// Convenience function to setup a webcam
const flip = true; // whether to flip the webcam
const width = 200;
const height = 200;
webcam = new tmImage.Webcam(width, height, flip);
await webcam.setup(); // request access to the webcam
if (isIos) {
document.getElementById('webcam-container').appendChild(webcam.webcam); // webcam object needs to be added in any case to make this work on iOS
// grab video-object in any way you want and set the attributes
const webCamVideo = document.getElementsByTagName('video')[0];
webCamVideo.setAttribute("playsinline", true); // written with "setAttribute" bc. iOS buggs otherwise
webCamVideo.muted = "true";
webCamVideo.style.width = width + 'px';
webCamVideo.style.height = height + 'px';
} else {
document.getElementById("webcam-container").appendChild(webcam.canvas);
}
// append elements to the DOM
labelContainer = document.getElementById('label-container');
for (let i = 0; i < maxPredictions; i++) { // and class labels
labelContainer.appendChild(document.createElement('div'));
}
webcam.play();
window.requestAnimationFrame(loop);
}
async function loop() {
webcam.update(); // update the webcam frame
await predict();
window.requestAnimationFrame(loop);
}
// run the webcam image through the image model
async function predict() {
// predict can take in an image, video or canvas html element
let prediction;
if (isIos) {
prediction = await model.predict(webcam.webcam);
} else {
prediction = await model.predict(webcam.canvas);
}
for (let i = 0; i < maxPredictions; i++) {
const classPrediction =
prediction[i].className + ': ' + prediction[i].probability.toFixed(2);
labelContainer.childNodes[i].innerHTML = classPrediction;
}
}
</script>
API
Loading the model - url checkpoints
tmImage
is the module name, which is automatically included when you use the <script src>
method. It gets added as an object to your window so you can access via window.tmImage
or simply tmImage
.
tmImage.load(
checkpoint: string,
metadata?: string | Metadata
)
Args:
- checkpoint: a URL to a json file that contains the model topology and a reference to a bin file (model weights)
- metadata: a URL to a json file that contains the text labels of your model and additional information
Usage:
await tmImage.load(checkpointURL, metadataURL);
Loading the model - browser files
You can upload your model files from a local hard drive by using a file picker and the File interface.
tmImage.loadFromFiles(
model: File,
weights: File,
metadata: File
)
Args:
- model: a File object that contains the model topology (.json)
- weights: a File object with the model weights (.bin)
- metadata: a File object that contains the text labels of your model and additional information (.json)
Usage:
// you need to create File objects, like with file input elements (<input type="file" ...>)
const uploadModel = document.getElementById('upload-model');
const uploadWeights = document.getElementById('upload-weights');
const uploadMetadata = document.getElementById('upload-metadata');
model = await tmImage.loadFromFiles(uploadModel.files[0], uploadWeights.files[0], uploadMetadata.files[0])
Model - get total classes
Once you have loaded a model, you can obtain the total number of classes in the model.
This method exists on the model that is loaded from tmImage.load
.
model.getTotalClasses()
Returns a number representing the total number of classes
Model - get class labels
Once you have loaded a model, you can obtain the class labels (i.e. the name of each category the model was trained on).
This method exists on the model that is loaded from tmImage.getClassLabels
.
model.getClassLabels()
Returns an array with class names as strings.
Model - predict
Once you have loaded a model, you can make a classificaiton with a couple of different input options.
This method exists on the model that is loaded from tmImage.load
.
model.predict(
image: HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | ImageBitmap,
flipped = false
)
Args:
- image: an image, canvas, or video element to make a classification on
- flipped: a boolean to trigger whether to flip on X or not the image input
Usage:
// predict can take in an image, video or canvas html element
// if using the webcam utility, we set flip to true since the webcam was only
// flipped in CSS
const flip = true;
const allPredictions = await model.predict(webcamElement, flip);
Model - predictTopK
This is an alternative function to predict()
which returns the probability for all classes.
This method exists on the model that is loaded from tmImage.load
.
model.predictTopK(
image: HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | ImageBitmap,
maxPredictions = 10,
flipped = false
)
Args:
- image: an image, canvas, or video element to make a classification on
- flipped: a boolean to trigger whether to flip on X or not the image input
- maxPredictions: total number of predictions to return
Usage:
// predictTopK can take in an image, video or canvas html element
// if using the webcam utility, we set flip to true since the webcam was only
// flipped in CSS
const flip = true;
const maxPredictions = model.getTotalClasses();
const prediction = await model.predictTopK(webcamElement, maxPredictions, flip);
Webcam
You can optionally use a webcam class that comes with the library, or spin up your own webcam. This class exists on the tmImage
module.
Please note that the default webcam used in Teachable Machine was flipped on X - so you should probably set flip = true
if creating your own webcam unless you flipped it manually in Teachable Machine.
new tmImage.Webcam(
width = 400,
height = 400,
flip = false,
)
Args:
- width: width of the webcam. It should ideally be square since that's how the model was trained with Teachable Machine.
- height: height of the webcam. It should ideally be square since that's how the model was trained with Teachable Machine.
- flip: boolean to signal whether webcam should be flipped on X. Please note this is only flipping on CSS.
Usage:
// webcam has a square ratio and is flipped by default to match training
const webcam = new tmImage.Webcam(200, 200, true);
await webcam.setup();
webcam.play();
document.body.appendChild(webcam.canvas);
Webcam - setup
After creating a Webcam object you need to call setup just once to set it up.
webcam.setup(
options: MediaTrackConstraints = {}
)
Args:
- options: optional media track contraints for the webcam
Usage:
await webcam.setup();
Webcam - play, pause, stop
webcam.play();
webcam.pause();
webcam.stop();
Webcam play loads and starts playback of a media resource. Returns a promise.
Webcam - update
Call on update to update the webcam frame.
webcam.update();