Code Monkey home page Code Monkey logo

Comments (2)

Craigacp avatar Craigacp commented on June 16, 2024

Covariance matrices need to be positive definite, and one of those necessary properties is symmetry so you need to at least make sure the matrix is symmetric. Assuming you're getting the error out of distribution2 which doesn't make the matrix symmetric unlike distribution1.

from tribuo.

Mohammed-Ryiad-Eiadeh avatar Mohammed-Ryiad-Eiadeh commented on June 16, 2024

Thanks a lot, now its work correctly.

package GeneratingData.org;

import org.tribuo.math.distributions.MultivariateNormalDistribution;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.Table;

import java.nio.file.Path;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**

  • This class is used to generate the dataset which is used for binary classification
    */
    public class GenerateDataSet {
    private final Path path;
    private final int numOfInstances;
    private final int numOfVariables;
    private final double P_Class1;
    private final Random random;

    /**

    • This is a default constructor for generating binary data
    • @param path The directory to store the generated data in
    • @param numOfInstances The number of samples per each class
    • @param numOfVariables The number of features for each class
    • @param priorC1 The prior for class C1
      */
      public GenerateDataSet(Path path, int numOfInstances, int numOfVariables, double priorC1) {
      this.path = path;
      this.numOfInstances = numOfInstances; // Number of samples for each class
      this.numOfVariables = numOfVariables; // Number of features for each class
      this.P_Class1 = priorC1;
      this.random = new Random();
      }

    /**

    • This method is called to generate the normally distributed data for two classes
      */
      public void generateDataSet() {
      // The number of samples in each class
      int numOfSamplesInC1 = (int) (numOfInstances * P_Class1);
      int numOfSamplesInC2 = (int) (numOfInstances * (1 - P_Class1));

      // The vector of mean values
      double[] meanVecC1 = new double[numOfVariables];
      for (int i = 0; i < meanVecC1.length; i++) {
      meanVecC1[i] = 1 + new Random().nextInt(4);
      }
      double[] meanVecC2 = new double[numOfVariables];
      for (int i = 0; i < meanVecC2.length; i++) {
      meanVecC2[i] = -1 * meanVecC1[i];
      }

      // The co_Variance matrices
      double[][] coVarMatC1 = generateSymmetricPositiveDefiniteMatrix(numOfVariables);
      Arrays.stream(coVarMatC1).forEach(i -> System.out.println(Arrays.toString(i)));

      MultivariateNormalDistribution distribution1 = new MultivariateNormalDistribution(meanVecC1, coVarMatC1, 12345);

      System.out.println();

      double[][] coVarMatC2 = generateSymmetricPositiveDefiniteMatrix(numOfVariables);
      Arrays.stream(coVarMatC2).forEach(i -> System.out.println(Arrays.toString(i)));

      MultivariateNormalDistribution distribution2 = new MultivariateNormalDistribution(meanVecC2, coVarMatC2, 12345);

      // The matrix to hold data for both classes
      double[][] columns = new double[numOfInstances][numOfVariables + 1]; // The entire dataset

      // For Class 0
      double[][] tabledDataC0 = new double[numOfSamplesInC1][numOfVariables];
      IntStream.range(0, tabledDataC0.length).parallel().forEach(i -> tabledDataC0[i] = distribution1.sampleVector().toArray());
      for (int i = 0; i < tabledDataC0.length; i++) {
      for (int j = 0; j < tabledDataC0[0].length; j++) {
      columns[i][j] = tabledDataC0[i][j];
      columns[i][numOfVariables] = 0;
      }
      }
      // For Class 1
      double[][] tabledDataC1 = new double[numOfSamplesInC2][numOfVariables];
      IntStream.range(0, tabledDataC1.length).parallel().forEach(i -> tabledDataC1[i] = distribution2.sampleVector().toArray());
      int row = 0;
      for (int i = tabledDataC1.length; i < (tabledDataC1.length * 2); i++) {
      for (int j = 0; j < tabledDataC1[0].length; j++) {
      columns[i][j] = tabledDataC1[row][j];
      columns[i][numOfVariables] = 1;
      }
      row++;
      }
      // Store columns as list of lists
      List<List> listMain = IntStream.range(0, columns[0].length).mapToObj(i -> Arrays.stream(columns).
      map(column -> column[i]).collect(Collectors.toList())).collect(Collectors.toList());

      // Creat list of objects from DoubleColumn to prepare data for TableSaw Library
      List doubleColumns = new ArrayList<>();
      for (int i = 0; i < listMain.size() - 1; i++) {
      doubleColumns.add(DoubleColumn.create("F" + i, listMain.get(i)));
      }
      doubleColumns.add(DoubleColumn.create("Label", listMain.get(listMain.size() - 1)));

      // Store this table as CSV file in the project directory
      Table table = Table.create("Data", new ArrayList<>(doubleColumns));
      table.write().csv(path.toString());
      }

    /**

    • This method generates a symmetric positive definite matrix
    • @param size The number of the features we want to generate
    • @return A symmetric positive definite matrix
      */
      private double[][] generateSymmetricPositiveDefiniteMatrix(int size) {
      SplittableRandom random = new SplittableRandom();
      // Generate symmetric matrix
      double[][] matrix = new double[size][size];
      for (int i = 0; i < size; i++) {
      for (int j = 0; j < size; j++) {
      matrix[i][j] = 1 + random.nextInt(10);
      }
      }
      for (int i = 0; i < size; i++) {
      for (int j = i + 1; j < size; j++) {
      matrix[j][i] = matrix[i][j];
      }
      }
      // Calculate the sum of the absolute of all non-diagonal values
      double sum = 0;
      for (int i = 0; i < size; i++) {
      for (int j = 1; j < size; j++) {
      if (i != j) {
      sum += Math.abs(matrix[i][j]);
      }
      }
      }
      // Make the matrix positive definite matrix
      for (int i = 0; i < size; i++) {
      matrix[i][i] = matrix[i][i] + 0.1 + sum;
      }
      return matrix;
      }
      }

from tribuo.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.