Sunday, 29 January 2023

Discriminated union in Python

Imagine I have a base class and two derived classes. I also have a factory method, that returns an object of one of the classes. The problem is, mypy or IntelliJ can't figure out which type the object is. They know it can be both, but not which one exactly. Is there any way I can help mypy/IntelliJ to figure this out WITHOUT putting a type hint next to the conn variable name?

import abc
import enum
import typing


class BaseConnection(abc.ABC):
    @abc.abstractmethod
    def sql(self, query: str) -> typing.List[typing.Any]:
        ...


class PostgresConnection(BaseConnection):

    def sql(self, query: str) -> typing.List[typing.Any]:
        return "This is a postgres result".split()

    def only_postgres_things(self):
        pass


class MySQLConnection(BaseConnection):

    def sql(self, query: str) -> typing.List[typing.Any]:
        return "This is a mysql result".split()

    def only_mysql_things(self):
        pass


class ConnectionType(enum.Enum):
    POSTGRES = 1
    MYSQL = 2


def connect(conn_type: ConnectionType) -> typing.Union[PostgresConnection, MySQLConnection]:
    if conn_type is ConnectionType.POSTGRES:
        return PostgresConnection()
    if conn_type is ConnectionType.MYSQL:
        return MySQLConnection()


conn = connect(ConnectionType.POSTGRES)
conn.only_postgres_things()

Look at how IntelliJ handles this: enter image description here

As you can see both methods: only_postgres_things and only_mysql_things are suggested when I'd like IntelliJ/mypy to figure it out out of the type I'm passing to the connect function.



from Discriminated union in Python

No comments:

Post a Comment