from __future__ import annotations
from dataclasses import dataclass
from decimal import Decimal
from typing import TYPE_CHECKING, TypeAlias, TypeVar, Union, cast
import pandas as pd
from mypy_extensions import mypyc_attr
__all__ = ["OHLCAsset"]
# use ints until mypyc supports IntFlag
# https://github.com/mypyc/mypyc/issues/1022
AssetDataFieldInfo = int
ADFI_AVAILABLE_AT_CLOSE: int = 1
ADFI_AVAILABLE_AT_OPEN: int = 2
ADFI_REQUIRED: int = 4
T = TypeVar("T", bound=Union[pd.Series, pd.DataFrame])
AssetName: TypeAlias = str
"""Asset name string."""
[docs]@mypyc_attr(allow_interpreted_subclasses=True)
@dataclass(kw_only=True)
class Asset:
"""Anything that has a price."""
name: AssetName
"""Name string."""
denom: str = "USD"
"""Denominated currency."""
price_round_dp: int = 2
"""Number of decimal places to round prices to."""
quantity_round_dp: int = 2
"""Number of decimal places to round quantities to."""
data_label: str | None = None
"""`StrategyRunner.data` column index 1st level label.
Defaults to `name`
"""
def __post_init__(self):
if self.data_label is None:
self.data_label = self.name
[docs] def round_quantity(self, quantity) -> Decimal:
"""Round `quantity`."""
return round(quantity, self.quantity_round_dp)
[docs] def intraday_traded_price(
self, asset_day_data: pd.Series, size: Decimal | None = None
) -> Decimal:
"""Calculate price during market hours with given row of `asset_day_data` and
the order `size`.
The `size` can be used to
determine a price from say, bid / ask spreads.
"""
raise NotImplementedError(
"The intraday_traded_price method needs to be implemented."
)
[docs] def end_of_day_price(self, asset_day_data: pd.Series) -> Decimal:
"""Calculate price at end of day with given row of `asset_day_data`."""
raise NotImplementedError(
"The end_of_day_price method needs to be implemented."
)
[docs] def check_and_fix_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""Checks dataframe `data` has correct fields and fixes columns where
necessary."""
raise NotImplementedError(
"The check_and_fix_data method needs to be implemented."
)
[docs] def data_fields(self) -> list[tuple[str, AssetDataFieldInfo]]:
"""List of data fields and their availability."""
raise NotImplementedError("The data_fields method needs to be implemented.")
def _get_fields(self, field_info: AssetDataFieldInfo) -> list[str]:
"""Internal method to get fields from `data_fields` with `field_info`."""
return [f for f, fi in self.data_fields() if fi & field_info]
def _filter_data(self, data: T) -> T:
"""Internal method to filter `data` columns and return only those relevant to
pricing."""
assert isinstance(self.data_label, str)
return cast(T, data[self.data_label])
[docs]@mypyc_attr(allow_interpreted_subclasses=True)
@dataclass(kw_only=True)
class OHLCAsset(Asset):
"""Assets whose price history is represented by High, Low, Open, Close and Volume
fields."""
[docs] def data_fields(self) -> list[tuple[str, AssetDataFieldInfo]]:
return [
("High", ADFI_AVAILABLE_AT_CLOSE),
("Low", ADFI_AVAILABLE_AT_CLOSE),
("Open", ADFI_AVAILABLE_AT_CLOSE | ADFI_AVAILABLE_AT_OPEN),
("Close", ADFI_AVAILABLE_AT_CLOSE | ADFI_REQUIRED),
("Volume", ADFI_AVAILABLE_AT_CLOSE),
]
[docs] def intraday_traded_price(
self, asset_day_data: pd.Series, size: Decimal | None = None
) -> Decimal:
if pd.notnull(asset_day_data.Low) and pd.notnull(asset_day_data.High):
p = Decimal((asset_day_data.Low + asset_day_data.High) / 2)
else:
p = Decimal(asset_day_data.Close)
return round(p, self.price_round_dp)
[docs] def end_of_day_price(self, asset_day_data: pd.Series) -> Decimal:
return round(Decimal(asset_day_data.Close), self.price_round_dp)
[docs] def check_and_fix_data(self, data: pd.DataFrame) -> pd.DataFrame:
# TODO: check low <= open, high, close & high >= open, low, close
# TODO: check volume >= 0
# check each asset has required fields
required_fields = self._get_fields(ADFI_REQUIRED)
missing_req_fields = set(required_fields) - set(data.columns)
if len(missing_req_fields):
raise ValueError(
f"data columns index requires fields {required_fields}"
f" and missing {missing_req_fields}"
)
# reindex columns with expected fields + additional fields
expected_fields = self._get_fields(ADFI_AVAILABLE_AT_CLOSE)
other_fields = list(set(data.columns) - set(expected_fields))
return data.reindex(expected_fields + other_fields, axis=1)