Code Monkey home page Code Monkey logo

Comments (4)

shivakharbanda avatar shivakharbanda commented on June 3, 2024

and there is lack of documentation regarding the event handler. it would be a huge help if we had more examples in the documentation

from taskweaver.

shivakharbanda avatar shivakharbanda commented on June 3, 2024

after a ton of debugging



class SessionEventHandlerBase(SessionEventHandler):
    def handle(self, event: TaskWeaverEvent):
        if event.scope == EventScope.session:
            assert isinstance(event.t, SessionEventType)
            session_event_type: SessionEventType = event.t
            self.handle_session(
                session_event_type,
                event.msg,
                event.extra,
            )
        elif event.scope == EventScope.round:
            assert isinstance(event.t, RoundEventType)
            assert event.round_id is not None
            round_event_type: RoundEventType = event.t
            self.handle_round(
                round_event_type,
                event.msg,
                event.extra,
                event.round_id,
            )

        elif event.scope == EventScope.post:
            assert isinstance(event.t, PostEventType)
            assert event.post_id is not None
            assert event.round_id is not None
            post_event_type: PostEventType = event.t
            self.handle_post(
                post_event_type,
                event.msg,
                event.extra,
                event.post_id,
                event.round_id,
            )

    def handle_session(
        self,
        type: SessionEventType,
        msg: str,
        extra: Any,
        **kwargs: Any,
    ):
        pass

    def handle_round(
        self,
        type: RoundEventType,
        msg: str,
        extra: Any,
        round_id: str,
        **kwargs: Any,
    ):
        pass

    def handle_post(
        self,
        type: PostEventType,
        msg: str,
        extra: Any,
        post_id: str,
        round_id: str,
        **kwargs: Any,
    ):
        pass

i came across this

in this part of code

elif event.scope == EventScope.post:
            assert isinstance(event.t, PostEventType)
            assert event.post_id is not None
            assert event.round_id is not None
            post_event_type: PostEventType = event.t
            self.handle_post(
                post_event_type,
                event.msg,
                event.extra,
                event.post_id,
                event.round_id,
            )
            

when i print the

self.handle_post

its giving this

self.handle_post
<bound method CustomSessionEventHandler.handle_post of <metadata.consumers.CustomSessionEventHandler object at 0x7b0cd2b41510>>

which indeed means that its part of my handler
but when i am not able to go in and which means my class functions are not being triggering.
i have made sure to add ton of breakpoints but none is being triggered
i dont know what i am doing wrong.

from taskweaver.

shivakharbanda avatar shivakharbanda commented on June 3, 2024

after more debugging i came accross this



# Private utility function called by _PyErr_WarnUnawaitedCoroutine
def _warn_unawaited_coroutine(coro):
    msg_lines = [
        f"coroutine '{coro.__qualname__}' was never awaited\n"
    ]
    if coro.cr_origin is not None:
        import linecache, traceback
        def extract():
            for filename, lineno, funcname in reversed(coro.cr_origin):
                line = linecache.getline(filename, lineno)
                yield (filename, lineno, funcname, line)
        msg_lines.append("Coroutine created at (most recent call last)\n")
        msg_lines += traceback.format_list(list(extract()))
    msg = "".join(msg_lines).rstrip("\n")
    # Passing source= here means that if the user happens to have tracemalloc
    # enabled and tracking where the coroutine was created, the warning will
    # contain that traceback. This does mean that if they have *both*
    # coroutine origin tracking *and* tracemalloc enabled, they'll get two
    # partially-redundant tracebacks. If we wanted to be clever we could
    # probably detect this case and avoid it, but for now we don't bother.
    warn(msg, category=RuntimeWarning, stacklevel=2, source=coro)
    

and the message is

"coroutine 'CustomSessionEventHandler.handle_post' was never awaited"

from taskweaver.

shivakharbanda avatar shivakharbanda commented on June 3, 2024

okay i solved it.
if anyone is curious i came up with this solution
this is my django channels consumers.py

class ChatAIConsumer(AsyncWebsocketConsumer):
    async def connect(self):
        # Extract session_id and datasource_id from the URL path
        self.session_id = self.scope['url_route']['kwargs']['session_id']
        self.datasource_id = self.scope['url_route']['kwargs']['datasource_id']

        logger.info(f"Attempting to connect: session_id={self.session_id}, datasource_id={self.datasource_id}")
        # Accept the WebSocket connection
        await self.accept()



        # Create a new AI session and store it in the user_sessions dictionary
        self.event_handler = CustomSessionEventHandler(self)
        asyncio.create_task(self.process_message_queue())

        # Asynchronously create an AI session to avoid blocking the WebSocket connection
        ai_client = await asyncio.get_event_loop().run_in_executor(executor, app.get_session)
        
        user_sessions[self.session_id] = UserSession(
            session_id=self.session_id, 
            auth_token=None,  # Token will be set after authentication
            datasource_id=self.datasource_id, 
            ai_client=ai_client
        )
        
        user_sessions[self.session_id].ai_client.update_session_var(variables = {"datasource_id": self.datasource_id})

        logger.info(f"WebSocket connection accepted and AI session created for session_id={self.session_id}")


    async def disconnect(self, close_code):
        # Handle cleanup on disconnect
        session = user_sessions.pop(self.session_id, None)
        if session:
            session.ai_client.stop()  # Ensure session cleanup
            logger.info(f"Session {self.session_id} disconnected and cleaned up")
        
    async def receive(self, text_data):

        # Process incoming messages from the user
        text_data_json = json.loads(text_data)

        logger.info(f"Received message: {text_data_json}")

        # Check for the expected message type
        if text_data_json.get('type') == 'authenticate':
            # Handle authentication
            auth_token = text_data_json.get('token')
            session_id = self.scope['url_route']['kwargs']['session_id']
            datasource_id = self.scope['url_route']['kwargs']['datasource_id']
            if not await self.authenticate_token(auth_token):
                # Close the connection if the token is invalid
                logger.error(f"Authentication failed for token={auth_token}")
                await self.close(code=4001)
                return
            
            # Update the auth token in the user session
            user_sessions[session_id].auth_token = auth_token
            

            
            user_sessions[session_id].ai_client.update_session_var(variables = {"auth_token": auth_token})


            logger.info(f"User authenticated successfully for session_id={session_id} and ds id {datasource_id}")

            await self.send(text_data=json.dumps({"message": "Authenticated successfully"}))
        else:

            
            # Handle other message types, such as AI chat messages
            message = text_data_json.get("message")
            session = user_sessions.get(self.session_id)
            if session and session.ai_client:
                # Use the session's AI client to handle the message and get a response
                await self.handle_ai_response(message, session.ai_client)
                logger.info(f"Message processed and response sent for session_id={self.session_id}")
            else:
                # Session not found or message received before authentication
                await self.send(text_data=json.dumps({"error": "Unauthorized"}))
                logger.warning(f"Unauthorized access attempt or session not found for session_id={self.session_id}")

    async def handle_ai_response(self, message, ai_client):
        await asyncio.get_event_loop().run_in_executor(
            executor, ai_client.send_message, message, self.event_handler
        )
        logger.info(f"Message processed and response sent for session_id={self.session_id}")

    async def process_message_queue(self):
        while True:
            event = await self.event_handler.message_queue.get()
            # Serialize and send the event as JSON
            await self.send(text_data=json.dumps(event))
            self.event_handler.message_queue.task_done()
        

    async def authenticate_token(self, token):
        logger.info(f"Authenticating token: {token}")
        # Implement actual token authentication logic here
        # For now, assuming all tokens are valid
        return True

and this is my event handler

def elem(name: str, cls: str = "", attr: Dict[str, str] = {}, **attr_dic: str):
    all_attr = {**attr, **attr_dic}
    if cls:
        all_attr.update({"class": cls})

    attr_str = ""
    if len(all_attr) > 0:
        attr_str += "".join(f' {k}="{v}"' for k, v in all_attr.items())

    def inner(*children: str):
        children_str = "".join(children)
        return f"<{name}{attr_str}>{children_str}</{name}>"

    return inner

def txt(content: str, br: bool = True):
    content = content.replace("<", "&lt;").replace(">", "&gt;")
    if br:
        content = content.replace("\n", "<br>")
    else:
        content = content.replace("\n", "&#10;")
    return content

div = functools.partial(elem, "div")
span = functools.partial(elem, "span")
blinking_cursor = span("tw-end-cursor")()

class CustomSessionEventHandler(SessionEventHandlerBase):
    def __init__(self, websocket):
        self.websocket = websocket
        self.message_queue = asyncio.Queue()
        self.reset_current_state()

    def reset_current_state(self):
        self.cur_attachment_list: List[Tuple[str, AttachmentType, str, bool]] = []
        self.cur_post_status: str = "Updating"
        self.cur_send_to: RoleName = "Unknown"
        self.cur_message: str = ""
        self.cur_message_is_end: bool = False
        self.cur_message_sent: bool = False

    def handle_session(self, type, msg, extra, **kwargs):
        self.queue_message("session", type, msg, extra)

    def handle_round(self, type, msg, extra, round_id, **kwargs):
        self.current_round_id = round_id
        self.queue_message("round", type, msg, extra)

    def handle_post(self, type, msg, extra, post_id, round_id, **kwargs):
        if type == PostEventType.post_start:
            self.reset_current_state()
        elif type == PostEventType.post_end:
            self.cur_message += msg
            self.queue_message("post", type, self.format_post_body(True), extra)
            self.reset_current_state()
        elif type == PostEventType.post_attachment_update:
            id: str = extra["id"]
            a_type: AttachmentType = extra["type"]
            is_end: bool = extra["is_end"]
            if len(self.cur_attachment_list) == 0 or id != self.cur_attachment_list[-1][0]:
                self.cur_attachment_list.append((id, a_type, msg, is_end))
            else:
                prev_msg = self.cur_attachment_list[-1][2]
                self.cur_attachment_list[-1] = (id, a_type, prev_msg + msg, is_end)
        elif type == PostEventType.post_send_to_update:
            self.cur_send_to = extra["role"]
        elif type == PostEventType.post_message_update:
            self.cur_message += msg
            if extra["is_end"]:
                self.cur_message_is_end = True
        elif type == PostEventType.post_status_update:
            self.cur_post_status = msg

        if not self.cur_message_is_end or not self.cur_message_sent:
            self.queue_message("post", type, self.format_post_body(False), extra)
            if self.cur_message_is_end and not self.cur_message_sent:
                self.cur_message_sent = True
                formatted_message = self.format_message(self.cur_message, True)
                self.queue_message("post", PostEventType.post_message_update, formatted_message, {"is_end": True})

    def format_post_body(self, is_end: bool) -> str:
        content_chunks: List[str] = []

        for attachment in self.cur_attachment_list:
            a_type = attachment[1]

            # skip artifact paths always
            if a_type in [AttachmentType.artifact_paths]:
                continue

            # skip Python in final result
            if is_end and a_type in [AttachmentType.python]:
                continue

            content_chunks.append(self.format_attachment(attachment))

        if self.cur_message != "":
            if self.cur_send_to == "Unknown":
                content_chunks.append("**Message**:")
            else:
                content_chunks.append(f"**Message To {self.cur_send_to}**:")

            if not self.cur_message_sent:
                content_chunks.append(
                    self.format_message(self.cur_message, self.cur_message_is_end),
                )

        if not is_end:
            content_chunks.append(
                div("tw-status")(
                    span("tw-status-updating")(
                        elem("svg", viewBox="22 22 44 44")(elem("circle")()),
                    ),
                    span("tw-status-msg")(txt(self.cur_post_status + "...")),
                ),
            )

        return "\n\n".join(content_chunks)

    
    def format_attachment(
        self,
        attachment: Tuple[str, AttachmentType, str, bool],
    ) -> str:
        id, a_type, msg, is_end = attachment
        header = div("tw-atta-header")(
            div("tw-atta-key")(
                " ".join([item.capitalize() for item in a_type.value.split("_")]),
            ),
            div("tw-atta-id")(id),
        )
        atta_cnt: List[str] = []

        if a_type in [AttachmentType.plan, AttachmentType.init_plan]:
            items: List[str] = []
            lines = msg.split("\n")
            for idx, row in enumerate(lines):
                item = row
                if "." in row and row.split(".")[0].isdigit():
                    item = row.split(".", 1)[1].strip()
                items.append(
                    div("tw-plan-item")(
                        div("tw-plan-idx")(str(idx + 1)),
                        div("tw-plan-cnt")(
                            txt(item),
                            blinking_cursor if not is_end and idx == len(lines) - 1 else "",
                        ),
                    ),
                )
            atta_cnt.append(div("tw-plan")(*items))
        elif a_type in [AttachmentType.execution_result]:
            atta_cnt.append(
                elem("pre", "tw-execution-result")(
                    elem("code")(txt(msg)),
                ),
            )
        elif a_type in [AttachmentType.python, AttachmentType.sample]:
            atta_cnt.append(
                elem("pre", "tw-python", {"data-lang": "python"})(
                    elem("code", "language-python")(txt(msg, br=False)),
                ),
            )
        else:
            atta_cnt.append(txt(msg))
            if not is_end:
                atta_cnt.append(blinking_cursor)

        return div("tw-atta")(
            header,
            div("tw-atta-cnt")(*atta_cnt),
        )
    
    def format_message(self, message: str, is_end: bool) -> str:
        content = txt(message, br=False)
        begin_regex = re.compile(r"^```(\w*)$\n", re.MULTILINE)
        end_regex = re.compile(r"^```$\n?", re.MULTILINE)

        if not is_end:
            end_tag = " " + blinking_cursor
        else:
            end_tag = ""

        while True:
            start_label = begin_regex.search(content)
            if not start_label:
                break
            start_pos = content.index(start_label[0])
            lang_tag = start_label[1]
            content = "".join(
                [
                    content[:start_pos],
                    f'<pre data-lang="{lang_tag}"><code class="language-{lang_tag}">',
                    content[start_pos + len(start_label[0]) :],
                ],
            )

            end_pos = end_regex.search(content)
            if not end_pos:
                content += end_tag + "</code></pre>"
                end_tag = ""
                break
            end_pos_pos = content.index(end_pos[0])
            content = f"{content[:end_pos_pos]}</code></pre>{content[end_pos_pos + len(end_pos[0]):]}"

        content += end_tag
        return content

    def queue_message(self, event_category, event_type, message, details):
        # Convert event_type and other non-serializable objects
        event = {
            "type": "chat_message",
            "event_category": event_category,
            "event_type": self.serialize_event_type(event_type),
            "message": message,
            "details": self.serialize_details(details)
        }
        self.message_queue.put_nowait(event)

    def serialize_event_type(self, event_type):
        # Assuming event_type is an enum or has a similar interface
        if isinstance(event_type, Enum):
            return {"name": event_type.name, "value": event_type.value}
        return {attr: self.serialize_value(getattr(event_type, attr)) for attr in dir(event_type) if not attr.startswith('_')}

    def serialize_details(self, details):
        # Similar to how Chainlit handles attachments and complex structures
        if isinstance(details, dict):
            return {k: self.serialize_value(v) for k, v in details.items()}
        return details

    def serialize_value(self, value):
        if isinstance(value, Enum):
            return value.name  # or value.value based on your needs
        if isinstance(value, dict):
            return {k: self.serialize_value(v) for k, v in value.items()}
        if isinstance(value, list):
            return [self.serialize_value(v) for v in value]
        if hasattr(value, '__dict__'):
            return {k: self.serialize_value(v) for k, v in value.__dict__.items() if not callable(v) and not k.startswith('_')}
        return value  # Fallback for basic types

i reverse engineered the implementation of the chainlit and used the same stuff mostly and changed it according to my use case

from taskweaver.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.