Friday, 10 September 2021

Get Type Argument of Arbitrarily High Generic Parent Class at Runtime

Given this:

from typing import Generic, TypeVar

T = TypeVar('T')

class Parent(Generic[T]):
    pass

I can get int from Parent[int] using typing.get_args(Parent[int])[0].

The problem becomes a bit more complicated with the following:

class Child1(Parent[int]):
    pass

class Child2(Child1):
    pass

To support an arbitrarily long inheritance hierarchy, I made the below solution:

import typing
from dataclasses import dataclass

@dataclass(frozen=True)
class Found:
    value: Any

def get_parent_type_parameter(child: type) -> Optional[Found]:
    for base in child.mro():
        # If no base classes of `base` are generic, then `__orig_bases__` is nonexistent causing an `AttributeError`.
        # Instead, we want to skip iteration.
        for generic_base in getattr(base, "__orig_bases__", ()):
            if typing.get_origin(generic_base) is Parent:
                [type_argument] = typing.get_args(generic_base)

                # Return `Found(type_argument)` instead of `type_argument` to differentiate between `Parent[None]` 
                # as a base class and `Parent` not appearing as a base class.
                return Found(type_argument)

    return None

such that get_parent_type_parameter(Child2) returns int. I am only interested in the type argument of one particular base class (Parent), so I've hardcoded that class into get_parent_type_parameter and ignore any other base classes.

But my above solution breaks down with chains like this:

class Child3(Parent[T], Generic[T]):
    pass

where get_parent_type_parameter(Child3[int]) returns T instead of int.

While any answers that tackle Child3 are already great, being able to deal with situations like Child4 would be even better:

from typing import Sequence

class Child4(Parent[Sequence[T]], Generic[T]):
    pass

so get_parent_type_parameter(Child4[int]) returns Sequence[int].

Is there a more robust way of accessing the type argument of a class X at runtime given an annotation A where issubclass(typing.get_origin(A), X) is True?

Why I need this:

Recent Python HTTP frameworks generate the endpoint documentation (and response schema) from the function's annotated return type. For example:

app = ...

@dataclass
class Data:
    hello: str

@app.get("/")
def hello() -> Data:
    return Data(hello="world")

I am trying to expand this to account for status code and other non-body components:

@dataclass
class Error:
    detail: str

class ClientResponse(Generic[T]):
    status_code: ClassVar[int]
    body: T

class OkResponse(ClientResponse[Data]):
    status_code: ClassVar[int] = 200

class BadResponse(ClientResponse[Error]):
    status_code: ClassVar[int] = 400

@app.get("/")
def hello() -> Union[OkResponse, BadResponse]:
    if random.randint(1, 2) == 1:
        return OkResponse(Data(hello="world"))

    return BadResponse(Error(detail="a_custom_error_label"))

To generate the OpenAPI documentation, my framework would evaluate get_parent_type_parameter(E) (with ClientResponse hardcoded as the parent in get_parent_type_parameter) on each E within the Union after inspecting the annotated return type of the function passed to app.get. So E would be OkResponse first resulting in Data. Then it would be ErrorResponse, resulting in Error. My framework then iterates through the __annotations__ of each of the body types and generates the response schema in the documentation for the client.



from Get Type Argument of Arbitrarily High Generic Parent Class at Runtime

No comments:

Post a Comment