Source code for fedflow.core.task

"""
Task
=====

The basic class of task.

User define task by inherit the ``Task`` class and overwrite ``load`` and ``train`` methods.
"""

__all__ = [
    "Task",
    "TaskStatus"
]


import abc
import enum
import logging
import multiprocessing
import os
import threading
import time
import traceback
import uuid
from typing import Union

from fedflow.core.message import Message, MessageListener


[docs]class TaskStatus(enum.Enum): """ The task status enum. """ UNKNOWN = 0 #: the error status INIT = 1 #: construct a task instance and hasn't started scheduling AVAILABLE = 2 #: start task subprocess LOADING = 3 #: start loading WAITING = 4 #: loaded successfully, and waiting for training TRAINING = 5 #: start training FINISHED = 6 #: training successfully EXITED = 7 #: subprocess exited EXCEPTION = 8 #: caught some exception while running INTERRUPT = 9 #: caught OOM(or cuda OOM) exception while running
[docs]class Task(object): """ the basic class of all user task """ main_logger = logging.getLogger("fedflow.task.main") sub_logger = logging.getLogger("fedflow.task.sub")
[docs] def __init__(self, task_id: Union[int, str] = None, *, estimate_memory: Union[int, str] = None, estimate_cuda_memory: Union[int, str] = None, device=None): """ Construct an instance of task :param task_id: task unique id, default is uuid string. :param estimate_memory: maximum memory expected to be used. :param estimate_cuda_memory: maximum cuda memory expected to be used. :param device: specify device the task used, if it's None, the device will be decided by scheduler. """ super(Task, self).__init__() self.task_id = task_id if task_id is not None else str(uuid.uuid4()) self.estimate_memory = estimate_memory self.estimate_cuda_memory = estimate_cuda_memory self.device = device self.load_numbers = 0 self.train_numbers = 0 self.__workdir = None self.load_time = -1 self.train_time = -1 self.result = {} self.__process = None self.__pipe = None self.__mq = None self.__status = TaskStatus.INIT
@property def workdir(self) -> str: """ The workdir of task. This property only can be used after task process was started. :return: """ if self.__workdir is None: raise ValueError("The workdir field is not available.") return self.__workdir @property def status(self) -> TaskStatus: """ The status of task. :return: """ return self.__status @status.setter def status(self, value: Union[int, str, TaskStatus]) -> None: """ status setter :param value: a int/str/TaskStatus value represents the status :return: """ try: if isinstance(value, int): s = TaskStatus(value) elif isinstance(value, str): s = TaskStatus[value] elif isinstance(value, TaskStatus): s = value else: s = TaskStatus.UNKNOWN except: s = TaskStatus.UNKNOWN self.__status = s # ====================================================================== # ------------------------ main process methods ------------------------ # --- The following methods will only be used in the main process. --- # ======================================================================
[docs] def start(self) -> None: """ Start task process *This method cannot be called by user.* :return: """ self.main_logger.info("{%s} start.", self.task_id) self.__workdir = os.path.join(os.curdir, str(self.task_id)) self.__workdir = os.path.abspath(self.__workdir) pipe = multiprocessing.Pipe() self.__pipe = pipe[0] self.__process = multiprocessing.Process(target=self.run, args=(pipe[1], MessageListener.mq())) self.__process.start()
[docs] def start_load(self) -> None: """ Start loading. *This method cannot be called by user.* :return: """ self.load_numbers += 1 self.main_logger.info("{%s} start load. retry time: %d", self.task_id, self.load_numbers) msg = Message(source="", cmd="LOAD", data={}) self.__pipe.send(msg)
[docs] def start_train(self, device: str) -> None: """ Start training. *This method cannot be called by user.* :param device: the device this task will use. :return: """ self.train_numbers += 1 self.main_logger.info("{%s} start train. retry time: %d", self.task_id, self.train_numbers) msg = Message(source="", cmd="TRAIN", data={ "device": device }) self.__pipe.send(msg)
[docs] def exit(self) -> None: """ Exit task process. *This method cannot be called by user.* :return: """ if self.__pipe is None or self.__pipe.closed: self.main_logger.warning("{%s} Try to exit a closed process.", self.task_id) return msg = Message(source="", cmd="EXIT", data={}) self.__pipe.send(msg) self.__pipe.close() self.main_logger.info("{%s} exit.", self.task_id)
[docs] def is_alive(self) -> bool: """ If the task process is alive. :return: a bool value """ return self.__process is not None and self.__process.is_alive()
# ====================================================================== # ------------------------- subprocess methods ------------------------- # --- The following methods will be only be used in the subprocess. --- # --- It means that the following method and the above method run --- # --- in different process spaces. --- # ======================================================================
[docs] def run(self, pipe, mq) -> None: """ subprocess code entry :param pipe: connection pipe between main process task and subprocess task :param mq: connection queue between main process scheduler and subprocess tasks :return: """ self.sub_logger.info("{%s} run.", self.task_id) self.__pipe = pipe self.__mq = mq self.__workdir = os.path.join(os.curdir, str(self.task_id)) self.__workdir = os.path.abspath(self.__workdir) os.makedirs(self.__workdir, exist_ok=True) os.chdir(self.__workdir) self.__listen()
def __listen(self) -> None: """ listen command from main process :return: """ self.__update_status(TaskStatus.AVAILABLE) while True: msg: Message = self.__pipe.recv() if msg.cmd == "EXIT": self.sub_logger.info("{%s} receive EXIT signal", self.task_id) break elif msg.cmd == "LOAD": self.sub_logger.info("{%s} receive LOAD signal", self.task_id) t = threading.Thread(target=self.__load) t.start() elif msg.cmd == "TRAIN": self.device = msg.data["device"] self.sub_logger.info("{%s} receive TRAIN[%s] signal", self.task_id, self.device) t = threading.Thread(target=self.__train) t.start() self.__pipe.close()
[docs] @abc.abstractmethod def load(self) -> None: """ User must overwrite this method in subclass. When implement subclass, user should put all loading action(such as load datasets) in this method. :return: """ raise NotImplementedError()
[docs] @abc.abstractmethod def train(self, device: str) -> dict: """ User must overwrite this method in subclass. When implement subclass, user should put all computer action(such as train or predict) in this method. :param device: the device this task will use. :return: a dict represent some properties used for reporting. """ raise NotImplementedError()
def __load(self): try: self.__update_status(TaskStatus.LOADING) start_time = time.time() self.load() self.load_time = int(1000 * (time.time() - start_time)) self.__update_status(TaskStatus.WAITING) self.sub_logger.info("{%s} load successful, used %dms", self.task_id, self.load_time) except Exception as e: if type(e) == MemoryError: self.sub_logger.error("{%s} OOM", self.task_id) self.__update_status(TaskStatus.INTERRUPT, { "stage": "LOAD" }) else: self.sub_logger.error("{%s} an error occurred during loading.", self.task_id, exc_info=True, stack_info=True) self.__update_status(TaskStatus.EXCEPTION, { "message": traceback.format_exc(), "stage": "LOAD" }) def __train(self): try: self.__update_status(TaskStatus.TRAINING) start_time = time.time() data = self.train(self.device) self.train_time = int(1000 * (time.time() - start_time)) if type(data) != dict: data = {} self.__send_message("set_result", data) data["load_time"] = self.load_time data["train_time"] = self.train_time self.__update_status(TaskStatus.FINISHED, data) self.sub_logger.info("{%s} train successful, used %dms", self.task_id, self.train_time) except Exception as e: if type(e) == RuntimeError and len(e.args) > 0 and "CUDA out of memory" in e.args[0]: self.sub_logger.error("{%s} cuda OOM", self.task_id) self.__update_status(TaskStatus.INTERRUPT, { "stage": "TRAIN" }) else: self.sub_logger.error("{%s} an error occurred during training.", self.task_id, exc_info=True, stack_info=True) self.__update_status(TaskStatus.EXCEPTION, { "message": traceback.format_exc(), "stage": "TRAIN" }) def __update_status(self, v, data=None): self.__status = v if data is None: data = {} data["status"] = v self.sub_logger.info("{%s} update status to %s", self.task_id, v.name) self.__send_message("update_status", data) def __send_message(self, cmd, data=None): if data is None: data = {} msg = Message(source=self.task_id, cmd=cmd, data=data) self.__mq.put(msg)