yiskw note

機械学習やプログラミングについて気まぐれで書きます

【numpy】特定の範囲だけにseedを設定する


概要

最近、numpyの配列をとある方法でシャッフルする関数を実装していたのですが、
与えたseedの値によってシャッフルの仕方が固定されるようにしたいと考えていました。

np.random.seedをシャッフルの処理の前に挟んでやれば、必ず同じシャッフルの仕方にはなるのですが、関数の外のseedも変更してしまうことになります。
そこで今回は、関数の外のseedを変更せずに、特定の範囲だけにseedを設定する方法について調べたので、こちらにメモを残しておきます。

実装

実装はシンプルで、seedを設定する前の状態を保持しておき、特定の処理が終わったらseedの状態を元に戻すcontext managerを実装するだけです。

import contextlib

import numpy as np


@contextlib.contextmanager
def set_temporary_seed(seed: int):
    """Set temporary seed.

    Args:
        seed (int): Seed value.
    """
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

使い方と挙動の確認

def main():
    np.random.seed(0)
    print(np.random.randn(2))  # [1.76405235 0.40015721]
    print(np.random.randn(2))  # [0.97873798 2.2408932 ]

    np.random.seed(0)
    print(np.random.randn(2))  # [1.76405235 0.40015721] ← seed=0の時の初回の乱数生成結果になっている

    with set_temporary_seed(0):
        print(np.random.randn(2))  # [1.76405235 0.40015721] ← with文内で、seedの状態が変更され、seed=0の時の初回の乱数生成結果になっている

    print(np.random.randn(2))  # [0.97873798 2.2408932 ] ← with文から出たので、seed=0の時の2回目の乱数生成結果になっている


if __name__ == "__main__":
    main()

Reference