/dev/null

脳みそのL1キャッシュ

Pythonでサブモジュールを自動でインポートする

はじめに

CTF用のライブラリを書いてるときに生じた問題について書きます。

問題

Pythonでライブラリを書くとき、そのライブラリの使い心地を考えると思います。例えば、自作ライブラリでutilモジュールを作っていたとしましょう。 utilモジュール以下には複数のサブモジュールがあり、これらのサブモジュールはわかりやすさのためにファイルが別れているとします。こんな感じです。

└── util
    ├── __init__.py
    ├── enc.py    # エンコーディング関連
    └── pack.py   # パッキング関連

すると、ライブラリを使う側は以下のようにユーティリティを複数のファイルからimportする必要があります。

from tools.util.enc import b64enc, b64dec
from tools.util.pack import p64, u64

しかし、使う側からしたら全部ユーティリティ関数なんだし、utilから直接importしたいと思うかもしれません。

from tools.util import b64enc, b64dec, p64, u64

こういうリクエストに対して、どう対処すればいいのかについて書きます。

解決法1: __init__.py内にすべて書く

一番簡単な方法ですね。

# in util/__init__.py

from .enc import *
from .pack import *

しかし、これには以下の問題があります。

  1. サブモジュールが増えるたびに一々__init__.pyを更新する必要がある
  2. サブモジュール内に隠蔽しておきたい関数、クラスも読み出される

2.に関しては以下のようにして解決可能ですが、書くのが面倒くさいです。

from .enc import b64enc, b64dec
from .pack import p64, u64

解決法2: 自動インポート用の関数を用意する

以下のようなファイルを作っておきます。

# in util/module.py

from importlib import import_module
from inspect import isclass, isfunction
from pkgutil import iter_modules


# サブモジュール内のすべての attribute を現在の名前空間に追加
def import_submodules(path: str, name: str, namespace: dict) -> None:
    for modinfo in iter_modules(path):
        module = import_module(f"{name}.{modinfo.name}")

        # module.__all__ があれば、__all__ に含まれているモジュールのみを読み込む
        if hasattr(module, "__all__"):
            for attribute_name in getattr(module, "__all__"):
                attribute = getattr(module, attribute_name)
                namespace[attribute_name] = attribute
        
        # module.__all__ がなければ、サブモジュール内のすべての関数・クラスを読み込む
        else:
            for attribute_name in dir(module):
                attribute = getattr(module, attribute_name)
                if isclass(attribute) or isfunction(attribute):
                    namespace[attribute_name] = attribute

そして、util/__init__.pyに以下を書いておけばOKです。

from ctftools.util.module import import_submodules

import_submodules(__path__, __name__, globals())

デフォルトの動作では、サブモジュール内すべての関数・クラスを読み込みますが、__all__配列がある場合は、これに含まれているもののみを読み込みます。

おわりに

なんとか解決