Code Monkey home page Code Monkey logo

Comments (2)

xiaohk avatar xiaohk commented on May 3, 2024

Hello, this is related to #2.

Currently CNN Explainer only supports the Tiny-VGG architecture that we described in our manuscript. If you want to use a different CNN model, then you would need to modify the code. If you want to use another pre-trained Tiny-VGG model, you can see the following functions in CNN Explainer (how we use tensorflow.js to load the model):

model = await loadTrainedModel('PUBLIC_URL/assets/data/model.json');
cnn = await constructCNN(`PUBLIC_URL/assets/img/${selectedImage}`, model);

/**
* Wrapper to load a model.
*
* @param {string} modelFile Filename of converted (through tensorflowjs.py)
* model json file.
*/
export const loadTrainedModel = (modelFile) => {
return tf.loadLayersModel(modelFile);
}

/**
* Construct a CNN with given extracted outputs from every layer.
*
* @param {number[][]} allOutputs Array of outputs for each layer.
* allOutputs[i][j] is the output for layer i node j.
* @param {Model} model Loaded tf.js model.
* @param {Tensor} inputImageTensor Loaded input image tensor.
*/
const constructCNNFromOutputs = (allOutputs, model, inputImageTensor) => {
let cnn = [];
// Add the first layer (input layer)
let inputLayer = [];
let inputShape = model.layers[0].batchInputShape.slice(1);
let inputImageArray = inputImageTensor.transpose([2, 0, 1]).arraySync();
// First layer's three nodes' outputs are the channels of inputImageArray
for (let i = 0; i < inputShape[2]; i++) {
let node = new Node('input', i, nodeType.INPUT, 0, inputImageArray[i]);
inputLayer.push(node);
}
cnn.push(inputLayer);
let curLayerIndex = 1;
for (let l = 0; l < model.layers.length; l++) {
let layer = model.layers[l];
// Get the current output
let outputs = allOutputs[l].squeeze();
outputs = outputs.arraySync();
let curLayerNodes = [];
let curLayerType;
// Identify layer type based on the layer name
if (layer.name.includes('conv')) {
curLayerType = nodeType.CONV;
} else if (layer.name.includes('pool')) {
curLayerType = nodeType.POOL;
} else if (layer.name.includes('relu')) {
curLayerType = nodeType.RELU;
} else if (layer.name.includes('output')) {
curLayerType = nodeType.FC;
} else if (layer.name.includes('flatten')) {
curLayerType = nodeType.FLATTEN;
} else {
console.log('Find unknown type');
}
// Construct this layer based on its layer type
switch (curLayerType) {
case nodeType.CONV: {
let biases = layer.bias.val.arraySync();
// The new order is [output_depth, input_depth, height, width]
let weights = layer.kernel.val.transpose([3, 2, 0, 1]).arraySync();
// Add nodes into this layer
for (let i = 0; i < outputs.length; i++) {
let node = new Node(layer.name, i, curLayerType, biases[i],
outputs[i]);
// Connect this node to all previous nodes (create links)
// CONV layers have weights in links. Links are one-to-multiple.
for (let j = 0; j < cnn[curLayerIndex - 1].length; j++) {
let preNode = cnn[curLayerIndex - 1][j];
let curLink = new Link(preNode, node, weights[i][j]);
preNode.outputLinks.push(curLink);
node.inputLinks.push(curLink);
}
curLayerNodes.push(node);
}
break;
}
case nodeType.FC: {
let biases = layer.bias.val.arraySync();
// The new order is [output_depth, input_depth]
let weights = layer.kernel.val.transpose([1, 0]).arraySync();
// Add nodes into this layer
for (let i = 0; i < outputs.length; i++) {
let node = new Node(layer.name, i, curLayerType, biases[i],
outputs[i]);
// Connect this node to all previous nodes (create links)
// FC layers have weights in links. Links are one-to-multiple.
// Since we are visualizing the logit values, we need to track
// the raw value before softmax
let curLogit = 0;
for (let j = 0; j < cnn[curLayerIndex - 1].length; j++) {
let preNode = cnn[curLayerIndex - 1][j];
let curLink = new Link(preNode, node, weights[i][j]);
preNode.outputLinks.push(curLink);
node.inputLinks.push(curLink);
curLogit += preNode.output * weights[i][j];
}
curLogit += biases[i];
node.logit = curLogit;
curLayerNodes.push(node);
}
// Sort flatten layer based on the node TF index
cnn[curLayerIndex - 1].sort((a, b) => a.realIndex - b.realIndex);
break;
}
case nodeType.RELU:
case nodeType.POOL: {
// RELU and POOL have no bias nor weight
let bias = 0;
let weight = null;
// Add nodes into this layer
for (let i = 0; i < outputs.length; i++) {
let node = new Node(layer.name, i, curLayerType, bias, outputs[i]);
// RELU and POOL layers have no weights. Links are one-to-one
let preNode = cnn[curLayerIndex - 1][i];
let link = new Link(preNode, node, weight);
preNode.outputLinks.push(link);
node.inputLinks.push(link);
curLayerNodes.push(node);
}
break;
}
case nodeType.FLATTEN: {
// Flatten layer has no bias nor weights.
let bias = 0;
for (let i = 0; i < outputs.length; i++) {
// Flatten layer has no weights. Links are multiple-to-one.
// Use dummy weights to store the corresponding entry in the previsou
// node as (row, column)
// The flatten() in tf2.keras has order: channel -> row -> column
let preNodeWidth = cnn[curLayerIndex - 1][0].output.length,
preNodeNum = cnn[curLayerIndex - 1].length,
preNodeIndex = i % preNodeNum,
preNodeRow = Math.floor(Math.floor(i / preNodeNum) / preNodeWidth),
preNodeCol = Math.floor(i / preNodeNum) % preNodeWidth,
// Use channel, row, colume to compute the real index with order
// row -> column -> channel
curNodeRealIndex = preNodeIndex * (preNodeWidth * preNodeWidth) +
preNodeRow * preNodeWidth + preNodeCol;
let node = new Node(layer.name, i, curLayerType,
bias, outputs[i]);
// TF uses the (i) index for computation, but the real order should
// be (curNodeRealIndex). We will sort the nodes using the real order
// after we compute the logits in the output layer.
node.realIndex = curNodeRealIndex;
let link = new Link(cnn[curLayerIndex - 1][preNodeIndex],
node, [preNodeRow, preNodeCol]);
cnn[curLayerIndex - 1][preNodeIndex].outputLinks.push(link);
node.inputLinks.push(link);
curLayerNodes.push(node);
}
// Sort flatten layer based on the node TF index
curLayerNodes.sort((a, b) => a.index - b.index);
break;
}
default:
console.error('Encounter unknown layer type');
break;
}
// Add current layer to the NN
cnn.push(curLayerNodes);
curLayerIndex++;
}
return cnn;
}

/**
* Construct a CNN with given model and input.
*
* @param {string} inputImageFile filename of input image.
* @param {Model} model Loaded tf.js model.
*/
export const constructCNN = async (inputImageFile, model) => {
// Load the image file
let inputImageTensor = await getInputImageArray(inputImageFile, true);
// Need to feed the model with a batch
let inputImageTensorBatch = tf.stack([inputImageTensor]);
// To get intermediate layer outputs, we will iterate through all layers in
// the model, and sequencially apply transformations.
let preTensor = inputImageTensorBatch;
let outputs = [];
// Iterate through all layers, and build one model with that layer as output
for (let l = 0; l < model.layers.length; l++) {
let curTensor = model.layers[l].apply(preTensor);
// Record the output tensor
// Because there is only one element in the batch, we use squeeze()
// We also want to use CHW order here
let output = curTensor.squeeze();
if (output.shape.length === 3) {
output = output.transpose([2, 0, 1]);
}
outputs.push(output);
// Update preTensor for next nesting iteration
preTensor = curTensor;
}
let cnn = constructCNNFromOutputs(outputs, model, inputImageTensor);
return cnn;
}

Let us know if you have more questions. I will close the issue for now :P

from cnn-explainer.

EricCousineau-TRI avatar EricCousineau-TRI commented on May 3, 2024

As an FYI, docs for converting to TensorFlow.js models:
https://www.tensorflow.org/js/tutorials#convert_pretrained_models_to_tensorflowjs
Relates #12.

from cnn-explainer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.