使用HMM建模股票价格波动

来自集智百科 - 复杂系统|人工智能|复杂科学|复杂网络|自组织
跳到导航 跳到搜索

原理

HMM的原理参考这里

使用HMM拟合股票市场数据实验参考这篇论文

准备数据

    import datetime
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.finance import quotes_historical_yahoo
    from matplotlib.dates import YearLocator, MonthLocator, DateFormatter
    from sklearn.hmm import GaussianHMM
    
    # parepare data
    date1 = datetime.date(1995, 1, 1)  # start date
    date2 = datetime.date(2012, 1, 6)  # end date
    # get quotes from yahoo finance
    quotes = quotes_historical_yahoo("INTC", date1, date2)
    #data: 
    #Date	Open	Close High	Low		Volume    
    # unpack quotes
    dates = np.array([q[0] for q in quotes], dtype=int)
    close_v = np.array([q[2] for q in quotes])
    volume = np.array([q[5] for q in quotes])[1:]
    # take diff of close value
    # this makes len(diff) = len(close_t) - 1
    # therefore, others quantity also need to be shifted
    diff = close_v[1:] - close_v[:-1]
    dates = dates[1:]
    close_v = close_v[1:]
    # pack diff and volume for training
    X = np.column_stack([diff, volume])
    #plt.plot(close_v,volume,"bo")

拟合模型

    #fit Gaussian HMM
    n_components = 5
    # make an HMM instance and execute fit
    model = GaussianHMM(n_components, covariance_type="diag", n_iter=1000)
    model.fit([X])
    # predict the optimal sequence of internal hidden state
    hidden_states = model.predict(X)

打印模型训练的结果并绘图

    # print trained parameters and plot
    
    print(np.round(model.transmat_,2))
    print("means and vars of each hidden state")
    for i in range(n_components):
        print("%dth hidden state" % i)
        print("mean = ", model.means_[i])
        print("var = ",np.diag(model.covars_[i])   )
        print()
           
    years = YearLocator()   # every year
    months = MonthLocator()  # every month
    yearsFmt = DateFormatter('%Y')
    fig = plt.figure()
    ax = fig.add_subplot(111)
    for i in range(n_components):
        # use fancy indexing to plot data in each state
        idx = (hidden_states == i)
        ax.plot_date(dates[idx], close_v[idx], 'o', label="%dth hidden state" % i)
    
    # format the figure
    ax.legend()
    ax.xaxis.set_major_locator(years)
    ax.xaxis.set_major_formatter(yearsFmt)
    ax.xaxis.set_minor_locator(months)
    ax.autoscale_view()
    ax.fmt_xdata = DateFormatter('%Y-%m-%d')
    ax.fmt_ydata = lambda x: '$%1.2f' % x
    ax.grid(True)
    fig.autofmt_xdate()
    
    #plt.plot()

HMM模型拟合

Stockmarket HMM 1.png