程式碼:
# -*- coding: utf-8 -*- import csv from pymongo import MongoClient import sys import numpy as np import rpy2.robjects as robjects from rpy2.robjects import Formula, Environment from rpy2.robjects.vectors import IntVector, FloatVector from rpy2.robjects.packages import importr import rpy2.robjects.lib.ggplot2 as ggplot2 class MongoDB_INF: dbHost = 'localhost' dbPort = 27017 dbClient = None dbName = 'emprogria' dbConn = None def __init__(self, dbHost='localhost', dbPort=27017): self.dbHost = dbHost self.dbPort = dbPort def openDB(self, dbName='emprogria'): OK = False try: self.dbClient = MongoClient('mongodb://%s:%d/' % (self.dbHost, self.dbPort)) self.dbConn = self.dbClient[self.dbName] OK = True except: self.dbConn = None self.dbClient = None return OK def closeDB(self): if self.dbClient is not None: self.dbClient.close() # 刪除資料表內所有資料 def removeDocs(self, criteria={}, nameCollection='SP500'): self.dbConn[nameCollection].remove(criteria) # 查詢資料表內所有資料 def listDocs(self, criteria={}, nameCollection='SP500'): for doc in self.dbConn[nameCollection].find(criteria): print '%s-%s-%s' % (doc['StockDate'][0:4], doc['StockDate'][4:6], doc['StockDate'][6:8]) print u'\t開盤=%.2f' % (doc['OpenIndex']) print u'\t收盤=%.2f' % (doc['CloseIndex']) print u'\t盤後=%.2f' % (doc['AdjIndex']) print u'\t最高=%.2f' % (doc['HighIndex']) print u'\t最低=%.2f' % (doc['LowIndex']) print u'\t交易量=%d (M)' % (doc['StockVol']) def importFromCSV(self, csvFile, nameCollection='SP500'): csvF = open(csvFile, 'r') fieldHead = True recCount = 0 for rowDB in csv.DictReader(csvF, ["StockDate", "OpenIndex", "HighIndex", "LowIndex", "CloseIndex", "StockVol", "AdjIndex"]): if fieldHead: fieldHead = False else: if (rowDB is not None): stockData = { 'StockDate': rowDB['StockDate'], 'OpenIndex': float(rowDB['OpenIndex']), 'HighIndex': float(rowDB['HighIndex']), 'LowIndex': float(rowDB['LowIndex']), 'CloseIndex': float(rowDB['CloseIndex']), 'StockVol': float(rowDB['StockVol']) / 1000000.0, 'AdjIndex': float(rowDB['AdjIndex']) } # 寫入資料庫 self.dbConn[nameCollection].insert(stockData) recCount += 1 return recCount def getStats(self, criteria={}, nameCollection='SP500'): OpenIndex = [] HighIndex = [] LowIndex = [] CloseIndex = [] AdjIndex = [] StockVol = [] for doc in self.dbConn[nameCollection].find(criteria): OpenIndex.append(doc['OpenIndex']) CloseIndex.append(doc['CloseIndex']) AdjIndex.append(doc['AdjIndex']) HighIndex.append(doc['HighIndex']) LowIndex.append(doc['LowIndex']) StockVol.append(doc['StockVol']) _OpenIndex = np.array(OpenIndex) _CloseIndex = np.array(CloseIndex) _AdjIndex = np.array(AdjIndex) _HighIndex = np.array(HighIndex) _LowIndex = np.array(LowIndex) _StockVol = np.array(StockVol) print "%s\t: %8.2f\t%8.2f" % (u'開盤', _OpenIndex.mean(), _OpenIndex.std()) print "%s\t: %8.2f\t%8.2f" % (u"收盤", _CloseIndex.mean(), _CloseIndex.std()) print "%s\t: %8.2f\t%8.2f" % (u"盤後", _AdjIndex.mean(), _AdjIndex.std()) print "%s\t: %8.2f\t%8.2f" % (u"最高", _HighIndex.mean(), _HighIndex.std()) print "%s\t: %8.2f\t%8.2f" % (u"最低", _LowIndex.mean(), _LowIndex.std()) print "%s\t: %8.2f\t%8.2f" % (u"交易量", _StockVol.mean(), _StockVol.std()) def rPlot(self, criteria={}, nameCollection='SP500'): OpenIndex = [] CloseIndex = [] HighIndex = [] LowIndex = [] recCount = 0 for doc in self.dbConn[nameCollection].find(criteria): OpenIndex.append(doc['OpenIndex']) CloseIndex.append(doc['CloseIndex']) HighIndex.append(doc['HighIndex']) LowIndex.append(doc['LowIndex']) recCount += 1 rIndices = robjects.DataFrame( { '日期': IntVector(range(0, recCount)), '開盤': FloatVector(OpenIndex), '收盤': FloatVector(CloseIndex), '最高': FloatVector(HighIndex), '最低': FloatVector(LowIndex) } ) grdevices = importr('grDevices') gpIndices = ggplot2.ggplot(rIndices) pp = gpIndices + \ ggplot2.ggtitle(u'S&P500 指數線圖') + \ ggplot2.geom_line( ggplot2.aes_string(x='日期', y='開盤'), colour="blue") + \ ggplot2.geom_line( ggplot2.aes_string(x='日期', y='收盤'), colour="green") + \ ggplot2.geom_line( ggplot2.aes_string(x='日期', y='最高'), colour="red") + \ ggplot2.geom_line( ggplot2.aes_string(x='日期', y='最低'), colour="yellow") grdevices.png(file="%s.png" % (nameCollection), width=512, height=512) pp.plot() grdevices.dev_off() if __name__ == '__main__': csvFile = 'SP500.csv' if len(sys.argv) > 1: csvFile = sys.argv[1] jobTask = [False, False, False, True] queryCat = 2 mongoDB_Inf = MongoDB_INF(dbHost='192.168.171.1') if mongoDB_Inf.openDB(): if jobTask[0]: mongoDB_Inf.removeDocs() print u'筆數: %d' % (mongoDB_Inf.importFromCSV(csvFile)) if jobTask[1]: if queryCat == 0: # 列出所有資料 criteria = {} elif queryCat == 1: # 列出交易量 > 4200 資料 criteria = {'StockVol': {'$gt': 4000}} elif queryCat == 2: # 2040 < 列出盤後 < 2050 資料 criteria = {'AdjIndex': {'$gt': 2040}, 'AdjIndex': {'$lt': 2050}} mongoDB_Inf.listDocs(criteria) if jobTask[2]: criteria = {} mongoDB_Inf.getStats(criteria) if jobTask[3]: criteria = {} mongoDB_Inf.rPlot(criteria) mongoDB_Inf.closeDB() else: print u'資料庫錯誤'