Motivimi

Kur përcaktojmë një Pytorch Module, ne shpesh përcaktojmë një funksion forward që përcakton logjikën për kalimin përpara. Sidoqoftë, gjatë trajnimit / konkluzionit, ne përdorim model(...) në vend të model.forward(...)` pasi Pytorch mbështjell model.forward duke përdorur model.__call__ dhe shton funksionalitete të tilla si grepa.

Megjithatë, kjo sjellje mbështjellëse është shumë jomiqësore ndaj programuesit pasi IDE nuk është në gjendje të tregojë se model.__call__ pranon të njëjtat argumente si model.forward. Prandaj, ne nuk marrim asnjë sugjerim/përfundim automatik kur shtypim model(… në redaktues. Në vend të kësaj, do të na duhet të kontrollojmë kodin burimor të model.forward ose të shkruajmë model.foward( për të kontrolluar argumentet e pranuara dhe llojin e kthimit, gjë që është shumë e mundimshme.

Qëllimi

Për të kapërcyer kufizime të tilla, do të na duhet të gjejmë një zgjidhje që:

  1. Kontrolluesit e tipit mund të kuptojnë.
  2. Ruan sjelljen e model.__call__ .
  3. Përputhet me nënshkrimin e model.__call__ me model.forward.
  4. Mbështet typing.overload. Sepse programuesit e vëmendshëm si unë do të jepnin sugjerime për llojin për çdo mbingarkesë prej model.forward. (Ka shumë shembuj për mbingarkesën forward. P.sh. flamuri CNN classify për të ndryshuar ndërmjet klasifikimit dhe nxjerrjes së veçorive, flamuri StyleGAN return_latents për të marrë vektorë latente të shkëputur etj.)
  5. Lehtë për t'u përdorur. Sigurisht, ne thjesht mund të hartojmë __call__super().__call__ dhe të sigurojmë një typing.overload për çdo mbingarkesë prej forward. Por fat të mirë ta bësh këtë për çdo Module që e përcaktoni dhe ta ruani atë.

Zgjidhje

Zgjidhja me të cilën dola (pasi humba disa netë gjumë duke testuar inteligjencën e kontrolluesit tim të tipit VSCode Pyright):

Gëzojeni inteligjencën tuaj! ❤️

Shënime

Një version i përgjithshëm i këtij funksioni proxy është:

def proxy(f: C, attr: str) -> C:
    return cast(
        C,
        lambda self, *x, **y: getattr(
            cast(Any, super(self.__class__, self)), attr)(
            *x, **y
        ),
    )

Sigurisht, me Python 3.10, mund të shtoni/zbrisni argumente nga funksioni i mbështjellë duke përdorur ParamSpec. Por kjo është për një ditë tjetër 😃.