from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *
Type dispatch¶
Basic single and dual parameter dispatch
Helpers¶
assert not lenient_issubclass(typing.Collection, list)
assert lenient_issubclass(list, typing.Collection)
assert lenient_issubclass(typing.Collection, object)
assert lenient_issubclass(typing.List, typing.Collection)
assert not lenient_issubclass(typing.Collection, typing.List)
assert not lenient_issubclass(object, typing.Callable)
td = [3, 1, 2, 5]
test_eq(sorted_topologically(td), [1, 2, 3, 5])
test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])
td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])
td = [numbers.Integral, tuple, list, int, dict]
td = sorted_topologically(td, cmp=lenient_issubclass)
assert td.index(int) < td.index(numbers.Integral)
TypeDispatch¶
Type dispatch, or Multiple dispatch, allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a conceptual example of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:
collide_with(x::Asteroid, y::Asteroid) = ...
# deal with asteroid hitting asteroid
collide_with(x::Asteroid, y::Spaceship) = ...
# deal with asteroid hitting spaceship
collide_with(x::Spaceship, y::Asteroid) = ...
# deal with spaceship hitting asteroid
collide_with(x::Spaceship, y::Spaceship) = ...
# deal with spaceship hitting spaceship
Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.
The TypeDispatch
class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs.
To demonstrate how TypeDispatch
works, we define a set of functions that accept a variety of input types, specified with different type annotations:
def f2(x:int, y:float): return x+y #int and float for 2nd arg
def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric
def f_ni2(x:int): return x #integer
def f_bll(x:(bool,list)): return x #bool or list
def f_num(x:numbers.Number): return x #Number (root of numerics)
We can optionally initialize TypeDispatch
with a list of functions we want to search. Printing an instance of TypeDispatch
will display convenient mapping of types -> functions:
t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])
t
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(list,object) -> f_bll
(object,object) -> NoneType
Note that only the first two arguments are used for TypeDispatch
. If your function only contains one argument, the second parameter will be shown as object
. If you pass None
into TypeDispatch
, then this will be displayed as (object, object) -> NoneType
.
TypeDispatch
is a dictionary-like object, which means that you can retrieve a function by the associated type annotation. For example, the statement:
t[float]
Will return f_num
because that is the matching function that has a type annotation that is a super-class of of float
- numbers.Number
:
assert issubclass(float, numbers.Number)
test_eq(t[float], f_num)
The same is true for other types as well:
test_eq(t[np.int32], f_nin)
test_eq(t[bool], f_bll)
test_eq(t[list], f_bll)
test_eq(t[np.int32], f_nin)
If you try to get a type that doesn’t match, TypeDispatch
will return None
:
test_eq(t[str], None)
This method allows you to add an additional function to an existing TypeDispatch
instance :
def f_col(x:typing.Collection): return x
t.add(f_col)
test_eq(t[str], f_col)
t
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(list,object) -> f_bll
(typing.Collection,object) -> f_col
(object,object) -> NoneType
If you accidentally add the same function more than once things will still work as expected:
t.add(f_ni2)
test_eq(t[int], f_ni2)
However, if you add a function that has a type collision that raises an ambiguity, this will automatically resolve to the latest function added:
def f_ni3(z:int): return z # collides with f_ni2 with same type annotations
t.add(f_ni3)
test_eq(t[int], f_ni3)
Using bases
:¶
The argument bases
can optionally accept a single instance of TypeDispatch
or a collection (i.e. a tuple or list) of TypeDispatch
objects. This can provide functionality similar to multiple inheritance.
These are searched for matching functions if no match in your list of functions:
def f_str(x:str): return x+'1'
t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])
t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.
t2
(str,object) -> f_str
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(list,object) -> f_bll
(object,object) -> NoneType
test_eq(t2[int], f_ni2) # searches `t` b/c not found in `t2`
test_eq(t2[np.int32], f_nin) # searches `t` b/c not found in `t2`
test_eq(t2[float], f_num) # searches `t` b/c not found in `t2`
test_eq(t2[bool], f_bll) # searches `t` b/c not found in `t2`
test_eq(t2[str], f_str) # found in `t`!
test_eq(t2('a'), 'a1') # found in `t`!, and uses __call__
o = np.int32(1)
test_eq(t2(o), 2) # found in `t2` and uses __call__
Up To Two Arguments¶
TypeDispatch
supports up to two arguments when searching for the appropriate function. The following functions f1
and f2
both have two parameters:
def f1(x:numbers.Integral, y): return x+1 #Integral is a numeric type
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
t
(int,float) -> f2
(Integral,object) -> f1
You can lookup functions from a TypeDispatch
instance with two parameters like this:
test_eq(t[np.int32], f1)
test_eq(t[int,float], f2)
Keep in mind that anything beyond the first two parameters are ignored, and any collisions will be resolved in favor of the most recent function added. In the below example, f1
is ignored in favor of f2
because the first two parameters have identical type hints:
def f1(a:str, b:int, c:list): return a
def f2(a: str, b:int): return b
t = TypeDispatch([f1,f2])
test_eq(t[str, int], f2)
t
(str,int) -> f2
Matching¶
Type Dispatch
matches types with functions according to whether the supplied class is a subclass or the same class of the type annotation(s) of associated functions.
Let’s consider an example where we try to retrieve the function corresponding to types of [np.int32, float]
.
In this scenario, f2
will not be matched. This is because the first type annotation of f2
, int
, is not a superclass (or the same class) of np.int32
:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
assert not issubclass(np.int32, int)
Instead, f1
is a valid match, as its first argument is annoted with the type numbers.Integeral
, which np.int32
is a subclass of:
assert issubclass(np.int32, numbers.Integral)
test_eq(t[np.int32,float], f1)
In f1
, the 2nd parameter y
is not annotated, which means TypeDispatch
will match anything where the first argument matches int
that is not matched with anything else:
assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral
test_eq(t[int], f1)
test_eq(t[int,int], f1)
If no match is possible, None
is returned:
test_eq(t[float,float], None)
TypeDispatch
is also callable. When you call an instance of TypeDispatch
, it will execute the relevant function:
def f_arr(x:np.ndarray): return x.sum()
def f_int(x:np.int32): return x+1
t = TypeDispatch([f_arr, f_int])
arr = np.array([5,4,3,2,1])
test_eq(t(arr), 15) # dispatches to f_arr
o = np.int32(1)
test_eq(t(o), 2) # dispatches to f_int
assert t.first() is not None
You can also call an instance of of TypeDispatch
when there are two parameters:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)
When no match is found, a TypeDispatch
instance becomes an identity function. This default behavior is leveraged by fasatai for data transformations to provide a sensible default when a matching function cannot be found.
test_eq(t('a'), 'a')
You can optionally pass an object to TypeDispatch.returns
and get the return type annotation back:
def f1(x:int) -> np.ndarray: return np.array(x)
def f2(x:str) -> float: return List
def f3(x:float): return List # f3 has no return type annotation
t = TypeDispatch([f1, f2, f3])
test_eq(t.returns(1), np.ndarray) # dispatched to f1
test_eq(t.returns('Hello'), float) # dispatched to f2
test_eq(t.returns(1.0), None) # dispatched to f3
class _Test: pass
_test = _Test()
test_eq(t.returns(_test), None) # type `_Test` not found, so None returned
Using TypeDispatch With Methods¶
You can use TypeDispatch
when defining methods as well:
def m_nin(self, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(self, x:bool): self.foo='a'
def m_num(self, x:numbers.Number): return x*2
t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t # set class attribute `f` equal to a TypeDispatch instance
a = A()
test_eq(a.f(1), '11') #dispatch to m_nin
test_eq(a.f(1.), 2.) #dispatch to m_num
test_is(a.f.inst, a)
a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a'
test_eq(a.foo, 'a')
As discussed in TypeDispatch.__call__
, when there is not a match, TypeDispatch.__call__
becomes an identity function. In the below example, a tuple does not match any type annotations so a tuple is returned:
test_eq(a.f(()), ())
We extend the previous example by using bases
to add an additional method that supports tuples:
def m_tup(self, x:tuple): return x+(1,)
t2 = TypeDispatch(m_tup, bases=t)
class A2: f = t2
a2 = A2()
test_eq(a2.f(1), '11')
test_eq(a2.f(1.), 2.)
test_is(a2.f.inst, a2)
a2.f(False)
test_eq(a2.foo, 'a')
test_eq(a2.f(()), (1,))
Using TypeDispatch With Class Methods¶
You can use TypeDispatch
when defining class methods too:
def m_nin(cls, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(cls, x:bool): cls.foo='a'
def m_num(cls, x:numbers.Number): return x*2
t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t # set class attribute `f` equal to a TypeDispatch
test_eq(A.f(1), '11') #dispatch to m_nin
test_eq(A.f(1.), 2.) #dispatch to m_num
test_is(A.f.owner, A)
A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'
test_eq(A.foo, 'a')
typedispatch Decorator¶
@typedispatch
def f_td_test(x, y): return f'{x}{y}'
@typedispatch
def f_td_test(x:numbers.Integral, y): return x+1
@typedispatch
def f_td_test(x:int, y:float): return x+y
@typedispatch
def f_td_test(x:int, y:int): return x*y
test_eq(f_td_test(3,2.0), 5)
assert issubclass(int, numbers.Integral)
test_eq(f_td_test(3,2), 6)
test_eq(f_td_test('a','b'), 'ab')
Using typedispatch With other decorators¶
You can use typedispatch
with classmethod
and staticmethod
decorator
class A:
@typedispatch
def f_td_test(self, x:numbers.Integral, y): return x+1
@typedispatch
@classmethod
def f_td_test(cls, x:int, y:float): return x+y
@typedispatch
@staticmethod
def f_td_test(x:int, y:int): return x*y
test_eq(A.f_td_test(3,2), 6)
test_eq(A.f_td_test(3,2.0), 5)
test_eq(A().f_td_test(3,'2.0'), 4)
Casting¶
Now that we can dispatch on types, let’s make it easier to cast objects to a different type.
This works both for plain python classes:…
mk_class('_T1', 'a') # mk_class is a fastai utility that constructs a class.
class _T2(_T1): pass
t = _T1(a=1)
t2 = cast(t, _T2)
assert t2 is t # t2 refers to the same object as t
assert isinstance(t, _T2) # t also changed in-place
assert isinstance(t2, _T2)
test_eq_type(_T2(a=1), t2)
…as well as for arrays and tensors.
class _T1(ndarray): pass
t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))
To customize casting for other types, define a separate cast
function with typedispatch
for your type.
class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
c = retain_type(b, typ=_T)
test_eq_type(c, a)
If old
has a _meta
attribute, its content is passed when casting new
to the type of old
. In the below example, only the attribute a
, but not other_attr
is kept, because other_attr
is not in _meta
:
class _A():
set_meta = default_set_meta
def __init__(self, t): self.t=t
class _B1(_A):
def __init__(self, t, a=1):
super().__init__(t)
self._meta = {'a':a}
self.other_attr = 'Hello' # will not be kept after casting.
x = _B1(1, a=2)
b = _A(1)
c = retain_type(b, old=x)
test_eq(c._meta, {'a': 2})
assert not getattr(c, 'other_attr', None)
class T(tuple): pass
t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))
t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))
test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})