Friday, 1 July 2022

Celery: how to get queue size in a reliable and testable way

I'm losing my mind trying to find a reliable and testable way to get the number of tasks contained in a given Celery queue.

I've already read these two related discussions:

But I have not been able to solve my issue using the methods described in those threads.

I'm using Redis as backend, but I would like to have a backend independent and flexible solution, especially for tests.

This is my current situation: I've defined an EnhancedCelery class which inherits from Celery and adds a couple of methods, specifically get_queue_size() is the one I'm trying to properly implement/test.

The following is the code in my test case:

celery_test_app = EnhancedCelery(__name__)

# this is needed to avoid exception for ping command
# which is automatically triggered by the worker once started
celery_test_app.loader.import_module('celery.contrib.testing.tasks')

# in memory backend
celery_test_app.conf.broker_url = 'memory://'
celery_test_app.conf.result_backend = 'cache+memory://'

# We have to setup queues manually, 
# since it seems that auto queue creation doesn't work in tests :(
celery_test_app.conf.task_create_missing_queues = False
celery_test_app.conf.task_default_queue = 'default'
celery_test_app.conf.task_queues = (
    Queue('default', routing_key='task.#'),
    Queue('queue_1', routing_key='q1'),
    Queue('queue_2', routing_key='q2'),
    Queue('queue_3', routing_key='q3'),
)
celery_test_app.conf.task_default_exchange = 'tasks'
celery_test_app.conf.task_default_exchange_type = 'topic'
celery_test_app.conf.task_default_routing_key = 'task.default'
celery_test_app.conf.task_routes = {
    'sample_task': {
        'queue': 'default',
        'routing_key': 'task.default',
    },
    'sample_task_in_queue_1': {
        'queue': 'queue_1',
        'routing_key': 'q1',
    },
    'sample_task_in_queue_2': {
        'queue': 'queue_2',
        'routing_key': 'q2',
    },
    'sample_task_in_queue_3': {
        'queue': 'queue_3',
        'routing_key': 'q3',
    },
}


@celery_test_app.task()
def sample_task():
    return 'sample_task_result'


@celery_test_app.task(queue='queue_1')
def sample_task_in_queue_1():
    return 'sample_task_in_queue_1_result'


@celery_test_app.task(queue='queue_2')
def sample_task_in_queue_2():
    return 'sample_task_in_queue_2_result'


@celery_test_app.task(queue='queue_3')
def sample_task_in_queue_3():
    return 'sample_task_in_queue_3_result'


class EnhancedCeleryTest(TestCase):
    def test_get_queue_size_returns_expected_value(self):
        def add_task(task):
            task.apply_async()

        with start_worker(celery_test_app):
            for _ in range(7):
                add_task(sample_task_in_queue_1)

            for _ in range(4):
                add_task(sample_task_in_queue_2)

            for _ in range(2):
                add_task(sample_task_in_queue_3)

            self.assertEqual(celery_test_app.get_queue_size('queue_1'), 7)
            self.assertEqual(celery_test_app.get_queue_size('queue_2'), 4)
            self.assertEqual(celery_test_app.get_queue_size('queue_3'), 2)

Here are my attempts to implement get_queue_size():

  1. This always returns zero (jobs == 0):

    def get_queue_size(self, queue_name: str) -> Optional[int]:
        with self.connection_or_acquire() as connection:
            channel = connection.default_channel
    
            try:
                name, jobs, consumers = channel.queue_declare(queue=queue_name, passive=True)
                return jobs
            except (ChannelError, NotFound):
                pass
    
  2. This also always returns zero:

    def get_queue_size(self, queue_name: str) -> Optional[int]:
        inspection = self.control.inspect()
    
        return inspection.active() # zero!
    
        # or:
        return inspection.scheduled() # zero!
    
        # or:
        return inspection.reserved() # zero!
    
  3. This works by returning the expected number for each queue, but only in the test environment, because the channel.queues property does not exist when using the redis backend:

    def get_queue_size(self, queue_name: str) -> Optional[int]:
        with self.connection_or_acquire() as connection:
            channel = connection.default_channel
    
            if hasattr(channel, 'queues'):
                queue = channel.queues.get(queue_name)
    
                if queue is not None:
                    return queue.unfinished_tasks
    


from Celery: how to get queue size in a reliable and testable way

No comments:

Post a Comment