← 返回模块
3.2.4.4beta 可读 · 未来付费校验通过内容版本 2026-05-24

可模拟数据提供者与依赖注入

3.2.4 · 合成数据与 API · 编程

某私募的量化基础设施工程师把一个棘手问题摆到桌上:回测代码一份要在 CI 上跑(必须 deterministic、必须秒级、必须无网络),另一份要在研究 notebook 里跑(必须真接口、必须有缓存),两边的调用点不能动。本课把前三节的全部产物——L1 的 simulate_basket、L2 的 make_cohort,L3 的 fetch_yield_curve——统统塞到一个 DataProvider Protocol 背面,再做三个具体实现:SyntheticProvider(deterministic 合成)、LiveProvider(真 HTTP)、CachedProvider(装饰器包装)。然后用 pytest 配合 responses 把 HTTP mock 测试一次性教完——CI 不要真打公开接口、不要被 429 卡、不要因为网络抖动 flake。这是本模块的 capstone:这套 Protocol 就是 Track 4 的回测引擎挂载的缝隙,挂上之后回测代码不需要知道自己跑在合成还是真数据上。

第一步:定义 DataProvider Protocol

typing.Protocol 是 PEP 544 的结构子类型机制:一个类只要实现了 Protocol 声明的所有方法,就自动满足这个 Protocol,完全不需要写 class Foo(MyProtocol) 继承——隐式的鸭子类型,类型检查器静态可查。@runtime_checkableisinstance(p, DataProvider) 在运行时也能查;静态分析以外多一层兜底。

from typing import Protocol, runtime_checkable
from datetime import date
import pandas as pd

@runtime_checkable
class DataProvider(Protocol):
    def get_returns(self, tickers: list[str], start: date, end: date) -> pd.DataFrame: ...
    def get_yield_curve(self, date: date) -> pd.Series: ...

get_returns 返回 (T, N) 收益矩阵,行索引是日期,列是 ticker;get_yield_curve 返回某天的全期限收益率曲线,索引是期限(年),值是年化收益率。两个方法签名固定;以后想接更多数据(波动率曲面、tick 流),按 Protocol 加方法,旧的实现可以渐进迁移。

第二步:SyntheticProvider

SyntheticProvider(seed: int) 包装 L1 的 GBM 工厂与 L2 的 cohort 工厂,所有数据都从 seed 派生:

from dataclasses import dataclass

@dataclass
class SyntheticProvider:
    seed: int
    def get_returns(self, tickers, start, end):
        rng = np.random.default_rng(self.seed)
        N = len(tickers)
        mu = np.full(N, 0.06); sigma = np.full(N, 0.22)
        R = np.full((N, N), 0.30); np.fill_diagonal(R, 1.0)
        n_days = int(np.busday_count(start, end))
        return simulate_basket(n_days=n_days, tickers=tickers, mu=mu, sigma=sigma, R=R, seed=self.seed, start=start.isoformat())
    def get_yield_curve(self, date):
        rng = np.random.default_rng(self.seed)
        tenors = [1, 2, 3, 5, 7, 10, 20, 30]
        base_rate = 0.025  # 国债 10Y 年化大致中枢
        return pd.Series([base_rate + 0.001 * np.log(t) for t in tenors], index=tenors)

np.busday_count(start, end) 算出区间内的工作日数,直接喂给 L1 的 simulate_basket。同一个 seed 进去,必然出来一只逐字节相同的 DataFrame,这是 CI 测试与 reproducibility 演示的根基。

第三步:LiveProvider

LiveProvider(api_key, base_url) 包装 L3 的 fetch_yield_curve 与一个新的 _fetch_returns_one(按单 ticker 抓收益),全套重试 / 节流 / Pydantic / 缓存继续起作用。requests.Sessionfield(default_factory=requests.Session) 注入,这样每个 LiveProvider 实例都有自己的 session 但配置可以独立。

from dataclasses import dataclass, field
import requests

@dataclass
class LiveProvider:
    api_key: str | None
    base_url: str
    _session: requests.Session = field(default_factory=requests.Session)
    def get_returns(self, tickers, start, end):
        frames = [self._fetch_returns_one(t, start, end) for t in tickers]
        return pd.concat(frames, axis=1, keys=tickers)
    def get_yield_curve(self, date):
        df = fetch_yield_curve(date, api_key=self.api_key)  # 来自 L3
        return df.set_index("date")["yield"]
    def _fetch_returns_one(self, ticker, start, end):
        url = f"{self.base_url}/akshare/stock_zh_a_hist"
        params = {"symbol": ticker, "start_date": start.isoformat(), "end_date": end.isoformat()}
        # 五件套:复用 L3 的重试 / 节流 / Pydantic / 缓存
        ...

LiveProvider 的 worked example 指向 AKShare(无 token)与 Tushare 免费版(走 token):LiveProvider(api_key='<TUSHARE_TOKEN>', base_url='https://api.tushare.pro') 是规范的实例化形态。responses 的测试拦截到 https://api.tushare.pro/... 的 GET,完全不打真实网络。

第四步:CachedProvider

装饰器模式:CachedProvider(inner: DataProvider, cache_dir: Path, max_age: timedelta) 包住任何 DataProvider,把命中缓存的调用短路掉。键的构造遵循"哈希 method + args + kwargs":

from datetime import timedelta
from pathlib import Path
import hashlib, time

@dataclass
class CachedProvider:
    inner: DataProvider
    cache_dir: Path
    max_age: timedelta
    def __post_init__(self):
        self.cache_dir.mkdir(parents=True, exist_ok=True)
    def _cache_key(self, method, *args, **kwargs):
        key_str = method + repr(args) + repr(sorted(kwargs.items()))
        return hashlib.sha256(key_str.encode()).hexdigest()[:16]
    def _is_fresh(self, path):
        return path.exists() and (time.time() - path.stat().st_mtime < self.max_age.total_seconds())
    def get_returns(self, tickers, start, end):
        path = self.cache_dir / f"{self._cache_key('get_returns', tickers, start, end)}.parquet"
        if self._is_fresh(path):
            return pd.read_parquet(path)
        df = self.inner.get_returns(tickers, start, end)
        df.to_parquet(path)
        return df
    def get_yield_curve(self, date):
        path = self.cache_dir / f"{self._cache_key('get_yield_curve', date)}.parquet"
        if self._is_fresh(path):
            return pd.read_parquet(path)["yield"]
        s = self.inner.get_yield_curve(date)
        s.to_frame("yield").to_parquet(path)
        return s

哈希 method + repr(args) + repr(sorted(kwargs.items())),用 hashlib.sha256(...).hexdigest()[:16];对 kwargs 的 sorted(...)、对 date / 元组参数的 deterministic repr() 是关键——没有它们,两次逻辑等价的调用可能产出两套缓存键。优先用 typing.Protocol 在缝是 shape-only、继承只是偶然(incidental)的场合;真正的 is-a 层级、需要共享默认实现的场合,用 abc.ABC

第五步:pytest 三连击

SyntheticProvider 的确定性测试:

def test_synthetic_provider_is_deterministic():
    p = SyntheticProvider(seed=42)
    tickers = ['A', 'B', 'C']
    start, end = date(2024, 1, 1), date(2024, 12, 31)
    df1 = p.get_returns(tickers, start, end)
    df2 = p.get_returns(tickers, start, end)
    pd.testing.assert_frame_equal(df1, df2)

LiveProvider 的 429 重试测试,用 responses 注册一段 (429, 429, 200) 序列。responses.add 按顺序匹配同一 URL,第三次给合法 payload;测试断言重试循环最终一共发出 3 次调用:

import responses
@responses.activate
def test_live_provider_retries_on_429():
    responses.add(responses.GET, URL, status=429)
    responses.add(responses.GET, URL, status=429)
    responses.add(responses.GET, URL, status=200, json=PAYLOAD_OK)
    p = LiveProvider(api_key='test', base_url=BASE_URL)
    s = p.get_yield_curve(date(2024, 1, 2))
    assert not s.empty
    assert len(responses.calls) == 3

CachedProvider 的命中短路测试,用 MagicMock(spec=DataProvider) 做 spy:第一次调用进 inner,第二次命中缓存,inner 只被调一次。MagicMock(spec=DataProvider) 限定 mock 只暴露 Protocol 上的方法——测试代码里打错方法名会在测试时报 AttributeError,不会偷偷溜进生产:

from unittest.mock import MagicMock
def test_cached_provider_short_circuits_on_hit(tmp_path):
    inner = MagicMock(spec=DataProvider)
    inner.get_returns.return_value = pd.DataFrame({'A': [0.01, -0.02]}, index=pd.bdate_range('2024-01-02', periods=2))
    p = CachedProvider(inner=inner, cache_dir=tmp_path, max_age=timedelta(hours=1))
    tickers, start, end = ['A'], date(2024, 1, 2), date(2024, 1, 3)
    p.get_returns(tickers, start, end)
    p.get_returns(tickers, start, end)
    assert inner.get_returns.call_count == 1

Formula Explorer

\\text{cache\\_key} = \\text{sha256}(\\text{method} + \\text{repr(args)} + \\text{repr(sorted(kwargs))})

responsesrequests 的 urllib3 层拦截,respxhttpx 的 httpcore 层拦截,两套机制等价。沪深300 ETF (300ETF) 跟踪指数的真实代码 510300.SH 在 LiveProvider 默认 worked example 里出现一次,与 50ETF 一起说明 A 股 ETF 的 CCDC 数据接入路径。三种 Provider 都满足同一个 Protocol,回测代码写成 def run_backtest(provider: DataProvider): ...,在 CI 里注入 SyntheticProvider(seed=42),在 notebook 里注入 CachedProvider(LiveProvider(api_key=..., base_url=...), cache_dir, timedelta(hours=6))——同一份 run_backtest 拿到两套数据源,两套都能跑、两套都不需要改一行代码。

Exercise

实现 `DataProvider` Protocol 与三个 Provider(`SyntheticProvider`、`LiveProvider`、`CachedProvider`),按本课规范。再写三个 `pytest` 测试:(1) `test_synthetic_provider_is_deterministic` 断言两次同 seed 调用产出 `pd.testing.assert_frame_equal` 相等的 `pd.DataFrame`;(2) `test_live_provider_retries_on_429` 在 `@responses.activate` 下注册 `(429, 429, 200)` 三段响应,断言 `len(responses.calls) == 3`;(3) `test_cached_provider_short_circuits_on_hit(tmp_path)` 用 `MagicMock(spec=DataProvider)` 做 inner,同参数调两次,断言 `inner.get_returns.call_count == 1`。三个测试必须全部在没有真实网络的前提下通过。

提示
Hint 1: 类满足 typing.Protocol 不需要 class Foo(MyProtocol) 继承,只要实现对应方法即可(可以选择继承提升可读性,但不是必需的)。
提示
Hint 2: MagicMock(spec=DataProvider) 让 mock 只暴露 Protocol 上的方法;测试代码里写错方法名会在测试时报 AttributeError,而不是偷偷溜进生产。

参考阅读:PEP 544(Protocol);pytest 官方文档;responses 文档 (https://github.com/getsentry/responses);`respx` 文档 (https://lundberg.github.io/respx/);《Effective Python》(Brett Slatkin 中译,第二版)关于 Protocol 与依赖注入的章节;国内常用 pip 镜像(清华 / 阿里)安装 responses / respx 的实务提示。

为什么 Protocol 比抽象基类更合适这层抽象

很多读者会问:为什么不用 abc.ABC@abstractmethod 强制子类实现?答案要分两层。第一层是技术层面:abc.ABC 强制继承链,这意味着 LiveProvider 必须显式写 class LiveProvider(DataProvider):,而我们其实没有任何复用 DataProvider 默认实现的需求——三个 Provider 的方法体没有任何共用代码,继承关系完全是为了"满足类型检查器"而存在的。typing.Protocol 把这一层结构验证从运行时挪到静态检查器,既保留了类型安全又拿掉了不必要的继承绑定。第二层是设计哲学:Protocol 是 Python 对结构子类型(structural subtyping)的官方支持,与 Go 的接口、TypeScript 的接口、Haskell 的 type class 同源,本质上是"鸭子类型加类型注解"。在量化系统这种缝多、第三方接入多、需要快速试新数据源的场景里,Protocol 永远是首选;abc.ABC 适合的是真正的 is-a 层级——例如 class GBMProcess(StochasticProcess): 这种"GBM 是一种随机过程"的关系。

三种 Provider 的 staging 部署模式

国内私募的 staging 部署里,这套 Protocol 的典型组合是双层洋葱:外层 CachedProvider(cache_dir=/data/cache, max_age=timedelta(hours=6)) 把六小时内的重复查询都吃掉;内层根据环境变量切换——在 staging 是 LiveProvider(api_key=os.environ['TUSHARE_TOKEN'], base_url='https://api.tushare.pro'),在 CI 是 SyntheticProvider(seed=42)。这样 staging 与 CI 共用一套回测脚本、共用一套缓存协议、共用一套测试断言;切换 Provider 只是把工厂函数里的 if env == 'ci' 那一行翻一下。沪深300 ETF (300ETF) 与 50ETF 的日度收益走 Tushare 的 pro.fund_daily,与本课 LiveProvider 默认走的 fetch_yield_curve 共用同一份 session 与同一份 token,配额上互相通融。研究员手动写 notebook 时则用第三种组合:CachedProvider(LiveProvider, max_age=timedelta(days=1)),把一天内的重复查询吃掉,既照顾交互的延迟又不浪费 Tushare 的免费配额。

与 Track 4 的具体衔接

Track 4 的回测引擎(BacktestRunnerPortfolioConstructorExecutionSimulatorTearsheetBuilderPerformanceAttributorRiskAttributor)的所有入口都是 provider: DataProvider,内部不知道也不该知道数据来源。在 staging 上做 7 天压力测试时,我们把 SyntheticProvider(seed=k) 跑 100 个种子,产出 100 套等价回测,衡量 Sharpe 的标准误差;在 production 复盘时,我们把 CachedProvider(LiveProvider, max_age=timedelta(days=30)) 接上历史 Tushare token 跑真数据,与合成数据对账。BoCom、ICBC、CMB 的研究所内网与 Wind / iFinD / 优矿企业级接口都是 LiveProvider 的生产同行,base_url 不同、auth header 不同、payload schema 不同,但都满足同一个 DataProvider Protocol——一份代码,五种数据源都能跑。

至此 Subject 3.2(Python for Data & Quant)闭环:你能在 CI 里离线跑全套合成回测,在研究 notebook 里走真接口带缓存跑;调用点写 provider: DataProvider,具体注入由上层决定。这套 Protocol 就是 Track 4 回测引擎挂上来的缝隙,Track 4 的所有内容都默认这一层已经稳了——这是合成与真实数据之间最后那道接缝,合上它,模块就交付了。下一阶段读者可以直接进入 Track 4 的回测引擎章节,Track 4 第一课的入口签名就是 def run_backtest(provider: DataProvider, strategy: Strategy) -> BacktestResult:,本课的所有产物在那里被一次性消化掉。