Skip to content

Commit

Permalink
Fix for the GitHub workflow and its requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
Getty committed Jan 9, 2025
1 parent 043cdc6 commit 03be021
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 10 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,32 @@ jobs:
build:

runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install -e .
pip install -e .[dev]
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ license = {text = "MIT"}
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"flake8>=6.0.0",
"black>=22.0.0",
"isort>=5.0.0",
"mypy>=1.0.0",
"mock>=5.0.0",
"requests",
"openai",
"ollama",
"transformers",
Expand Down
14 changes: 10 additions & 4 deletions tackleberry/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ class TBEngine:
def __init__(self):
pass

def model(self, model: str):
def model(self,
model: str,
**kwargs,
):
from ..model import TBModel
return TBModel(self, model)
return TBModel(self, model, **kwargs)

def chat(self, model: str):
def chat(self,
model: str,
**kwargs,
):
from ..model import TBModelChat
return TBModelChat(self, model)
return TBModelChat(self, model, **kwargs)
6 changes: 6 additions & 0 deletions tackleberry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def model(self,
raise Exception(f"Can't find engine for engine class '{engine_class}'")
return engine.model(model)

def chat(self,
model: str,
**kwargs,
):
return self.model(model).chat(**kwargs)

def engine(self,
engine_class: str,
**kwargs,
Expand Down
6 changes: 4 additions & 2 deletions tackleberry/model/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Dict, Optional

from ..engine import TBEngine

class TBModel:

def __init__(self, engine: TBEngine, name: str):
self.engine = engine
self.name = name

def chat(self, **kwargs):
from .chat import TBModelChat
return TBModelChat(self.engine, self, **kwargs)
2 changes: 1 addition & 1 deletion tackleberry/model/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class TBModelChat(TBModel):
def __init__(self,
engine: TBEngine,
model_name_or_model: Union[str|TBModel],
system_prompt: str = None,
context: TBContext = None,
system_prompt: str = None,
**kwargs,
):
self.engine = engine
Expand Down
13 changes: 12 additions & 1 deletion tests/test_tackleberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def test_000_unknown(self):
"""Test not existing Model and Engine"""
with self.assertRaises(ModuleNotFoundError):
engine = TB.engine('xxxxx')
with self.assertRaises(KeyError):
with self.assertRaises(KeyError):
model = TB.model('xxxxx')
with self.assertRaises(KeyError):
modelchat = TB.chat('xxxxx')

def test_010_openai(self):
"""Test OpenAI"""
Expand All @@ -35,6 +37,9 @@ def test_010_openai(self):
self.assertEqual(type(model).__name__, "TBModel")
self.assertIsInstance(model.engine, TBEngine)
self.assertEqual(type(model.engine).__name__, "TBEngineOpenai")
modelchat = TB.chat('gpt-4o')
self.assertIsInstance(modelchat, TBModel)
self.assertEqual(type(modelchat).__name__, "TBModelChat")
models = engine.get_models()
self.assertTrue(len(models) > 20)
else:
Expand All @@ -57,6 +62,9 @@ def test_020_anthropic(self):
self.assertEqual(type(model).__name__, "TBModel")
self.assertIsInstance(model.engine, TBEngine)
self.assertEqual(type(model.engine).__name__, "TBEngineAnthropic")
modelchat = TB.chat('claude-2.1')
self.assertIsInstance(modelchat, TBModel)
self.assertEqual(type(modelchat).__name__, "TBModelChat")
models = engine.get_models()
self.assertTrue(len(models) > 3)
else:
Expand All @@ -79,6 +87,9 @@ def test_030_groq(self):
self.assertEqual(type(model).__name__, "TBModel")
self.assertIsInstance(model.engine, TBEngine)
self.assertEqual(type(model.engine).__name__, "TBEngineGroq")
modelchat = TB.chat('llama3-8b-8192')
self.assertIsInstance(modelchat, TBModel)
self.assertEqual(type(modelchat).__name__, "TBModelChat")
models = engine.get_models()
self.assertTrue(len(models) > 10)
else:
Expand Down

0 comments on commit 03be021

Please sign in to comment.