From db019be00e50a17dde93b8837268ad0c878b0f0f Mon Sep 17 00:00:00 2001 From: Riddhi Shivhare Date: Thu, 26 Feb 2026 23:06:23 +0530 Subject: [PATCH] feat: Full implementation of Cloud Tasks routing for PushQueue --- .../appengine/api/taskqueue/taskqueue.py | 333 +++++++++++++----- 1 file changed, 251 insertions(+), 82 deletions(-) diff --git a/src/google/appengine/api/taskqueue/taskqueue.py b/src/google/appengine/api/taskqueue/taskqueue.py index 20c5c53..d664b5e 100755 --- a/src/google/appengine/api/taskqueue/taskqueue.py +++ b/src/google/appengine/api/taskqueue/taskqueue.py @@ -15,7 +15,6 @@ # limitations under the License. # - """Task Queue API. Enables an application to queue background work for itself. Work is done through @@ -38,6 +37,7 @@ import os import re import time +import base64 from google.appengine.api import apiproxy_stub_map from google.appengine.api import app_identity @@ -51,20 +51,137 @@ from six.moves import urllib import six.moves.urllib.parse - - - - - - - - - - - +try: + from google.cloud import tasks_v2beta2 + from google.api_core import exceptions as api_core_exceptions + from google.protobuf import field_mask_pb2 + import google.auth + google_auth_present = True +except ImportError: + tasks_v2beta2 = None + api_core_exceptions = None + field_mask_pb2 = None + google_auth_present = False + logging.warning("google-auth or google-cloud-tasks not found. Cloud Tasks redirection will not work.") + +_CT_SYNC_CLIENT = None + +def _get_ct_sync_client(): + global _CT_SYNC_CLIENT + if _CT_SYNC_CLIENT is None: + _CT_SYNC_CLIENT = tasks_v2beta2.CloudTasksClient() + return _CT_SYNC_CLIENT + +def _ShouldUseCloudTasks(): + """Helper to check if Cloud Tasks redirection should be used.""" + return os.environ.get('GAE_USE_CLOUDTASKS_PATH') == 'true' and google_auth_present + +# Variable indicating which backend path is currently active +backend_used = "Cloud Tasks" if _ShouldUseCloudTasks() else "Legacy TaskQueue" + +class _SimpleSyncRPC(object): + """A simple RPC-like object to wrap a synchronous result.""" + def __init__(self, result): + self._result = result + + def get_result(self): + return self._result + + def check_success(self): + pass + + def wait(self): + pass + +def _GetTaskForCloudTasks(add_req, project, location): + """Translates an App Engine TaskQueueAddRequest into a Cloud Tasks Task. + + Args: + add_req: The App Engine TaskQueueAddRequest to translate. + project: The Google Cloud project ID. + location: The Google Cloud location. + + Returns: + A tuple containing the parent resource name and the Cloud Tasks Task. + """ + queue_name_str = add_req.queue_name.decode('utf-8') if isinstance(add_req.queue_name, bytes) else add_req.queue_name + parent = 'projects/%s/locations/%s/queues/%s' % (project, location, queue_name_str) + + _METHOD_MAP_INT_TO_STR = {1: 'GET', 2: 'POST', 3: 'HEAD', 4: 'PUT', 5: 'DELETE'} + method_str = _METHOD_MAP_INT_TO_STR.get(add_req.method, 'POST') + + app_engine_http_request = { + 'http_method': method_str, + 'relative_url': add_req.url.decode('utf-8') if isinstance(add_req.url, bytes) else add_req.url, + 'headers': { + (h.key.decode('utf-8') if isinstance(h.key, bytes) else h.key): + (h.value.decode('utf-8') if isinstance(h.value, bytes) else h.value) + for h in add_req.header + } + } + + if add_req.body: + app_engine_http_request['payload'] = add_req.body + + task_payload = {'app_engine_http_request': app_engine_http_request} + task = tasks_v2beta2.Task(**task_payload) + + if add_req.eta_usec and add_req.eta_usec > 0: + task.schedule_time = datetime.datetime.fromtimestamp( + float(add_req.eta_usec) / 1e6, tz=datetime.timezone.utc + ) + + if add_req.task_name: + task_name_str = add_req.task_name.decode('utf-8') if isinstance(add_req.task_name, bytes) else add_req.task_name + task.name = '%s/tasks/%s' % (parent, task_name_str) + + return parent, task + +def _execute_ct_sync_call(method, path, body=None, **kwargs): + """Executes a synchronous Cloud Tasks API call. + + Args: + method: The HTTP method to use. + path: The path to use. + body: The body to use. + **kwargs: Additional keyword arguments to pass to the Cloud Tasks API. + + Returns: + The result of the Cloud Tasks API call. + """ + if not google_auth_present: + raise TransientError("google-cloud-tasks not installed.") + + client = _get_ct_sync_client() + try: + if method == 'POST' and '/tasks' in path and not ':' in path: # Create Task + parent = path.replace('/tasks', '') + # Pass the task inside the request payload dict + return client.create_task(request={'parent': parent, 'task': body.get('task')}) + elif method == 'DELETE' and '/tasks/' in path: # Delete Task + return client.delete_task(request={'name': path}) + elif method == 'POST' and ':purge' in path: # Purge Queue + queue_name = path.replace(':purge', '') + return client.purge_queue(request={'name': queue_name}) + elif method == 'GET' and '/queues/' in path: # Get Queue / Stats + request_dict = {'name': path} + request_dict.update(kwargs) + return client.get_queue(request=request_dict) + else: + raise NotImplementedError("Cloud Tasks method %s for path %s not implemented." % (method, path)) + + except api_core_exceptions.NotFound as e: + raise UnknownQueueError("Queue or Task not found: %s" % e) + except api_core_exceptions.AlreadyExists as e: + raise TaskAlreadyExistsError("Task already exists: %s" % e) + except api_core_exceptions.PermissionDenied as e: + raise PermissionDeniedError("Permission denied for Cloud Tasks: %s" % e) + except api_core_exceptions.GoogleAPIError as e: + raise TransientError("Cloud Tasks API call failed: %s" % e) __all__ = [ + 'backend_used', 'BadTaskStateError', 'BadTransactionState', @@ -375,7 +492,6 @@ def __repr__(self): } - class _UTCTimeZone(datetime.tzinfo): """UTC time zone.""" @@ -445,10 +561,6 @@ def get_string(value): elif isinstance(value, six.binary_type): return value else: - - - - return six.ensure_binary(str(value)) param_list = [] @@ -721,18 +833,15 @@ class Task(object): be inserted into one queue only. """ - __CONSTRUCTOR_KWARGS = frozenset([ 'countdown', 'eta', 'headers', 'method', 'name', 'params', 'retry_options', 'tag', 'target', 'url', '_size_check', 'dispatch_deadline_usec' ]) - __eta_posix = None __target = None - def __init__(self, payload=None, **kwargs): """Initializer. @@ -825,11 +934,9 @@ def __init__(self, payload=None, **kwargs): self.__queue_name = None self.__dispatch_deadline_usec = kwargs.get('dispatch_deadline_usec') - size_check = kwargs.get('_size_check', True) params = kwargs.get('params', {}) - apps_namespace = namespace_manager.google_apps_namespace() if apps_namespace is not None: self.__headers.setdefault('X-AppEngine-Default-Namespace', apps_namespace) @@ -930,10 +1037,6 @@ def __resolve_hostname_and_target(self): InvalidTaskError: If the task is invalid. """ - - - - if context.get('HTTP_HOST', None) is None: logging.warning( 'The HTTP_HOST environment variable was not set, but is required ' @@ -947,8 +1050,6 @@ def __resolve_hostname_and_target(self): elif self.__target is not None: host = self.__host_from_target(self.__target) if host: - - self.__headers['Host'] = host elif 'Host' in self.__headers: self.__target = self.__target_from_host(self.__headers['Host']) @@ -957,8 +1058,6 @@ def __resolve_hostname_and_target(self): self.__headers['Host'] = context.get('HTTP_HOST') self.__target = self.__target_from_host(self.__headers['Host']) else: - - self.__target = _UNKNOWN_APP_VERSION @staticmethod @@ -978,21 +1077,13 @@ def __target_from_host(host): """ default_hostname = app_identity.get_default_version_hostname() if default_hostname is None: - - - return _UNKNOWN_APP_VERSION if host.endswith(default_hostname): - version_name = host[:-(len(default_hostname) + 1)] if version_name: return version_name - - - - return DEFAULT_APP_VERSION @staticmethod @@ -1009,9 +1100,6 @@ def __host_from_target(target): """ default_hostname = app_identity.get_default_version_hostname() if default_hostname is None: - - - return None server_software = os.environ.get('SERVER_SOFTWARE', '') @@ -1019,7 +1107,6 @@ def __host_from_target(target): return default_hostname elif server_software.startswith( 'Dev') and server_software != 'Development/1.0 (testbed)': - target_components = target.rsplit('.', 3) module = target_components[-1] version = len(target_components) > 1 and target_components[-2] or None @@ -1028,8 +1115,6 @@ def __host_from_target(target): return modules.get_hostname(module=module, version=version, instance=instance) except modules.InvalidModuleError as e: - - if not version: return modules.get_hostname(module='default', version=module, instance=instance) @@ -1101,10 +1186,8 @@ def __determine_eta_posix(eta=None, countdown=None, current_time=None): if not isinstance(eta, datetime.datetime): raise InvalidTaskError('ETA must be a datetime.datetime instance') elif eta.tzinfo is None: - return time.mktime(eta.timetuple()) + eta.microsecond*1e-6 else: - return calendar.timegm(eta.utctimetuple()) + eta.microsecond*1e-6 elif countdown is not None: try: @@ -1167,9 +1250,6 @@ def _FromQueryAndOwnResponseTask(cls, queue_name, response_task): kwargs['tag'] = six.ensure_text(response_task.tag) self = cls(**kwargs) - - - self.__eta_posix = response_task.eta_usec * 1e-6 self.__retry_count = response_task.retry_count @@ -1186,7 +1266,6 @@ def dispatch_deadline_usec(self): def eta_posix(self): """Returns a POSIX timestamp of when this task will run or be leased.""" if self.__eta_posix is None and self.__eta is not None: - self.__eta_posix = Task.__determine_eta_posix(self.__eta) return self.__eta_posix @@ -1200,14 +1279,6 @@ def eta(self): @property def _eta_usec(self): """Returns a int microseconds timestamp when this task will run.""" - - - - - - - - return int(round(self.eta_posix * 1e6)) @property @@ -1315,7 +1386,6 @@ def extract_params(self): tasks) or the URL does not contain a valid query (all other requests). """ if self.__method in ('PULL', 'POST'): - query = self.__payload else: query = six.moves.urllib.parse.urlparse(self.__relative_url).query @@ -1490,7 +1560,6 @@ def fetch_async(cls, queue_or_queues, rpc=None): """ wants_list = True - if isinstance(queue_or_queues, six.string_types): queue_or_queues = [queue_or_queues] wants_list = False @@ -1554,7 +1623,6 @@ def fetch(cls, queue_or_queues, deadline=10): _ValidateDeadline(deadline) if not queue_or_queues: - return [] rpc = create_rpc(deadline) @@ -1565,6 +1633,53 @@ def fetch(cls, queue_or_queues, deadline=10): def _FetchMultipleQueues(cls, queues, multiple, rpc=None): """Internal implementation of fetch stats where queues must be a list.""" + if _ShouldUseCloudTasks(): + project = os.environ.get('GOOGLE_CLOUD_PROJECT') + location = os.environ.get('CLOUD_TASKS_LOCATION', 'us-central1') + results = [] + for queue in queues: + path = 'projects/%s/locations/%s/queues/%s' % (project, location, queue.name) + try: + # Request the stats field mask + mask = field_mask_pb2.FieldMask(paths=['stats']) + ct_queue = _execute_ct_sync_call('GET', path, read_mask=mask) + + if ct_queue.stats: + # Convert oldest_estimated_arrival_time (datetime) to microseconds + oldest_eta_usec = -1 + if ct_queue.stats.oldest_estimated_arrival_time: + # timestamp() gives seconds as a float, multiply by 1e6 for usec + oldest_eta_usec = int(ct_queue.stats.oldest_estimated_arrival_time.timestamp() * 1e6) + + stats = QueueStatistics( + queue=queue, + tasks=ct_queue.stats.tasks_count, + oldest_eta_usec=oldest_eta_usec, + executed_last_minute=ct_queue.stats.executed_last_minute_count, + in_flight=ct_queue.stats.concurrent_dispatches_count, + enforced_rate=ct_queue.stats.effective_execution_rate + ) + else: + # Fallback if the queue has no stats populated + stats = QueueStatistics(queue=queue, tasks=0, oldest_eta_usec=-1) + + except UnknownQueueError: + stats = QueueStatistics(queue=queue, tasks=0, oldest_eta_usec=-1) + except Exception as e: + raise TransientError("Failed to fetch stats: %s" % e) + stats._backend_used = 'Cloud Tasks' + results.append(stats) + + res = results if multiple else results[0] + + if rpc is not None: + # Patch the UserRPC object so its get_result returns our CT stats + rpc.get_result = lambda: res + rpc.check_success = lambda: None + rpc.wait = lambda: None + return rpc + return _SimpleSyncRPC(res) + def ResultHook(rpc): """Processes the TaskQueueFetchQueueStatsResponse.""" try: @@ -1578,6 +1693,9 @@ def ResultHook(rpc): queue_stats = [cls._ConstructFromFetchQueueStatsResponse(queue, response) for queue, response in zip(queues, rpc.response.queuestats)] + for stats in queue_stats: + stats._backend_used = 'Legacy TaskQueue' + if multiple: return queue_stats else: @@ -1630,7 +1748,6 @@ def __init__(self, name=_DEFAULT_QUEUE): InvalidQueueNameError: If the queue name is invalid. """ - if not _QUEUE_NAME_RE.match(name): raise InvalidQueueNameError( 'Queue name does not match pattern "%s"; found %s' % @@ -1638,10 +1755,6 @@ def __init__(self, name=_DEFAULT_QUEUE): self.__name = name self.__url = '%s/%s' % (_DEFAULT_QUEUE_PATH, self.__name) - - - - self._app = None def purge(self): @@ -1654,6 +1767,13 @@ def purge(self): Raises: Error-subclass on application errors. """ + if _ShouldUseCloudTasks(): + project = os.environ.get('GOOGLE_CLOUD_PROJECT') + location = os.environ.get('CLOUD_TASKS_LOCATION', 'us-central1') + path = 'projects/%s/locations/%s/queues/%s:purge' % (project, location, self.__name) + _execute_ct_sync_call('POST', path) + return + request = taskqueue_service_pb2.TaskQueuePurgeQueueRequest() response = taskqueue_service_pb2.TaskQueuePurgeQueueResponse() @@ -1789,6 +1909,32 @@ def delete_tasks(self, task): def __DeleteTasks(self, tasks, multiple, rpc=None): """Internal implementation of delete_tasks_async(), tasks must be a list.""" + + if _ShouldUseCloudTasks(): + if len(tasks) > 1: + raise NotImplementedError("Batch delete is not supported in Cloud Tasks path yet.") + + task = tasks[0] + if not task.name: + raise BadTaskStateError('A task name must be specified for a task') + if task.was_deleted: + raise BadTaskStateError('The task %s has already been deleted' % task.name) + + project = os.environ.get('GOOGLE_CLOUD_PROJECT') + location = os.environ.get('CLOUD_TASKS_LOCATION', 'us-central1') + path = 'projects/%s/locations/%s/queues/%s/tasks/%s' % (project, location, self.__name, task.name) + _execute_ct_sync_call('DELETE', path) + + task._Task__deleted = True + task._backend_used = 'Cloud Tasks' + + res = tasks if multiple else tasks[0] + if rpc is not None: + rpc.get_result = lambda: res + rpc.check_success = lambda: None + rpc.wait = lambda: None + return rpc + return _SimpleSyncRPC(res) def ResultHook(rpc): """Processes the TaskQueueDeleteResponse.""" @@ -1808,11 +1954,11 @@ def ResultHook(rpc): exception = None for task, result in zip(tasks, rpc.response.result): + task._backend_used = 'Legacy TaskQueue' + if result == taskqueue_service_pb2.TaskQueueServiceError.OK: - task._Task__deleted = True elif result in IGNORED_STATES: - task._Task__deleted = False elif exception is None: exception = _TranslateError(result) @@ -2118,9 +2264,6 @@ def add_async(self, task, transactional=False, rpc=None): else: multiple = True - - - has_push_task = False has_pull_task = False for task in tasks: @@ -2214,6 +2357,41 @@ def add(self, task, transactional=False): def __AddTasks(self, tasks, transactional, fill_request, multiple, rpc=None): """Internal implementation of adding tasks where tasks must be a list.""" + if _ShouldUseCloudTasks(): + if len(tasks) > 1: + raise NotImplementedError("Batch operations are not supported in Cloud Tasks path yet.") + if transactional: + raise NotImplementedError("Transactional tasks are not supported in Cloud Tasks path currently.") + + task = tasks[0] + if task.was_enqueued: + raise BadTaskStateError('The task has already been enqueued.') + + request = taskqueue_service_pb2.TaskQueueBulkAddRequest() + fill_request(task, request.add_request.add(), transactional) + add_req = request.add_request[0] + + project = os.environ.get('GOOGLE_CLOUD_PROJECT') + location = os.environ.get('CLOUD_TASKS_LOCATION', 'us-central1') + parent, ct_task = _GetTaskForCloudTasks(add_req, project, location) + + ct_result = _execute_ct_sync_call('POST', '%s/tasks' % parent, body={'task': ct_task}) + + if ct_result.name: + task._Task__name = ct_result.name.split('/')[-1] + task._Task__queue_name = self.__name + task._Task__enqueued = True + task._backend_used = 'Cloud Tasks' + + res = tasks if multiple else tasks[0] + + if rpc is not None: + rpc.get_result = lambda: res + rpc.check_success = lambda: None + rpc.wait = lambda: None + return rpc + return _SimpleSyncRPC(res) + def ResultHook(rpc): """Processes the TaskQueueBulkAddResponse.""" try: @@ -2227,6 +2405,8 @@ def ResultHook(rpc): exception = None for task, task_result in zip(tasks, rpc.response.taskresult): + task._backend_used = 'Legacy TaskQueue' + if (task_result.result == taskqueue_service_pb2.TaskQueueServiceError.OK ): if task_result.HasField('chosen_task_name'): @@ -2332,13 +2512,6 @@ def __FillAddPushTasksRequest(self, task, task_request, transactional): if task.on_queue_url: adjusted_url = self.__url + task.url - - - - - - - task_request.method = _METHOD_MAP.get(task.method) task_request.url = six.ensure_binary(adjusted_url) @@ -2394,8 +2567,6 @@ def __FillTaskCommon(self, task, task_request, transactional): if task.tag: task_request.tag = six.ensure_binary(task.tag) - - if transactional: from google.appengine.api import datastore if not datastore._MaybeSetupTransaction(task_request, []): @@ -2494,8 +2665,6 @@ def __ne__(self, o): return self.name != o.name or self._app != o._app - - def add(*args, **kwargs): """Convenience method that creates a task and adds it to a queue.