Stock price forecasting mit LSTM in Java
-
Es ist nur etwas kompliziert. In dem beiliegenden Beispiel werden 22 time frames verwendet, weil ein Monat so viele Werktage hat. Zudem braucht das Ding, open, close, high, low und volume.
Ich habe aber zurzeit nur rate, vol und cap. Cap entfällt (also market cap ist kein feature).
Hab es vorhin einmal auf der CPU laufen lassen... Mehr als 50 time frames und mehr als 10 epochs schaffe ich aber nicht, weil das Training sonst abbricht. Ich denke, vielleicht wegen Overfitness oder weil ich nicht genug Ram habe.
Probiere das später noch einmal, zurzeit bin ich unterwegs... aber wie gesagt, ich möchte den Verlauf nicht bis Weihnachten vorhersagen, denn 5 Tage würden schon reichen.
-
Wenn du uns deinen Quellcode zeigst, können wir das besser analysieren.
Ich suche nämlich auch sowas. Ich möchte die Lottozahlen berechnen mahne der bisher gezogenen.
-
Mach nen Abflug.
-
Also ich fand den gut...
-
Du könntest auch im Java Forum direkt fragen.
Ahne. Da machst du immer irgendwie einen Abflug.
Naja. Bestimmt sind in diesem Forum einfach zu wenig theoretische Algorithmiker.
-
Dieser Beitrag wurde gelöscht!
-
Was soll das werden? Ciao.
-
Könnt ihr mal schauen, was falsch/unlogisch ist? Er trainiert nur etwa 15 Epochen, anstatt 100. Dadurch sind die Vorhersagen natürlich Murks...
Den Code habe ich größtenteils übernommen und angepasst, aber nicht refactored.
Gradle:
dependencies { // https://mvnrepository.com/artifact/org.json/json implementation 'org.json:json:20240205' // https://mvnrepository.com/artifact/com.konghq/unirest-java implementation 'com.konghq:unirest-java:3.14.5' // Stock predictions: // https://mvnrepository.com/artifact/org.nd4j/nd4j-api implementation 'org.nd4j:nd4j-api:0.9.1' // https://mvnrepository.com/artifact/org.nd4j/nd4j-native implementation 'org.nd4j:nd4j-native:0.9.1' // https://mvnrepository.com/artifact/org.nd4j/nd4j-native-platform implementation 'org.nd4j:nd4j-native-platform:0.9.1' // https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-core implementation 'org.deeplearning4j:deeplearning4j-core:0.9.1' // https://mvnrepository.com/artifact/org.datavec/datavec-api implementation 'org.datavec:datavec-api:0.9.1' // https://mvnrepository.com/artifact/org.datavec/datavec-dataframe implementation 'org.datavec:datavec-dataframe:0.9.1' // https://mvnrepository.com/artifact/org.slf4j/slf4j-simple implementation 'org.slf4j:slf4j-simple:2.0.12' // https://mvnrepository.com/artifact/org.jfree/jfreechart implementation 'org.jfree:jfreechart:1.5.4' }
Klassen (in struktureller Reihenfolge):
package history; import kong.unirest.Unirest; import org.json.JSONArray; import org.json.JSONObject; public class History { public static JSONArray getHistory(String code, long days) { long mod = 1000 * 60; long time = System.currentTimeMillis() / mod * mod; long start = time - (days * 24 * 60 * 60 * 1000); JSONObject jo = new JSONObject(); jo.put("currency", "USD"); jo.put("code", code); jo.put("meta", false); jo.put("start", start); jo.put("end", time); String body = Unirest.post("https://api.livecoinwatch.com/coins/single/history") .header("content-type", "application/json") .header("x-api-key", "insert your key here") .body(jo.toString()) .asString() .getBody(); return new JSONObject(body).getJSONArray("history"); } } package model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.lossfunctions.LossFunctions; /** * Created by zhanghao on 26/7/17. * * @author ZHANG HAO */ public class RecurrentNets { private static final double learningRate = 0.05; private static final int iterations = 1; private static final int seed = 12345; private static final int lstmLayer1Size = 256; private static final int lstmLayer2Size = 256; private static final int denseLayerSize = 32; private static final double dropoutRatio = 0.2; private static final int truncatedBPTTLength = 22; public static MultiLayerNetwork buildLstmNetworks(int nIn, int nOut) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .learningRate(learningRate) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER) .updater(Updater.RMSPROP) .regularization(true) .l2(1e-4) .list() .layer( 0, new GravesLSTM.Builder() .nIn(nIn) .nOut(lstmLayer1Size) .activation(Activation.TANH) .gateActivationFunction(Activation.HARDSIGMOID) .dropOut(dropoutRatio) .build()) .layer( 1, new GravesLSTM.Builder() .nIn(lstmLayer1Size) .nOut(lstmLayer2Size) .activation(Activation.TANH) .gateActivationFunction(Activation.HARDSIGMOID) .dropOut(dropoutRatio) .build()) .layer( 2, new DenseLayer.Builder() .nIn(lstmLayer2Size) .nOut(denseLayerSize) .activation(Activation.RELU) .build()) .layer( 3, new RnnOutputLayer.Builder() .nIn(denseLayerSize) .nOut(nOut) .activation(Activation.IDENTITY) .lossFunction(LossFunctions.LossFunction.MSE) .build()) .backpropType(BackpropType.TruncatedBPTT) .tBPTTForwardLength(truncatedBPTTLength) .tBPTTBackwardLength(truncatedBPTTLength) .pretrain(false) .backprop(true) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(100)); return net; } } package predict; import java.io.File; import java.io.IOException; import java.util.List; import java.util.NoSuchElementException; import model.RecurrentNets; import org.apache.commons.lang3.tuple.Pair; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import representation.PriceCategory; import representation.StockDataSetIterator; import utils.PlotUtil; /** * Created by zhanghao on 26/7/17. Modified by zhanghao on 28/9/17. * * @author ZHANG HAO */ public class StockPricePrediction { private static final Logger log = LoggerFactory.getLogger(StockPricePrediction.class); private static final int exampleLength = 15; // time series length public static void main(String[] args) throws IOException { String file = "dummy-prices-split-adjusted.csv"; String symbol = "SOL"; // stock name int batchSize = 64; // mini-batch size double splitRatio = 0.8; // 90% for training, 10% for testing int epochs = 100; // training epochs log.info("Create dataSet iterator..."); PriceCategory category = PriceCategory.CLOSE; // CLOSE: predict close price StockDataSetIterator iterator = new StockDataSetIterator(file, symbol, batchSize, exampleLength, splitRatio, category); log.info("Load test dataset..."); List<Pair<INDArray, INDArray>> test = iterator.getTestDataSet(); log.info("Build lstm networks..."); MultiLayerNetwork net = RecurrentNets.buildLstmNetworks(iterator.inputColumns(), iterator.totalOutcomes()); log.info("Training..."); for (int i = 0; i < epochs; i++) { while (iterator.hasNext()) net.fit(iterator.next()); // fit model using mini-batch data iterator.reset(); // reset iterator net.rnnClearPreviousState(); // clear previous state } log.info("Saving model..."); File locationToSave = new File("StockPriceLSTM_".concat(String.valueOf(category)).concat(".zip")); // saveUpdater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this to train your // network more in the future ModelSerializer.writeModel(net, locationToSave, true); log.info("Load model..."); net = ModelSerializer.restoreMultiLayerNetwork(locationToSave); log.info("Testing..."); if (category.equals(PriceCategory.ALL)) { INDArray max = Nd4j.create(iterator.getMaxArray()); INDArray min = Nd4j.create(iterator.getMinArray()); predictAllCategories(net, test, max, min); } else { double max = iterator.getMaxNum(category); double min = iterator.getMinNum(category); predictPriceOneAhead(net, test, max, min, category); } log.info("Done..."); } /** Predict one feature of a stock one-day ahead */ private static void predictPriceOneAhead( MultiLayerNetwork net, List<Pair<INDArray, INDArray>> testData, double max, double min, PriceCategory category) { double[] predicts = new double[testData.size()]; double[] actuals = new double[testData.size()]; for (int i = 0; i < testData.size(); i++) { predicts[i] = net.rnnTimeStep(testData.get(i).getKey()).getDouble(exampleLength - 1) * (max - min) + min; actuals[i] = testData.get(i).getValue().getDouble(0); } log.info("Print out Predictions and Actual Values..."); log.info("Predict,Actual"); for (int i = 0; i < predicts.length; i++) log.info(predicts[i] + "," + actuals[i]); log.info("Plot..."); PlotUtil.plot(predicts, actuals, String.valueOf(category)); } private static void predictPriceMultiple( MultiLayerNetwork net, List<Pair<INDArray, INDArray>> testData, double max, double min) { // TODO } /** * Predict all the features (open, close, low, high prices and volume) of a stock one-day ahead */ private static void predictAllCategories( MultiLayerNetwork net, List<Pair<INDArray, INDArray>> testData, INDArray max, INDArray min) { INDArray[] predicts = new INDArray[testData.size()]; INDArray[] actuals = new INDArray[testData.size()]; for (int i = 0; i < testData.size(); i++) { predicts[i] = net.rnnTimeStep(testData.get(i).getKey()) .getRow(exampleLength - 1) .mul(max.sub(min)) .add(min); actuals[i] = testData.get(i).getValue(); } log.info("Print out Predictions and Actual Values..."); log.info("Predict\tActual"); for (int i = 0; i < predicts.length; i++) log.info(predicts[i] + "\t" + actuals[i]); log.info("Plot..."); for (int n = 0; n < 5; n++) { double[] pred = new double[predicts.length]; double[] actu = new double[actuals.length]; for (int i = 0; i < predicts.length; i++) { pred[i] = predicts[i].getDouble(n); actu[i] = actuals[i].getDouble(n); } String name; switch (n) { case 0: name = "Stock OPEN Price"; break; case 1: name = "Stock CLOSE Price"; break; case 2: name = "Stock LOW Price"; break; case 3: name = "Stock HIGH Price"; break; case 4: name = "Stock VOLUME Amount"; break; default: throw new NoSuchElementException(); } PlotUtil.plot(pred, actu, name); } } } package representation; /** * Created by zhanghao on 28/9/17. * * @author ZHANG HAO */ public enum PriceCategory { OPEN, CLOSE, LOW, HIGH, VOLUME, ALL } package representation; /** * Created by zhanghao on 26/7/17. * * @author ZHANG HAO */ public class StockData { private String date; // date private String symbol; // stock name private double open; // open price private double close; // close price private double low; // low price private double high; // high price private double volume; // volume public StockData() {} public StockData( String date, String symbol, double open, double close, double low, double high, double volume) { this.date = date; this.symbol = symbol; this.open = open; this.close = close; this.low = low; this.high = high; this.volume = volume; } public String getDate() { return date; } public void setDate(String date) { this.date = date; } public String getSymbol() { return symbol; } public void setSymbol(String symbol) { this.symbol = symbol; } public double getOpen() { return open; } public void setOpen(double open) { this.open = open; } public double getClose() { return close; } public void setClose(double close) { this.close = close; } public double getLow() { return low; } public void setLow(double low) { this.low = low; } public double getHigh() { return high; } public void setHigh(double high) { this.high = high; } public double getVolume() { return volume; } public void setVolume(double volume) { this.volume = volume; } } package representation; import static history.History.getHistory; import com.google.common.collect.ImmutableMap; import java.util.*; import org.apache.commons.lang3.tuple.Pair; import org.json.JSONArray; import org.json.JSONObject; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; /** * Created by zhanghao on 26/7/17. Modified by zhanghao on 28/9/17. * * @author ZHANG HAO */ public class StockDataSetIterator implements DataSetIterator { /** category and its index */ private final Map<PriceCategory, Integer> featureMapIndex = ImmutableMap.of( PriceCategory.OPEN, 0, PriceCategory.CLOSE, 1, PriceCategory.LOW, 2, PriceCategory.HIGH, 3, PriceCategory.VOLUME, 4); private final int VECTOR_SIZE = 5; // number of features for a stock data private int miniBatchSize; // mini-batch size private int exampleLength; // default 22, say, 22 working days per month private int predictLength = 1; // default 1, say, one day ahead prediction /** minimal values of each feature in stock dataset */ private double[] minArray = new double[VECTOR_SIZE]; /** maximal values of each feature in stock dataset */ private double[] maxArray = new double[VECTOR_SIZE]; /** feature to be selected as a training target */ private PriceCategory category; /** mini-batch offset */ private LinkedList<Integer> exampleStartOffsets = new LinkedList<>(); /** stock dataset for training */ private List<StockData> train; /** adjusted stock dataset for testing */ private List<Pair<INDArray, INDArray>> test; public StockDataSetIterator( String filename, String symbol, int miniBatchSize, int exampleLength, double splitRatio, PriceCategory category) { this.miniBatchSize = miniBatchSize; this.exampleLength = exampleLength; this.category = category; List<StockData> stockDataList = readStockDataFromFile(filename, symbol); // hier läuft vermutlich was falsch, anfang int split = (int) Math.round(exampleLength * splitRatio); train = stockDataList.subList(stockDataList.size() - split, stockDataList.size()); System.out.println("train.size() = " + train.size()); test = generateTestDataSet(stockDataList); System.out.println("test.size() = " + test.size()); initializeOffsets(); // hier läuft vermutlich was falsch, ende } /** initialize the mini-batch offsets */ private void initializeOffsets() { exampleStartOffsets.clear(); int window = exampleLength + predictLength; for (int i = 0; i < train.size() - window; i++) { exampleStartOffsets.add(i); } } public List<Pair<INDArray, INDArray>> getTestDataSet() { return test; } public double[] getMaxArray() { return maxArray; } public double[] getMinArray() { return minArray; } public double getMaxNum(PriceCategory category) { return maxArray[featureMapIndex.get(category)]; } public double getMinNum(PriceCategory category) { return minArray[featureMapIndex.get(category)]; } @Override public DataSet next(int num) { if (exampleStartOffsets.size() == 0) throw new NoSuchElementException(); int actualMiniBatchSize = Math.min(num, exampleStartOffsets.size()); INDArray input = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f'); INDArray label; if (category.equals(PriceCategory.ALL)) label = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f'); else label = Nd4j.create(new int[] {actualMiniBatchSize, predictLength, exampleLength}, 'f'); for (int index = 0; index < actualMiniBatchSize; index++) { int startIdx = exampleStartOffsets.removeFirst(); int endIdx = startIdx + exampleLength; StockData curData = train.get(startIdx); StockData nextData; for (int i = startIdx; i < endIdx; i++) { int c = i - startIdx; input.putScalar( new int[] {index, 0, c}, (curData.getOpen() - minArray[0]) / (maxArray[0] - minArray[0])); input.putScalar( new int[] {index, 1, c}, (curData.getClose() - minArray[1]) / (maxArray[1] - minArray[1])); input.putScalar( new int[] {index, 2, c}, (curData.getLow() - minArray[2]) / (maxArray[2] - minArray[2])); input.putScalar( new int[] {index, 3, c}, (curData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3])); input.putScalar( new int[] {index, 4, c}, (curData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4])); nextData = train.get(i + 1); if (category.equals(PriceCategory.ALL)) { label.putScalar( new int[] {index, 0, c}, (nextData.getOpen() - minArray[1]) / (maxArray[1] - minArray[1])); label.putScalar( new int[] {index, 1, c}, (nextData.getClose() - minArray[1]) / (maxArray[1] - minArray[1])); label.putScalar( new int[] {index, 2, c}, (nextData.getLow() - minArray[2]) / (maxArray[2] - minArray[2])); label.putScalar( new int[] {index, 3, c}, (nextData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3])); label.putScalar( new int[] {index, 4, c}, (nextData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4])); } else { label.putScalar(new int[] {index, 0, c}, feedLabel(nextData)); } curData = nextData; } if (exampleStartOffsets.size() == 0) break; } return new DataSet(input, label); } private double feedLabel(StockData data) { double value; switch (category) { case OPEN: value = (data.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]); break; case CLOSE: value = (data.getClose() - minArray[1]) / (maxArray[1] - minArray[1]); break; case LOW: value = (data.getLow() - minArray[2]) / (maxArray[2] - minArray[2]); break; case HIGH: value = (data.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]); break; case VOLUME: value = (data.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]); break; default: throw new NoSuchElementException(); } return value; } @Override public int totalExamples() { return train.size() - exampleLength - predictLength; } @Override public int inputColumns() { return VECTOR_SIZE; } @Override public int totalOutcomes() { if (this.category.equals(PriceCategory.ALL)) return VECTOR_SIZE; else return predictLength; } @Override public boolean resetSupported() { return false; } @Override public boolean asyncSupported() { return false; } @Override public void reset() { initializeOffsets(); } @Override public int batch() { return miniBatchSize; } @Override public int cursor() { return totalExamples() - exampleStartOffsets.size(); } @Override public int numExamples() { return totalExamples(); } @Override public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) { throw new UnsupportedOperationException("Not Implemented"); } @Override public DataSetPreProcessor getPreProcessor() { throw new UnsupportedOperationException("Not Implemented"); } @Override public List<String> getLabels() { throw new UnsupportedOperationException("Not Implemented"); } @Override public boolean hasNext() { return exampleStartOffsets.size() > 0; } @Override public DataSet next() { return next(miniBatchSize); } private List<Pair<INDArray, INDArray>> generateTestDataSet(List<StockData> stockDataList) { int window = exampleLength + predictLength; List<Pair<INDArray, INDArray>> test = new ArrayList<>(); for (int i = 0; i < stockDataList.size() - window; i++) { INDArray input = Nd4j.create(new int[] {exampleLength, VECTOR_SIZE}, 'f'); for (int j = i; j < i + exampleLength; j++) { StockData stock = stockDataList.get(j); input.putScalar( new int[] {j - i, 0}, (stock.getOpen() - minArray[0]) / (maxArray[0] - minArray[0])); input.putScalar( new int[] {j - i, 1}, (stock.getClose() - minArray[1]) / (maxArray[1] - minArray[1])); input.putScalar( new int[] {j - i, 2}, (stock.getLow() - minArray[2]) / (maxArray[2] - minArray[2])); input.putScalar( new int[] {j - i, 3}, (stock.getHigh() - minArray[3]) / (maxArray[3] - minArray[3])); input.putScalar( new int[] {j - i, 4}, (stock.getVolume() - minArray[4]) / (maxArray[4] - minArray[4])); } StockData stock = stockDataList.get(i + exampleLength); INDArray label; if (category.equals(PriceCategory.ALL)) { label = Nd4j.create(new int[] {VECTOR_SIZE}, 'f'); // ordering is set as 'f', faster construct label.putScalar(new int[] {0}, stock.getOpen()); label.putScalar(new int[] {1}, stock.getClose()); label.putScalar(new int[] {2}, stock.getLow()); label.putScalar(new int[] {3}, stock.getHigh()); label.putScalar(new int[] {4}, stock.getVolume()); } else { label = Nd4j.create(new int[] {1}, 'f'); switch (category) { case OPEN: label.putScalar(new int[] {0}, stock.getOpen()); break; case CLOSE: label.putScalar(new int[] {0}, stock.getClose()); break; case LOW: label.putScalar(new int[] {0}, stock.getLow()); break; case HIGH: label.putScalar(new int[] {0}, stock.getHigh()); break; case VOLUME: label.putScalar(new int[] {0}, stock.getVolume()); break; default: throw new NoSuchElementException(); } } test.add(Pair.of(input, label)); } return test; } private List<StockData> readStockDataFromFile(String filename, String symbol) { List<StockData> stockDataList = new ArrayList<>(); for (int i = 0; i < maxArray.length; i++) { // initialize max and min arrays maxArray[i] = Double.MIN_VALUE; minArray[i] = Double.MAX_VALUE; } // load all elements in a list JSONArray ja = getHistory(symbol, 30); System.out.println("ja.length() = " + ja.length()); List<List<Object>> list = new ArrayList<>(); for (int i = 0; i < ja.length(); i++) { JSONObject jo = ja.getJSONObject(i); list.add( new ArrayList<>( List.of( jo.getLong("date"), jo.getDouble("rate"), Math.round(jo.getDouble("volume") / 10000.0) / 100.0, jo.getDouble("cap")))); } list.sort(Comparator.comparing(l -> (Long) l.getFirst())); list.forEach( l -> { int daysOfEpoch = (int) ((long) l.getFirst() / 1000L / 60L / 60L / 24L); l.set(0, daysOfEpoch); }); System.out.println("list.size() = " + list.size()); List<List<Object>> list2 = new ArrayList<>(); for (int i = 0; i < list.size(); i++) { List<Object> l = list.get(i); int daysOfEpoch = (int) l.getFirst(); double open = (double) l.get(1); double close = (double) l.get(1); double low = (double) l.get(1); double high = (double) l.get(1); double volume = (double) l.get(2); for (int j = i + 1; j < list.size(); j++) { List<Object> l2 = list.get(j); int daysOfEpoch2 = (int) l2.getFirst(); if (daysOfEpoch == daysOfEpoch2) { close = (double) l2.get(1); low = Math.min(low, (double) l2.get(1)); high = Math.max(high, (double) l2.get(1)); volume += (double) l2.get(2); } else { i = j - 1; break; } } list2.add(List.of(daysOfEpoch, open, close, low, high, volume)); } System.out.println("list2.size() = " + list2.size()); list2.forEach(l -> System.out.println("l = " + l)); for (int j = 0; j < list2.size(); j++) { List<Object> l = list2.get(j); double[] nums = { (double) l.get(1), (double) l.get(2), (double) l.get(3), (double) l.get(4), (double) l.get(5), }; for (int i = 0; i < VECTOR_SIZE; i++) { if (nums[i] > maxArray[i]) maxArray[i] = nums[i]; if (nums[i] < minArray[i]) minArray[i] = nums[i]; } stockDataList.add( new StockData(l.get(0) + "", symbol, nums[0], nums[1], nums[2], nums[3], nums[4])); } return stockDataList; } } package utils; import javax.swing.*; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.JFreeChart; import org.jfree.chart.axis.NumberAxis; import org.jfree.chart.axis.NumberTickUnit; import org.jfree.chart.plot.PlotOrientation; import org.jfree.chart.plot.XYPlot; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; /** * Created by zhanghao on 26/7/17. Modified by zhanghao on 28/9/17. * * @author ZHANG HAO */ public class PlotUtil { public static void plot(double[] predicts, double[] actuals, String name) { double[] index = new double[predicts.length]; for (int i = 0; i < predicts.length; i++) index[i] = i; int min = minValue(predicts, actuals); int max = maxValue(predicts, actuals); final XYSeriesCollection dataSet = new XYSeriesCollection(); addSeries(dataSet, index, predicts, "Predicts"); addSeries(dataSet, index, actuals, "Actuals"); final JFreeChart chart = ChartFactory.createXYLineChart( "Prediction Result", // chart title "Index", // x axis label name, // y axis label dataSet, // data PlotOrientation.VERTICAL, true, // include legend true, // tooltips false // urls ); XYPlot xyPlot = chart.getXYPlot(); // X-axis final NumberAxis domainAxis = (NumberAxis) xyPlot.getDomainAxis(); domainAxis.setRange((int) index[0], (int) (index[index.length - 1] + 2)); domainAxis.setTickUnit(new NumberTickUnit(20)); domainAxis.setVerticalTickLabels(true); // Y-axis final NumberAxis rangeAxis = (NumberAxis) xyPlot.getRangeAxis(); rangeAxis.setRange(min, max); rangeAxis.setTickUnit(new NumberTickUnit(50)); final ChartPanel panel = new ChartPanel(chart); final JFrame f = new JFrame(); f.add(panel); f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); f.pack(); f.setVisible(true); } private static void addSeries( final XYSeriesCollection dataSet, double[] x, double[] y, final String label) { final XYSeries s = new XYSeries(label); for (int j = 0; j < x.length; j++) s.add(x[j], y[j]); dataSet.addSeries(s); } private static int minValue(double[] predicts, double[] actuals) { double min = Integer.MAX_VALUE; for (int i = 0; i < predicts.length; i++) { if (min > predicts[i]) min = predicts[i]; if (min > actuals[i]) min = actuals[i]; } return (int) (min * 0.98); } private static int maxValue(double[] predicts, double[] actuals) { double max = Integer.MIN_VALUE; for (int i = 0; i < predicts.length; i++) { if (max < predicts[i]) max = predicts[i]; if (max < actuals[i]) max = actuals[i]; } return (int) (max * 1.02); } }
Die fragliche Stelle habe ich euch markiert, das ist Zeile 461 bis 468.
Danke
-
Edit: Mit folgenden Daten hatte ich das getestet (die Spalten sollten selbsterklärend sein):
list2.size() = 32 l = [19748, 91.83797335102334, 92.71648526759283, 91.83797335102334, 92.71648526759283, 4203.3099999999995] l = [19749, 93.36752755612166, 92.64628278182262, 92.12689371773429, 93.36752755612166, 6161.77] l = [19750, 94.40827278049834, 95.56986201017443, 94.40827278049834, 97.80647678983507, 7970.26] l = [19751, 97.37665025155518, 100.77814246091854, 97.37665025155518, 100.77814246091854, 7992.23] l = [19752, 102.63094180671153, 101.43243634953313, 101.43243634953313, 104.02560679484226, 13268.67] l = [19753, 100.50742797179204, 98.1066613081412, 98.1066613081412, 100.50742797179204, 11700.29] l = [19754, 93.45949987216143, 96.30517749954835, 93.45949987216143, 96.30517749954835, 12483.83] l = [19755, 98.63787508229797, 100.9428572386196, 98.63787508229797, 101.32282620549877, 12444.04] l = [19756, 99.78646788861874, 98.14028160990891, 97.96491587753393, 99.78646788861874, 6038.719999999999] l = [19757, 97.91699797883597, 96.90941297944825, 96.90941297944825, 97.91699797883597, 4360.49] l = [19758, 95.2105059295542, 94.88591585705699, 94.88591585705699, 97.40977369541514, 6934.799999999999] l = [19759, 95.25615067490419, 97.6627038684511, 94.57842668292786, 97.6627038684511, 5556.82] l = [19760, 95.81869065389131, 96.94537440993797, 94.96985657563032, 96.94537440993797, 5433.09] l = [19761, 101.49823001006955, 102.58697481001379, 101.25393508360612, 104.48851802118729, 10427.73] l = [19762, 104.83512403285263, 105.10605161573433, 104.83512403285263, 105.6318898617301, 8543.07] l = [19763, 109.85718169398818, 108.06725132732612, 108.06725132732612, 109.85718169398818, 9098.49] l = [19764, 108.67799649709828, 107.64252733687948, 107.64252733687948, 109.39178024684149, 7010.89] l = [19765, 105.81123324621163, 110.20579564774938, 104.42941559194774, 110.20579564774938, 6380.76] l = [19766, 113.87616473303284, 110.1654238619581, 110.1654238619581, 114.22238899033734, 10161.42] l = [19767, 111.32926743082966, 116.79505031796758, 111.32926743082966, 118.05399765058999, 11254.34] l = [19768, 115.91907847345627, 113.72627278052416, 113.72627278052416, 116.02751475453104, 8238.42] l = [19769, 113.95029261214435, 109.06913856356621, 109.06913856356621, 113.95029261214435, 7930.49] l = [19770, 111.15067034817159, 109.65974259654443, 106.33753287839521, 111.15067034817159, 7950.01] l = [19771, 109.56562807356273, 112.91557017125702, 109.56562807356273, 112.91557017125702, 5157.08] l = [19772, 112.0306760091593, 113.09392950614297, 112.0306760091593, 113.09392950614297, 5933.93] l = [19773, 112.02534274914494, 108.24036074103542, 108.24036074103542, 112.02534274914494, 9767.079999999998] l = [19774, 106.3463777107441, 101.78040058120601, 101.78040058120601, 106.3463777107441, 8264.92] l = [19775, 102.87331537357632, 101.77337611631232, 101.77337611631232, 106.04341760593877, 9739.18] l = [19776, 103.3423316637421, 101.5471202445096, 101.5471202445096, 103.3423316637421, 5677.18] l = [19777, 100.8444063780233, 103.89123602922389, 100.8444063780233, 103.89123602922389, 4982.28] l = [19778, 103.58599623283995, 102.5102713232609, 102.5102713232609, 103.58599623283995, 2408.31] l = [19778, 102.5102713232609, 102.5102713232609, 102.5102713232609, 102.5102713232609, 1178.1]
-
@oxide sagte in Stock price forecasting mit LSTM in Java:
Ich suche nämlich auch sowas. Ich möchte die Lottozahlen berechnen mahne der bisher gezogenen.
Dann würde ich dir Python mit den folgenden Modulen empfehlen:
- Scikit-learn
- Matplotlib
- NumPy
- SciPy
- Pandas
Die Module nehmen dir eine ganze Menge von Programmierarbeit ab und du hast beim Einstieg in die KI Welt erst einmal viel mit den neuen Begriffen (z.B. hoher Bias, hohe Varianz) zu tun. Wenn du tiefer in die Materie stoßen willst, würde ich dir auch ein Buch empfehlen.
Ich weis aber nicht ob Lottozahlen vorhersagen ein gutes Beispiel ist, da ich ahne und vermute es sich hierbei um ein chaotisches System handeln könnte. Kleine Änderungen in den Anfangsbedingungen könnte zu komplett anderen Ergebnissen führen. Die Vorhersagbarkeit kann aufgrund dieses Verhaltens zusammenbrechen. Ein ähnliches Problem ist das dynamische Billard.
-
@Quiche-Lorraine sagte in Stock price forecasting mit LSTM in Java:
Ich weis aber nicht ob Lottozahlen vorhersagen ein gutes Beispiel ist,
Jetzt mal ehrlich und Butter bei die Fische, das war nur ein Troll-Versuch von @oxide , um mich zu diskreditieren ...
In serösen Foren würde das einfach wegmoderiert werden, aber hier sind die Mods nicht willens, einzugreifen.
Aber Danke für die Anregungen, mit Python das zu machen.
-
Ich >glaube< langsam, dass obiges
MultiLayerNetwork
falsch ist und damit auch oben verlinkte Library kompletter Murks ist, also nicht nur derStockDataSetIterator
"spinnt"... Oder ich kenne mich zu wenig mit der Materie aus. Aber 9.5 von 10 Blogs, die man im Internet zu diesem Thema findet, sind auch Murks...Aber egal... der Bitcoin ist auf dem Weg zu einem neuen Allzeithoch, das genügt mir als Prognose.
-
Da haben wir den Schlamassel ja schon:
https://www.cnbc.com/2024/03/04/crypto-market-today.html
Das neue Allzeithoch kommt/kam für mich schneller als erwartet! Für euch auch?
-
@omggg sagte in Stock price forecasting mit LSTM in Java:
wegen Overfitness
Meinte overfitting ...
Ich muss das hier leider richtigstellen, weil ich befürchte, dass andernfalls die Trolle wieder auftauchen.
Hast du noch Lust, etwas präziser auf meine Frage(n) zu antworten? Oder führt kein Weg mehr an Python vorbei?
-
@omggg sagte in Stock price forecasting mit LSTM in Java:
Hast du noch Lust, etwas präziser auf meine Frage(n) zu antworten? Oder führt kein Weg mehr an Python vorbei?
Eine kleine Vorgeschichte: Ich stöbere privat seit einiger Zeit durch das Buch "Machine Learning mit Python und Scikit-Learn und TensorFlow". Und dieses Buch zeigte mir eine etwas andere Sichtweise auf Maschine Learning. Vorher dachte ich "Viele Daten rein, trainieren und gut ist es". Nun achte ich aber andere Dinge wie Korrelationen, Overfitting, hoher Bias, usw
Vorauf ich hinaus möchte ist folgendes: Ich kenne KI nur in Verbindung mit Python. Wohl aber gibt es einige Module bzw. Bibliotheken welche ich nicht missen möchte.
Und das fängt mit
scikit-learn
an. Dieses Modul ist groß und besteht aus unterschiedlichern Klassifiern, Vorverarbeitungen (z.B. sklearn.preprocessing.StandardScaler), Batches,... Da ist halt alles dabei, was man nicht generell von anderen Github Projekten sagen kann.Von daher würde ich beim Maschine Learning darauf achten dass man entweder
scikit-learn
oderTensorFlow
nutzt. Andere Libs dieser Größenordnungen kenne ich nicht. Und ferner würde ich das oben genannte Buch empfehlen.Python muss nicht sein, ist aber sehr praktisch. Ein Aufruf von "pip install scikit-learn" und innerhalb kurzer Zeit ist
scikit-learn
Und das ist eine schöne Sache an Python.