Code Monkey home page Code Monkey logo

Comments (7)

MohamedLEGH avatar MohamedLEGH commented on May 27, 2024

For information this is the model parameters before learning (if you know how to write all the values instead of "37 more" please tell me):

weight: (1, 57) cpu() float32 hasGradient
[[-0.0309, -0.0797,  0.3092, -0.1011,  0.2231,  0.118 ,  0.1488,  0.0543, -0.5098, -0.4196, -0.1896,  0.0717,  0.0586, -0.0586,  0.0395,  0.0914,  0.3283, -0.1463, -0.3702,  0.1664, ... 37 more],
]

bias: (1) cpu() float32 hasGradient
[0.]

And this is after training:

weight: (1, 57) cpu() float32 hasGradient
[[ -0.3021,  -3.2209,  -0.8564,   0.5633,   0.0698,   0.1584,   1.1403,   0.4022,  -0.3137,  -1.5199,  -0.0627,  -5.3316,  -0.1716,  -0.3107,   0.3686,   1.119 ,   0.8939,  -0.0675,  -8.1434,   0.9069, ... 37 more],
]

bias: (1) cpu() float32 hasGradient
[-9.5678]

I tried to use Predictor to check the values predicted, with this code

Translator translator = new NoopTranslator();
Predictor predictor = model.newPredictor(translator);

for(Batch b: validationSet.getData(manager)) {
    NDList data = b.getData();
    NDArray prediction = ((NDList) predictor.predict(data)).singletonOrThrow();
    NDArray truelabel = b.getLabels().singletonOrThrow();

    System.out.println("Predicted is: " + prediction.toString());
    System.out.println("True value is: " + truelabel.toString());
}

And this is the result (one batch but it's similar for all batch):

Predicted is: ND: (10, 1) cpu() float32
[[-10483.418 ],
 [ -1452.9528],
 [ -7738.9502],
 [  -709.6656],
 [ -1241.8141],
 [ -3426.0046],
 [  -188.6846],
 [  -580.9608],
 [   -98.5921],
 [ -1608.1211],
]

True value is: ND: (10, 1) cpu() int32
[[ 1],
 [ 0],
 [ 1],
 [ 0],
 [ 0],
 [ 1],
 [ 1],
 [ 0],
 [ 0],
 [ 1],
]

Then I tried to use a sigmoid operations on the predicted values, like this:

for(Batch b: validationSet.getData(manager)) {
    NDList data = b.getData();
    NDArray prediction = ((NDList) predictor.predict(data)).singletonOrThrow();
    NDArray prediction_binary = Activation.sigmoid(prediction);
    NDArray truelabel = b.getLabels().singletonOrThrow();

    System.out.println("Predicted is: " + prediction_binary.toString());
    System.out.println("True value is: " + truelabel.toString());
}

and results are like this :

Predicted is: ND: (10, 1) cpu() float32
[[1.    ],
 [1.    ],
 [1.    ],
 [1.    ],
 [1.    ],
 [1.    ],
 [1.    ],
 [0.0012],
 [1.    ],
 [1.    ],
]

True value is: ND: (10, 1) cpu() int32
[[ 0],
 [ 0],
 [ 0],
 [ 1],
 [ 0],
 [ 0],
 [ 0],
 [ 0],
 [ 1],
 [ 1],
]

Any idea where the problem is?

from djl.

adinath1233 avatar adinath1233 commented on May 27, 2024

Hey @MohamedLEGH,
I have question is NoopTranslator() is used for only for classification or can used for the Regression?

as well this is because weights I guess try to StandardScale or min max scaling on the input data.

from djl.

MohamedLEGH avatar MohamedLEGH commented on May 27, 2024

Ok, I have updated the code,
I made the following modifications:

  • The data are normalized (remove mean and divide by standard deviation)
  • The BinaryAccuracy is used instead of Accuracy

Below you find the code updated :

package machine_learning;

import java.io.IOException;
import java.util.Random;

import tech.tablesaw.api.Table;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.Parameter;
import ai.djl.nn.Activation;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.loss.Loss;
import ai.djl.training.loss.SigmoidBinaryCrossEntropyLoss;
import ai.djl.training.tracker.Tracker;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.BinaryAccuracy;
import ai.djl.training.TrainingResult;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.dataset.Batch;
import ai.djl.training.EasyTrain;
import ai.djl.training.initializer.ConstantInitializer;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.NoopTranslator;
import ai.djl.inference.Predictor;
import ai.djl.util.Pair;

public class LogisticRegression {

    public static void main(String[] args) throws IOException, TranslateException {
        Table spambase = Table.read().csv("spambase.csv");
        Table inputs = spambase.copy().removeColumns("is_spam");
        Table outputs = spambase.copy().retainColumns("is_spam");
        NDManager manager = NDManager.newBaseManager();
        NDArray x = manager.create(inputs.as().floatMatrix());
        NDArray scaled_x = Utils.normalize(x);

        NDArray y = manager.create(outputs.as().intMatrix());
        int batchSize = inputs.rowCount();
        ArrayDataset dataset = Utils.loadArray(scaled_x, y, batchSize, true);
        RandomAccessDataset[] datasets_split = dataset.randomSplit(80, 20);
        ArrayDataset trainingSet = (ArrayDataset) datasets_split[0];
        ArrayDataset testingSet = (ArrayDataset) datasets_split[1];

        Model model = Model.newInstance("logistic");
        SequentialBlock net = new SequentialBlock();
        Linear linearBlock = Linear.builder().optBias(true).setUnits(1).build();
        net.add(linearBlock);
        //net.setInitializer(new ConstantInitializer(0), Parameter.Type.WEIGHT);
        //net.initialize(manager, DataType.FLOAT32, x.getShape());

        model.setBlock(net);
        Loss loss = new SigmoidBinaryCrossEntropyLoss();

        float lr = 0.01f;

        Tracker lrt = Tracker.fixed(lr);
        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

        TrainingConfig config = new DefaultTrainingConfig(loss)
            .optOptimizer(sgd) // Optimizer
            .optDevices(manager.getEngine().getDevices(0)) // CPU
            .addEvaluator(new BinaryAccuracy()) // Model Accuracy
            .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

        Trainer trainer = model.newTrainer(config);

        trainer.initialize(new Shape(1, inputs.columnCount())); 
        Metrics metrics = new Metrics();
        trainer.setMetrics(metrics);

        int numEpochs = 1000; // only 10 with initialization of weights to 0
        EasyTrain.fit(trainer, numEpochs, trainingSet, testingSet);
    }
}

And the results are now close to 90% accuracy.

mai 14, 2024 5:17:01 PM ai.djl.training.listener.LoggingTrainingListener onEpoch
INFOS: Train: BinaryAccuracy: 0,91, SigmoidBinaryCrossEntropyLoss: 0,29
mai 14, 2024 5:17:01 PM ai.djl.training.listener.LoggingTrainingListener onEpoch
INFOS: Validate: BinaryAccuracy: 0,89, SigmoidBinaryCrossEntropyLoss: 0,31

If the weights are initialized with 0, the convergence is way much faster.I have solve my issue but I think the Accuracy() metric should take into account the binary case instead of having a BinaryAccuracy() metric.

from djl.

MohamedLEGH avatar MohamedLEGH commented on May 27, 2024

Hey @MohamedLEGH, I have question is NoopTranslator() is used for only for classification or can used for the Regression?

as well this is because weights I guess try to StandardScale or min max scaling on the input data.

I suppose you can also use it for Regression. I don't know as I'm only doing classification right now but I suppose it should work.

from djl.

adinath1233 avatar adinath1233 commented on May 27, 2024

Hey @MohamedLEGH, I have question is NoopTranslator() is used for only for classification or can used for the Regression?

as well this is because weights I guess try to StandardScale or min max scaling on the input data.

I suppose you can also use it for Regression. I don't know as I'm only doing classification right now but I suppose it should work.

Yeah it worked.

from djl.

adinath1233 avatar adinath1233 commented on May 27, 2024

the code updated :

Hey @MohamedLEGH ,
So you are using utils.normalise is utils your own custom class? and how you are performing the normalisation can you share the code snippet. Cause I am trying classification to normalise it but I am not getting proper output so.

from djl.

MohamedLEGH avatar MohamedLEGH commented on May 27, 2024

the code updated :

Hey @MohamedLEGH , So you are using utils.normalise is utils your own custom class? and how you are performing the normalisation can you share the code snippet. Cause I am trying classification to normalise it but I am not getting proper output so.

My Utils.java code below, hope it helps:

package machine_learning;

import ai.djl.training.dataset.Record;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.ArrayDataset;

public class Utils {
    public static ArrayDataset loadArray(NDArray features, NDArray labels, int batchSize, boolean shuffle) {
        return new ArrayDataset.Builder()
                    .setData(features) // set the features
                    .optLabels(labels) // set the labels
                    .setSampling(batchSize, shuffle) // set the batch size and random sampling
                    .build();
    }

    public static NDArray mean(NDArray X) {
        return X.mean(new int[]{0});
    }

    public static NDArray std(NDArray X) {
        NDArray mean = mean(X);
        NDArray squaredDiff = X.sub(mean).pow(2);
        NDArray variance = squaredDiff.mean(new int[]{0}); 
        return variance.sqrt(); 
    }

    public static NDArray normalize(NDArray X) {
        return X.sub(mean(X)).div(std(X));
    }
}

from djl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.