私は次のようなパンダのデータフレームを持っています:
ticket date close
0 AAA 2018-01-12 176.16
1 AAA 2018-01-13 176.49
3 AAA 2018-01-14 176.00
4 BBB 2018-01-12 78.19
5 BBB 2018-01-13 79.90
6 BBB 2018-01-14 78.10
私は機能を持っています:
def rsi(dataframe, period, column = 'close'):
delta = dataframe[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
必要なのは、この関数を各groupby( 'ticket')のデータフレームに適用することです。これを試しましたが、機能しません。アドバイスをお願いします。
print(dataframe.groupby('ticket').apply(rsi, 2))
エラーが発生します:
重複する軸からインデックスを再作成することはできません
# -*- coding: utf-8 -*-
import json
import pandas
import requests
import datetime
def get_historical_prices(tickets, range):
request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
united_dataframe = pandas.DataFrame()
for symbol in json:
ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
ticket_dataframe.insert(0, 'ticket', symbol)
united_dataframe = united_dataframe.append(ticket_dataframe)
return united_dataframe[['ticket', 'date', 'close']]
def rsi(dataframe, period, column = 'close'):
delta = all_prices[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
# Get the data
tickets = ['AAPL', 'FB', 'TSLA']
all_prices = get_historical_prices(tickets, '1m')
print(all_prices.groupby('ticket').apply(rsi, 2))
ソースコードに問題があります。この線
delta = all_prices[column].diff()
する必要があります
delta = dataframe[column].diff()
それを修正することも問題なく実行されます。再割り当てにより、列rsi
がall_prices
ieに追加されます
all_prices = all_prices.groupby('ticket').apply(rsi, 2)
最終的なタラと結果を以下に示します
In [20]: # -*- coding: utf-8 -*-
...:
...: import json
...: import pandas
...: import requests
...: import datetime
...:
...: def get_historical_prices(tickets, range):
...: request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
...: json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
...: united_dataframe = pandas.DataFrame()
...: for symbol in json:
...: ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
...: ticket_dataframe.insert(0, 'ticket', symbol)
...: united_dataframe = united_dataframe.append(ticket_dataframe)
...: return united_dataframe[['ticket', 'date', 'close']]
...:
...: def rsi(dataframe, period, column = 'close'):
...: delta = dataframe[column].diff()
...: up, down = delta.copy(), delta.copy()
...: up[up < 0] = 0
...: down[down > 0] = 0
...: rolling_up = up.ewm(com=period - 1, adjust=False).mean()
...: rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
...: rsi = 100 - 100 / (1 + rolling_up / rolling_down)
...: dataframe['rsi'] = rsi
...: return dataframe
...:
...: # Get the data
...: tickets = ['AAPL', 'FB', 'TSLA']
...: all_prices = get_historical_prices(tickets, '1m')
...:
...: all_prices = all_prices.groupby('ticket').apply(rsi, 2)
...: print(all_prices.head())
...:
...:
ticket date close rsi
0 AAPL 2018-01-12 177.09 NaN
1 AAPL 2018-01-16 176.19 0.000000
2 AAPL 2018-01-17 179.10 76.377953
3 AAPL 2018-01-18 179.26 78.208232
4 AAPL 2018-01-19 178.46 44.065484
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加