【numpy】特定の範囲だけにseedを設定する
概要
最近、numpyの配列をとある方法でシャッフルする関数を実装していたのですが、
与えたseedの値によってシャッフルの仕方が固定されるようにしたいと考えていました。
np.random.seed
をシャッフルの処理の前に挟んでやれば、必ず同じシャッフルの仕方にはなるのですが、関数の外のseedも変更してしまうことになります。
そこで今回は、関数の外のseedを変更せずに、特定の範囲だけにseedを設定する方法について調べたので、こちらにメモを残しておきます。
実装
実装はシンプルで、seedを設定する前の状態を保持しておき、特定の処理が終わったらseedの状態を元に戻すcontext managerを実装するだけです。
import contextlib import numpy as np @contextlib.contextmanager def set_temporary_seed(seed: int): """Set temporary seed. Args: seed (int): Seed value. """ state = np.random.get_state() np.random.seed(seed) try: yield finally: np.random.set_state(state)
使い方と挙動の確認
def main(): np.random.seed(0) print(np.random.randn(2)) # [1.76405235 0.40015721] print(np.random.randn(2)) # [0.97873798 2.2408932 ] np.random.seed(0) print(np.random.randn(2)) # [1.76405235 0.40015721] ← seed=0の時の初回の乱数生成結果になっている with set_temporary_seed(0): print(np.random.randn(2)) # [1.76405235 0.40015721] ← with文内で、seedの状態が変更され、seed=0の時の初回の乱数生成結果になっている print(np.random.randn(2)) # [0.97873798 2.2408932 ] ← with文から出たので、seed=0の時の2回目の乱数生成結果になっている if __name__ == "__main__": main()