Stock price forecasting mit LSTM in Java



  • Also ich fand den gut...



  • @omggg

    Du könntest auch im Java Forum direkt fragen.

    Ahne. Da machst du immer irgendwie einen Abflug.

    Naja. Bestimmt sind in diesem Forum einfach zu wenig theoretische Algorithmiker.


  • Gesperrt

    Dieser Beitrag wurde gelöscht!

  • Gesperrt

    @oxide

    Was soll das werden? Ciao.


  • Gesperrt

    Könnt ihr mal schauen, was falsch/unlogisch ist? Er trainiert nur etwa 15 Epochen, anstatt 100. Dadurch sind die Vorhersagen natürlich Murks...

    Den Code habe ich größtenteils übernommen und angepasst, aber nicht refactored.

    Gradle:

    dependencies {
        // https://mvnrepository.com/artifact/org.json/json
        implementation 'org.json:json:20240205'
        // https://mvnrepository.com/artifact/com.konghq/unirest-java
        implementation 'com.konghq:unirest-java:3.14.5'
    
        // Stock predictions:
        // https://mvnrepository.com/artifact/org.nd4j/nd4j-api
        implementation 'org.nd4j:nd4j-api:0.9.1'
        // https://mvnrepository.com/artifact/org.nd4j/nd4j-native
        implementation 'org.nd4j:nd4j-native:0.9.1'
        // https://mvnrepository.com/artifact/org.nd4j/nd4j-native-platform
        implementation 'org.nd4j:nd4j-native-platform:0.9.1'
        // https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-core
        implementation 'org.deeplearning4j:deeplearning4j-core:0.9.1'
        // https://mvnrepository.com/artifact/org.datavec/datavec-api
        implementation 'org.datavec:datavec-api:0.9.1'
        // https://mvnrepository.com/artifact/org.datavec/datavec-dataframe
        implementation 'org.datavec:datavec-dataframe:0.9.1'
    
        // https://mvnrepository.com/artifact/org.slf4j/slf4j-simple
        implementation 'org.slf4j:slf4j-simple:2.0.12'
        // https://mvnrepository.com/artifact/org.jfree/jfreechart
        implementation 'org.jfree:jfreechart:1.5.4'
    }
    

    Klassen (in struktureller Reihenfolge):

    package history;
    
    import kong.unirest.Unirest;
    import org.json.JSONArray;
    import org.json.JSONObject;
    
    public class History {
      public static JSONArray getHistory(String code, long days) {
        long mod = 1000 * 60;
        long time = System.currentTimeMillis() / mod * mod;
        long start = time - (days * 24 * 60 * 60 * 1000);
        JSONObject jo = new JSONObject();
        jo.put("currency", "USD");
        jo.put("code", code);
        jo.put("meta", false);
        jo.put("start", start);
        jo.put("end", time);
        String body =
            Unirest.post("https://api.livecoinwatch.com/coins/single/history")
                .header("content-type", "application/json")
                .header("x-api-key", "insert your key here")
                .body(jo.toString())
                .asString()
                .getBody();
        return new JSONObject(body).getJSONArray("history");
      }
    }
    
    
    package model;
    
    import org.deeplearning4j.nn.api.OptimizationAlgorithm;
    import org.deeplearning4j.nn.conf.BackpropType;
    import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
    import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
    import org.deeplearning4j.nn.conf.Updater;
    import org.deeplearning4j.nn.conf.layers.DenseLayer;
    import org.deeplearning4j.nn.conf.layers.GravesLSTM;
    import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
    import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    import org.deeplearning4j.nn.weights.WeightInit;
    import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
    import org.nd4j.linalg.activations.Activation;
    import org.nd4j.linalg.lossfunctions.LossFunctions;
    
    /**
     * Created by zhanghao on 26/7/17.
     *
     * @author ZHANG HAO
     */
    public class RecurrentNets {
    
      private static final double learningRate = 0.05;
      private static final int iterations = 1;
      private static final int seed = 12345;
    
      private static final int lstmLayer1Size = 256;
      private static final int lstmLayer2Size = 256;
      private static final int denseLayerSize = 32;
      private static final double dropoutRatio = 0.2;
      private static final int truncatedBPTTLength = 22;
    
      public static MultiLayerNetwork buildLstmNetworks(int nIn, int nOut) {
        MultiLayerConfiguration conf =
            new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .learningRate(learningRate)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .weightInit(WeightInit.XAVIER)
                .updater(Updater.RMSPROP)
                .regularization(true)
                .l2(1e-4)
                .list()
                .layer(
                    0,
                    new GravesLSTM.Builder()
                        .nIn(nIn)
                        .nOut(lstmLayer1Size)
                        .activation(Activation.TANH)
                        .gateActivationFunction(Activation.HARDSIGMOID)
                        .dropOut(dropoutRatio)
                        .build())
                .layer(
                    1,
                    new GravesLSTM.Builder()
                        .nIn(lstmLayer1Size)
                        .nOut(lstmLayer2Size)
                        .activation(Activation.TANH)
                        .gateActivationFunction(Activation.HARDSIGMOID)
                        .dropOut(dropoutRatio)
                        .build())
                .layer(
                    2,
                    new DenseLayer.Builder()
                        .nIn(lstmLayer2Size)
                        .nOut(denseLayerSize)
                        .activation(Activation.RELU)
                        .build())
                .layer(
                    3,
                    new RnnOutputLayer.Builder()
                        .nIn(denseLayerSize)
                        .nOut(nOut)
                        .activation(Activation.IDENTITY)
                        .lossFunction(LossFunctions.LossFunction.MSE)
                        .build())
                .backpropType(BackpropType.TruncatedBPTT)
                .tBPTTForwardLength(truncatedBPTTLength)
                .tBPTTBackwardLength(truncatedBPTTLength)
                .pretrain(false)
                .backprop(true)
                .build();
    
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        net.setListeners(new ScoreIterationListener(100));
        return net;
      }
    }
    
    
    package predict;
    
    import java.io.File;
    import java.io.IOException;
    import java.util.List;
    import java.util.NoSuchElementException;
    import model.RecurrentNets;
    import org.apache.commons.lang3.tuple.Pair;
    import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    import org.deeplearning4j.util.ModelSerializer;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.factory.Nd4j;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import representation.PriceCategory;
    import representation.StockDataSetIterator;
    import utils.PlotUtil;
    
    /**
     * Created by zhanghao on 26/7/17. Modified by zhanghao on 28/9/17.
     *
     * @author ZHANG HAO
     */
    public class StockPricePrediction {
    
      private static final Logger log = LoggerFactory.getLogger(StockPricePrediction.class);
    
      private static final int exampleLength = 15; // time series length
    
      public static void main(String[] args) throws IOException {
        String file = "dummy-prices-split-adjusted.csv";
        String symbol = "SOL"; // stock name
        int batchSize = 64; // mini-batch size
        double splitRatio = 0.8; // 90% for training, 10% for testing
        int epochs = 100; // training epochs
    
        log.info("Create dataSet iterator...");
        PriceCategory category = PriceCategory.CLOSE; // CLOSE: predict close price
        StockDataSetIterator iterator =
            new StockDataSetIterator(file, symbol, batchSize, exampleLength, splitRatio, category);
        log.info("Load test dataset...");
        List<Pair<INDArray, INDArray>> test = iterator.getTestDataSet();
    
        log.info("Build lstm networks...");
        MultiLayerNetwork net =
            RecurrentNets.buildLstmNetworks(iterator.inputColumns(), iterator.totalOutcomes());
    
        log.info("Training...");
        for (int i = 0; i < epochs; i++) {
          while (iterator.hasNext()) net.fit(iterator.next()); // fit model using mini-batch data
          iterator.reset(); // reset iterator
          net.rnnClearPreviousState(); // clear previous state
        }
    
        log.info("Saving model...");
        File locationToSave =
            new File("StockPriceLSTM_".concat(String.valueOf(category)).concat(".zip"));
        // saveUpdater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this to train your
        // network more in the future
        ModelSerializer.writeModel(net, locationToSave, true);
    
        log.info("Load model...");
        net = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
    
        log.info("Testing...");
        if (category.equals(PriceCategory.ALL)) {
          INDArray max = Nd4j.create(iterator.getMaxArray());
          INDArray min = Nd4j.create(iterator.getMinArray());
          predictAllCategories(net, test, max, min);
        } else {
          double max = iterator.getMaxNum(category);
          double min = iterator.getMinNum(category);
          predictPriceOneAhead(net, test, max, min, category);
        }
        log.info("Done...");
      }
    
      /** Predict one feature of a stock one-day ahead */
      private static void predictPriceOneAhead(
          MultiLayerNetwork net,
          List<Pair<INDArray, INDArray>> testData,
          double max,
          double min,
          PriceCategory category) {
        double[] predicts = new double[testData.size()];
        double[] actuals = new double[testData.size()];
        for (int i = 0; i < testData.size(); i++) {
          predicts[i] =
              net.rnnTimeStep(testData.get(i).getKey()).getDouble(exampleLength - 1) * (max - min)
                  + min;
          actuals[i] = testData.get(i).getValue().getDouble(0);
        }
        log.info("Print out Predictions and Actual Values...");
        log.info("Predict,Actual");
        for (int i = 0; i < predicts.length; i++) log.info(predicts[i] + "," + actuals[i]);
        log.info("Plot...");
        PlotUtil.plot(predicts, actuals, String.valueOf(category));
      }
    
      private static void predictPriceMultiple(
          MultiLayerNetwork net, List<Pair<INDArray, INDArray>> testData, double max, double min) {
        // TODO
      }
    
      /**
       * Predict all the features (open, close, low, high prices and volume) of a stock one-day ahead
       */
      private static void predictAllCategories(
          MultiLayerNetwork net, List<Pair<INDArray, INDArray>> testData, INDArray max, INDArray min) {
        INDArray[] predicts = new INDArray[testData.size()];
        INDArray[] actuals = new INDArray[testData.size()];
        for (int i = 0; i < testData.size(); i++) {
          predicts[i] =
              net.rnnTimeStep(testData.get(i).getKey())
                  .getRow(exampleLength - 1)
                  .mul(max.sub(min))
                  .add(min);
          actuals[i] = testData.get(i).getValue();
        }
        log.info("Print out Predictions and Actual Values...");
        log.info("Predict\tActual");
        for (int i = 0; i < predicts.length; i++) log.info(predicts[i] + "\t" + actuals[i]);
        log.info("Plot...");
        for (int n = 0; n < 5; n++) {
          double[] pred = new double[predicts.length];
          double[] actu = new double[actuals.length];
          for (int i = 0; i < predicts.length; i++) {
            pred[i] = predicts[i].getDouble(n);
            actu[i] = actuals[i].getDouble(n);
          }
          String name;
          switch (n) {
            case 0:
              name = "Stock OPEN Price";
              break;
            case 1:
              name = "Stock CLOSE Price";
              break;
            case 2:
              name = "Stock LOW Price";
              break;
            case 3:
              name = "Stock HIGH Price";
              break;
            case 4:
              name = "Stock VOLUME Amount";
              break;
            default:
              throw new NoSuchElementException();
          }
          PlotUtil.plot(pred, actu, name);
        }
      }
    }
    
    
    package representation;
    
    /**
     * Created by zhanghao on 28/9/17.
     *
     * @author ZHANG HAO
     */
    public enum PriceCategory {
      OPEN,
      CLOSE,
      LOW,
      HIGH,
      VOLUME,
      ALL
    }
    
    
    package representation;
    
    /**
     * Created by zhanghao on 26/7/17.
     *
     * @author ZHANG HAO
     */
    public class StockData {
      private String date; // date
      private String symbol; // stock name
    
      private double open; // open price
      private double close; // close price
      private double low; // low price
      private double high; // high price
      private double volume; // volume
    
      public StockData() {}
    
      public StockData(
          String date,
          String symbol,
          double open,
          double close,
          double low,
          double high,
          double volume) {
        this.date = date;
        this.symbol = symbol;
        this.open = open;
        this.close = close;
        this.low = low;
        this.high = high;
        this.volume = volume;
      }
    
      public String getDate() {
        return date;
      }
    
      public void setDate(String date) {
        this.date = date;
      }
    
      public String getSymbol() {
        return symbol;
      }
    
      public void setSymbol(String symbol) {
        this.symbol = symbol;
      }
    
      public double getOpen() {
        return open;
      }
    
      public void setOpen(double open) {
        this.open = open;
      }
    
      public double getClose() {
        return close;
      }
    
      public void setClose(double close) {
        this.close = close;
      }
    
      public double getLow() {
        return low;
      }
    
      public void setLow(double low) {
        this.low = low;
      }
    
      public double getHigh() {
        return high;
      }
    
      public void setHigh(double high) {
        this.high = high;
      }
    
      public double getVolume() {
        return volume;
      }
    
      public void setVolume(double volume) {
        this.volume = volume;
      }
    }
    
    
    package representation;
    
    import static history.History.getHistory;
    
    import com.google.common.collect.ImmutableMap;
    import java.util.*;
    import org.apache.commons.lang3.tuple.Pair;
    import org.json.JSONArray;
    import org.json.JSONObject;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.dataset.DataSet;
    import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
    import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
    import org.nd4j.linalg.factory.Nd4j;
    
    /**
     * Created by zhanghao on 26/7/17. Modified by zhanghao on 28/9/17.
     *
     * @author ZHANG HAO
     */
    public class StockDataSetIterator implements DataSetIterator {
    
      /** category and its index */
      private final Map<PriceCategory, Integer> featureMapIndex =
          ImmutableMap.of(
              PriceCategory.OPEN,
              0,
              PriceCategory.CLOSE,
              1,
              PriceCategory.LOW,
              2,
              PriceCategory.HIGH,
              3,
              PriceCategory.VOLUME,
              4);
    
      private final int VECTOR_SIZE = 5; // number of features for a stock data
      private int miniBatchSize; // mini-batch size
      private int exampleLength; // default 22, say, 22 working days per month
      private int predictLength = 1; // default 1, say, one day ahead prediction
    
      /** minimal values of each feature in stock dataset */
      private double[] minArray = new double[VECTOR_SIZE];
    
      /** maximal values of each feature in stock dataset */
      private double[] maxArray = new double[VECTOR_SIZE];
    
      /** feature to be selected as a training target */
      private PriceCategory category;
    
      /** mini-batch offset */
      private LinkedList<Integer> exampleStartOffsets = new LinkedList<>();
    
      /** stock dataset for training */
      private List<StockData> train;
    
      /** adjusted stock dataset for testing */
      private List<Pair<INDArray, INDArray>> test;
    
      public StockDataSetIterator(
          String filename,
          String symbol,
          int miniBatchSize,
          int exampleLength,
          double splitRatio,
          PriceCategory category) {
        this.miniBatchSize = miniBatchSize;
        this.exampleLength = exampleLength;
        this.category = category;
        List<StockData> stockDataList = readStockDataFromFile(filename, symbol);
    
        // hier läuft vermutlich was falsch, anfang
        int split = (int) Math.round(exampleLength * splitRatio);
        train = stockDataList.subList(stockDataList.size() - split, stockDataList.size());
        System.out.println("train.size() = " + train.size());
        test = generateTestDataSet(stockDataList);
        System.out.println("test.size() = " + test.size());
        initializeOffsets();
        // hier läuft vermutlich was falsch, ende
      }
    
      /** initialize the mini-batch offsets */
      private void initializeOffsets() {
        exampleStartOffsets.clear();
        int window = exampleLength + predictLength;
        for (int i = 0; i < train.size() - window; i++) {
          exampleStartOffsets.add(i);
        }
      }
    
      public List<Pair<INDArray, INDArray>> getTestDataSet() {
        return test;
      }
    
      public double[] getMaxArray() {
        return maxArray;
      }
    
      public double[] getMinArray() {
        return minArray;
      }
    
      public double getMaxNum(PriceCategory category) {
        return maxArray[featureMapIndex.get(category)];
      }
    
      public double getMinNum(PriceCategory category) {
        return minArray[featureMapIndex.get(category)];
      }
    
      @Override
      public DataSet next(int num) {
        if (exampleStartOffsets.size() == 0) throw new NoSuchElementException();
        int actualMiniBatchSize = Math.min(num, exampleStartOffsets.size());
        INDArray input = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f');
        INDArray label;
        if (category.equals(PriceCategory.ALL))
          label = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f');
        else label = Nd4j.create(new int[] {actualMiniBatchSize, predictLength, exampleLength}, 'f');
        for (int index = 0; index < actualMiniBatchSize; index++) {
          int startIdx = exampleStartOffsets.removeFirst();
          int endIdx = startIdx + exampleLength;
          StockData curData = train.get(startIdx);
          StockData nextData;
          for (int i = startIdx; i < endIdx; i++) {
            int c = i - startIdx;
            input.putScalar(
                new int[] {index, 0, c},
                (curData.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]));
            input.putScalar(
                new int[] {index, 1, c},
                (curData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
            input.putScalar(
                new int[] {index, 2, c},
                (curData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
            input.putScalar(
                new int[] {index, 3, c},
                (curData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
            input.putScalar(
                new int[] {index, 4, c},
                (curData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
            nextData = train.get(i + 1);
            if (category.equals(PriceCategory.ALL)) {
              label.putScalar(
                  new int[] {index, 0, c},
                  (nextData.getOpen() - minArray[1]) / (maxArray[1] - minArray[1]));
              label.putScalar(
                  new int[] {index, 1, c},
                  (nextData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
              label.putScalar(
                  new int[] {index, 2, c},
                  (nextData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
              label.putScalar(
                  new int[] {index, 3, c},
                  (nextData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
              label.putScalar(
                  new int[] {index, 4, c},
                  (nextData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
            } else {
              label.putScalar(new int[] {index, 0, c}, feedLabel(nextData));
            }
            curData = nextData;
          }
          if (exampleStartOffsets.size() == 0) break;
        }
        return new DataSet(input, label);
      }
    
      private double feedLabel(StockData data) {
        double value;
        switch (category) {
          case OPEN:
            value = (data.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]);
            break;
          case CLOSE:
            value = (data.getClose() - minArray[1]) / (maxArray[1] - minArray[1]);
            break;
          case LOW:
            value = (data.getLow() - minArray[2]) / (maxArray[2] - minArray[2]);
            break;
          case HIGH:
            value = (data.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]);
            break;
          case VOLUME:
            value = (data.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]);
            break;
          default:
            throw new NoSuchElementException();
        }
        return value;
      }
    
      @Override
      public int totalExamples() {
        return train.size() - exampleLength - predictLength;
      }
    
      @Override
      public int inputColumns() {
        return VECTOR_SIZE;
      }
    
      @Override
      public int totalOutcomes() {
        if (this.category.equals(PriceCategory.ALL)) return VECTOR_SIZE;
        else return predictLength;
      }
    
      @Override
      public boolean resetSupported() {
        return false;
      }
    
      @Override
      public boolean asyncSupported() {
        return false;
      }
    
      @Override
      public void reset() {
        initializeOffsets();
      }
    
      @Override
      public int batch() {
        return miniBatchSize;
      }
    
      @Override
      public int cursor() {
        return totalExamples() - exampleStartOffsets.size();
      }
    
      @Override
      public int numExamples() {
        return totalExamples();
      }
    
      @Override
      public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        throw new UnsupportedOperationException("Not Implemented");
      }
    
      @Override
      public DataSetPreProcessor getPreProcessor() {
        throw new UnsupportedOperationException("Not Implemented");
      }
    
      @Override
      public List<String> getLabels() {
        throw new UnsupportedOperationException("Not Implemented");
      }
    
      @Override
      public boolean hasNext() {
        return exampleStartOffsets.size() > 0;
      }
    
      @Override
      public DataSet next() {
        return next(miniBatchSize);
      }
    
      private List<Pair<INDArray, INDArray>> generateTestDataSet(List<StockData> stockDataList) {
        int window = exampleLength + predictLength;
        List<Pair<INDArray, INDArray>> test = new ArrayList<>();
        for (int i = 0; i < stockDataList.size() - window; i++) {
          INDArray input = Nd4j.create(new int[] {exampleLength, VECTOR_SIZE}, 'f');
          for (int j = i; j < i + exampleLength; j++) {
            StockData stock = stockDataList.get(j);
            input.putScalar(
                new int[] {j - i, 0}, (stock.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]));
            input.putScalar(
                new int[] {j - i, 1}, (stock.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
            input.putScalar(
                new int[] {j - i, 2}, (stock.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
            input.putScalar(
                new int[] {j - i, 3}, (stock.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
            input.putScalar(
                new int[] {j - i, 4}, (stock.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
          }
          StockData stock = stockDataList.get(i + exampleLength);
          INDArray label;
          if (category.equals(PriceCategory.ALL)) {
            label =
                Nd4j.create(new int[] {VECTOR_SIZE}, 'f'); // ordering is set as 'f', faster construct
            label.putScalar(new int[] {0}, stock.getOpen());
            label.putScalar(new int[] {1}, stock.getClose());
            label.putScalar(new int[] {2}, stock.getLow());
            label.putScalar(new int[] {3}, stock.getHigh());
            label.putScalar(new int[] {4}, stock.getVolume());
          } else {
            label = Nd4j.create(new int[] {1}, 'f');
            switch (category) {
              case OPEN:
                label.putScalar(new int[] {0}, stock.getOpen());
                break;
              case CLOSE:
                label.putScalar(new int[] {0}, stock.getClose());
                break;
              case LOW:
                label.putScalar(new int[] {0}, stock.getLow());
                break;
              case HIGH:
                label.putScalar(new int[] {0}, stock.getHigh());
                break;
              case VOLUME:
                label.putScalar(new int[] {0}, stock.getVolume());
                break;
              default:
                throw new NoSuchElementException();
            }
          }
          test.add(Pair.of(input, label));
        }
        return test;
      }
    
      private List<StockData> readStockDataFromFile(String filename, String symbol) {
        List<StockData> stockDataList = new ArrayList<>();
        for (int i = 0; i < maxArray.length; i++) { // initialize max and min arrays
          maxArray[i] = Double.MIN_VALUE;
          minArray[i] = Double.MAX_VALUE;
        }
    
        // load all elements in a list
        JSONArray ja = getHistory(symbol, 30);
        System.out.println("ja.length() = " + ja.length());
        List<List<Object>> list = new ArrayList<>();
        for (int i = 0; i < ja.length(); i++) {
          JSONObject jo = ja.getJSONObject(i);
          list.add(
              new ArrayList<>(
                  List.of(
                      jo.getLong("date"),
                      jo.getDouble("rate"),
                      Math.round(jo.getDouble("volume") / 10000.0) / 100.0,
                      jo.getDouble("cap"))));
        }
        list.sort(Comparator.comparing(l -> (Long) l.getFirst()));
        list.forEach(
            l -> {
              int daysOfEpoch = (int) ((long) l.getFirst() / 1000L / 60L / 60L / 24L);
              l.set(0, daysOfEpoch);
            });
        System.out.println("list.size() = " + list.size());
    
        List<List<Object>> list2 = new ArrayList<>();
        for (int i = 0; i < list.size(); i++) {
          List<Object> l = list.get(i);
          int daysOfEpoch = (int) l.getFirst();
          double open = (double) l.get(1);
          double close = (double) l.get(1);
          double low = (double) l.get(1);
          double high = (double) l.get(1);
          double volume = (double) l.get(2);
          for (int j = i + 1; j < list.size(); j++) {
            List<Object> l2 = list.get(j);
            int daysOfEpoch2 = (int) l2.getFirst();
            if (daysOfEpoch == daysOfEpoch2) {
              close = (double) l2.get(1);
              low = Math.min(low, (double) l2.get(1));
              high = Math.max(high, (double) l2.get(1));
              volume += (double) l2.get(2);
            } else {
              i = j - 1;
              break;
            }
          }
          list2.add(List.of(daysOfEpoch, open, close, low, high, volume));
        }
        System.out.println("list2.size() = " + list2.size());
        list2.forEach(l -> System.out.println("l = " + l));
    
        for (int j = 0; j < list2.size(); j++) {
          List<Object> l = list2.get(j);
          double[] nums = {
            (double) l.get(1),
            (double) l.get(2),
            (double) l.get(3),
            (double) l.get(4),
            (double) l.get(5),
          };
          for (int i = 0; i < VECTOR_SIZE; i++) {
            if (nums[i] > maxArray[i]) maxArray[i] = nums[i];
            if (nums[i] < minArray[i]) minArray[i] = nums[i];
          }
          stockDataList.add(
              new StockData(l.get(0) + "", symbol, nums[0], nums[1], nums[2], nums[3], nums[4]));
        }
    
        return stockDataList;
      }
    }
    
    
    package utils;
    
    import javax.swing.*;
    import org.jfree.chart.ChartFactory;
    import org.jfree.chart.ChartPanel;
    import org.jfree.chart.JFreeChart;
    import org.jfree.chart.axis.NumberAxis;
    import org.jfree.chart.axis.NumberTickUnit;
    import org.jfree.chart.plot.PlotOrientation;
    import org.jfree.chart.plot.XYPlot;
    import org.jfree.data.xy.XYSeries;
    import org.jfree.data.xy.XYSeriesCollection;
    
    /**
     * Created by zhanghao on 26/7/17. Modified by zhanghao on 28/9/17.
     *
     * @author ZHANG HAO
     */
    public class PlotUtil {
    
      public static void plot(double[] predicts, double[] actuals, String name) {
        double[] index = new double[predicts.length];
        for (int i = 0; i < predicts.length; i++) index[i] = i;
        int min = minValue(predicts, actuals);
        int max = maxValue(predicts, actuals);
        final XYSeriesCollection dataSet = new XYSeriesCollection();
        addSeries(dataSet, index, predicts, "Predicts");
        addSeries(dataSet, index, actuals, "Actuals");
        final JFreeChart chart =
            ChartFactory.createXYLineChart(
                "Prediction Result", // chart title
                "Index", // x axis label
                name, // y axis label
                dataSet, // data
                PlotOrientation.VERTICAL,
                true, // include legend
                true, // tooltips
                false // urls
                );
        XYPlot xyPlot = chart.getXYPlot();
        // X-axis
        final NumberAxis domainAxis = (NumberAxis) xyPlot.getDomainAxis();
        domainAxis.setRange((int) index[0], (int) (index[index.length - 1] + 2));
        domainAxis.setTickUnit(new NumberTickUnit(20));
        domainAxis.setVerticalTickLabels(true);
        // Y-axis
        final NumberAxis rangeAxis = (NumberAxis) xyPlot.getRangeAxis();
        rangeAxis.setRange(min, max);
        rangeAxis.setTickUnit(new NumberTickUnit(50));
        final ChartPanel panel = new ChartPanel(chart);
        final JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setVisible(true);
      }
    
      private static void addSeries(
          final XYSeriesCollection dataSet, double[] x, double[] y, final String label) {
        final XYSeries s = new XYSeries(label);
        for (int j = 0; j < x.length; j++) s.add(x[j], y[j]);
        dataSet.addSeries(s);
      }
    
      private static int minValue(double[] predicts, double[] actuals) {
        double min = Integer.MAX_VALUE;
        for (int i = 0; i < predicts.length; i++) {
          if (min > predicts[i]) min = predicts[i];
          if (min > actuals[i]) min = actuals[i];
        }
        return (int) (min * 0.98);
      }
    
      private static int maxValue(double[] predicts, double[] actuals) {
        double max = Integer.MIN_VALUE;
        for (int i = 0; i < predicts.length; i++) {
          if (max < predicts[i]) max = predicts[i];
          if (max < actuals[i]) max = actuals[i];
        }
        return (int) (max * 1.02);
      }
    }
    
    

    Die fragliche Stelle habe ich euch markiert, das ist Zeile 461 bis 468.

    Danke


  • Gesperrt

    Edit: Mit folgenden Daten hatte ich das getestet (die Spalten sollten selbsterklärend sein):

    list2.size() = 32
    l = [19748, 91.83797335102334, 92.71648526759283, 91.83797335102334, 92.71648526759283, 4203.3099999999995]
    l = [19749, 93.36752755612166, 92.64628278182262, 92.12689371773429, 93.36752755612166, 6161.77]
    l = [19750, 94.40827278049834, 95.56986201017443, 94.40827278049834, 97.80647678983507, 7970.26]
    l = [19751, 97.37665025155518, 100.77814246091854, 97.37665025155518, 100.77814246091854, 7992.23]
    l = [19752, 102.63094180671153, 101.43243634953313, 101.43243634953313, 104.02560679484226, 13268.67]
    l = [19753, 100.50742797179204, 98.1066613081412, 98.1066613081412, 100.50742797179204, 11700.29]
    l = [19754, 93.45949987216143, 96.30517749954835, 93.45949987216143, 96.30517749954835, 12483.83]
    l = [19755, 98.63787508229797, 100.9428572386196, 98.63787508229797, 101.32282620549877, 12444.04]
    l = [19756, 99.78646788861874, 98.14028160990891, 97.96491587753393, 99.78646788861874, 6038.719999999999]
    l = [19757, 97.91699797883597, 96.90941297944825, 96.90941297944825, 97.91699797883597, 4360.49]
    l = [19758, 95.2105059295542, 94.88591585705699, 94.88591585705699, 97.40977369541514, 6934.799999999999]
    l = [19759, 95.25615067490419, 97.6627038684511, 94.57842668292786, 97.6627038684511, 5556.82]
    l = [19760, 95.81869065389131, 96.94537440993797, 94.96985657563032, 96.94537440993797, 5433.09]
    l = [19761, 101.49823001006955, 102.58697481001379, 101.25393508360612, 104.48851802118729, 10427.73]
    l = [19762, 104.83512403285263, 105.10605161573433, 104.83512403285263, 105.6318898617301, 8543.07]
    l = [19763, 109.85718169398818, 108.06725132732612, 108.06725132732612, 109.85718169398818, 9098.49]
    l = [19764, 108.67799649709828, 107.64252733687948, 107.64252733687948, 109.39178024684149, 7010.89]
    l = [19765, 105.81123324621163, 110.20579564774938, 104.42941559194774, 110.20579564774938, 6380.76]
    l = [19766, 113.87616473303284, 110.1654238619581, 110.1654238619581, 114.22238899033734, 10161.42]
    l = [19767, 111.32926743082966, 116.79505031796758, 111.32926743082966, 118.05399765058999, 11254.34]
    l = [19768, 115.91907847345627, 113.72627278052416, 113.72627278052416, 116.02751475453104, 8238.42]
    l = [19769, 113.95029261214435, 109.06913856356621, 109.06913856356621, 113.95029261214435, 7930.49]
    l = [19770, 111.15067034817159, 109.65974259654443, 106.33753287839521, 111.15067034817159, 7950.01]
    l = [19771, 109.56562807356273, 112.91557017125702, 109.56562807356273, 112.91557017125702, 5157.08]
    l = [19772, 112.0306760091593, 113.09392950614297, 112.0306760091593, 113.09392950614297, 5933.93]
    l = [19773, 112.02534274914494, 108.24036074103542, 108.24036074103542, 112.02534274914494, 9767.079999999998]
    l = [19774, 106.3463777107441, 101.78040058120601, 101.78040058120601, 106.3463777107441, 8264.92]
    l = [19775, 102.87331537357632, 101.77337611631232, 101.77337611631232, 106.04341760593877, 9739.18]
    l = [19776, 103.3423316637421, 101.5471202445096, 101.5471202445096, 103.3423316637421, 5677.18]
    l = [19777, 100.8444063780233, 103.89123602922389, 100.8444063780233, 103.89123602922389, 4982.28]
    l = [19778, 103.58599623283995, 102.5102713232609, 102.5102713232609, 103.58599623283995, 2408.31]
    l = [19778, 102.5102713232609, 102.5102713232609, 102.5102713232609, 102.5102713232609, 1178.1]
    


  • @oxide sagte in Stock price forecasting mit LSTM in Java:

    Ich suche nämlich auch sowas. Ich möchte die Lottozahlen berechnen mahne der bisher gezogenen.

    Dann würde ich dir Python mit den folgenden Modulen empfehlen:

    • Scikit-learn
    • Matplotlib
    • NumPy
    • SciPy
    • Pandas

    Die Module nehmen dir eine ganze Menge von Programmierarbeit ab und du hast beim Einstieg in die KI Welt erst einmal viel mit den neuen Begriffen (z.B. hoher Bias, hohe Varianz) zu tun. Wenn du tiefer in die Materie stoßen willst, würde ich dir auch ein Buch empfehlen.


    Ich weis aber nicht ob Lottozahlen vorhersagen ein gutes Beispiel ist, da ich ahne und vermute es sich hierbei um ein chaotisches System handeln könnte. Kleine Änderungen in den Anfangsbedingungen könnte zu komplett anderen Ergebnissen führen. Die Vorhersagbarkeit kann aufgrund dieses Verhaltens zusammenbrechen. Ein ähnliches Problem ist das dynamische Billard.

    https://www.youtube.com/watch?v=svV1MsUdInE


  • Gesperrt

    @Quiche-Lorraine sagte in Stock price forecasting mit LSTM in Java:

    Ich weis aber nicht ob Lottozahlen vorhersagen ein gutes Beispiel ist,

    Jetzt mal ehrlich und Butter bei die Fische, das war nur ein Troll-Versuch von @oxide , um mich zu diskreditieren ...

    In serösen Foren würde das einfach wegmoderiert werden, aber hier sind die Mods nicht willens, einzugreifen.

    Aber Danke für die Anregungen, mit Python das zu machen. 🙂


  • Gesperrt

    Ich >glaube< langsam, dass obiges MultiLayerNetwork falsch ist und damit auch oben verlinkte Library kompletter Murks ist, also nicht nur der StockDataSetIterator "spinnt"... Oder ich kenne mich zu wenig mit der Materie aus. Aber 9.5 von 10 Blogs, die man im Internet zu diesem Thema findet, sind auch Murks...

    Aber egal... der Bitcoin ist auf dem Weg zu einem neuen Allzeithoch, das genügt mir als Prognose. 🙂


  • Gesperrt

    Da haben wir den Schlamassel ja schon:

    https://www.cnbc.com/2024/03/04/crypto-market-today.html

    Das neue Allzeithoch kommt/kam für mich schneller als erwartet! Für euch auch? 😅


  • Gesperrt

    @omggg sagte in Stock price forecasting mit LSTM in Java:

    wegen Overfitness

    Meinte overfitting ...

    Ich muss das hier leider richtigstellen, weil ich befürchte, dass andernfalls die Trolle wieder auftauchen.

    @Quiche-Lorraine

    Hast du noch Lust, etwas präziser auf meine Frage(n) zu antworten? Oder führt kein Weg mehr an Python vorbei?



  • @omggg sagte in Stock price forecasting mit LSTM in Java:

    Hast du noch Lust, etwas präziser auf meine Frage(n) zu antworten? Oder führt kein Weg mehr an Python vorbei?

    Eine kleine Vorgeschichte: Ich stöbere privat seit einiger Zeit durch das Buch "Machine Learning mit Python und Scikit-Learn und TensorFlow". Und dieses Buch zeigte mir eine etwas andere Sichtweise auf Maschine Learning. Vorher dachte ich "Viele Daten rein, trainieren und gut ist es". Nun achte ich aber andere Dinge wie Korrelationen, Overfitting, hoher Bias, usw

    Vorauf ich hinaus möchte ist folgendes: Ich kenne KI nur in Verbindung mit Python. Wohl aber gibt es einige Module bzw. Bibliotheken welche ich nicht missen möchte.

    Und das fängt mit scikit-learn an. Dieses Modul ist groß und besteht aus unterschiedlichern Klassifiern, Vorverarbeitungen (z.B. sklearn.preprocessing.StandardScaler), Batches,... Da ist halt alles dabei, was man nicht generell von anderen Github Projekten sagen kann.

    Von daher würde ich beim Maschine Learning darauf achten dass man entweder scikit-learn oder TensorFlow nutzt. Andere Libs dieser Größenordnungen kenne ich nicht. Und ferner würde ich das oben genannte Buch empfehlen.

    Python muss nicht sein, ist aber sehr praktisch. Ein Aufruf von "pip install scikit-learn" und innerhalb kurzer Zeit ist scikit-learn Und das ist eine schöne Sache an Python.


Anmelden zum Antworten