platform/helpers/tools/converter/test/fixtures/sample.py
2024-11-11 12:40:51 -07:00

46 lines
1.2 KiB
Python

from typing import List, Optional, Protocol
from abc import ABC, abstractmethod
class LLMProtocol(Protocol):
"""Protocol for language models."""
@abstractmethod
async def generate(self, prompts: List[str], **kwargs) -> List[str]:
"""Generate completions for the prompts."""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Get the model name."""
pass
class BaseChain(ABC):
"""Base class for chains."""
memory: Optional[dict] = None
verbose: bool = False
@abstractmethod
async def run(self, inputs: dict) -> dict:
"""Run the chain on the inputs."""
pass
def set_memory(self, memory: dict) -> None:
"""Set the memory for the chain."""
self.memory = memory
class SimpleChain(BaseChain):
"""A simple implementation of a chain."""
def __init__(self, llm: LLMProtocol):
"""Initialize the chain."""
self.llm = llm
self.history: List[str] = []
async def run(self, inputs: dict) -> dict:
"""Execute the chain logic."""
prompt = inputs.get("prompt", "")
result = await self.llm.generate([prompt])
self.history.append(result[0])
return {"output": result[0]}