blog_20160823_1_4915636 211行 Java
Raw
   1
   2
   3
   4
   5
   6
   7
   8
   9
  10
  11
  12
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
package edu.zju.cst.krselee.example.stock;

import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;

/**
* Created by kexi.lkx on 2016/8/23.
*/
public class StockDataIterator implements DataSetIterator {

private static final int VECTOR_SIZE = 6;
//每批次的训练数据组数
private int batchNum;

//每组训练数据长度(DailyData的个数)
private int exampleLength;

//数据集
private List<DailyData> dataList;

//存放剩余数据组的index信息
private List<Integer> dataRecord;

private double[] maxNum;
/**
* 构造方法
* */
public StockDataIterator(){
dataRecord = new ArrayList<>();
}

/**
* 加载数据并初始化
* */
public boolean loadData(String fileName, int batchNum, int exampleLength){
this.batchNum = batchNum;
this.exampleLength = exampleLength;
maxNum = new double[6];
//加载文件中的股票数据
try {
readDataFromFile(fileName);
}catch (Exception e){
e.printStackTrace();
return false;
}
//重置训练批次列表
resetDataRecord();
return true;
}

/**
* 重置训练批次列表
* */
private void resetDataRecord(){
dataRecord.clear();
int total = dataList.size()/exampleLength+1;
for( int i=0; i<total; i++ ){
dataRecord.add(i * exampleLength);
}
}

/**
* 从文件中读取股票数据
* */
public List<DailyData> readDataFromFile(String fileName) throws IOException{
dataList = new ArrayList<>();
FileInputStream fis = new FileInputStream(fileName);
BufferedReader in = new BufferedReader(new InputStreamReader(fis,"UTF-8"));
String line = in.readLine();
for(int i=0;i<maxNum.length;i++){
maxNum[i] = 0;
}
System.out.println("读取数据..");
while(line!=null){
String[] strArr = line.split(",");
if(strArr.length>=7) {
DailyData data = new DailyData();
//获得最大值信息,用于归一化
double[] nums = new double[6];
for(int j=0;j<6;j++){
nums[j] = Double.valueOf(strArr[j+2]);
if( nums[j]>maxNum[j] ){
maxNum[j] = nums[j];
}
}
//构造data对象
data.setOpenPrice(Double.valueOf(nums[0]));
data.setCloseprice(Double.valueOf(nums[1]));
data.setMaxPrice(Double.valueOf(nums[2]));
data.setMinPrice(Double.valueOf(nums[3]));
data.setTurnover(Double.valueOf(nums[4]));
data.setVolume(Double.valueOf(nums[5]));
dataList.add(data);

}
line = in.readLine();
}
in.close();
fis.close();
System.out.println("反转list...");
Collections.reverse(dataList);
return dataList;
}

public double[] getMaxArr(){
return this.maxNum;
}

public void reset(){
resetDataRecord();
}

public boolean hasNext(){
return dataRecord.size() > 0;
}

public DataSet next(){
return next(batchNum);
}

/**
* 获得接下来一次的训练数据集
* */
public DataSet next(int num){
if( dataRecord.size() <= 0 ) {
throw new NoSuchElementException();
}
int actualBatchSize = Math.min(num, dataRecord.size());
int actualLength = Math.min(exampleLength,dataList.size()-dataRecord.get(0)-1);
INDArray input = Nd4j.create(new int[]{actualBatchSize,VECTOR_SIZE,actualLength}, 'f');
INDArray label = Nd4j.create(new int[]{actualBatchSize,1,actualLength}, 'f');
DailyData nextData = null,curData = null;
//获取每批次的训练数据和标签数据
for(int i=0;i<actualBatchSize;i++){
int index = dataRecord.remove(0);
int endIndex = Math.min(index+exampleLength,dataList.size()-1);
curData = dataList.get(index);
for(int j=index;j<endIndex;j++){
//获取数据信息
nextData = dataList.get(j+1);
//构造训练向量
int c = endIndex-j-1;
input.putScalar(new int[]{i, 0, c}, curData.getOpenPrice()/maxNum[0]);
input.putScalar(new int[]{i, 1, c}, curData.getCloseprice()/maxNum[1]);
input.putScalar(new int[]{i, 2, c}, curData.getMaxPrice()/maxNum[2]);
input.putScalar(new int[]{i, 3, c}, curData.getMinPrice()/maxNum[3]);
input.putScalar(new int[]{i, 4, c}, curData.getTurnover()/maxNum[4]);
input.putScalar(new int[]{i, 5, c}, curData.getVolume()/maxNum[5]);
//构造label向量
label.putScalar(new int[]{i, 0, c}, nextData.getCloseprice()/maxNum[1]);
curData = nextData;
}
if(dataRecord.size()<=0) {
break;
}
}

return new DataSet(input, label);
}

public int batch() {
return batchNum;
}

public int cursor() {
return totalExamples() - dataRecord.size();
}

public int numExamples() {
return totalExamples();
}

public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException("Not implemented");
}

public int totalExamples() {
return (dataList.size()) / exampleLength;
}

public int inputColumns() {
return dataList.size();
}

public int totalOutcomes() {
return 1;
}

@Override
public List<String> getLabels() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
blog_20160823_2_4326992 86行 Java
Raw
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
package edu.zju.cst.krselee.example.stock;

/**
* Created by kexi.lkx on 2016/8/23.
*/
public class DailyData {

//开盘价
private double openPrice;
//收盘价
private double closeprice;
//最高价
private double maxPrice;
//最低价
private double minPrice;
//成交量
private double turnover;
//成交额
private double volume;

public double getTurnover() {

return turnover;
}

public double getVolume() {
return volume;
}

public DailyData(){

}

public double getOpenPrice() {
return openPrice;
}

public double getCloseprice() {
return closeprice;
}

public double getMaxPrice() {
return maxPrice;
}

public double getMinPrice() {
return minPrice;
}

public void setOpenPrice(double openPrice) {
this.openPrice = openPrice;
}

public void setCloseprice(double closeprice) {
this.closeprice = closeprice;
}

public void setMaxPrice(double maxPrice) {
this.maxPrice = maxPrice;
}

public void setMinPrice(double minPrice) {
this.minPrice = minPrice;
}

public void setTurnover(double turnover) {
this.turnover = turnover;
}

public void setVolume(double volume) {
this.volume = volume;
}

@Override
public String toString(){
StringBuilder builder = new StringBuilder();
builder.append("开盘价="+this.openPrice+", ");
builder.append("收盘价="+this.closeprice+", ");
builder.append("最高价="+this.maxPrice+", ");
builder.append("最低价="+this.minPrice+", ");
builder.append("成交量="+this.turnover+", ");
builder.append("成交额="+this.volume);
return builder.toString();
}
}
blog_20160823_3_5897155 33行 Java
Raw
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
private static final int IN_NUM = 6;
private static final int OUT_NUM = 1;
private static final int Epochs = 100;

private static final int lstmLayer1Size = 50;
private static final int lstmLayer2Size = 100;

public static MultiLayerNetwork getNetModel(int nIn,int nOut){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.learningRate(0.1)
.rmsDecay(0.5)
.seed(12345)
.regularization(true)
.l2(0.001)
.weightInit(WeightInit.XAVIER)
.updater(Updater.RMSPROP)
.list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayer1Size)
.activation("tanh").build())
.layer(1, new GravesLSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size)
.activation("tanh").build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("identity")
.nIn(lstmLayer2Size).nOut(nOut).build())
.pretrain(false).backprop(true)
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));

return net;
}
blog_20160823_4_9626909 34行 Java
Raw
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
public static void train(MultiLayerNetwork net,StockDataIterator iterator){
//迭代训练
for(int i=0;i<Epochs;i++) {
DataSet dataSet = null;
while (iterator.hasNext()) {
dataSet = iterator.next();
net.fit(dataSet);
}
iterator.reset();
System.out.println();
System.out.println("=================>完成第"+i+"次完整训练");
INDArray initArray = getInitArray(iterator);

System.out.println("预测结果:");
for(int j=0;j<20;j++) {
INDArray output = net.rnnTimeStep(initArray);
System.out.print(output.getDouble(0)*iterator.getMaxArr()[1]+" ");
}
System.out.println();
net.rnnClearPreviousState();
}
}

private static INDArray getInitArray(StockDataIterator iter){
double[] maxNums = iter.getMaxArr();
INDArray initArray = Nd4j.zeros(1, 6, 1);
initArray.putScalar(new int[]{0,0,0}, 3433.85/maxNums[0]);
initArray.putScalar(new int[]{0,1,0}, 3445.41/maxNums[1]);
initArray.putScalar(new int[]{0,2,0}, 3327.81/maxNums[2]);
initArray.putScalar(new int[]{0,3,0}, 3470.37/maxNums[3]);
initArray.putScalar(new int[]{0,4,0}, 304197903.0/maxNums[4]);
initArray.putScalar(new int[]{0,5,0}, 3.8750365e+11/maxNums[5]);
return initArray;
}
blog_20160823_5_3880051 1行 Text
Raw
 1
3489.9679512619973 3516.991701169014 3510.4443733012677 3490.410951650143 3476.138713735342 3469.275475754738 3466.278687063456 3464.9017547094822 3464.2161934530736 3463.8574357616903 3463.670068384409 3463.582194536925 3463.5545977914335 3463.5658543586733 3463.6010765206815 3463.650460170508 3463.7067430067063 3463.764115188122 3463.8196717941764 3463.8705079042916
blog_20160823_5_2505236 11行 Java
Raw
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
public static void main(String[] args) {
String inputFile = StockRnnPredict.class.getClassLoader().getResource("stock/sh000001.csv").getPath();
int batchSize = 1;
int exampleLength = 30;
//初始化深度神经网络
StockDataIterator iterator = new StockDataIterator();
iterator.loadData(inputFile,batchSize,exampleLength);

MultiLayerNetwork net = getNetModel(IN_NUM,OUT_NUM);
train(net, iterator);
}
blog_20160827_1_1029212 1行 Text
Raw
 1
股票代码 日期开盘价 收盘价最高价 最低价成交量 成交额涨跌幅