統計分析-使用 MongoDB 並繪圖
資料源:
程式碼:
# -*- coding: utf-8 -*- import csv from pymongo import MongoClient import sys import numpy as np import matplotlib.pyplot as plt from matplotlib.font_manager import fontManager class MongoDB_INF: dbHost = 'localhost' dbPort = 27017 dbClient = None dbName = 'emprogria' dbConn = None def __init__(self, dbHost='localhost', dbPort=27017): self.dbHost = dbHost self.dbPort = dbPort def openDB(self, dbName='emprogria'): OK = False try: self.dbClient = MongoClient('mongodb://%s:%d/' % (self.dbHost, self.dbPort)) self.dbConn = self.dbClient[self.dbName] OK = True except: self.dbConn = None self.dbClient = None return OK def closeDB(self): if self.dbClient is not None: self.dbClient.close() # 刪除資料表內所有資料 def removeDocs(self, criteria={}, nameCollection='SP500'): self.dbConn[nameCollection].remove(criteria) # 查詢資料表內所有資料 def listDocs(self, criteria={}, nameCollection='SP500'): for doc in self.dbConn[nameCollection].find(criteria): print '%s-%s-%s' % (doc['StockDate'][0:4], doc['StockDate'][4:6], doc['StockDate'][6:8]) print u'\t開盤=%.2f' % (doc['OpenIndex']) print u'\t收盤=%.2f' % (doc['CloseIndex']) print u'\t盤後=%.2f' % (doc['AdjIndex']) print u'\t最高=%.2f' % (doc['HighIndex']) print u'\t最低=%.2f' % (doc['LowIndex']) print u'\t交易量=%d (M)' % (doc['StockVol']) def importFromCSV(self, csvFile, nameCollection='SP500'): csvF = open(csvFile, 'r') fieldHead = True recCount = 0 for rowDB in csv.DictReader(csvF, ["StockDate", "OpenIndex", "HighIndex", "LowIndex", "CloseIndex", "StockVol", "AdjIndex"]): if fieldHead: fieldHead = False else: if (rowDB is not None): stockData = { 'StockDate': rowDB['StockDate'], 'OpenIndex': float(rowDB['OpenIndex']), 'HighIndex': float(rowDB['HighIndex']), 'LowIndex': float(rowDB['LowIndex']), 'CloseIndex': float(rowDB['CloseIndex']), 'StockVol': float(rowDB['StockVol']) / 1000000.0, 'AdjIndex': float(rowDB['AdjIndex']) } # 寫入資料庫 self.dbConn[nameCollection].insert(stockData) recCount += 1 return recCount def getStats(self, criteria={}, nameCollection='SP500'): StockIndices = [] for doc in self.dbConn[nameCollection].find(criteria): StockIndices.append( [ doc['OpenIndex'], doc['CloseIndex'], doc['HighIndex'], doc['LowIndex'], doc['AdjIndex'], doc['StockVol'] ] ) _StockIndices = np.array(StockIndices) print "%s\t: %8.2f\t%8.2f" % (u'開盤', _StockIndices[:, 0].mean(), _StockIndices[:, 0].std()) print "%s\t: %8.2f\t%8.2f" % (u"收盤", _StockIndices[:, 1].mean(), _StockIndices[:, 1].std()) print "%s\t: %8.2f\t%8.2f" % (u"盤後", _StockIndices[:, 2].mean(), _StockIndices[:, 2].std()) print "%s\t: %8.2f\t%8.2f" % (u"最高", _StockIndices[:, 3].mean(), _StockIndices[:, 3].std()) print "%s\t: %8.2f\t%8.2f" % (u"最低", _StockIndices[:, 4].mean(), _StockIndices[:, 4].std()) print "%s\t: %8.2f\t%8.2f" % (u"交易量", _StockIndices[:, 5].mean(), _StockIndices[:, 5].std()) def Plot(self, criteria={}, nameCollection='SP500'): PlotData = [] for doc in self.dbConn[nameCollection].find(criteria): PlotData.append( [ doc['OpenIndex'], doc['CloseIndex'], doc['HighIndex'], doc['LowIndex'], doc['AdjIndex'], doc['StockVol'] ] ) _PlotData = np.array(PlotData) plt.rcParams["font.family"] = 'Microsoft MHei' plt.title(u"S&P500 指數線圖") plt.xlabel(u"日期") plt.ylabel(u"指數") plt.plot(range(0, len(PlotData)), _PlotData[:, 0], label=u'開盤') plt.plot(range(0, len(PlotData)), _PlotData[:, 1], label=u'收盤') plt.plot(range(0, len(PlotData)), _PlotData[:, 2], label=u'最高') plt.plot(range(0, len(PlotData)), _PlotData[:, 3], label=u'最低') plt.legend(loc='lower left') plt.show() if __name__ == '__main__': csvFile = 'SP500.csv' if len(sys.argv) > 1: csvFile = sys.argv[1] jobTask = [False, False, True, False] queryCat = 2 mongoDB_Inf = MongoDB_INF() if mongoDB_Inf.openDB(): if jobTask[0]: mongoDB_Inf.removeDocs() print u'筆數: %d' % (mongoDB_Inf.importFromCSV(csvFile)) if jobTask[1]: if queryCat == 0: # 列出所有資料 criteria = {} elif queryCat == 1: # 列出交易量 > 4200 資料 criteria = {'StockVol': {'$gt': 4000}} elif queryCat == 2: # 2040 < 列出盤後 < 2050 資料 criteria = { '$and': [ {'AdjIndex': {'$gt': 2040}}, {'AdjIndex': {'$lt': 2050}} ] } mongoDB_Inf.listDocs(criteria) if jobTask[2]: criteria = {} mongoDB_Inf.getStats(criteria) if jobTask[3]: criteria = {} mongoDB_Inf.Plot(criteria) mongoDB_Inf.closeDB() else: print u'資料庫錯誤'