diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 9dad4ca43..a72246821 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -12,29 +12,31 @@ jobs:
notify-about-push:
runs-on: ubuntu-latest
steps:
- - name: Send msg on push
- uses: appleboy/telegram-action@master
- with:
- to: ${{ secrets.TELEGRAM_TO }}
- token: ${{ secrets.TELEGRAM_TOKEN }}
- format: html
- args: |
+ - name: Send msg on push
+ uses: appleboy/telegram-action@master
+ with:
+ to: ${{ secrets.TELEGRAM_TO }}
+ token: ${{ secrets.TELEGRAM_TOKEN }}
+ format: html
+ args: |
Repository: ${{ github.event.repository.full_name }}
Ref: ${{ github.ref }}
Event: ${{ github.event_name }}
Info: ${{ github.event.pull_request.title }}
${{ github.event.pull_request.html_url }}
- nitta:
+ nitta-haskell-deps:
runs-on: ubuntu-latest
- timeout-minutes: 60
+ timeout-minutes: 45
steps:
- uses: actions/checkout@v3
- name: Cache haskell-stack
uses: actions/cache@v3.2.5
with:
- path: ~/.stack
+ path: |
+ ~/.stack
+ .stack-work
key: ${{ runner.os }}-haskell-stack-${{ hashFiles('**/stack.yaml', '**/package.yaml') }}
- name: Install haskell-stack
@@ -42,13 +44,34 @@ jobs:
with:
enable-stack: true
stack-no-global: true
- stack-version: 'latest'
+ stack-version: "latest"
- name: Build nitta backend dependencies and doctest
run: |
stack build --haddock --test --only-dependencies
stack install doctest
+ nitta:
+ runs-on: ubuntu-latest
+ needs: nitta-haskell-deps
+ timeout-minutes: 45
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Cache haskell-stack
+ uses: actions/cache@v3.2.5
+ with:
+ path: |
+ ~/.stack
+ .stack-work
+ key: ${{ runner.os }}-haskell-stack-${{ hashFiles('**/stack.yaml', '**/package.yaml') }}
+
+ - name: Install haskell-stack
+ uses: haskell/actions/setup@v2.3.3
+ with:
+ enable-stack: true
+ stack-no-global: true
+ stack-version: "latest"
- name: Install Icarus Verilog
run: sudo apt-get install iverilog
@@ -59,9 +82,7 @@ jobs:
run: stack hpc report nitta
- name: Check examples by doctest
- run: |
- stack build
- find src -name '*.hs' -exec grep -l '>>>' {} \; | xargs -t -L 1 -P 4 stack exec doctest
+ run: find src -name '*.hs' -exec grep -l '>>>' {} \; | xargs -t -L 1 -P 4 stack exec doctest
- name: Generate backend API
run: stack exec nitta-api-gen -- -v
@@ -69,7 +90,7 @@ jobs:
- name: Cache node_modules
uses: actions/cache@v3.2.5
with:
- path: '**/node_modules'
+ path: "**/node_modules"
key: ${{ runner.os }}-modules-${{ hashFiles('**/yarn.lock') }}
- name: Build nitta frontend dependencies
@@ -88,7 +109,7 @@ jobs:
- name: Copy test coverage to GH_PAGES_DIR
run: cp -r $(stack path --local-hpc-root)/combined/custom ${{ env.GH_PAGES_DIR }}/hpc
-
+
- name: Copy API doc to GH_PAGES_DIR
run: |
mkdir -p "${{ env.GH_PAGES_DIR }}/rest-api/"
@@ -106,7 +127,6 @@ jobs:
verilog-formatting:
runs-on: ubuntu-latest
- needs: nitta
container: ryukzak/alpine-iverilog
defaults:
run:
@@ -120,43 +140,58 @@ jobs:
haskell-lint:
runs-on: ubuntu-latest
- needs: nitta
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v3
- - name: 'Set up HLint'
- uses: haskell/actions/hlint-setup@v2.3.3
- with:
- version: '3.5'
+ - name: "Set up HLint"
+ uses: haskell/actions/hlint-setup@v2.3.3
+ with:
+ version: "3.5"
- - name: 'Run HLint'
- uses: haskell/actions/hlint-run@v2.3.3
- with:
- path: .
- fail-on: suggestion
+ - name: "Run HLint"
+ uses: haskell/actions/hlint-run@v2.3.3
+ with:
+ path: .
+ fail-on: suggestion
haskell-formatting:
runs-on: ubuntu-latest
- needs: nitta
steps:
- uses: actions/checkout@v3
- - name: Cache haskell-stack and fourmolu
- uses: actions/cache@v3.2.5
- with:
- path: ~/.stack
- key: ${{ runner.os }}-haskell-stack-${{ hashFiles('**/stack.yaml', '**/package.yaml') }}-fourmolu
-
- name: Check formatting
- uses: fourmolu/fourmolu-action@v6
+ uses: fourmolu/fourmolu-action@v8 # fourmolu-0.12.0.0
typescript-formatting:
runs-on: ubuntu-latest
- needs: nitta
steps:
- uses: actions/checkout@v3
+
- name: Check ts and tsx code style by prettier
working-directory: ./web
run: |
yarn add -s prettier
yarn exec -s prettier -- --check src/**/*.{ts,tsx}
+
+ python-formatting:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: "3.x"
+
+ - name: Install dependencies
+ run: |
+ pip install black isort flake8
+
+ - name: Check code style with black
+ run: black --check ml
+
+ - name: Sort imports with isort
+ run: isort --recursive ml
+
+ - name: Check code style with flake8
+ run: flake8 ml
diff --git a/README.md b/README.md
index 087363145..735de3b49 100644
--- a/README.md
+++ b/README.md
@@ -44,12 +44,12 @@ See: [CONTRIBUTING.md](CONTRIBUTING.md)
Install [Stack](https://github.com/commercialhaskell/stack) and required developer tools for Haskell.
``` console
-$ brew install stack
-$ stack install hlint fourmolu
+$ brew install ghcup
+$ ghcup install stack
```
> Make sure that PATH contains $HOME/.local/bin.
-> Make sure that you have up to date version of hlint and fourmolu (as on CI)!
+> Make sure that you have up to date version of hlint and fourmolu (as on CI)!
Install [icarus-verilog](https://github.com/steveicarus/iverilog/) and [gtkwave](https://github.com/gtkwave/gtkwave).
``` console
@@ -73,7 +73,7 @@ $ stack install hlint fourmolu
```
> Make sure that PATH contains $HOME/.local/bin.
-> Make sure that you have up to date version of hlint and fourmolu (as on CI)!
+> Make sure that you have up to date version of hlint and fourmolu (as on CI)!
Install [icarus-verilog](https://github.com/steveicarus/iverilog/) and [gtkwave](https://github.com/gtkwave/gtkwave).
``` console
@@ -257,4 +257,3 @@ $ stack exec nitta -- examples/teacup.lua -v --lsim -t=fx24.32
$ stack exec nitta -- examples/teacup.lua -p=8080
Running NITTA server at http://localhost:8080 ...
```
-
diff --git a/app/APIGen.hs b/app/APIGen.hs
index f9b38ed02..4ff31cb9a 100644
--- a/app/APIGen.hs
+++ b/app/APIGen.hs
@@ -76,8 +76,8 @@ $(deriveTypeScript defaultOptions ''OptimizeAccumMetrics)
$(deriveTypeScript defaultOptions ''ResolveDeadlockMetrics)
$(deriveTypeScript defaultOptions ''ViewPointID)
-$(deriveTypeScript defaultOptions ''TimelinePoint)
$(deriveTypeScript defaultOptions ''Interval)
+$(deriveTypeScript defaultOptions ''TimelinePoint)
$(deriveTypeScript defaultOptions ''TimeConstraint)
$(deriveTypeScript defaultOptions ''TimelineWithViewPoint)
$(deriveTypeScript defaultOptions ''ProcessTimelines)
@@ -109,8 +109,8 @@ $(deriveTypeScript defaultOptions ''TestbenchReport)
-- Microarchitecture
$(deriveTypeScript defaultOptions ''IOSynchronization)
-$(deriveTypeScript defaultOptions ''NetworkDesc)
$(deriveTypeScript defaultOptions ''UnitDesc)
+$(deriveTypeScript defaultOptions ''NetworkDesc)
$(deriveTypeScript defaultOptions ''MicroarchitectureDesc)
main = do
@@ -186,8 +186,9 @@ main = do
(ts ++ "\n" ++ "type NId = string\n")
[ ("type ", "export type ") -- export all types
, ("interface ", "export interface ") -- export all interfaces
- , ("[k: T1]", "[k: string]") -- dirty hack for fixing map types for TestbenchReport
- , ("[k: T2]", "[k: string]") -- dirty hack for fixing map types for TestbenchReport
+ , ("[k in T1]?", "[k: string]") -- dirty hack for fixing map types for TestbenchReport
+ , ("[k in T2]?", "[k: string]") -- dirty hack for fixing map types for TestbenchReport
+ , ("[k in number]?: number", "[k: number]: number") -- dirty hack for fixing map types for TreeInfo
]
infoM "NITTA.APIGen" $ "Generate typescript interface " <> output_path <> "/types.ts...OK"
diff --git a/app/Main.hs b/app/Main.hs
index 80d49d020..3f11fd791 100644
--- a/app/Main.hs
+++ b/app/Main.hs
@@ -23,6 +23,7 @@ import Data.ByteString.Lazy.Char8 qualified as BS
import Data.Default (def)
import Data.Maybe
import Data.Proxy
+
import Data.String.Utils qualified as S
import Data.Text qualified as T
import Data.Text.IO qualified as T
@@ -171,6 +172,8 @@ getNittaArgs = do
exitWith exitCode
catch (cmdArgs nittaArgs) handleError
+fromConf toml s = getFromTomlSection s =<< toml
+
main = do
( Nitta
filename
@@ -197,7 +200,6 @@ main = do
Nothing -> return Nothing
Just path -> Just . getToml <$> T.readFile path
- let fromConf s = getFromTomlSection s =<< toml
let exactFrontendType = identifyFrontendType filename frontend_language
src <- readSourceCode filename
@@ -207,14 +209,15 @@ main = do
-- FIXME: https://nitta.io/nitta-corp/nitta/-/issues/50
-- data for sin_ident
received = [("u#0", map (\i -> read $ show $ sin ((2 :: Double) * 3.14 * 50 * 0.001 * i)) [0 .. toEnum n])]
- ioSync = fromJust $ io_sync <|> fromConf "ioSync" <|> Just Sync
+ ioSync = fromJust $ io_sync <|> fromConf toml "ioSync" <|> Just Sync
confMa = toml >>= Just . mkMicroarchitecture ioSync
+ ma :: BusNetwork T.Text T.Text (Attr (FX m b)) Int
ma
| auto_uarch && isJust confMa =
error $
"auto_uarch flag means that an empty uarch with default prototypes will be used. "
<> "Remove uarch flag or specify prototypes list in config file and remove auto_uarch."
- | auto_uarch = microarchWithProtos ioSync :: BusNetwork T.Text T.Text (Attr (FX m b)) Int
+ | auto_uarch = microarchWithProtos ioSync
| isJust confMa = fromJust confMa
| otherwise = defMicroarch ioSync
@@ -235,28 +238,30 @@ main = do
prj <-
synthesizeTargetSystem
- def
+ (def :: TargetSynthesis T.Text T.Text (Attr (FX m b)) Int)
{ tName = "main"
, tPath = output_path
, tMicroArch = ma
, tDFG = frDataFlow
, tReceivedValues = received
, tTemplates = S.split ":" templates
+ , tSynthesisMethod = stateOfTheArtSynthesisIO ()
, tSimulationCycleN = n
, tSourceCodeType = exactFrontendType
}
- >>= \case
- Left msg -> error msg
- Right p -> return p
+ >>= either error return
when lsim $ logicalSimulation format frPrettyLog prj
)
$ parseFX . fromJust
- $ type_ <|> fromConf "type" <|> Just "fx32.32"
+ $ type_ <|> fromConf toml "type" <|> Just "fx32.32"
parseFX input =
let typePattern = mkRegex "fx([0-9]+).([0-9]+)"
- [m, b] = fromMaybe (error "incorrect Bus type input") $ matchRegex typePattern input
+ (m, b) = case fromMaybe (error "incorrect Bus type input") $ matchRegex typePattern input of
+ [m_, b_] -> (m_, b_)
+ _ -> error "parseFX: impossible"
+
convert = fromJust . someNatVal . read
in (convert m, convert b)
@@ -281,6 +286,7 @@ readSourceCode filename = do
return src
-- | Simulation on intermediate level (data-flow graph)
+functionalSimulation :: (Val x, Var v) => Int -> [(v, [x])] -> [Char] -> FrontendResult v x -> IO ()
functionalSimulation n received format FrontendResult{frDataFlow, frPrettyLog} = do
let cntx = simulateDataFlowGraph n def received frDataFlow
infoM "NITTA" "run functional simulation..."
diff --git a/fourmolu.yaml b/fourmolu.yaml
index 544b7216b..19345d625 100644
--- a/fourmolu.yaml
+++ b/fourmolu.yaml
@@ -6,3 +6,6 @@ diff-friendly-import-export: true # 'false' uses Ormolu-style lists
respectful: true # don't be too opinionated about newlines etc.
haddock-style: multi-line # '--' vs. '{-'
newlines-between-decls: 1 # number of newlines between top-level declarations
+
+# this will be available only since fourmolu-0.12.0.0, see https://github.com/ryukzak/nitta/issues/242
+single-constraint-parens: never # always | never | auto (preserve as is) - whether to style optional parentheses around single constraints
diff --git a/ml/synthesis/README.md b/ml/synthesis/README.md
index 0e1240b75..9109fb544 100644
--- a/ml/synthesis/README.md
+++ b/ml/synthesis/README.md
@@ -147,6 +147,29 @@ docker run \
-it \
nitta-dev
```
+### Script evaluation guide
+
+The `evaluation.py` script supports various command-line arguments that can be used to customize and control the evaluation process. This guide provides an overview of the available arguments and their usage.
+
+- `example_paths` (required): Paths to the example files, separated by spaces.
+- `--evaluator` (optional): Evaluation methods to use. You can specify one or multiple methods, separated by spaces. Allowed values: nitta, ml.
+- `--nitta_args` (optional): Arguments passed to Nitta. Enter the arguments in the format `--nitta_args=""`.
+- `--help`: Prints help information about the available arguments.
+
+examples using
+```bash
+python3 ml/synthesis/src/scripts/evaluation.py examples/fibonacci.lua examples/counter.lua --evaluator nitta ml --nitta_args="--format=csv"
+
+Algorithm: examples/fibonacci.lua
+ duration depth evaluator_calls time
+nitta 5 8 9 0.906887
+ml 5 9 10 1.902110
+
+Algorithm: examples/counter.lua
+ duration depth evaluator_calls time
+nitta 5 8 9 5.687186
+ml 5 8 9 7.122119
+```
#### Linux
diff --git a/ml/synthesis/pyproject.toml b/ml/synthesis/pyproject.toml
new file mode 100644
index 000000000..9793af723
--- /dev/null
+++ b/ml/synthesis/pyproject.toml
@@ -0,0 +1,8 @@
+[tool.flake8]
+max-line-length = 125
+
+[tool.isort]
+profile = "black"
+
+[tool.black]
+line-length = 12
\ No newline at end of file
diff --git a/ml/synthesis/src/components/common/data_loading.py b/ml/synthesis/src/components/common/data_loading.py
index 36f6650aa..bfd9c4652 100644
--- a/ml/synthesis/src/components/common/data_loading.py
+++ b/ml/synthesis/src/components/common/data_loading.py
@@ -3,9 +3,8 @@
from typing import List
import pandas as pd
-from pandas import DataFrame
-
from consts import DATA_DIR
+from pandas import DataFrame
def load_all_existing_training_data(data_dir: Path = DATA_DIR) -> DataFrame:
diff --git a/ml/synthesis/src/components/common/logging.py b/ml/synthesis/src/components/common/logging.py
index 50f567170..99d289083 100644
--- a/ml/synthesis/src/components/common/logging.py
+++ b/ml/synthesis/src/components/common/logging.py
@@ -3,7 +3,7 @@
def get_logger(module_name: str) -> Logger:
- """ Fixes stdlib's logging function case and create a unified logger factory. """
+ """Fixes stdlib's logging function case and create a unified logger factory."""
if module_name == "__main__":
return logging.getLogger() # root logger
return logging.getLogger(module_name)
diff --git a/ml/synthesis/src/components/common/model_loading.py b/ml/synthesis/src/components/common/model_loading.py
index 34bfb62cd..5d910e1b9 100644
--- a/ml/synthesis/src/components/common/model_loading.py
+++ b/ml/synthesis/src/components/common/model_loading.py
@@ -3,7 +3,6 @@
from typing import Tuple
import tensorflow as tf
-
from components.common.logging import get_logger
from components.model_generation.model_metainfo import ModelMetainfo
diff --git a/ml/synthesis/src/components/common/port_management.py b/ml/synthesis/src/components/common/port_management.py
index 490f189c4..6871ed30b 100644
--- a/ml/synthesis/src/components/common/port_management.py
+++ b/ml/synthesis/src/components/common/port_management.py
@@ -12,7 +12,7 @@
def is_port_in_use(port):
logger.debug(f"Finding out if port {port} is in use...")
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- result = s.connect_ex(('localhost', port)) == 0
+ result = s.connect_ex(("localhost", port)) == 0
logger.debug(f"Port {port} {'is in use' if result else 'is free'}")
return result
diff --git a/ml/synthesis/src/components/data_crawling/example_running.py b/ml/synthesis/src/components/data_crawling/example_running.py
index f22a11747..488cbdd27 100644
--- a/ml/synthesis/src/components/data_crawling/example_running.py
+++ b/ml/synthesis/src/components/data_crawling/example_running.py
@@ -1,23 +1,20 @@
import asyncio
import os
import pickle
-import random
import signal
import sys
from asyncio import sleep
from contextlib import asynccontextmanager
from pathlib import Path
-from typing import AsyncGenerator
-from typing import Tuple, List
+from typing import AsyncGenerator, List, Tuple
import pandas as pd
-from joblib import Parallel, delayed
-
from components.common.logging import get_logger
-from components.common.port_management import is_port_in_use, find_random_free_port
-from components.data_crawling.tree_processing import assemble_tree_dataframe
+from components.common.port_management import find_random_free_port
from components.data_crawling.tree_retrieving import retrieve_whole_nitta_tree
+from components.data_crawling.tree_processing import assemble_tree_dataframe
from consts import DATA_DIR, ROOT_DIR
+from joblib import Parallel, delayed
logger = get_logger(__name__)
@@ -25,12 +22,13 @@
@asynccontextmanager
-async def run_nitta(example: Path,
- nitta_exe_path: str = "stack exec nitta -- ",
- nitta_args: str = "",
- nitta_env: dict = None,
- port: int = None
- ) -> AsyncGenerator[Tuple[asyncio.subprocess.Process, str], None]:
+async def run_nitta(
+ example: Path,
+ nitta_exe_path: str = "stack exec nitta -- ",
+ nitta_args: str = "",
+ nitta_env: dict = None,
+ port: int = None,
+) -> AsyncGenerator[Tuple[asyncio.subprocess.Process, str], None]:
if port is None:
port = find_random_free_port()
@@ -46,8 +44,13 @@ async def run_nitta(example: Path,
try:
preexec_fn = None if os.name == "nt" else os.setsid # see https://stackoverflow.com/a/4791612
proc = await asyncio.create_subprocess_shell(
- cmd, cwd=str(ROOT_DIR), stdout=sys.stdout, stderr=sys.stderr, shell=True,
- preexec_fn=preexec_fn, env=env
+ cmd,
+ cwd=str(ROOT_DIR),
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ shell=True,
+ preexec_fn=preexec_fn,
+ env=env,
)
logger.info(f"NITTA has been launched, PID {proc.pid}. Waiting for {_NITTA_START_WAIT_DELAY_S} secs.")
@@ -63,12 +66,14 @@ async def run_nitta(example: Path,
await proc.wait()
-async def run_example_and_retrieve_tree_data(example: Path,
- data_dir: Path = DATA_DIR,
- nitta_exe_path: str = "stack exec nitta -- ") -> pd.DataFrame:
+async def run_example_and_retrieve_tree_data(
+ example: Path,
+ data_dir: Path = DATA_DIR,
+ nitta_exe_path: str = "stack exec nitta -- ",
+) -> pd.DataFrame:
example_name = os.path.basename(example)
async with run_nitta(example, nitta_exe_path) as (proc, nitta_baseurl):
- logger.info(f"Retrieving tree.")
+ logger.info("Retrieving tree.")
tree = await retrieve_whole_nitta_tree(nitta_baseurl)
data_dir.mkdir(exist_ok=True)
diff --git a/ml/synthesis/src/components/data_crawling/nitta_node.py b/ml/synthesis/src/components/data_crawling/nitta_node.py
index 6ddcf8805..7dbd4c657 100644
--- a/ml/synthesis/src/components/data_crawling/nitta_node.py
+++ b/ml/synthesis/src/components/data_crawling/nitta_node.py
@@ -1,6 +1,6 @@
from collections import deque
from dataclasses import dataclass, field
-from typing import Optional, Any, List, Tuple, Deque
+from typing import Any, Deque, List, Optional, Tuple
import numpy as np
import pandas as pd
@@ -41,8 +41,8 @@ class NittaNode:
duration: Optional[int]
sid: str
- children: Optional[List['NittaNode']] = field(default=None, repr=False)
- parent: Optional['NittaNode'] = field(default=None, repr=False)
+ children: Optional[List["NittaNode"]] = field(default=None, repr=False)
+ parent: Optional["NittaNode"] = field(default=None, repr=False)
def __hash__(self):
return hash(self.sid)
@@ -58,18 +58,19 @@ def subtree_size(self):
@cached_property
def depth(self) -> int:
- return self.sid.count('-') if self.sid != '-' else 0
+ return self.sid.count("-") if self.sid != "-" else 0
@cached_property
def subtree_leafs_metrics(self) -> Optional[Deque[Tuple[int, int]]]:
- """ :returns: deque(tuple(duration, depth)) or None if node is a failed leaf """
+ """:returns: deque(tuple(duration, depth)) or None if node is a failed leaf"""
if self.is_leaf:
if not self.is_finish:
return None
return deque(((self.duration, self.depth),))
else:
- children_metrics = \
- (child.subtree_leafs_metrics for child in self.children if child.subtree_leafs_metrics is not None)
+ children_metrics = (
+ child.subtree_leafs_metrics for child in self.children if child.subtree_leafs_metrics is not None
+ )
return sum(children_metrics, deque())
@cached_node_method
@@ -77,7 +78,10 @@ def get_subtree_leafs_labels(self, metrics_distrib: np.ndarray) -> deque:
if self.is_leaf:
return deque((self.compute_label(metrics_distrib),))
else:
- return sum((child.get_subtree_leafs_labels(metrics_distrib) for child in self.children), deque())
+ return sum(
+ (child.get_subtree_leafs_labels(metrics_distrib) for child in self.children),
+ deque(),
+ )
@cached_node_method
def compute_label(self, metrics_distrib: np.ndarray) -> float:
diff --git a/ml/synthesis/src/components/data_crawling/tree_processing.py b/ml/synthesis/src/components/data_crawling/tree_processing.py
index cba0a6bb4..734bc72aa 100644
--- a/ml/synthesis/src/components/data_crawling/tree_processing.py
+++ b/ml/synthesis/src/components/data_crawling/tree_processing.py
@@ -1,13 +1,13 @@
from __future__ import annotations
+
from collections import deque
from typing import Deque, Optional
import numpy as np
import pandas as pd
-from cachetools import cached, Cache
-from joblib import Parallel, delayed
-
+from cachetools import Cache, cached
from components.data_crawling.nitta_node import NittaNode
+from joblib import Parallel, delayed
def _extract_params_dict(node: NittaNode) -> dict:
@@ -36,13 +36,19 @@ def _extract_alternative_siblings_dict(node: NittaNode, siblings: tuple[NittaNod
else:
refactorings += 1
- return dict(alt_bindings=bindings,
- alt_refactorings=refactorings,
- alt_dataflows=dataflows)
+ return dict(
+ alt_bindings=bindings,
+ alt_refactorings=refactorings,
+ alt_dataflows=dataflows,
+ )
@cached(cache=Cache(10000))
-def nitta_node_to_df_dict(node: NittaNode, siblings: tuple[NittaNode], example: str = None, ) -> dict:
+def nitta_node_to_df_dict(
+ node: NittaNode,
+ siblings: tuple[NittaNode],
+ example: str = None,
+) -> dict:
return dict(
example=example,
sid=node.sid,
@@ -54,8 +60,14 @@ def nitta_node_to_df_dict(node: NittaNode, siblings: tuple[NittaNode], example:
)
-def assemble_tree_dataframe(example: str, node: NittaNode, metrics_distrib=None, include_label=True,
- levels_left=None, n_workers: int = 1) -> pd.DataFrame:
+def assemble_tree_dataframe(
+ example: str,
+ node: NittaNode,
+ metrics_distrib=None,
+ include_label=True,
+ levels_left=None,
+ n_workers: int = 1,
+) -> pd.DataFrame:
if include_label and metrics_distrib is None:
metrics_distrib = np.array(node.subtree_leafs_metrics)
@@ -72,8 +84,14 @@ def child_process_job(node: NittaNode):
return pd.DataFrame(sum(deques, deque()))
-def _assemble_tree_dataframe_recursion(accum: Deque[dict], example: str, node: NittaNode, metrics_distrib: np.ndarray,
- include_label: bool, levels_left: Optional[int]):
+def _assemble_tree_dataframe_recursion(
+ accum: Deque[dict],
+ example: str,
+ node: NittaNode,
+ metrics_distrib: np.ndarray,
+ include_label: bool,
+ levels_left: Optional[int],
+):
siblings = (node.parent.children or []) if node.parent else []
self_dict = nitta_node_to_df_dict(node, tuple(siblings), example)
@@ -85,7 +103,13 @@ def _assemble_tree_dataframe_recursion(accum: Deque[dict], example: str, node: N
else:
levels_left_for_child = None if levels_left is None else levels_left - 1
for child in node.children:
- _assemble_tree_dataframe_recursion(accum, example, child, metrics_distrib, include_label,
- levels_left_for_child)
+ _assemble_tree_dataframe_recursion(
+ accum,
+ example,
+ child,
+ metrics_distrib,
+ include_label,
+ levels_left_for_child,
+ )
if node.sid != "-":
accum.appendleft(self_dict) # so it's from roots to leaves
diff --git a/ml/synthesis/src/components/data_crawling/tree_retrieving.py b/ml/synthesis/src/components/data_crawling/tree_retrieving.py
index ed4ec9a41..cc7169334 100644
--- a/ml/synthesis/src/components/data_crawling/tree_retrieving.py
+++ b/ml/synthesis/src/components/data_crawling/tree_retrieving.py
@@ -2,7 +2,6 @@
import time
from aiohttp import ClientSession
-
from components.common.logging import get_logger
from components.common.utils import debounce
from components.data_crawling.nitta_node import NittaNode
@@ -11,7 +10,12 @@
logger_debug_debounced = debounce(1)(logger.debug)
-async def retrieve_subforest(node: NittaNode, session: ClientSession, nitta_baseurl: str, levels_left=None):
+async def retrieve_subforest(
+ node: NittaNode,
+ session: ClientSession,
+ nitta_baseurl: str,
+ levels_left=None,
+):
node.children = []
if node.is_leaf or levels_left == -1:
return
@@ -35,7 +39,7 @@ async def retrieve_subforest(node: NittaNode, session: ClientSession, nitta_base
async def retrieve_whole_nitta_tree(nitta_baseurl: str, max_depth=None) -> NittaNode:
start_time = time.perf_counter()
async with ClientSession() as session:
- async with session.get(nitta_baseurl + f"/node/-") as resp:
+ async with session.get(nitta_baseurl + "/node/-") as resp:
root_raw = await resp.json()
root = NittaNode.from_dict(root_raw)
await retrieve_subforest(root, session, nitta_baseurl, max_depth)
diff --git a/ml/synthesis/src/components/data_processing/dataset_creation.py b/ml/synthesis/src/components/data_processing/dataset_creation.py
index 367c5ff03..f84a14a67 100644
--- a/ml/synthesis/src/components/data_processing/dataset_creation.py
+++ b/ml/synthesis/src/components/data_processing/dataset_creation.py
@@ -1,10 +1,9 @@
from typing import Tuple
+from components.common.logging import get_logger
from sklearn.model_selection import train_test_split
from tensorflow.python.data import Dataset
-from components.common.logging import get_logger
-
logger = get_logger(__name__)
TARGET_COLUMNS = ["label"]
diff --git a/ml/synthesis/src/components/data_processing/feature_engineering.py b/ml/synthesis/src/components/data_processing/feature_engineering.py
index fadea6ecd..a00050b83 100644
--- a/ml/synthesis/src/components/data_processing/feature_engineering.py
+++ b/ml/synthesis/src/components/data_processing/feature_engineering.py
@@ -1,9 +1,8 @@
from __future__ import annotations
-import pandas as pd
-
-from pandas import DataFrame
+import pandas as pd
from components.common.logging import get_logger
+from pandas import DataFrame
logger = get_logger(__name__)
@@ -19,25 +18,56 @@ def _map_categorical(df, c):
def preprocess_df(df: DataFrame) -> DataFrame:
df: DataFrame = df.copy()
- for bool_column in ["is_leaf", "pCritical", "pPossibleDeadlock", "pRestrictedTime"]:
+ for bool_column in [
+ "is_leaf",
+ "pCritical",
+ "pPossibleDeadlock",
+ "pRestrictedTime",
+ ]:
if bool_column in df.columns:
df[bool_column] = _map_bool(df[bool_column])
else:
logger.warning(f"Column/parameter {bool_column} not found in provided node info.")
df = _map_categorical(df, df.tag)
- df = df.drop(["pWave", "example", "sid", "old_score", "is_leaf", "pRefactoringType"], axis="columns", errors="ignore")
+ df = df.drop(
+ [
+ "pWave",
+ "example",
+ "sid",
+ "old_score",
+ "is_leaf",
+ "pRefactoringType",
+ ],
+ axis="columns",
+ errors="ignore",
+ )
df = df.fillna(0)
return df
# TODO: move that to metainfo of the model, find a way to make input building model-dependent
# (pickled module? function name?)
-_BASELINE_MODEL_COLUMNS = \
- ["alt_bindings", "alt_refactorings", "alt_dataflows", "pAllowDataFlow", "pAlternative", "pCritical",
- "pNumberOfBindedFunctions", "pOutputNumber", "pPercentOfBindedInputs", "pPossibleDeadlock", "pRestless",
- "pFirstWaveOfTargetUse", "pNotTransferableInputs", "pRestrictedTime", "pWaitTime", "tag_BindDecisionView",
- "tag_BreakLoopView", "tag_DataflowDecisionView"]
+_BASELINE_MODEL_COLUMNS = [
+ "alt_bindings",
+ "alt_refactorings",
+ "alt_dataflows",
+ "pAllowDataFlow",
+ "pAlternative",
+ "pCritical",
+ "pNumberOfBindedFunctions",
+ "pOutputNumber",
+ "pPercentOfBindedInputs",
+ "pPossibleDeadlock",
+ "pRestless",
+ "pFirstWaveOfTargetUse",
+ "pNotTransferableInputs",
+ "pRestrictedTime",
+ "pWaitTime",
+ "tag_BindDecisionView",
+ "tag_BreakLoopView",
+ "tag_DataflowDecisionView",
+]
def df_to_model_columns(df: DataFrame, model_columns: list[str] = None) -> DataFrame:
diff --git a/ml/synthesis/src/components/model_generation/models.py b/ml/synthesis/src/components/model_generation/models.py
index 4e1c016e2..cc88b8ddc 100644
--- a/ml/synthesis/src/components/model_generation/models.py
+++ b/ml/synthesis/src/components/model_generation/models.py
@@ -3,15 +3,17 @@
def create_baseline_model(input_shape) -> tf.keras.Model:
- model = tf.keras.Sequential([
- layers.InputLayer(input_shape=input_shape),
- layers.Dense(128, activation="relu", kernel_regularizer="l2"),
- layers.Dense(128, activation="relu", kernel_regularizer="l2"),
- layers.Dense(64, activation="relu", kernel_regularizer="l2"),
- layers.Dense(64, activation="relu", kernel_regularizer="l2"),
- layers.Dense(32, activation="relu"),
- layers.Dense(1)
- ])
+ model = tf.keras.Sequential(
+ [
+ layers.InputLayer(input_shape=input_shape),
+ layers.Dense(128, activation="relu", kernel_regularizer="l2"),
+ layers.Dense(128, activation="relu", kernel_regularizer="l2"),
+ layers.Dense(64, activation="relu", kernel_regularizer="l2"),
+ layers.Dense(64, activation="relu", kernel_regularizer="l2"),
+ layers.Dense(32, activation="relu"),
+ layers.Dense(1),
+ ]
+ )
model.compile(
optimizer=tf.keras.optimizers.Adam(lr=3e-4),
diff --git a/ml/synthesis/src/components/model_generation/training.py b/ml/synthesis/src/components/model_generation/training.py
index dd6c3c26f..ba04bf8de 100644
--- a/ml/synthesis/src/components/model_generation/training.py
+++ b/ml/synthesis/src/components/model_generation/training.py
@@ -2,20 +2,23 @@
from time import strftime
from typing import Tuple
-from tensorflow.python.data import Dataset
-from tensorflow.python.keras.models import Model
-
from components.common.logging import get_logger
from components.model_generation.model_metainfo import ModelMetainfo
from components.model_generation.models import create_baseline_model
from consts import MODELS_DIR
+from tensorflow.python.data import Dataset
+from tensorflow.python.keras.models import Model
logger = get_logger(__name__)
-def train_and_save_baseline_model(train_ds: Dataset, val_ds: Dataset, fitting_kwargs: dict = None,
- output_model_name: str = None, models_dir: Path = MODELS_DIR) \
- -> Tuple[Model, ModelMetainfo]:
+def train_and_save_baseline_model(
+ train_ds: Dataset,
+ val_ds: Dataset,
+ fitting_kwargs: dict = None,
+ output_model_name: str = None,
+ models_dir: Path = MODELS_DIR,
+) -> Tuple[Model, ModelMetainfo]:
models_dir.mkdir(exist_ok=True)
sample = next(iter(val_ds))[0][0]
@@ -31,7 +34,10 @@ def train_and_save_baseline_model(train_ds: Dataset, val_ds: Dataset, fitting_kw
results = model.fit(x=train_ds, validation_data=val_ds, **effective_fitting_kwargs)
# TODO: proper model evaluation on an independent dataset
- metainfo = ModelMetainfo(train_mae=results.history["mae"][-1], validation_mae=results.history["val_mae"][-1])
+ metainfo = ModelMetainfo(
+ train_mae=results.history["mae"][-1],
+ validation_mae=results.history["val_mae"][-1],
+ )
if not output_model_name:
output_model_name = f"model-{strftime('%Y%m%d-%H%M%S')}"
diff --git a/ml/synthesis/src/components/utils/string.py b/ml/synthesis/src/components/utils/string.py
index 5a1739305..651c58260 100644
--- a/ml/synthesis/src/components/utils/string.py
+++ b/ml/synthesis/src/components/utils/string.py
@@ -1,5 +1,5 @@
def snake_to_lower_camel_case(snake_str: str):
- components = [c for c in snake_str.strip().split('_') if c]
+ components = [c for c in snake_str.strip().split("_") if c]
if not components:
return ""
return components[0] + "".join(x[0].title() + x[1:] for x in components[1:])
diff --git a/ml/synthesis/src/mlbackend/__main__.py b/ml/synthesis/src/mlbackend/__main__.py
index a7f49f305..b04f19dc3 100644
--- a/ml/synthesis/src/mlbackend/__main__.py
+++ b/ml/synthesis/src/mlbackend/__main__.py
@@ -1,14 +1,19 @@
import uvicorn
-
-from components.common.logging import get_logger, configure_logging
+from components.common.logging import configure_logging, get_logger
from consts import ML_BACKEND_BASE_URL_FILEPATH
-from mlbackend.app import app
from mlbackend.backend_base_url_file import BackendBaseUrlFile
logger = get_logger(__name__)
configure_logging()
-with BackendBaseUrlFile(filepath=ML_BACKEND_BASE_URL_FILEPATH,
- base_url_fmt="http://127.0.0.1:{port}") as base_url_file:
+with BackendBaseUrlFile(
+ filepath=ML_BACKEND_BASE_URL_FILEPATH,
+ base_url_fmt="http://127.0.0.1:{port}",
+) as base_url_file:
logger.info(f"Starting ML backend server on port {base_url_file.port}")
- uvicorn.run("mlbackend.app:app", host="127.0.0.1", port=base_url_file.port, workers=4)
+ uvicorn.run(
+ "mlbackend.app:app",
+ host="127.0.0.1",
+ port=base_url_file.port,
+ workers=4,
+ )
diff --git a/ml/synthesis/src/mlbackend/app.py b/ml/synthesis/src/mlbackend/app.py
index 39c3422ba..95c8cfc1a 100644
--- a/ml/synthesis/src/mlbackend/app.py
+++ b/ml/synthesis/src/mlbackend/app.py
@@ -1,18 +1,17 @@
from __future__ import annotations
+
from http import HTTPStatus
-import numpy as np
import pandas as pd
+from components.data_crawling.nitta_node import NittaNode
+from components.data_crawling.tree_processing import nitta_node_to_df_dict
+from components.data_processing.feature_engineering import df_to_model_columns, preprocess_df
from fastapi import FastAPI, HTTPException
from fastapi.exception_handlers import http_exception_handler
+from mlbackend.dtos import ModelInfo, PostScoreRequestBody, PostScoreResponseData, Response
+from mlbackend.models_store import ModelNotFoundError, models
from starlette.responses import HTMLResponse
-from components.data_crawling.nitta_node import NittaNode
-from components.data_crawling.tree_processing import nitta_node_to_df_dict
-from components.data_processing.feature_engineering import preprocess_df, df_to_model_columns
-from mlbackend.dtos import Response, ModelInfo, PostScoreRequestBody, PostScoreResponseData
-from mlbackend.models_store import models, ModelNotFoundError
-
app = FastAPI(
title="NITTA ML Backend",
version="0.0.0",
@@ -25,12 +24,18 @@
@app.get("/models/{model_name}")
def get_model_info(model_name: str) -> Response[ModelInfo]:
model, meta = models[model_name]
- return Response(data=ModelInfo(name=model_name, train_mae=meta.train_mae, validation_mae=meta.validation_mae))
+ return Response(
+ data=ModelInfo(
+ name=model_name,
+ train_mae=meta.train_mae,
+ validation_mae=meta.validation_mae,
+ )
+ )
@app.post("/models/{model_name}/score")
def score_with_model(model_name: str, body: PostScoreRequestBody) -> Response[PostScoreResponseData]:
- """ Runs score prediction with model of given name for each input in a given list of inputs. """
+ """Runs score prediction with model of given name for each input in a given list of inputs."""
model, meta = models[model_name]
scores = []
@@ -45,21 +50,29 @@ def score_with_model(model_name: str, body: PostScoreRequestBody) -> Response[Po
siblings.append(node)
if not target_nodes:
- raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="No target node(s) were found")
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST,
+ detail="No target node(s) were found",
+ )
df = pd.DataFrame([nitta_node_to_df_dict(target_node, siblings=tuple(nodes)) for target_node in target_nodes])
df = preprocess_df(df)
df = df_to_model_columns(df)
scores.append(model.predict(df.values).reshape(-1).tolist())
- return Response(data=PostScoreResponseData(
- scores=scores,
- ))
+ return Response(
+ data=PostScoreResponseData(
+ scores=scores,
+ )
+ )
@app.exception_handler(ModelNotFoundError)
async def model_not_found_exception_handler(request, exc):
- return await http_exception_handler(request, HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)))
+ return await http_exception_handler(
+ request,
+ HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)),
+ )
@app.get("/docs", response_class=HTMLResponse, include_in_schema=False)
@@ -69,19 +82,19 @@ def get_rapidoc_docs():
-
-
-
+