Skip to content

Commit b54df75

Browse files
authored
Merge pull request #45 from serengil/feat-task-2512-more-type-hinting-and-docstrings
type hinting
2 parents 0ea6e22 + 5b4a039 commit b54df75

16 files changed

+364
-124
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ __pycache__/*
33
commons/__pycache__/*
44
training/__pycache__/*
55
tuning/__pycache__/*
6+
tests/__pycache__/*
67
build/
78
dist/
89
Pipfile
@@ -18,4 +19,5 @@ chefboost/tuning/__pycache__/*
1819
.DS_Store
1920
chefboost/.DS_Store
2021
tests/.DS_Store
21-
.pytest_cache
22+
.pytest_cache
23+
*.pyc

chefboost/commons/daemon.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import multiprocessing
2+
import multiprocessing.pool
3+
4+
class NoDaemonProcess(multiprocessing.Process):
5+
"""
6+
NoDaemonProcess class for recursive parallel runs
7+
"""
8+
def _get_daemon(self):
9+
# make 'daemon' attribute always return False
10+
return False
11+
12+
def _set_daemon(self, value):
13+
pass
14+
15+
daemon = property(_get_daemon, _set_daemon)
16+
17+
18+
class NoDaemonContext(type(multiprocessing.get_context())):
19+
"""
20+
NoDaemonContext class for recursive parallel runs
21+
"""
22+
# pylint: disable=too-few-public-methods
23+
Process = NoDaemonProcess
24+
25+
26+
class CustomPool(multiprocessing.pool.Pool):
27+
"""
28+
MyPool class for recursive parallel runs
29+
"""
30+
# pylint: disable=too-few-public-methods, abstract-method, super-with-arguments
31+
def __init__(self, *args, **kwargs):
32+
kwargs["context"] = NoDaemonContext()
33+
super(CustomPool, self).__init__(*args, **kwargs)

chefboost/commons/functions.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import pathlib
22
import os
3+
import sys
34
from os import path
5+
from types import ModuleType
46
import multiprocessing
5-
from typing import Optional
7+
from typing import Optional, Union
68
import numpy as np
9+
import pandas as pd
710
from chefboost import Chefboost as cb
811
from chefboost.commons.logger import Logger
912
from chefboost.commons.module import load_module
@@ -13,7 +16,15 @@
1316
logger = Logger(module="chefboost/commons/functions.py")
1417

1518

16-
def bulk_prediction(df, model):
19+
def bulk_prediction(df: pd.DataFrame, model: dict) -> None:
20+
"""
21+
Perform a bulk prediction on given dataframe
22+
Args:
23+
df (pd.DataFrame): input data frame
24+
model (dict): built model
25+
Returns:
26+
None
27+
"""
1728
predictions = []
1829
for _, instance in df.iterrows():
1930
features = instance.values[0:-1]
@@ -23,17 +34,35 @@ def bulk_prediction(df, model):
2334
df["Prediction"] = predictions
2435

2536

26-
def restoreTree(module_name):
37+
def restoreTree(module_name: str) -> ModuleType:
38+
"""
39+
Restores a built tree
40+
"""
2741
return load_module(module_name)
2842

2943

30-
def softmax(w):
44+
def softmax(w: list) -> np.ndarray:
45+
"""
46+
Softmax function
47+
Args:
48+
w (list): probabilities
49+
Returns:
50+
result (numpy.ndarray): softmax of inputs
51+
"""
3152
e = np.exp(np.array(w, dtype=np.float32))
3253
dist = e / np.sum(e)
3354
return dist
3455

3556

36-
def sign(x):
57+
def sign(x: Union[int, float]) -> int:
58+
"""
59+
Sign function
60+
Args:
61+
x (int or float): input
62+
Returns
63+
result (int) 1 for positive inputs, -1 for negative
64+
inputs, 0 for neutral input
65+
"""
3766
if x > 0:
3867
return 1
3968
elif x < 0:
@@ -42,7 +71,14 @@ def sign(x):
4271
return 0
4372

4473

45-
def formatRule(root):
74+
def formatRule(root: int) -> str:
75+
"""
76+
Format a rule in the output file (tree)
77+
Args:
78+
root (int): degree of current rule
79+
Returns:
80+
formatted rule (str)
81+
"""
4682
resp = ""
4783

4884
for _ in range(0, root):
@@ -51,20 +87,37 @@ def formatRule(root):
5187
return resp
5288

5389

54-
def storeRule(file, content):
90+
def storeRule(file: str, content: str) -> None:
91+
"""
92+
Store a custom rule
93+
Args:
94+
file (str): target file
95+
content (str): content to store
96+
Returns:
97+
None
98+
"""
5599
with open(file, "a+", encoding="UTF-8") as f:
56100
f.writelines(content)
57101
f.writelines("\n")
58102

59103

60-
def createFile(file, content):
104+
def createFile(file: str, content: str) -> None:
105+
"""
106+
Create a file with given content
107+
Args:
108+
file (str): target file
109+
content (str): content to store
110+
Returns
111+
None
112+
"""
61113
with open(file, "w", encoding="UTF-8") as f:
62114
f.write(content)
63115

64116

65-
def initializeFolders():
66-
import sys
67-
117+
def initializeFolders() -> None:
118+
"""
119+
Initialize required folders
120+
"""
68121
sys.path.append("..")
69122
pathlib.Path("outputs").mkdir(parents=True, exist_ok=True)
70123
pathlib.Path("outputs/data").mkdir(parents=True, exist_ok=True)
@@ -97,8 +150,14 @@ def initializeFolders():
97150
# ------------------------------------
98151

99152

100-
def initializeParams(config: Optional[dict] = None):
101-
153+
def initializeParams(config: Optional[dict] = None) -> dict:
154+
"""
155+
Arrange a chefboost configuration
156+
Args:
157+
config (dict): initial configuration
158+
Returns:
159+
config (dict): final configuration
160+
"""
102161
if config == None:
103162
config = {}
104163

chefboost/training/Preprocess.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
11
import math
22
import numpy as np
3+
import pandas as pd
34
from chefboost.training import Training
45
from chefboost.commons.logger import Logger
56

67
logger = Logger(module="chefboost/training/Preprocess.py")
78

89

9-
def processContinuousFeatures(algorithm, df, column_name, entropy, config):
10+
def processContinuousFeatures(
11+
algorithm: str, df: pd.DataFrame, column_name: str, entropy: float, config: dict
12+
) -> pd.DataFrame:
13+
"""
14+
Find the best split point for numeric features
15+
Args:
16+
df (pd.DataFrame): (sub) training dataframe
17+
column_name (str): current column to process
18+
entropy (float): calculated entropy
19+
config (dict): training configuration
20+
Returns
21+
df (pd.DataFrame): dataframe with numeric columns updated
22+
to nominal (e.g. instead of continious age >40 or <=40)
23+
"""
1024
# if True:
1125
if df[column_name].nunique() <= 20:
1226
unique_values = sorted(df[column_name].unique())

0 commit comments

Comments
 (0)