diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9dad4ca43..a72246821 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,29 +12,31 @@ jobs: notify-about-push: runs-on: ubuntu-latest steps: - - name: Send msg on push - uses: appleboy/telegram-action@master - with: - to: ${{ secrets.TELEGRAM_TO }} - token: ${{ secrets.TELEGRAM_TOKEN }} - format: html - args: | + - name: Send msg on push + uses: appleboy/telegram-action@master + with: + to: ${{ secrets.TELEGRAM_TO }} + token: ${{ secrets.TELEGRAM_TOKEN }} + format: html + args: | Repository: ${{ github.event.repository.full_name }} Ref: ${{ github.ref }} Event: ${{ github.event_name }} Info: ${{ github.event.pull_request.title }} ${{ github.event.pull_request.html_url }} - nitta: + nitta-haskell-deps: runs-on: ubuntu-latest - timeout-minutes: 60 + timeout-minutes: 45 steps: - uses: actions/checkout@v3 - name: Cache haskell-stack uses: actions/cache@v3.2.5 with: - path: ~/.stack + path: | + ~/.stack + .stack-work key: ${{ runner.os }}-haskell-stack-${{ hashFiles('**/stack.yaml', '**/package.yaml') }} - name: Install haskell-stack @@ -42,13 +44,34 @@ jobs: with: enable-stack: true stack-no-global: true - stack-version: 'latest' + stack-version: "latest" - name: Build nitta backend dependencies and doctest run: | stack build --haddock --test --only-dependencies stack install doctest + nitta: + runs-on: ubuntu-latest + needs: nitta-haskell-deps + timeout-minutes: 45 + steps: + - uses: actions/checkout@v3 + + - name: Cache haskell-stack + uses: actions/cache@v3.2.5 + with: + path: | + ~/.stack + .stack-work + key: ${{ runner.os }}-haskell-stack-${{ hashFiles('**/stack.yaml', '**/package.yaml') }} + + - name: Install haskell-stack + uses: haskell/actions/setup@v2.3.3 + with: + enable-stack: true + stack-no-global: true + stack-version: "latest" - name: Install Icarus Verilog run: sudo apt-get install iverilog @@ -59,9 +82,7 @@ jobs: run: stack hpc report nitta - name: Check examples by doctest - run: | - stack build - find src -name '*.hs' -exec grep -l '>>>' {} \; | xargs -t -L 1 -P 4 stack exec doctest + run: find src -name '*.hs' -exec grep -l '>>>' {} \; | xargs -t -L 1 -P 4 stack exec doctest - name: Generate backend API run: stack exec nitta-api-gen -- -v @@ -69,7 +90,7 @@ jobs: - name: Cache node_modules uses: actions/cache@v3.2.5 with: - path: '**/node_modules' + path: "**/node_modules" key: ${{ runner.os }}-modules-${{ hashFiles('**/yarn.lock') }} - name: Build nitta frontend dependencies @@ -88,7 +109,7 @@ jobs: - name: Copy test coverage to GH_PAGES_DIR run: cp -r $(stack path --local-hpc-root)/combined/custom ${{ env.GH_PAGES_DIR }}/hpc - + - name: Copy API doc to GH_PAGES_DIR run: | mkdir -p "${{ env.GH_PAGES_DIR }}/rest-api/" @@ -106,7 +127,6 @@ jobs: verilog-formatting: runs-on: ubuntu-latest - needs: nitta container: ryukzak/alpine-iverilog defaults: run: @@ -120,43 +140,58 @@ jobs: haskell-lint: runs-on: ubuntu-latest - needs: nitta steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v3 - - name: 'Set up HLint' - uses: haskell/actions/hlint-setup@v2.3.3 - with: - version: '3.5' + - name: "Set up HLint" + uses: haskell/actions/hlint-setup@v2.3.3 + with: + version: "3.5" - - name: 'Run HLint' - uses: haskell/actions/hlint-run@v2.3.3 - with: - path: . - fail-on: suggestion + - name: "Run HLint" + uses: haskell/actions/hlint-run@v2.3.3 + with: + path: . + fail-on: suggestion haskell-formatting: runs-on: ubuntu-latest - needs: nitta steps: - uses: actions/checkout@v3 - - name: Cache haskell-stack and fourmolu - uses: actions/cache@v3.2.5 - with: - path: ~/.stack - key: ${{ runner.os }}-haskell-stack-${{ hashFiles('**/stack.yaml', '**/package.yaml') }}-fourmolu - - name: Check formatting - uses: fourmolu/fourmolu-action@v6 + uses: fourmolu/fourmolu-action@v8 # fourmolu-0.12.0.0 typescript-formatting: runs-on: ubuntu-latest - needs: nitta steps: - uses: actions/checkout@v3 + - name: Check ts and tsx code style by prettier working-directory: ./web run: | yarn add -s prettier yarn exec -s prettier -- --check src/**/*.{ts,tsx} + + python-formatting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.x" + + - name: Install dependencies + run: | + pip install black isort flake8 + + - name: Check code style with black + run: black --check ml + + - name: Sort imports with isort + run: isort --recursive ml + + - name: Check code style with flake8 + run: flake8 ml diff --git a/README.md b/README.md index 087363145..735de3b49 100644 --- a/README.md +++ b/README.md @@ -44,12 +44,12 @@ See: [CONTRIBUTING.md](CONTRIBUTING.md) Install [Stack](https://github.com/commercialhaskell/stack) and required developer tools for Haskell. ``` console -$ brew install stack -$ stack install hlint fourmolu +$ brew install ghcup +$ ghcup install stack ``` > Make sure that PATH contains $HOME/.local/bin. -> Make sure that you have up to date version of hlint and fourmolu (as on CI)! +> Make sure that you have up to date version of hlint and fourmolu (as on CI)! Install [icarus-verilog](https://github.com/steveicarus/iverilog/) and [gtkwave](https://github.com/gtkwave/gtkwave). ``` console @@ -73,7 +73,7 @@ $ stack install hlint fourmolu ``` > Make sure that PATH contains $HOME/.local/bin. -> Make sure that you have up to date version of hlint and fourmolu (as on CI)! +> Make sure that you have up to date version of hlint and fourmolu (as on CI)! Install [icarus-verilog](https://github.com/steveicarus/iverilog/) and [gtkwave](https://github.com/gtkwave/gtkwave). ``` console @@ -257,4 +257,3 @@ $ stack exec nitta -- examples/teacup.lua -v --lsim -t=fx24.32 $ stack exec nitta -- examples/teacup.lua -p=8080 Running NITTA server at http://localhost:8080 ... ``` - diff --git a/app/APIGen.hs b/app/APIGen.hs index f9b38ed02..4ff31cb9a 100644 --- a/app/APIGen.hs +++ b/app/APIGen.hs @@ -76,8 +76,8 @@ $(deriveTypeScript defaultOptions ''OptimizeAccumMetrics) $(deriveTypeScript defaultOptions ''ResolveDeadlockMetrics) $(deriveTypeScript defaultOptions ''ViewPointID) -$(deriveTypeScript defaultOptions ''TimelinePoint) $(deriveTypeScript defaultOptions ''Interval) +$(deriveTypeScript defaultOptions ''TimelinePoint) $(deriveTypeScript defaultOptions ''TimeConstraint) $(deriveTypeScript defaultOptions ''TimelineWithViewPoint) $(deriveTypeScript defaultOptions ''ProcessTimelines) @@ -109,8 +109,8 @@ $(deriveTypeScript defaultOptions ''TestbenchReport) -- Microarchitecture $(deriveTypeScript defaultOptions ''IOSynchronization) -$(deriveTypeScript defaultOptions ''NetworkDesc) $(deriveTypeScript defaultOptions ''UnitDesc) +$(deriveTypeScript defaultOptions ''NetworkDesc) $(deriveTypeScript defaultOptions ''MicroarchitectureDesc) main = do @@ -186,8 +186,9 @@ main = do (ts ++ "\n" ++ "type NId = string\n") [ ("type ", "export type ") -- export all types , ("interface ", "export interface ") -- export all interfaces - , ("[k: T1]", "[k: string]") -- dirty hack for fixing map types for TestbenchReport - , ("[k: T2]", "[k: string]") -- dirty hack for fixing map types for TestbenchReport + , ("[k in T1]?", "[k: string]") -- dirty hack for fixing map types for TestbenchReport + , ("[k in T2]?", "[k: string]") -- dirty hack for fixing map types for TestbenchReport + , ("[k in number]?: number", "[k: number]: number") -- dirty hack for fixing map types for TreeInfo ] infoM "NITTA.APIGen" $ "Generate typescript interface " <> output_path <> "/types.ts...OK" diff --git a/app/Main.hs b/app/Main.hs index 80d49d020..3f11fd791 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -23,6 +23,7 @@ import Data.ByteString.Lazy.Char8 qualified as BS import Data.Default (def) import Data.Maybe import Data.Proxy + import Data.String.Utils qualified as S import Data.Text qualified as T import Data.Text.IO qualified as T @@ -171,6 +172,8 @@ getNittaArgs = do exitWith exitCode catch (cmdArgs nittaArgs) handleError +fromConf toml s = getFromTomlSection s =<< toml + main = do ( Nitta filename @@ -197,7 +200,6 @@ main = do Nothing -> return Nothing Just path -> Just . getToml <$> T.readFile path - let fromConf s = getFromTomlSection s =<< toml let exactFrontendType = identifyFrontendType filename frontend_language src <- readSourceCode filename @@ -207,14 +209,15 @@ main = do -- FIXME: https://nitta.io/nitta-corp/nitta/-/issues/50 -- data for sin_ident received = [("u#0", map (\i -> read $ show $ sin ((2 :: Double) * 3.14 * 50 * 0.001 * i)) [0 .. toEnum n])] - ioSync = fromJust $ io_sync <|> fromConf "ioSync" <|> Just Sync + ioSync = fromJust $ io_sync <|> fromConf toml "ioSync" <|> Just Sync confMa = toml >>= Just . mkMicroarchitecture ioSync + ma :: BusNetwork T.Text T.Text (Attr (FX m b)) Int ma | auto_uarch && isJust confMa = error $ "auto_uarch flag means that an empty uarch with default prototypes will be used. " <> "Remove uarch flag or specify prototypes list in config file and remove auto_uarch." - | auto_uarch = microarchWithProtos ioSync :: BusNetwork T.Text T.Text (Attr (FX m b)) Int + | auto_uarch = microarchWithProtos ioSync | isJust confMa = fromJust confMa | otherwise = defMicroarch ioSync @@ -235,28 +238,30 @@ main = do prj <- synthesizeTargetSystem - def + (def :: TargetSynthesis T.Text T.Text (Attr (FX m b)) Int) { tName = "main" , tPath = output_path , tMicroArch = ma , tDFG = frDataFlow , tReceivedValues = received , tTemplates = S.split ":" templates + , tSynthesisMethod = stateOfTheArtSynthesisIO () , tSimulationCycleN = n , tSourceCodeType = exactFrontendType } - >>= \case - Left msg -> error msg - Right p -> return p + >>= either error return when lsim $ logicalSimulation format frPrettyLog prj ) $ parseFX . fromJust - $ type_ <|> fromConf "type" <|> Just "fx32.32" + $ type_ <|> fromConf toml "type" <|> Just "fx32.32" parseFX input = let typePattern = mkRegex "fx([0-9]+).([0-9]+)" - [m, b] = fromMaybe (error "incorrect Bus type input") $ matchRegex typePattern input + (m, b) = case fromMaybe (error "incorrect Bus type input") $ matchRegex typePattern input of + [m_, b_] -> (m_, b_) + _ -> error "parseFX: impossible" + convert = fromJust . someNatVal . read in (convert m, convert b) @@ -281,6 +286,7 @@ readSourceCode filename = do return src -- | Simulation on intermediate level (data-flow graph) +functionalSimulation :: (Val x, Var v) => Int -> [(v, [x])] -> [Char] -> FrontendResult v x -> IO () functionalSimulation n received format FrontendResult{frDataFlow, frPrettyLog} = do let cntx = simulateDataFlowGraph n def received frDataFlow infoM "NITTA" "run functional simulation..." diff --git a/fourmolu.yaml b/fourmolu.yaml index 544b7216b..19345d625 100644 --- a/fourmolu.yaml +++ b/fourmolu.yaml @@ -6,3 +6,6 @@ diff-friendly-import-export: true # 'false' uses Ormolu-style lists respectful: true # don't be too opinionated about newlines etc. haddock-style: multi-line # '--' vs. '{-' newlines-between-decls: 1 # number of newlines between top-level declarations + +# this will be available only since fourmolu-0.12.0.0, see https://github.com/ryukzak/nitta/issues/242 +single-constraint-parens: never # always | never | auto (preserve as is) - whether to style optional parentheses around single constraints diff --git a/ml/synthesis/README.md b/ml/synthesis/README.md index 0e1240b75..9109fb544 100644 --- a/ml/synthesis/README.md +++ b/ml/synthesis/README.md @@ -147,6 +147,29 @@ docker run \ -it \ nitta-dev ``` +### Script evaluation guide + +The `evaluation.py` script supports various command-line arguments that can be used to customize and control the evaluation process. This guide provides an overview of the available arguments and their usage. + +- `example_paths` (required): Paths to the example files, separated by spaces. +- `--evaluator` (optional): Evaluation methods to use. You can specify one or multiple methods, separated by spaces. Allowed values: nitta, ml. +- `--nitta_args` (optional): Arguments passed to Nitta. Enter the arguments in the format `--nitta_args=""`. +- `--help`: Prints help information about the available arguments. + +examples using +```bash +python3 ml/synthesis/src/scripts/evaluation.py examples/fibonacci.lua examples/counter.lua --evaluator nitta ml --nitta_args="--format=csv" + +Algorithm: examples/fibonacci.lua + duration depth evaluator_calls time +nitta 5 8 9 0.906887 +ml 5 9 10 1.902110 + +Algorithm: examples/counter.lua + duration depth evaluator_calls time +nitta 5 8 9 5.687186 +ml 5 8 9 7.122119 +``` #### Linux diff --git a/ml/synthesis/pyproject.toml b/ml/synthesis/pyproject.toml new file mode 100644 index 000000000..9793af723 --- /dev/null +++ b/ml/synthesis/pyproject.toml @@ -0,0 +1,8 @@ +[tool.flake8] +max-line-length = 125 + +[tool.isort] +profile = "black" + +[tool.black] +line-length = 12 \ No newline at end of file diff --git a/ml/synthesis/src/components/common/data_loading.py b/ml/synthesis/src/components/common/data_loading.py index 36f6650aa..bfd9c4652 100644 --- a/ml/synthesis/src/components/common/data_loading.py +++ b/ml/synthesis/src/components/common/data_loading.py @@ -3,9 +3,8 @@ from typing import List import pandas as pd -from pandas import DataFrame - from consts import DATA_DIR +from pandas import DataFrame def load_all_existing_training_data(data_dir: Path = DATA_DIR) -> DataFrame: diff --git a/ml/synthesis/src/components/common/logging.py b/ml/synthesis/src/components/common/logging.py index 50f567170..99d289083 100644 --- a/ml/synthesis/src/components/common/logging.py +++ b/ml/synthesis/src/components/common/logging.py @@ -3,7 +3,7 @@ def get_logger(module_name: str) -> Logger: - """ Fixes stdlib's logging function case and create a unified logger factory. """ + """Fixes stdlib's logging function case and create a unified logger factory.""" if module_name == "__main__": return logging.getLogger() # root logger return logging.getLogger(module_name) diff --git a/ml/synthesis/src/components/common/model_loading.py b/ml/synthesis/src/components/common/model_loading.py index 34bfb62cd..5d910e1b9 100644 --- a/ml/synthesis/src/components/common/model_loading.py +++ b/ml/synthesis/src/components/common/model_loading.py @@ -3,7 +3,6 @@ from typing import Tuple import tensorflow as tf - from components.common.logging import get_logger from components.model_generation.model_metainfo import ModelMetainfo diff --git a/ml/synthesis/src/components/common/port_management.py b/ml/synthesis/src/components/common/port_management.py index 490f189c4..6871ed30b 100644 --- a/ml/synthesis/src/components/common/port_management.py +++ b/ml/synthesis/src/components/common/port_management.py @@ -12,7 +12,7 @@ def is_port_in_use(port): logger.debug(f"Finding out if port {port} is in use...") with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(('localhost', port)) == 0 + result = s.connect_ex(("localhost", port)) == 0 logger.debug(f"Port {port} {'is in use' if result else 'is free'}") return result diff --git a/ml/synthesis/src/components/data_crawling/example_running.py b/ml/synthesis/src/components/data_crawling/example_running.py index f22a11747..488cbdd27 100644 --- a/ml/synthesis/src/components/data_crawling/example_running.py +++ b/ml/synthesis/src/components/data_crawling/example_running.py @@ -1,23 +1,20 @@ import asyncio import os import pickle -import random import signal import sys from asyncio import sleep from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator -from typing import Tuple, List +from typing import AsyncGenerator, List, Tuple import pandas as pd -from joblib import Parallel, delayed - from components.common.logging import get_logger -from components.common.port_management import is_port_in_use, find_random_free_port -from components.data_crawling.tree_processing import assemble_tree_dataframe +from components.common.port_management import find_random_free_port from components.data_crawling.tree_retrieving import retrieve_whole_nitta_tree +from components.data_crawling.tree_processing import assemble_tree_dataframe from consts import DATA_DIR, ROOT_DIR +from joblib import Parallel, delayed logger = get_logger(__name__) @@ -25,12 +22,13 @@ @asynccontextmanager -async def run_nitta(example: Path, - nitta_exe_path: str = "stack exec nitta -- ", - nitta_args: str = "", - nitta_env: dict = None, - port: int = None - ) -> AsyncGenerator[Tuple[asyncio.subprocess.Process, str], None]: +async def run_nitta( + example: Path, + nitta_exe_path: str = "stack exec nitta -- ", + nitta_args: str = "", + nitta_env: dict = None, + port: int = None, +) -> AsyncGenerator[Tuple[asyncio.subprocess.Process, str], None]: if port is None: port = find_random_free_port() @@ -46,8 +44,13 @@ async def run_nitta(example: Path, try: preexec_fn = None if os.name == "nt" else os.setsid # see https://stackoverflow.com/a/4791612 proc = await asyncio.create_subprocess_shell( - cmd, cwd=str(ROOT_DIR), stdout=sys.stdout, stderr=sys.stderr, shell=True, - preexec_fn=preexec_fn, env=env + cmd, + cwd=str(ROOT_DIR), + stdout=sys.stdout, + stderr=sys.stderr, + shell=True, + preexec_fn=preexec_fn, + env=env, ) logger.info(f"NITTA has been launched, PID {proc.pid}. Waiting for {_NITTA_START_WAIT_DELAY_S} secs.") @@ -63,12 +66,14 @@ async def run_nitta(example: Path, await proc.wait() -async def run_example_and_retrieve_tree_data(example: Path, - data_dir: Path = DATA_DIR, - nitta_exe_path: str = "stack exec nitta -- ") -> pd.DataFrame: +async def run_example_and_retrieve_tree_data( + example: Path, + data_dir: Path = DATA_DIR, + nitta_exe_path: str = "stack exec nitta -- ", +) -> pd.DataFrame: example_name = os.path.basename(example) async with run_nitta(example, nitta_exe_path) as (proc, nitta_baseurl): - logger.info(f"Retrieving tree.") + logger.info("Retrieving tree.") tree = await retrieve_whole_nitta_tree(nitta_baseurl) data_dir.mkdir(exist_ok=True) diff --git a/ml/synthesis/src/components/data_crawling/nitta_node.py b/ml/synthesis/src/components/data_crawling/nitta_node.py index 6ddcf8805..7dbd4c657 100644 --- a/ml/synthesis/src/components/data_crawling/nitta_node.py +++ b/ml/synthesis/src/components/data_crawling/nitta_node.py @@ -1,6 +1,6 @@ from collections import deque from dataclasses import dataclass, field -from typing import Optional, Any, List, Tuple, Deque +from typing import Any, Deque, List, Optional, Tuple import numpy as np import pandas as pd @@ -41,8 +41,8 @@ class NittaNode: duration: Optional[int] sid: str - children: Optional[List['NittaNode']] = field(default=None, repr=False) - parent: Optional['NittaNode'] = field(default=None, repr=False) + children: Optional[List["NittaNode"]] = field(default=None, repr=False) + parent: Optional["NittaNode"] = field(default=None, repr=False) def __hash__(self): return hash(self.sid) @@ -58,18 +58,19 @@ def subtree_size(self): @cached_property def depth(self) -> int: - return self.sid.count('-') if self.sid != '-' else 0 + return self.sid.count("-") if self.sid != "-" else 0 @cached_property def subtree_leafs_metrics(self) -> Optional[Deque[Tuple[int, int]]]: - """ :returns: deque(tuple(duration, depth)) or None if node is a failed leaf """ + """:returns: deque(tuple(duration, depth)) or None if node is a failed leaf""" if self.is_leaf: if not self.is_finish: return None return deque(((self.duration, self.depth),)) else: - children_metrics = \ - (child.subtree_leafs_metrics for child in self.children if child.subtree_leafs_metrics is not None) + children_metrics = ( + child.subtree_leafs_metrics for child in self.children if child.subtree_leafs_metrics is not None + ) return sum(children_metrics, deque()) @cached_node_method @@ -77,7 +78,10 @@ def get_subtree_leafs_labels(self, metrics_distrib: np.ndarray) -> deque: if self.is_leaf: return deque((self.compute_label(metrics_distrib),)) else: - return sum((child.get_subtree_leafs_labels(metrics_distrib) for child in self.children), deque()) + return sum( + (child.get_subtree_leafs_labels(metrics_distrib) for child in self.children), + deque(), + ) @cached_node_method def compute_label(self, metrics_distrib: np.ndarray) -> float: diff --git a/ml/synthesis/src/components/data_crawling/tree_processing.py b/ml/synthesis/src/components/data_crawling/tree_processing.py index cba0a6bb4..734bc72aa 100644 --- a/ml/synthesis/src/components/data_crawling/tree_processing.py +++ b/ml/synthesis/src/components/data_crawling/tree_processing.py @@ -1,13 +1,13 @@ from __future__ import annotations + from collections import deque from typing import Deque, Optional import numpy as np import pandas as pd -from cachetools import cached, Cache -from joblib import Parallel, delayed - +from cachetools import Cache, cached from components.data_crawling.nitta_node import NittaNode +from joblib import Parallel, delayed def _extract_params_dict(node: NittaNode) -> dict: @@ -36,13 +36,19 @@ def _extract_alternative_siblings_dict(node: NittaNode, siblings: tuple[NittaNod else: refactorings += 1 - return dict(alt_bindings=bindings, - alt_refactorings=refactorings, - alt_dataflows=dataflows) + return dict( + alt_bindings=bindings, + alt_refactorings=refactorings, + alt_dataflows=dataflows, + ) @cached(cache=Cache(10000)) -def nitta_node_to_df_dict(node: NittaNode, siblings: tuple[NittaNode], example: str = None, ) -> dict: +def nitta_node_to_df_dict( + node: NittaNode, + siblings: tuple[NittaNode], + example: str = None, +) -> dict: return dict( example=example, sid=node.sid, @@ -54,8 +60,14 @@ def nitta_node_to_df_dict(node: NittaNode, siblings: tuple[NittaNode], example: ) -def assemble_tree_dataframe(example: str, node: NittaNode, metrics_distrib=None, include_label=True, - levels_left=None, n_workers: int = 1) -> pd.DataFrame: +def assemble_tree_dataframe( + example: str, + node: NittaNode, + metrics_distrib=None, + include_label=True, + levels_left=None, + n_workers: int = 1, +) -> pd.DataFrame: if include_label and metrics_distrib is None: metrics_distrib = np.array(node.subtree_leafs_metrics) @@ -72,8 +84,14 @@ def child_process_job(node: NittaNode): return pd.DataFrame(sum(deques, deque())) -def _assemble_tree_dataframe_recursion(accum: Deque[dict], example: str, node: NittaNode, metrics_distrib: np.ndarray, - include_label: bool, levels_left: Optional[int]): +def _assemble_tree_dataframe_recursion( + accum: Deque[dict], + example: str, + node: NittaNode, + metrics_distrib: np.ndarray, + include_label: bool, + levels_left: Optional[int], +): siblings = (node.parent.children or []) if node.parent else [] self_dict = nitta_node_to_df_dict(node, tuple(siblings), example) @@ -85,7 +103,13 @@ def _assemble_tree_dataframe_recursion(accum: Deque[dict], example: str, node: N else: levels_left_for_child = None if levels_left is None else levels_left - 1 for child in node.children: - _assemble_tree_dataframe_recursion(accum, example, child, metrics_distrib, include_label, - levels_left_for_child) + _assemble_tree_dataframe_recursion( + accum, + example, + child, + metrics_distrib, + include_label, + levels_left_for_child, + ) if node.sid != "-": accum.appendleft(self_dict) # so it's from roots to leaves diff --git a/ml/synthesis/src/components/data_crawling/tree_retrieving.py b/ml/synthesis/src/components/data_crawling/tree_retrieving.py index ed4ec9a41..cc7169334 100644 --- a/ml/synthesis/src/components/data_crawling/tree_retrieving.py +++ b/ml/synthesis/src/components/data_crawling/tree_retrieving.py @@ -2,7 +2,6 @@ import time from aiohttp import ClientSession - from components.common.logging import get_logger from components.common.utils import debounce from components.data_crawling.nitta_node import NittaNode @@ -11,7 +10,12 @@ logger_debug_debounced = debounce(1)(logger.debug) -async def retrieve_subforest(node: NittaNode, session: ClientSession, nitta_baseurl: str, levels_left=None): +async def retrieve_subforest( + node: NittaNode, + session: ClientSession, + nitta_baseurl: str, + levels_left=None, +): node.children = [] if node.is_leaf or levels_left == -1: return @@ -35,7 +39,7 @@ async def retrieve_subforest(node: NittaNode, session: ClientSession, nitta_base async def retrieve_whole_nitta_tree(nitta_baseurl: str, max_depth=None) -> NittaNode: start_time = time.perf_counter() async with ClientSession() as session: - async with session.get(nitta_baseurl + f"/node/-") as resp: + async with session.get(nitta_baseurl + "/node/-") as resp: root_raw = await resp.json() root = NittaNode.from_dict(root_raw) await retrieve_subforest(root, session, nitta_baseurl, max_depth) diff --git a/ml/synthesis/src/components/data_processing/dataset_creation.py b/ml/synthesis/src/components/data_processing/dataset_creation.py index 367c5ff03..f84a14a67 100644 --- a/ml/synthesis/src/components/data_processing/dataset_creation.py +++ b/ml/synthesis/src/components/data_processing/dataset_creation.py @@ -1,10 +1,9 @@ from typing import Tuple +from components.common.logging import get_logger from sklearn.model_selection import train_test_split from tensorflow.python.data import Dataset -from components.common.logging import get_logger - logger = get_logger(__name__) TARGET_COLUMNS = ["label"] diff --git a/ml/synthesis/src/components/data_processing/feature_engineering.py b/ml/synthesis/src/components/data_processing/feature_engineering.py index fadea6ecd..a00050b83 100644 --- a/ml/synthesis/src/components/data_processing/feature_engineering.py +++ b/ml/synthesis/src/components/data_processing/feature_engineering.py @@ -1,9 +1,8 @@ from __future__ import annotations -import pandas as pd - -from pandas import DataFrame +import pandas as pd from components.common.logging import get_logger +from pandas import DataFrame logger = get_logger(__name__) @@ -19,25 +18,56 @@ def _map_categorical(df, c): def preprocess_df(df: DataFrame) -> DataFrame: df: DataFrame = df.copy() - for bool_column in ["is_leaf", "pCritical", "pPossibleDeadlock", "pRestrictedTime"]: + for bool_column in [ + "is_leaf", + "pCritical", + "pPossibleDeadlock", + "pRestrictedTime", + ]: if bool_column in df.columns: df[bool_column] = _map_bool(df[bool_column]) else: logger.warning(f"Column/parameter {bool_column} not found in provided node info.") df = _map_categorical(df, df.tag) - df = df.drop(["pWave", "example", "sid", "old_score", "is_leaf", "pRefactoringType"], axis="columns", errors="ignore") + df = df.drop( + [ + "pWave", + "example", + "sid", + "old_score", + "is_leaf", + "pRefactoringType", + ], + axis="columns", + errors="ignore", + ) df = df.fillna(0) return df # TODO: move that to metainfo of the model, find a way to make input building model-dependent # (pickled module? function name?) -_BASELINE_MODEL_COLUMNS = \ - ["alt_bindings", "alt_refactorings", "alt_dataflows", "pAllowDataFlow", "pAlternative", "pCritical", - "pNumberOfBindedFunctions", "pOutputNumber", "pPercentOfBindedInputs", "pPossibleDeadlock", "pRestless", - "pFirstWaveOfTargetUse", "pNotTransferableInputs", "pRestrictedTime", "pWaitTime", "tag_BindDecisionView", - "tag_BreakLoopView", "tag_DataflowDecisionView"] +_BASELINE_MODEL_COLUMNS = [ + "alt_bindings", + "alt_refactorings", + "alt_dataflows", + "pAllowDataFlow", + "pAlternative", + "pCritical", + "pNumberOfBindedFunctions", + "pOutputNumber", + "pPercentOfBindedInputs", + "pPossibleDeadlock", + "pRestless", + "pFirstWaveOfTargetUse", + "pNotTransferableInputs", + "pRestrictedTime", + "pWaitTime", + "tag_BindDecisionView", + "tag_BreakLoopView", + "tag_DataflowDecisionView", +] def df_to_model_columns(df: DataFrame, model_columns: list[str] = None) -> DataFrame: diff --git a/ml/synthesis/src/components/model_generation/models.py b/ml/synthesis/src/components/model_generation/models.py index 4e1c016e2..cc88b8ddc 100644 --- a/ml/synthesis/src/components/model_generation/models.py +++ b/ml/synthesis/src/components/model_generation/models.py @@ -3,15 +3,17 @@ def create_baseline_model(input_shape) -> tf.keras.Model: - model = tf.keras.Sequential([ - layers.InputLayer(input_shape=input_shape), - layers.Dense(128, activation="relu", kernel_regularizer="l2"), - layers.Dense(128, activation="relu", kernel_regularizer="l2"), - layers.Dense(64, activation="relu", kernel_regularizer="l2"), - layers.Dense(64, activation="relu", kernel_regularizer="l2"), - layers.Dense(32, activation="relu"), - layers.Dense(1) - ]) + model = tf.keras.Sequential( + [ + layers.InputLayer(input_shape=input_shape), + layers.Dense(128, activation="relu", kernel_regularizer="l2"), + layers.Dense(128, activation="relu", kernel_regularizer="l2"), + layers.Dense(64, activation="relu", kernel_regularizer="l2"), + layers.Dense(64, activation="relu", kernel_regularizer="l2"), + layers.Dense(32, activation="relu"), + layers.Dense(1), + ] + ) model.compile( optimizer=tf.keras.optimizers.Adam(lr=3e-4), diff --git a/ml/synthesis/src/components/model_generation/training.py b/ml/synthesis/src/components/model_generation/training.py index dd6c3c26f..ba04bf8de 100644 --- a/ml/synthesis/src/components/model_generation/training.py +++ b/ml/synthesis/src/components/model_generation/training.py @@ -2,20 +2,23 @@ from time import strftime from typing import Tuple -from tensorflow.python.data import Dataset -from tensorflow.python.keras.models import Model - from components.common.logging import get_logger from components.model_generation.model_metainfo import ModelMetainfo from components.model_generation.models import create_baseline_model from consts import MODELS_DIR +from tensorflow.python.data import Dataset +from tensorflow.python.keras.models import Model logger = get_logger(__name__) -def train_and_save_baseline_model(train_ds: Dataset, val_ds: Dataset, fitting_kwargs: dict = None, - output_model_name: str = None, models_dir: Path = MODELS_DIR) \ - -> Tuple[Model, ModelMetainfo]: +def train_and_save_baseline_model( + train_ds: Dataset, + val_ds: Dataset, + fitting_kwargs: dict = None, + output_model_name: str = None, + models_dir: Path = MODELS_DIR, +) -> Tuple[Model, ModelMetainfo]: models_dir.mkdir(exist_ok=True) sample = next(iter(val_ds))[0][0] @@ -31,7 +34,10 @@ def train_and_save_baseline_model(train_ds: Dataset, val_ds: Dataset, fitting_kw results = model.fit(x=train_ds, validation_data=val_ds, **effective_fitting_kwargs) # TODO: proper model evaluation on an independent dataset - metainfo = ModelMetainfo(train_mae=results.history["mae"][-1], validation_mae=results.history["val_mae"][-1]) + metainfo = ModelMetainfo( + train_mae=results.history["mae"][-1], + validation_mae=results.history["val_mae"][-1], + ) if not output_model_name: output_model_name = f"model-{strftime('%Y%m%d-%H%M%S')}" diff --git a/ml/synthesis/src/components/utils/string.py b/ml/synthesis/src/components/utils/string.py index 5a1739305..651c58260 100644 --- a/ml/synthesis/src/components/utils/string.py +++ b/ml/synthesis/src/components/utils/string.py @@ -1,5 +1,5 @@ def snake_to_lower_camel_case(snake_str: str): - components = [c for c in snake_str.strip().split('_') if c] + components = [c for c in snake_str.strip().split("_") if c] if not components: return "" return components[0] + "".join(x[0].title() + x[1:] for x in components[1:]) diff --git a/ml/synthesis/src/mlbackend/__main__.py b/ml/synthesis/src/mlbackend/__main__.py index a7f49f305..b04f19dc3 100644 --- a/ml/synthesis/src/mlbackend/__main__.py +++ b/ml/synthesis/src/mlbackend/__main__.py @@ -1,14 +1,19 @@ import uvicorn - -from components.common.logging import get_logger, configure_logging +from components.common.logging import configure_logging, get_logger from consts import ML_BACKEND_BASE_URL_FILEPATH -from mlbackend.app import app from mlbackend.backend_base_url_file import BackendBaseUrlFile logger = get_logger(__name__) configure_logging() -with BackendBaseUrlFile(filepath=ML_BACKEND_BASE_URL_FILEPATH, - base_url_fmt="http://127.0.0.1:{port}") as base_url_file: +with BackendBaseUrlFile( + filepath=ML_BACKEND_BASE_URL_FILEPATH, + base_url_fmt="http://127.0.0.1:{port}", +) as base_url_file: logger.info(f"Starting ML backend server on port {base_url_file.port}") - uvicorn.run("mlbackend.app:app", host="127.0.0.1", port=base_url_file.port, workers=4) + uvicorn.run( + "mlbackend.app:app", + host="127.0.0.1", + port=base_url_file.port, + workers=4, + ) diff --git a/ml/synthesis/src/mlbackend/app.py b/ml/synthesis/src/mlbackend/app.py index 39c3422ba..95c8cfc1a 100644 --- a/ml/synthesis/src/mlbackend/app.py +++ b/ml/synthesis/src/mlbackend/app.py @@ -1,18 +1,17 @@ from __future__ import annotations + from http import HTTPStatus -import numpy as np import pandas as pd +from components.data_crawling.nitta_node import NittaNode +from components.data_crawling.tree_processing import nitta_node_to_df_dict +from components.data_processing.feature_engineering import df_to_model_columns, preprocess_df from fastapi import FastAPI, HTTPException from fastapi.exception_handlers import http_exception_handler +from mlbackend.dtos import ModelInfo, PostScoreRequestBody, PostScoreResponseData, Response +from mlbackend.models_store import ModelNotFoundError, models from starlette.responses import HTMLResponse -from components.data_crawling.nitta_node import NittaNode -from components.data_crawling.tree_processing import nitta_node_to_df_dict -from components.data_processing.feature_engineering import preprocess_df, df_to_model_columns -from mlbackend.dtos import Response, ModelInfo, PostScoreRequestBody, PostScoreResponseData -from mlbackend.models_store import models, ModelNotFoundError - app = FastAPI( title="NITTA ML Backend", version="0.0.0", @@ -25,12 +24,18 @@ @app.get("/models/{model_name}") def get_model_info(model_name: str) -> Response[ModelInfo]: model, meta = models[model_name] - return Response(data=ModelInfo(name=model_name, train_mae=meta.train_mae, validation_mae=meta.validation_mae)) + return Response( + data=ModelInfo( + name=model_name, + train_mae=meta.train_mae, + validation_mae=meta.validation_mae, + ) + ) @app.post("/models/{model_name}/score") def score_with_model(model_name: str, body: PostScoreRequestBody) -> Response[PostScoreResponseData]: - """ Runs score prediction with model of given name for each input in a given list of inputs. """ + """Runs score prediction with model of given name for each input in a given list of inputs.""" model, meta = models[model_name] scores = [] @@ -45,21 +50,29 @@ def score_with_model(model_name: str, body: PostScoreRequestBody) -> Response[Po siblings.append(node) if not target_nodes: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="No target node(s) were found") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="No target node(s) were found", + ) df = pd.DataFrame([nitta_node_to_df_dict(target_node, siblings=tuple(nodes)) for target_node in target_nodes]) df = preprocess_df(df) df = df_to_model_columns(df) scores.append(model.predict(df.values).reshape(-1).tolist()) - return Response(data=PostScoreResponseData( - scores=scores, - )) + return Response( + data=PostScoreResponseData( + scores=scores, + ) + ) @app.exception_handler(ModelNotFoundError) async def model_not_found_exception_handler(request, exc): - return await http_exception_handler(request, HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc))) + return await http_exception_handler( + request, + HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)), + ) @app.get("/docs", response_class=HTMLResponse, include_in_schema=False) @@ -69,19 +82,19 @@ def get_rapidoc_docs(): - - - + """ diff --git a/ml/synthesis/src/mlbackend/backend_base_url_file.py b/ml/synthesis/src/mlbackend/backend_base_url_file.py index 30c285656..d91689a56 100644 --- a/ml/synthesis/src/mlbackend/backend_base_url_file.py +++ b/ml/synthesis/src/mlbackend/backend_base_url_file.py @@ -28,6 +28,7 @@ class BackendBaseUrlFile(AbstractContextManager): with BackendBaseUrlFile(".ml_backend_base_url", "http://localhost:{port}/mlbackend") as base_url_file: server.start(port=base_url_file.port) """ + filepath: Path port: int | None _base_url_fmt: str @@ -37,7 +38,7 @@ def __init__(self, filepath: str | PathLike, base_url_fmt: str): self.port = None self._base_url_fmt = base_url_fmt - def __enter__(self) -> 'BackendBaseUrlFile': + def __enter__(self) -> "BackendBaseUrlFile": if self.filepath.exists(): with self.filepath.open("r") as f: old_base_url = f.read() @@ -55,12 +56,16 @@ def __enter__(self) -> 'BackendBaseUrlFile': f.write(base_url + "\n") return self - def __exit__(self, __exc_type: Type[BaseException] | None, - __exc_value: BaseException | None, - __traceback: TracebackType | None) -> bool | None: + def __exit__( + self, + __exc_type: Type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: if not self.filepath.exists(): - logger.warning(f"Exiting, so wanted to remove {self.filepath.absolute()}, " - f"but it does not exist. Doing nothing.") + logger.warning( + f"Exiting, so wanted to remove {self.filepath.absolute()}, " f"but it does not exist. Doing nothing." + ) return logger.info(f"Exiting, so removing {self.filepath.absolute()}.") diff --git a/ml/synthesis/src/mlbackend/dtos.py b/ml/synthesis/src/mlbackend/dtos.py index 733ba209d..01fff1c56 100644 --- a/ml/synthesis/src/mlbackend/dtos.py +++ b/ml/synthesis/src/mlbackend/dtos.py @@ -1,10 +1,9 @@ -from typing import TypeVar, Generic, Dict, List, Union +from typing import Generic, List, TypeVar +from components.utils.string import snake_to_lower_camel_case from pydantic import BaseModel, Field from pydantic.generics import GenericModel -from components.utils.string import snake_to_lower_camel_case - TData = TypeVar("TData") @@ -36,7 +35,8 @@ class ModelInfo(CustomizedBaseModel): # related to corresponding NITTA REST API DTO # TODO: link those types to NITTA Haskell source code via code generation like it's done with with TypeScript? class NittaNodeView(CustomizedBaseModel): - """ `NodeView` from NITTA Haskell sources. """ + """`NodeView` from NITTA Haskell sources.""" + sid: str = Field(example="-0-4-7-3-4-1-1-0") is_terminal: bool is_finish: bool @@ -52,18 +52,19 @@ class NittaNodeView(CustomizedBaseModel): "pWave": 4, "pOutputNumber": 1, "pAlternative": 2, - "pCritical": False - }) + "pCritical": False, + }, + ) decision: dict = Field( description="SynthesisDecision.decision from NITTA Haskell sources.", example={ "tag": "BindDecisionView", "function": { "fvFun": "loop(0.000000, res^0#0) = i^0#0", - "fvHistory": [] + "fvHistory": [], }, - "pu": "fram1" - } + "pu": "fram1", + }, ) score: int # compatibility @@ -85,23 +86,22 @@ class ScoringInput(CustomizedBaseModel): to ignore that for now since optimizations will definitely be possible when they become needed. For example, we can always add mentioned data gathering branching based on currently used model at the cost of reduced flexibility. """ + scoring_target: str = Field( description="SID of a node which we need to predict the score for. This node must be in `nodes` list. " - "You can also pass the value `all` to get scores for all nodes in the `nodes` list.", - example="-0-4-7-3-4-1-1-0" + "You can also pass the value `all` to get scores for all nodes in the `nodes` list.", + example="-0-4-7-3-4-1-1-0", ) nodes: List[NittaNodeView] = Field( description="`NodeView`s of scoring target node and all its siblings (all possible synthesis tree choices " - "from current parent node)." + "from current parent node)." ) # parents? # decision history? (can differ from parents!) class PostScoreRequestBody(CustomizedBaseModel): - inputs: List[ScoringInput] = Field( - description="List of inputs to get score predictions for. " - ) + inputs: List[ScoringInput] = Field(description="List of inputs to get score predictions for. ") # data, not whole body (which can include "data" field and metadata) diff --git a/ml/synthesis/src/mlbackend/models_store.py b/ml/synthesis/src/mlbackend/models_store.py index fee30e2e6..42c5626c6 100644 --- a/ml/synthesis/src/mlbackend/models_store.py +++ b/ml/synthesis/src/mlbackend/models_store.py @@ -3,18 +3,18 @@ from os import PathLike from pathlib import Path -from tensorflow.python.keras import Model - from components.common.logging import get_logger from components.common.model_loading import load_model from components.model_generation.model_metainfo import ModelMetainfo from consts import MODELS_DIR +from tensorflow.python.keras import Model logger = get_logger(__name__) class ModelsStore: - """ Loads models on-demand and caches them between calls. """ + """Loads models on-demand and caches them between calls.""" + model_dir: Path models: dict[str, (Model, ModelMetainfo)] = {} diff --git a/ml/synthesis/src/scripts/evaluation.py b/ml/synthesis/src/scripts/evaluation.py new file mode 100644 index 000000000..601863077 --- /dev/null +++ b/ml/synthesis/src/scripts/evaluation.py @@ -0,0 +1,310 @@ +import argparse +import asyncio +from collections import defaultdict +from pathlib import Path +from time import perf_counter + +import pandas as pd +import tensorflow as tf +from aiohttp import ClientSession, ServerDisconnectedError +from components.common.logging import configure_logging, get_logger +from components.data_crawling.example_running import run_nitta +from components.data_crawling.nitta_node import NittaNode +from components.data_crawling.tree_retrieving import retrieve_subforest, retrieve_whole_nitta_tree +from components.data_processing.feature_engineering import df_to_model_columns +from consts import MODELS_DIR +from IPython.display import display + +model = tf.keras.models.load_model(MODELS_DIR) + + +def preprocess_df(df: pd.DataFrame) -> pd.DataFrame: + def map_bool(c): + return c.apply(lambda v: 1 if v is True else (0 if v is False else v)) + + def map_categorical(df, c, options=None): + return pd.concat( + [ + df.drop([c.name], axis=1), + pd.get_dummies(c, prefix=c.name, columns=options), + ], + axis=1, + ) + + df = df.copy() + df.is_leaf = map_bool(df.is_leaf) + df.pCritical = map_bool(df.pCritical) + df.pPossibleDeadlock = map_bool(df.pPossibleDeadlock) + df.pRestrictedTime = map_bool(df.pRestrictedTime) + df = map_categorical( + df, + df.tag, + [ + "tag_BindDecisionView", + "tag_BreakLoopView", + "tag_ConstantFoldingView", + "tag_DataflowDecisionView", + "tag_OptimizeAccumView", + "tag_ResolveDeadlockView", + ], + ) + df = df.drop( + [ + "pWave", + "example", + "sid", + "old_score", + "is_leaf", + "pRefactoringType", + ], + axis="columns", + ) + + df = df.fillna(0) + return df + + +def _extract_params_dict(node: NittaNode) -> dict: + if node.decision.tag in ["BindDecisionView", "DataflowDecisionView"]: + result = node.parameters.copy() + if node.decision.tag == "DataflowDecisionView": + result["pNotTransferableInputs"] = sum(result["pNotTransferableInputs"]) + return result + elif node.decision.tag == "RootView": + return {} + else: + # refactorings + return {"pRefactoringType": node.decision.tag} + + +def assemble_tree_dataframe( + example: str, + node: NittaNode, + metrics_distrib=None, + include_label=True, + levels_left=None, +) -> pd.DataFrame: + if include_label and metrics_distrib is None: + metrics_distrib = node.subtree_leafs_metrics + + self_df = pd.DataFrame( + dict( + example=example, + sid=node.sid, + tag=node.decision.tag, + old_score=node.score, + is_leaf=node.is_leaf, + **_extract_params_dict(node), + ), + index=[0], + ) + if include_label: + self_df["label"] = node.compute_label(metrics_distrib) + + levels_left_for_child = None if levels_left is None else levels_left - 1 + if node.is_leaf or levels_left == -1: + return self_df + else: + result = [ + assemble_tree_dataframe( + example, + child, + metrics_distrib, + include_label, + levels_left_for_child, + ) + for child in node.children + ] + if node.sid != "-": + result.insert(0, self_df) + return pd.concat(result) + + +async def select_best_by_evaluator(session, evaluator, node, nitta_baseurl, counters, children_limit=None): + counters[evaluator.__name__] += 1 + + if node.is_leaf: + if not node.is_finish: + return None + + return node + + try: + await retrieve_subforest(node, session, nitta_baseurl) + except ServerDisconnectedError: + # print(f"Invalid node with NITTA exception: {node}") + return None + + children = node.children + + if children_limit: + children = children[:children_limit] + + children = [(evaluator(child), child) for child in node.children] + children.sort(key=lambda v: v[0], reverse=True) + # print(f"children: {[d[0] for d in children]}") + + while children: + next_best_child = children.pop(0)[1] + # print(f"next best: {next_best_child}") + result = await select_best_by_evaluator( + session, + evaluator, + next_best_child, + nitta_baseurl, + counters, + children_limit, + ) + if result is not None: + return result + + return None + + +def old_evaluator(node: NittaNode): + return node.score + + +def new_evaluator(node: NittaNode): + final_columns = [ + "alt_bindings", + "alt_refactorings", + "alt_dataflows", + "pOutputNumber", + "pAlternative", + "pAllowDataFlow", + "pCritical", + "pPercentOfBindedInputs", + "pPossibleDeadlock", + "pNumberOfBindedFunctions", + "pRestless", + "pNotTransferableInputs", + "pRestrictedTime", + "pWaitTime", + "tag_BindDecisionView", + "tag_BreakLoopView", + "tag_ConstantFoldingView", + "tag_DataflowDecisionView", + "tag_OptimizeAccumView", + "tag_ResolveDeadlockView", + ] + metrics_columns = [cn for cn in final_columns if cn.startswith("p")] + [ + "pRefactoringType", + "pWave", + ] + + node_df = assemble_tree_dataframe("", node, include_label=False, levels_left=-1) + filled_metrics_df = pd.concat([pd.DataFrame(columns=metrics_columns), node_df]) + preprocessed_df = preprocess_df(filled_metrics_df) + + final_df = df_to_model_columns(preprocessed_df, model_columns=final_columns) + + return model.predict(final_df.values)[0][0] + + +def reset_counters(): + global counters + counters = defaultdict(lambda: 0) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("example_paths", type=str, nargs="+", help="Paths to the example files") + parser.add_argument( + "--evaluator", + type=str, + nargs="+", + choices=["nitta", "ml"], + help="Evaluator to use", + ) + parser.add_argument( + "--nitta_args", + type=str, + default="", + help="Additional arguments for Nitta", + ) + return parser.parse_args() + + +async def main(args): + examples = args.example_paths + evaluator_choices = args.evaluator + + evaluator_dict = {"nitta": old_evaluator, "ml": new_evaluator} + + evaluator_choices = [choice for choice in evaluator_choices if choice in evaluator_dict] + if not evaluator_choices: + print("Invalid evaluator choices. Using the new evaluator as default.") + evaluator_choices = ["nitta"] + + results = [] + + for example in examples: + example = Path(example) + + logger = get_logger(__name__) + configure_logging() + reset_counters() + + logger.info(f"Selected algorithm: {example}") + + nitta_tree = None + async with run_nitta(example, nitta_args=args.nitta_args) as ( + proc, + nitta_baseurl, + ): + nitta_tree = await retrieve_whole_nitta_tree(nitta_baseurl) + new_evaluator(nitta_tree.children[0]) + reset_counters() + root = await retrieve_whole_nitta_tree(nitta_baseurl) + + async with ClientSession() as session: + result_dict = {"example": example, "evaluators": {}} + for evaluator_choice in evaluator_choices: + evaluator = evaluator_dict[evaluator_choice] + start_time = perf_counter() + best = await select_best_by_evaluator(session, evaluator, root, nitta_baseurl, counters, 2) + end_time = perf_counter() - start_time + result_dict["evaluators"][evaluator_choice] = { + "best": best, + "duration": best.duration, + "depth": best.depth, + "evaluator_calls": counters[evaluator_choice + "_evaluator"], + "time": end_time, + } + logger.info(f"{evaluator_choice.upper()} DONE %s", best) + logger.info(f"Finished {evaluator_choice} in {end_time:.2f} s") + results.append(result_dict) + + for result in results: + print(f"\nAlgorithm: {result['example']}") + dfs = [] + for evaluator, evaluator_result in result["evaluators"].items(): + df = pd.DataFrame( + dict( + duration=[evaluator_result["duration"]], + depth=[evaluator_result["depth"]], + evaluator_calls=[evaluator_result["evaluator_calls"]], + time=[evaluator_result["time"]], + ), + index=[evaluator], + ) + dfs.append(df) + result_df = pd.concat(dfs) + display(result_df) + + +async def test_script(): + class Args: + def __init__(self, example_path): + self.example_path = example_path + + args = Args("examples/fibonacci.lua") + await main(args) + + +if __name__ == "__main__": + args = parse_args() + asyncio.run(main(args)) +# Раскомментируйте следующую строку для выполнения тестовой функции +# asyncio.run(test_script()) diff --git a/ml/synthesis/src/scripts/measure_data_crawling_speed.py b/ml/synthesis/src/scripts/measure_data_crawling_speed.py index 90ce1876a..cc823d54d 100644 --- a/ml/synthesis/src/scripts/measure_data_crawling_speed.py +++ b/ml/synthesis/src/scripts/measure_data_crawling_speed.py @@ -3,10 +3,10 @@ from pathlib import Path from time import perf_counter +from components.common.logging import configure_logging, get_logger from components.data_crawling.example_running import run_example_and_retrieve_tree_data -from components.common.logging import get_logger, configure_logging -if __name__ == '__main__': +if __name__ == "__main__": logger = get_logger(__name__) configure_logging() diff --git a/ml/synthesis/src/scripts/train_evaluate_in_ci.py b/ml/synthesis/src/scripts/train_evaluate_in_ci.py index fc7622f84..b8a394d3f 100644 --- a/ml/synthesis/src/scripts/train_evaluate_in_ci.py +++ b/ml/synthesis/src/scripts/train_evaluate_in_ci.py @@ -2,7 +2,7 @@ from pathlib import Path from components.common.data_loading import load_all_existing_training_data -from components.common.logging import get_logger, configure_logging +from components.common.logging import configure_logging, get_logger from components.common.model_loading import load_model from components.data_crawling.example_running import get_data_for_many_examples_parallel from components.data_processing.dataset_creation import create_datasets @@ -10,7 +10,7 @@ from components.model_generation.training import train_and_save_baseline_model from consts import MODELS_DIR -if __name__ == '__main__': +if __name__ == "__main__": logger = get_logger(__name__) configure_logging() diff --git a/ml/synthesis/src/tests/test_data_processing.py b/ml/synthesis/src/tests/test_data_processing.py index 1faaff880..1192e5f72 100644 --- a/ml/synthesis/src/tests/test_data_processing.py +++ b/ml/synthesis/src/tests/test_data_processing.py @@ -1,27 +1,26 @@ import pandas as pd - from components.data_processing.feature_engineering import df_to_model_columns def test_df_to_model_columns_conversion_introduces_columns(): - df_inp = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) - df_out_expected = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [0, 0, 0]}) + df_inp = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_out_expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0, 0, 0]}) assert df_to_model_columns(df_inp, model_columns=["a", "b", "c"]).equals(df_out_expected) def test_df_to_model_columns_conversion_removes_columns(): - df_inp = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [0, 0, 0]}) - df_out_expected = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + df_inp = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0, 0, 0]}) + df_out_expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) assert df_to_model_columns(df_inp, model_columns=["a", "b"]).equals(df_out_expected) def test_df_to_model_columns_conversion_reorders_columns(): - df_inp = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [0, 0, 0]}) - df_out_expected = pd.DataFrame({'a': [1, 2, 3], 'c': [0, 0, 0], 'b': [4, 5, 6]}) + df_inp = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0, 0, 0]}) + df_out_expected = pd.DataFrame({"a": [1, 2, 3], "c": [0, 0, 0], "b": [4, 5, 6]}) assert df_to_model_columns(df_inp, model_columns=["a", "c", "b"]).equals(df_out_expected) def test_df_to_model_columns_conversion_fills_missing_data(): - df_inp = pd.DataFrame({'a': [1, 2, 3], 'b': [4, pd.NA, pd.NA], 'c': [pd.NA, pd.NA, 1]}) - df_out_expected = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 0, 0]}) + df_inp = pd.DataFrame({"a": [1, 2, 3], "b": [4, pd.NA, pd.NA], "c": [pd.NA, pd.NA, 1]}) + df_out_expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 0, 0]}) assert df_to_model_columns(df_inp, model_columns=["a", "b"]).equals(df_out_expected) diff --git a/ml/synthesis/src/tests/test_smoke.py b/ml/synthesis/src/tests/test_smoke.py index 1e778a7f7..efe6cdf25 100644 --- a/ml/synthesis/src/tests/test_smoke.py +++ b/ml/synthesis/src/tests/test_smoke.py @@ -2,13 +2,12 @@ from tempfile import TemporaryDirectory import numpy as np - from components.common.data_loading import load_all_existing_training_data from components.common.model_loading import load_model from components.common.utils import strip_none_from_tensor_shape from components.data_crawling.example_running import run_example_and_retrieve_tree_data, run_nitta from components.data_crawling.tree_retrieving import retrieve_whole_nitta_tree -from components.data_processing.dataset_creation import create_datasets, TARGET_COLUMNS +from components.data_processing.dataset_creation import TARGET_COLUMNS, create_datasets from components.data_processing.feature_engineering import preprocess_df from components.model_generation.training import train_and_save_baseline_model from consts import EXAMPLES_DIR, EnvVarNames @@ -29,10 +28,16 @@ async def test_smoke(): tds, vds = create_datasets(pdf) model_name = "test-model" - train_and_save_baseline_model(tds, vds, fitting_kwargs=dict( - epochs=1, - steps_per_epoch=1, - ), output_model_name=model_name, models_dir=tmp_models_dir) + train_and_save_baseline_model( + tds, + vds, + fitting_kwargs=dict( + epochs=1, + steps_per_epoch=1, + ), + output_model_name=model_name, + models_dir=tmp_models_dir, + ) model, metainfo = load_model(tmp_models_dir / model_name) inp_shape = strip_none_from_tensor_shape(model.input_shape) @@ -47,18 +52,23 @@ async def _get_scores(baseurl: str): tree = await retrieve_whole_nitta_tree(baseurl, max_depth=0) return [c.score for c in tree.children] - async with run_nitta(EXAMPLES_DIR / "fibonacci.lua") as (_, nitta_baseurl): + async with run_nitta(EXAMPLES_DIR / "fibonacci.lua") as ( + _, + nitta_baseurl, + ): non_ml_scores = await _get_scores(nitta_baseurl) - async with run_nitta(EXAMPLES_DIR / "fibonacci.lua", - nitta_args=f"--ml-scoring-model=does_not_exist") \ - as (_, nitta_baseurl): + async with run_nitta( + EXAMPLES_DIR / "fibonacci.lua", + nitta_args="--ml-scoring-model=does_not_exist", + ) as (_, nitta_baseurl): fallback_non_ml_scores = await _get_scores(nitta_baseurl) assert non_ml_scores == fallback_non_ml_scores - async with run_nitta(EXAMPLES_DIR / "fibonacci.lua", - nitta_args=f"--ml-scoring-model=\"{model_name}\"", - nitta_env={EnvVarNames.MODELS_DIR: tmp_model_dir_name}) \ - as (_, nitta_baseurl): + async with run_nitta( + EXAMPLES_DIR / "fibonacci.lua", + nitta_args=f'--ml-scoring-model="{model_name}"', + nitta_env={EnvVarNames.MODELS_DIR: tmp_model_dir_name}, + ) as (_, nitta_baseurl): ml_scores = await _get_scores(nitta_baseurl) assert non_ml_scores != ml_scores diff --git a/ml/synthesis/src/tests/test_utils.py b/ml/synthesis/src/tests/test_utils.py index 0bffdcda5..eebe312ca 100644 --- a/ml/synthesis/src/tests/test_utils.py +++ b/ml/synthesis/src/tests/test_utils.py @@ -1,32 +1,32 @@ -import unittest - import pytest - from components.utils.string import snake_to_lower_camel_case -@pytest.mark.parametrize("inp, out_expected", [ - ("", ""), - ("test", "test"), - ("test_case", "testCase"), - ("test_case_1", "testCase1"), - ("test_case_1_2", "testCase12"), - ("_test_case_1_2", "testCase12"), - ("__test_case_1_2", "testCase12"), - ("__test_case_1_2_", "testCase12"), - ("__test_case_1_2__", "testCase12"), - ("test_multi_word_test_test", "testMultiWordTestTest"), - ("test_multi_word_test_test_", "testMultiWordTestTest"), - ("test_multi_word_test_test__", "testMultiWordTestTest"), - ("test_multi_word_test_test__12", "testMultiWordTestTest12"), - ("test_multi_word_test_test__1_2", "testMultiWordTestTest12"), - ("test_multi_word_test_test__1__2", "testMultiWordTestTest12"), - ("test_multi_word_test_test__1__2_", "testMultiWordTestTest12"), - ("test_multi_word_test_test__1__2__", "testMultiWordTestTest12"), - ("test18", "test18"), - ("test_multi18", "testMulti18"), - ("_", ""), - ("__", "") -]) +@pytest.mark.parametrize( + "inp, out_expected", + [ + ("", ""), + ("test", "test"), + ("test_case", "testCase"), + ("test_case_1", "testCase1"), + ("test_case_1_2", "testCase12"), + ("_test_case_1_2", "testCase12"), + ("__test_case_1_2", "testCase12"), + ("__test_case_1_2_", "testCase12"), + ("__test_case_1_2__", "testCase12"), + ("test_multi_word_test_test", "testMultiWordTestTest"), + ("test_multi_word_test_test_", "testMultiWordTestTest"), + ("test_multi_word_test_test__", "testMultiWordTestTest"), + ("test_multi_word_test_test__12", "testMultiWordTestTest12"), + ("test_multi_word_test_test__1_2", "testMultiWordTestTest12"), + ("test_multi_word_test_test__1__2", "testMultiWordTestTest12"), + ("test_multi_word_test_test__1__2_", "testMultiWordTestTest12"), + ("test_multi_word_test_test__1__2__", "testMultiWordTestTest12"), + ("test18", "test18"), + ("test_multi18", "testMulti18"), + ("_", ""), + ("__", ""), + ], +) def test_snake_to_lower_camel_case(inp, out_expected): assert snake_to_lower_camel_case(inp) == out_expected diff --git a/package.yaml b/package.yaml index 480a6190b..7423385f7 100644 --- a/package.yaml +++ b/package.yaml @@ -20,23 +20,23 @@ description: |- Page: -category: 'CGRA, ASIP, CAD, hardware' +category: "CGRA, ASIP, CAD, hardware" author: Aleksandr Penskoi maintainer: aleksandr.penskoi@gmail.com copyright: 2021 Aleksandr Penskoi license: BSD3 -homepage: 'https://ryukzak.github.io/projects/nitta/' +homepage: "https://ryukzak.github.io/projects/nitta/" extra-doc-files: - README.md ghc-options: - - '-j' - - '-Wall' - - '-Werror' - - '-Wcompat' - - '-Wredundant-constraints' - - '-fno-warn-missing-signatures' - - '-optP-Wno-nonportable-include-path' + - "-j" + - "-Wall" + - "-Werror" + - "-Wcompat" + - "-Wredundant-constraints" + - "-fno-warn-missing-signatures" + - "-optP-Wno-nonportable-include-path" default-extensions: - DeriveDataTypeable @@ -104,7 +104,6 @@ library: - parsec - http-conduit - executables: nitta: main: Main @@ -112,23 +111,23 @@ executables: - Paths_nitta source-dirs: app ghc-options: - - '-threaded' - - '-rtsopts' - - '-with-rtsopts=-N' + - "-threaded" + - "-rtsopts" + - "-with-rtsopts=-N" dependencies: - cmdargs - hslogger - nitta - + nitta-api-gen: main: APIGen source-dirs: app other-modules: - Paths_nitta ghc-options: - - '-threaded' - - '-rtsopts' - - '-with-rtsopts=-N' + - "-threaded" + - "-rtsopts" + - "-with-rtsopts=-N" dependencies: - aeson - aeson-typescript @@ -137,16 +136,15 @@ executables: - hslogger - nitta - tests: nitta-test: main: Spec source-dirs: test ghc-options: - - '-threaded' - - '-rtsopts' - - '-with-rtsopts=-N' - - '-j' + - "-threaded" + - "-rtsopts" + - "-with-rtsopts=-N" + - "-j" dependencies: - QuickCheck - atomic-primops diff --git a/src/NITTA/Frontends/Lua.hs b/src/NITTA/Frontends/Lua.hs index c4ff86494..a6c6245a0 100644 --- a/src/NITTA/Frontends/Lua.hs +++ b/src/NITTA/Frontends/Lua.hs @@ -3,7 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} -{-# OPTIONS_GHC -Wno-type-defaults #-} +{-# OPTIONS_GHC -Wno-type-defaults -Wno-incomplete-uni-patterns #-} {- | Module : NITTA.Frontends.Lua @@ -229,7 +229,12 @@ getNextTmpVarName fOut return $ "_0#" <> fOut addStartupFuncArgs (FunCall (NormalFunCall _ (Args exps))) (FunAssign _ (FunBody names _ _)) = do - mapM_ (\(Name name, Number _ valueString, serialNumber) -> addToBuffer name valueString serialNumber) $ zip3 names exps [0 ..] + mapM_ + ( \case + (Name name, Number _ valueString, serialNumber) -> addToBuffer name valueString serialNumber + _ -> error "addStartupFuncArgs: internal error" + ) + $ zip3 names exps [0 ..] return "" where addToBuffer name valueString serialNumber = do @@ -251,7 +256,12 @@ processStatement fn (Assign lexps@[_] [Unop Neg (Number ntype ntext)]) = processStatement _ (Assign lexp [rexp]) = do parseRightExp (map parseLeftExp lexp) rexp processStatement startupFunctionName (Assign vars exps) | length vars == length exps = do - mapM_ (\(VarName (Name name), expr) -> processStatement startupFunctionName (Assign [VarName (Name (getTempAlias name))] [expr])) $ zip vars exps + mapM_ + ( \case + (VarName (Name name), expr) -> processStatement startupFunctionName (Assign [VarName (Name (getTempAlias name))] [expr]) + _ -> error "processStatement: internal error" + ) + $ zip vars exps mapM_ (\(VarName (Name name)) -> addAlias name (getTempAlias name)) vars where getTempAlias name = name <> "&" @@ -259,7 +269,7 @@ processStatement startupFunctionName (Assign vars exps) | length vars == length processStatement fn (FunCall (NormalFunCall (PEVar (VarName (Name fName))) (Args args))) | fn == fName = do LuaAlgBuilder{algStartupArgs} <- get - let startupVarsNames = map ((\(Just x) -> x) . (`HM.lookup` algStartupArgs)) [0 .. (HM.size algStartupArgs)] + let startupVarsNames = map (fromMaybe (error "processStatement: internal error") . (`HM.lookup` algStartupArgs)) [0 .. (HM.size algStartupArgs)] let startupVarsVersions = map (\x -> LuaValueInstance{lviName = fst x, lviAssignCount = 0, lviIsConstant = False}) startupVarsNames mapM_ parseStartupArg $ zip3 args startupVarsVersions (map (readText . snd) startupVarsNames) where diff --git a/src/NITTA/Intermediate/DataFlow.hs b/src/NITTA/Intermediate/DataFlow.hs index eec23f666..cb5ddf1ac 100644 --- a/src/NITTA/Intermediate/DataFlow.hs +++ b/src/NITTA/Intermediate/DataFlow.hs @@ -41,7 +41,7 @@ instance Eq (DataFlowGraph v x) where (DFLeaf f1) == (DFLeaf f2) = f1 == f2 _ == _ = False -instance (Var v) => Variables (DataFlowGraph v x) v where +instance Var v => Variables (DataFlowGraph v x) v where variables (DFLeaf fb) = variables fb variables (DFCluster g) = unionsMap variables g @@ -71,7 +71,7 @@ instance (Var v, Val x) => OptimizeAccumProblem (DataFlowGraph v x) v x where optimizeAccumDecision dfg ref@OptimizeAccum{} = fsToDataFlowGraph $ optimizeAccumDecision (functions dfg) ref -instance (Var v) => ResolveDeadlockProblem (DataFlowGraph v x) v x where +instance Var v => ResolveDeadlockProblem (DataFlowGraph v x) v x where resolveDeadlockOptions _dfg = [] resolveDeadlockDecision dfg ResolveDeadlock{newBuffer, changeset} = diff --git a/src/NITTA/Intermediate/Functions.hs b/src/NITTA/Intermediate/Functions.hs index 5226a08e6..49575e5af 100644 --- a/src/NITTA/Intermediate/Functions.hs +++ b/src/NITTA/Intermediate/Functions.hs @@ -156,11 +156,11 @@ instance Function (Loop v x) v where isInternalLockPossible _ = True inputs (Loop _ _a b) = variables b outputs (Loop _ a _b) = variables a -instance (Var v) => Patch (Loop v x) (v, v) where +instance Var v => Patch (Loop v x) (v, v) where patch diff (Loop x a b) = Loop x (patch diff a) (patch diff b) -instance (Var v) => Locks (Loop v x) v where +instance Var v => Locks (Loop v x) v where locks (Loop _ (O as) (I b)) = [Lock{locked = b, lockBy = a} | a <- elems as] -instance (Var v) => FunctionSimulation (Loop v x) v x where +instance Var v => FunctionSimulation (Loop v x) v x where simulate CycleCntx{cycleCntx} (Loop (X x) (O vs) (I _)) = case oneOf vs `HM.lookup` cycleCntx of -- if output variables are defined - nothing to do (values thrown on upper level) @@ -170,52 +170,52 @@ instance (Var v) => FunctionSimulation (Loop v x) v x where data LoopBegin v x = LoopBegin (Loop v x) (O v) deriving (Typeable, Eq) instance (Var v, Show x) => Show (LoopBegin v x) where show = label -instance (Var v) => Label (LoopBegin v x) where +instance Var v => Label (LoopBegin v x) where label (LoopBegin _ os) = "LoopBegin() = " <> show os -instance (Var v) => Function (LoopBegin v x) v where +instance Var v => Function (LoopBegin v x) v where outputs (LoopBegin _ o) = variables o isInternalLockPossible _ = True -instance (Var v) => Patch (LoopBegin v x) (v, v) where +instance Var v => Patch (LoopBegin v x) (v, v) where patch diff (LoopBegin l a) = LoopBegin (patch diff l) $ patch diff a -instance (Var v) => Locks (LoopBegin v x) v where +instance Var v => Locks (LoopBegin v x) v where locks _ = [] -instance (Var v) => FunctionSimulation (LoopBegin v x) v x where +instance Var v => FunctionSimulation (LoopBegin v x) v x where simulate cntx (LoopBegin l _) = simulate cntx l data LoopEnd v x = LoopEnd (Loop v x) (I v) deriving (Typeable, Eq) instance (Var v, Show x) => Show (LoopEnd v x) where show = label -instance (Var v) => Label (LoopEnd v x) where +instance Var v => Label (LoopEnd v x) where label (LoopEnd (Loop _ os _) i) = "LoopEnd(" <> show i <> ") pair out: " <> show os -instance (Var v) => Function (LoopEnd v x) v where +instance Var v => Function (LoopEnd v x) v where inputs (LoopEnd _ o) = variables o isInternalLockPossible _ = True -instance (Var v) => Patch (LoopEnd v x) (v, v) where +instance Var v => Patch (LoopEnd v x) (v, v) where patch diff (LoopEnd l a) = LoopEnd (patch diff l) $ patch diff a -instance (Var v) => Locks (LoopEnd v x) v where locks (LoopEnd l _) = locks l -instance (Var v) => FunctionSimulation (LoopEnd v x) v x where +instance Var v => Locks (LoopEnd v x) v where locks (LoopEnd l _) = locks l +instance Var v => FunctionSimulation (LoopEnd v x) v x where simulate cntx (LoopEnd l _) = simulate cntx l data Buffer v x = Buffer (I v) (O v) deriving (Typeable, Eq) instance Label (Buffer v x) where label Buffer{} = "buf" -instance (Var v) => Show (Buffer v x) where +instance Var v => Show (Buffer v x) where show (Buffer i os) = "buffer(" <> show i <> ")" <> " = " <> show os buffer :: (Var v, Val x) => v -> [v] -> F v x buffer a b = packF $ Buffer (I a) (O $ fromList b) -instance (Var v) => Function (Buffer v x) v where +instance Var v => Function (Buffer v x) v where inputs (Buffer a _b) = variables a outputs (Buffer _a b) = variables b -instance (Var v) => Patch (Buffer v x) (v, v) where +instance Var v => Patch (Buffer v x) (v, v) where patch diff (Buffer a b) = Buffer (patch diff a) (patch diff b) -instance (Var v) => Locks (Buffer v x) v where +instance Var v => Locks (Buffer v x) v where locks = inputsLockOutputs -instance (Var v) => FunctionSimulation (Buffer v x) v x where +instance Var v => FunctionSimulation (Buffer v x) v x where simulate cntx (Buffer (I a) (O vs)) = [(v, cntx `getCntx` a) | v <- elems vs] data Add v x = Add (I v) (I v) (O v) deriving (Typeable, Eq) instance Label (Add v x) where label Add{} = "+" -instance (Var v) => Show (Add v x) where +instance Var v => Show (Add v x) where show (Add a b c) = let lexp = show a <> " + " <> show b rexp = show c @@ -223,12 +223,12 @@ instance (Var v) => Show (Add v x) where add :: (Var v, Val x) => v -> v -> [v] -> F v x add a b c = packF $ Add (I a) (I b) $ O $ fromList c -instance (Var v) => Function (Add v x) v where +instance Var v => Function (Add v x) v where inputs (Add a b _c) = variables a `union` variables b outputs (Add _a _b c) = variables c -instance (Var v) => Patch (Add v x) (v, v) where +instance Var v => Patch (Add v x) (v, v) where patch diff (Add a b c) = Add (patch diff a) (patch diff b) (patch diff c) -instance (Var v) => Locks (Add v x) v where +instance Var v => Locks (Add v x) v where locks = inputsLockOutputs instance (Var v, Num x) => FunctionSimulation (Add v x) v x where simulate cntx (Add (I v1) (I v2) (O vs)) = @@ -239,7 +239,7 @@ instance (Var v, Num x) => FunctionSimulation (Add v x) v x where data Sub v x = Sub (I v) (I v) (O v) deriving (Typeable, Eq) instance Label (Sub v x) where label Sub{} = "-" -instance (Var v) => Show (Sub v x) where +instance Var v => Show (Sub v x) where show (Sub a b c) = let lexp = show a <> " - " <> show b rexp = show c @@ -247,12 +247,12 @@ instance (Var v) => Show (Sub v x) where sub :: (Var v, Val x) => v -> v -> [v] -> F v x sub a b c = packF $ Sub (I a) (I b) $ O $ fromList c -instance (Var v) => Function (Sub v x) v where +instance Var v => Function (Sub v x) v where inputs (Sub a b _c) = variables a `union` variables b outputs (Sub _a _b c) = variables c -instance (Var v) => Patch (Sub v x) (v, v) where +instance Var v => Patch (Sub v x) (v, v) where patch diff (Sub a b c) = Sub (patch diff a) (patch diff b) (patch diff c) -instance (Var v) => Locks (Sub v x) v where +instance Var v => Locks (Sub v x) v where locks = inputsLockOutputs instance (Var v, Num x) => FunctionSimulation (Sub v x) v x where simulate cntx (Sub (I v1) (I v2) (O vs)) = @@ -263,18 +263,18 @@ instance (Var v, Num x) => FunctionSimulation (Sub v x) v x where data Multiply v x = Multiply (I v) (I v) (O v) deriving (Typeable, Eq) instance Label (Multiply v x) where label Multiply{} = "*" -instance (Var v) => Show (Multiply v x) where +instance Var v => Show (Multiply v x) where show (Multiply a b c) = show a <> " * " <> show b <> " = " <> show c multiply :: (Var v, Val x) => v -> v -> [v] -> F v x multiply a b c = packF $ Multiply (I a) (I b) $ O $ fromList c -instance (Var v) => Function (Multiply v x) v where +instance Var v => Function (Multiply v x) v where inputs (Multiply a b _c) = variables a `union` variables b outputs (Multiply _a _b c) = variables c -instance (Var v) => Patch (Multiply v x) (v, v) where +instance Var v => Patch (Multiply v x) (v, v) where patch diff (Multiply a b c) = Multiply (patch diff a) (patch diff b) (patch diff c) -instance (Var v) => Locks (Multiply v x) v where +instance Var v => Locks (Multiply v x) v where locks = inputsLockOutputs instance (Var v, Num x) => FunctionSimulation (Multiply v x) v x where simulate cntx (Multiply (I v1) (I v2) (O vs)) = @@ -289,7 +289,7 @@ data Division v x = Division } deriving (Typeable, Eq) instance Label (Division v x) where label Division{} = "/" -instance (Var v) => Show (Division v x) where +instance Var v => Show (Division v x) where show Division{denom, numer, quotient, remain} = let q = show numer <> " / " <> show denom <> " = " <> show quotient r = show numer <> " mod " <> show denom <> " = " <> show remain @@ -304,12 +304,12 @@ division d n q r = , remain = O $ fromList r } -instance (Var v) => Function (Division v x) v where +instance Var v => Function (Division v x) v where inputs Division{denom, numer} = variables denom `union` variables numer outputs Division{quotient, remain} = variables quotient `union` variables remain -instance (Var v) => Patch (Division v x) (v, v) where +instance Var v => Patch (Division v x) (v, v) where patch diff (Division a b c d) = Division (patch diff a) (patch diff b) (patch diff c) (patch diff d) -instance (Var v) => Locks (Division v x) v where +instance Var v => Locks (Division v x) v where locks = inputsLockOutputs instance (Var v, Integral x) => FunctionSimulation (Division v x) v x where simulate cntx Division{denom = I d, numer = I n, quotient = O qs, remain = O rs} = @@ -320,18 +320,18 @@ instance (Var v, Integral x) => FunctionSimulation (Division v x) v x where data Neg v x = Neg (I v) (O v) deriving (Typeable, Eq) instance Label (Neg v x) where label Neg{} = "neg" -instance (Var v) => Show (Neg v x) where +instance Var v => Show (Neg v x) where show (Neg i o) = "-" <> show i <> " = " <> show o neg :: (Var v, Val x) => v -> [v] -> F v x neg i o = packF $ Neg (I i) $ O $ fromList o -instance (Ord v) => Function (Neg v x) v where +instance Ord v => Function (Neg v x) v where inputs (Neg i _) = variables i outputs (Neg _ o) = variables o -instance (Ord v) => Patch (Neg v x) (v, v) where +instance Ord v => Patch (Neg v x) (v, v) where patch diff (Neg i o) = Neg (patch diff i) (patch diff o) -instance (Var v) => Locks (Neg v x) v where +instance Var v => Locks (Neg v x) v where locks = inputsLockOutputs instance (Var v, Num x) => FunctionSimulation (Neg v x) v x where simulate cntx (Neg (I i) (O o)) = @@ -340,7 +340,7 @@ instance (Var v, Num x) => FunctionSimulation (Neg v x) v x where in [(v, y) | v <- elems o] data Constant v x = Constant (X x) (O v) deriving (Typeable, Eq) -instance (Show x) => Label (Constant v x) where label (Constant (X x) _) = show x +instance Show x => Label (Constant v x) where label (Constant (X x) _) = show x instance (Var v, Show x) => Show (Constant v x) where show (Constant (X x) os) = "const(" <> show x <> ") = " <> show os constant :: (Var v, Val x) => x -> [v] -> F v x @@ -351,9 +351,9 @@ isConst f instance (Show x, Eq x, Typeable x) => Function (Constant v x) v where outputs (Constant _ o) = variables o -instance (Var v) => Patch (Constant v x) (v, v) where +instance Var v => Patch (Constant v x) (v, v) where patch diff (Constant x a) = Constant x (patch diff a) -instance (Var v) => Locks (Constant v x) v where locks _ = [] +instance Var v => Locks (Constant v x) v where locks _ = [] instance FunctionSimulation (Constant v x) v x where simulate _cntx (Constant (X x) (O vs)) = [(v, x) | v <- elems vs] @@ -365,25 +365,25 @@ data ShiftLR v x | ShiftR Int (I v) (O v) deriving (Typeable, Eq) -instance (Var v) => Show (ShiftLR v x) where +instance Var v => Show (ShiftLR v x) where show (ShiftL s i os) = show i <> " << " <> show s <> " = " <> show os show (ShiftR s i os) = show i <> " >> " <> show s <> " = " <> show os -instance (Var v) => Label (ShiftLR v x) where label = show +instance Var v => Label (ShiftLR v x) where label = show shiftL :: (Var v, Val x) => Int -> v -> [v] -> F v x shiftL s i o = packF $ ShiftL s (I i) $ O $ fromList o shiftR :: (Var v, Val x) => Int -> v -> [v] -> F v x shiftR s i o = packF $ ShiftR s (I i) $ O $ fromList o -instance (Var v) => Function (ShiftLR v x) v where +instance Var v => Function (ShiftLR v x) v where inputs (ShiftL _ i _) = variables i inputs (ShiftR _ i _) = variables i outputs (ShiftL _ _ o) = variables o outputs (ShiftR _ _ o) = variables o -instance (Var v) => Patch (ShiftLR v x) (v, v) where +instance Var v => Patch (ShiftLR v x) (v, v) where patch diff (ShiftL s i o) = ShiftL s (patch diff i) (patch diff o) patch diff (ShiftR s i o) = ShiftR s (patch diff i) (patch diff o) -instance (Var v) => Locks (ShiftLR v x) v where +instance Var v => Locks (ShiftLR v x) v where locks = inputsLockOutputs instance (Var v, B.Bits x) => FunctionSimulation (ShiftLR v x) v x where simulate cntx (ShiftL s (I i) (O os)) = do @@ -392,30 +392,30 @@ instance (Var v, B.Bits x) => FunctionSimulation (ShiftLR v x) v x where [(o, getCntx cntx i `B.shiftR` s) | o <- elems os] newtype Send v x = Send (I v) deriving (Typeable, Eq) -instance (Var v) => Show (Send v x) where +instance Var v => Show (Send v x) where show (Send i) = "send(" <> show i <> ")" instance Label (Send v x) where label Send{} = "send" send :: (Var v, Val x) => v -> F v x send a = packF $ Send $ I a -instance (Var v) => Function (Send v x) v where +instance Var v => Function (Send v x) v where inputs (Send i) = variables i -instance (Var v) => Patch (Send v x) (v, v) where +instance Var v => Patch (Send v x) (v, v) where patch diff (Send a) = Send (patch diff a) -instance (Var v) => Locks (Send v x) v where locks _ = [] +instance Var v => Locks (Send v x) v where locks _ = [] instance FunctionSimulation (Send v x) v x where simulate _cntx Send{} = [] newtype Receive v x = Receive (O v) deriving (Typeable, Eq) -instance (Var v) => Show (Receive v x) where +instance Var v => Show (Receive v x) where show (Receive os) = "receive() = " <> show os instance Label (Receive v x) where label Receive{} = "receive" receive :: (Var v, Val x) => [v] -> F v x receive a = packF $ Receive $ O $ fromList a -instance (Var v) => Function (Receive v x) v where +instance Var v => Function (Receive v x) v where outputs (Receive o) = variables o -instance (Var v) => Patch (Receive v x) (v, v) where +instance Var v => Patch (Receive v x) (v, v) where patch diff (Receive a) = Receive (patch diff a) -instance (Var v) => Locks (Receive v x) v where locks _ = [] +instance Var v => Locks (Receive v x) v where locks _ = [] instance (Var v, Val x) => FunctionSimulation (Receive v x) v x where simulate CycleCntx{cycleCntx} (Receive (O vs)) = case oneOf vs `HM.lookup` cycleCntx of @@ -428,17 +428,17 @@ instance (Var v, Val x) => FunctionSimulation (Receive v x) v x where data BrokenBuffer v x = BrokenBuffer (I v) (O v) deriving (Typeable, Eq) instance Label (BrokenBuffer v x) where label BrokenBuffer{} = "broken" -instance (Var v) => Show (BrokenBuffer v x) where +instance Var v => Show (BrokenBuffer v x) where show (BrokenBuffer i os) = "brokenBuffer(" <> show i <> ")" <> " = " <> show os brokenBuffer :: (Var v, Val x) => v -> [v] -> F v x brokenBuffer a b = packF $ BrokenBuffer (I a) (O $ fromList b) -instance (Var v) => Function (BrokenBuffer v x) v where +instance Var v => Function (BrokenBuffer v x) v where inputs (BrokenBuffer a _b) = variables a outputs (BrokenBuffer _a b) = variables b -instance (Var v) => Patch (BrokenBuffer v x) (v, v) where +instance Var v => Patch (BrokenBuffer v x) (v, v) where patch diff (BrokenBuffer a b) = BrokenBuffer (patch diff a) (patch diff b) -instance (Var v) => Locks (BrokenBuffer v x) v where +instance Var v => Locks (BrokenBuffer v x) v where locks = inputsLockOutputs -instance (Var v) => FunctionSimulation (BrokenBuffer v x) v x where +instance Var v => FunctionSimulation (BrokenBuffer v x) v x where simulate cntx (BrokenBuffer (I a) (O vs)) = [(v, cntx `getCntx` a) | v <- elems vs] diff --git a/src/NITTA/Intermediate/Functions/Accum.hs b/src/NITTA/Intermediate/Functions/Accum.hs index 193650b5a..d880fa69b 100644 --- a/src/NITTA/Intermediate/Functions/Accum.hs +++ b/src/NITTA/Intermediate/Functions/Accum.hs @@ -43,7 +43,7 @@ instance Show Sign where data Action v = Push Sign (I v) | Pull (O v) deriving (Typeable, Eq) -instance (Var v) => Show (Action v) where +instance Var v => Show (Action v) where show (Push s (I v)) = show s <> toString v show (Pull (O vs)) = S.join " " $ map ("= " <>) $ vsToStringList vs @@ -53,7 +53,7 @@ instance Variables (Action v) v where newtype Acc v x = Acc {actions :: [Action v]} deriving (Typeable, Eq) -instance (Var v) => Show (Acc v x) where +instance Var v => Show (Acc v x) where show (Acc acts) = let lastElement = last acts initElements = init acts @@ -82,11 +82,11 @@ fromPush _ = error "Error in fromPush function in acc" fromPull (Pull (O vs)) = vs fromPull _ = error "Error in fromPull function in acc" -instance (Ord v) => Function (Acc v x) v where +instance Ord v => Function (Acc v x) v where inputs (Acc lst) = S.fromList $ map fromPush $ filter isPush lst outputs (Acc lst) = unionsMap fromPull $ filter isPull lst -instance (Ord v) => Patch (Acc v x) (v, v) where +instance Ord v => Patch (Acc v x) (v, v) where patch diff (Acc lst) = Acc $ nub $ @@ -108,15 +108,30 @@ toBlocksSplit exprInput = in splitBySemicolon $ matchAll exprPattern filtered [] accGen blocks = - let partedExpr = map (partition (\(x : _) -> x /= '=')) + let partedExpr = + map + ( partition $ \case + (x : _) -> x /= '=' + x -> error $ "error in accGen: " <> show x + ) signPush ('+' : name) = Push Plus (I $ T.pack name) signPush ('-' : name) = Push Minus (I $ T.pack name) signPush _ = error "Error in matching + and -" pushCreate lst = map signPush lst - pullCreate lst = Pull $ O $ S.fromList $ foldl (\buff (_ : name) -> T.pack name : buff) [] lst + pullCreate lst = + Pull $ + O $ + S.fromList $ + foldl + ( \buff -> \case + (_ : name) -> T.pack name : buff + _ -> error "accGen internal error" + ) + [] + lst in Acc $ concatMap (\(push, pull) -> pushCreate push ++ [pullCreate pull]) $ partedExpr blocks -instance (Var v) => Locks (Acc v x) v where +instance Var v => Locks (Acc v x) v where locks (Acc actions) = let (lockByActions, lockedActions) = span isPush actions in [ Lock{locked, lockBy} diff --git a/src/NITTA/Intermediate/Types.hs b/src/NITTA/Intermediate/Types.hs index 36002b151..d4af7217d 100644 --- a/src/NITTA/Intermediate/Types.hs +++ b/src/NITTA/Intermediate/Types.hs @@ -82,9 +82,9 @@ import Text.PrettyPrint.Boxes hiding ((<>)) newtype I v = I v deriving (Eq, Ord) -instance (ToString v) => Show (I v) where show (I v) = toString v +instance ToString v => Show (I v) where show (I v) = toString v -instance (Eq v) => Patch (I v) (v, v) where +instance Eq v => Patch (I v) (v, v) where patch (v, v') i@(I v0) | v0 == v = I v' | otherwise = i @@ -96,10 +96,10 @@ instance Variables (I v) v where newtype O v = O (S.Set v) deriving (Eq, Ord) -instance (Ord v) => Patch (O v) (v, v) where +instance Ord v => Patch (O v) (v, v) where patch (v, v') (O vs) = O $ S.fromList $ map (\e -> if e == v then v' else e) $ S.elems vs -instance (ToString v) => Show (O v) where +instance ToString v => Show (O v) where show (O vs) | S.null vs = "_" | otherwise = S.join " = " $ vsToStringList vs @@ -119,7 +119,7 @@ For example: > c := a + b > [ Lock{ locked=c, lockBy=a }, Lock{ locked=c, lockBy=b } ] -} -class (Var v) => Locks x v | x -> v where +class Var v => Locks x v | x -> v where locks :: x -> [Lock v] -- | Variable casuality. @@ -129,11 +129,11 @@ data Lock v = Lock } deriving (Eq, Ord, Generic) -instance (ToString v) => Show (Lock v) where +instance ToString v => Show (Lock v) where show Lock{locked, lockBy} = "Lock{locked=" <> toString locked <> ", lockBy=" <> toString lockBy <> "}" -instance (ToJSON v) => ToJSON (Lock v) +instance ToJSON v => ToJSON (Lock v) -- | All input variables locks all output variables. inputsLockOutputs f = @@ -210,7 +210,7 @@ instance FunctionSimulation (F v x) v x where instance Label (F v x) where label F{fun} = label fun -instance (Var v) => Locks (F v x) v where +instance Var v => Locks (F v x) v where locks F{fun} = locks fun instance Ord (F v x) where @@ -223,7 +223,7 @@ instance Patch (F v x) (v, v) where , funHistory = fun0 : funHistory } -instance (Ord v) => Patch (F v x) (Changeset v) where +instance Ord v => Patch (F v x) (Changeset v) where patch Changeset{changeI, changeO} f0 = let changeI' = mapMaybe @@ -244,13 +244,13 @@ instance (Ord v) => Patch (F v x) (Changeset v) where $ outputs f0 in foldl (\f diff -> patch diff f) f0 $ changeI' ++ changeO' -instance (Patch b v) => Patch [b] v where +instance Patch b v => Patch [b] v where patch diff fs = map (patch diff) fs instance Show (F v x) where show F{fun} = show fun -instance (Var v) => Variables (F v x) v where +instance Var v => Variables (F v x) v where variables F{fun} = inputs fun `S.union` outputs fun -- | Helper for extraction function from existential container 'F'. @@ -301,7 +301,7 @@ data Cntx v x = Cntx , cntxCycleNumber :: Int } -instance (Show x) => Show (Cntx String x) where +instance Show x => Show (Cntx String x) where show Cntx{cntxProcess} = log2md $ map (HM.map show . cycleCntx) cntxProcess log2list cntxProcess0 = @@ -325,7 +325,9 @@ log2md records = let n = length records cntx2listCycle = ("Cycle" : map show [1 .. n]) : log2list records maxLength t = length $ foldr1 (\x y -> if length x >= length y then x else y) t - cycleFormattedTable = map ((\x@(x1 : x2 : xs) -> x1 : ("|:" ++ replicate (maxLength x) '-') : x2 : xs) . map ("| " ++)) cntx2listCycle ++ [replicate (n + 2) "|"] + formatCell x@(x1 : x2 : xs) = x1 : ("|:" ++ replicate (maxLength x) '-') : x2 : xs + formatCell x = error $ "formatCell: unexpected sequence:" <> show x + cycleFormattedTable = map (formatCell . map ("| " ++)) cntx2listCycle ++ [replicate (n + 2) "|"] in render ( hsep 0 left $ map (vcat left . map text) cycleFormattedTable @@ -347,8 +349,11 @@ log2md records = ] -} log2json records = - let listHashMap = transpose $ map (\(k : vs) -> map (\v -> (k, read v :: Double)) vs) $ log2list records + let listHashMap = transpose $ map varAndValues $ log2list records in encodePretty $ map HM.fromList listHashMap + where + varAndValues (k : vs) = map (\v -> (k, read v :: Double)) vs + varAndValues x = error $ "varAndValues: unexpected sequence:" <> show x {- | >>> import qualified Data.ByteString.Lazy.Char8 as BS @@ -369,7 +374,7 @@ instance Default (Cntx v x) where } -- | Make sequence of received values '[ Map v x ]' -cntxReceivedBySlice :: (Ord v) => Cntx v x -> [M.Map v x] +cntxReceivedBySlice :: Ord v => Cntx v x -> [M.Map v x] cntxReceivedBySlice Cntx{cntxReceived} = cntxReceivedBySlice' $ M.assocs cntxReceived cntxReceivedBySlice' received @@ -410,7 +415,7 @@ data Changeset v = Changeset } deriving (Eq) -instance (Var v) => Show (Changeset v) where +instance Var v => Show (Changeset v) where show Changeset{changeI, changeO} = let changeI' = S.join ", " $ map (\(a, b) -> "(" <> toString a <> ", " <> toString b <> ")") $ M.assocs changeI changeO' = S.join ", " $ map (\(a, bs) -> "(" <> toString a <> ", [" <> S.join ", " (vsToStringList bs) <> "])") $ M.assocs changeO diff --git a/src/NITTA/Intermediate/Value.hs b/src/NITTA/Intermediate/Value.hs index 5ca71753a..0d96aef6b 100644 --- a/src/NITTA/Intermediate/Value.hs +++ b/src/NITTA/Intermediate/Value.hs @@ -157,7 +157,7 @@ valueMask :: Val x => x -> x valueMask x = fromRaw (setBit (0 :: Integer) (dataWidth x - 1) - 1) 0 -- TODO: try to avoid this class -class (Default x) => DefaultX u x | u -> x where +class Default x => DefaultX u x | u -> x where defX :: u -> x defX _ = def @@ -173,7 +173,7 @@ scalingFactor x = 2 ** fromIntegral (scalingFactorPower x) -- | All values with attributes. data Attr x = Attr {value :: x, invalid :: Bool} deriving (Eq, Ord) -instance (Validity x) => Validity (Attr x) where +instance Validity x => Validity (Attr x) where validate Attr{value} = validate value setInvalidAttr Attr{value, invalid} = Attr{value, invalid = invalid || isInvalid value} @@ -188,19 +188,19 @@ instance Applicative Attr where let value = f x y in Attr{value, invalid = x' || y'} -instance (Show x) => Show (Attr x) where +instance Show x => Show (Attr x) where show Attr{invalid = True} = "NaN" show Attr{value, invalid = False} = show value -instance (Read x) => Read (Attr x) where +instance Read x => Read (Attr x) where readsPrec d r = case readsPrec d r of [(x, r')] -> [(pure x, r')] _ -> error $ "can not read IntX from: " ++ r -instance (PrintfArg x) => PrintfArg (Attr x) where +instance PrintfArg x => PrintfArg (Attr x) where formatArg Attr{value} = formatArg value -instance (Default x) => Default (Attr x) where +instance Default x => Default (Attr x) where def = pure def instance (Enum x, Validity x) => Enum (Attr x) where @@ -215,10 +215,10 @@ instance (Num x, Validity x) => Num (Attr x) where fromInteger = setInvalidAttr . pure . fromInteger negate = setInvalidAttr . fmap negate -instance (Ord x, Real x, Validity x) => Real (Attr x) where +instance (Real x, Validity x) => Real (Attr x) where toRational Attr{value} = toRational value -instance (Integral x, Validity x, Val x) => Integral (Attr x) where +instance Val x => Integral (Attr x) where toInteger Attr{value} = toInteger value Attr{value = a} `quotRem` Attr{value = b} = let (minB, maxB) = minMaxRaw' (dataWidth b `shiftR` 1) @@ -242,7 +242,7 @@ instance (Bits x, Validity x) => Bits (Attr x) where bit ix = pure $ bit ix popCount Attr{value} = popCount value -instance (Val x, Integral x) => Val (Attr x) where +instance Val x => Val (Attr x) where dataWidth Attr{value} = dataWidth value rawData Attr{value} = rawData value @@ -259,11 +259,11 @@ instance (Val x, Integral x) => Val (Attr x) where verilogHelper Attr{value} = verilogHelper value verilogAssertRE Attr{value} = verilogAssertRE value -instance (FixedPointCompatible x) => FixedPointCompatible (Attr x) where +instance FixedPointCompatible x => FixedPointCompatible (Attr x) where scalingFactorPower Attr{value} = scalingFactorPower value fractionalBitSize Attr{value} = fractionalBitSize value -instance (ToJSON x) => ToJSON (Attr x) where +instance ToJSON x => ToJSON (Attr x) where toJSON Attr{value} = toJSON value -- * Integer @@ -285,7 +285,7 @@ instance Val Int where newtype IntX (w :: Nat) = IntX {intX :: Integer} deriving (Show, Eq, Ord) -instance (KnownNat m) => Validity (IntX m) where +instance KnownNat m => Validity (IntX m) where validate x@(IntX raw) = let (minRaw, maxRaw) = minMaxRaw x in check (minRaw <= raw && raw <= maxRaw) "value is not out of range" @@ -322,7 +322,7 @@ instance Integral (IntX w) where let (a', b') = a `quotRem` b in (IntX a', IntX b') -instance (KnownNat w) => Bits (IntX w) where +instance KnownNat w => Bits (IntX w) where (IntX a) .&. (IntX b) = IntX (a .&. b) (IntX a) .|. (IntX b) = IntX (a .|. b) (IntX a) `xor` (IntX b) = IntX (a `xor` b) @@ -337,7 +337,7 @@ instance (KnownNat w) => Bits (IntX w) where bit ix = IntX $ bit ix popCount (IntX a) = popCount a -instance (KnownNat w) => Val (IntX w) where +instance KnownNat w => Val (IntX w) where dataWidth _ = fromInteger $ natVal (Proxy :: Proxy w) rawData (IntX x) = fromIntegral x @@ -375,7 +375,9 @@ instance (KnownNat b, KnownNat m) => Validity (FX m b) where instance (KnownNat m, KnownNat b) => Read (FX m b) where readsPrec d r = - let [(x :: Double, "")] = readsPrec d r + let x = case readsPrec d r of + [(x' :: Double, "")] -> x' + _ -> error $ "can not read FX from: " ++ r result = FX $ round (x * scalingFactor result) in [(result, "")] diff --git a/src/NITTA/Model/Microarchitecture/Types.hs b/src/NITTA/Model/Microarchitecture/Types.hs index 1df23d7e8..c9448a317 100644 --- a/src/NITTA/Model/Microarchitecture/Types.hs +++ b/src/NITTA/Model/Microarchitecture/Types.hs @@ -31,7 +31,7 @@ data MicroarchitectureDesc tag = MicroarchitectureDesc } deriving (Generic) -instance (ToJSON tag) => ToJSON (MicroarchitectureDesc tag) +instance ToJSON tag => ToJSON (MicroarchitectureDesc tag) data NetworkDesc tag = NetworkDesc { networkTag :: tag @@ -40,7 +40,7 @@ data NetworkDesc tag = NetworkDesc } deriving (Generic) -instance (ToJSON tag) => ToJSON (NetworkDesc tag) +instance ToJSON tag => ToJSON (NetworkDesc tag) data UnitDesc tag = UnitDesc { unitTag :: tag @@ -48,9 +48,9 @@ data UnitDesc tag = UnitDesc } deriving (Generic) -instance (ToJSON tag) => ToJSON (UnitDesc tag) +instance ToJSON tag => ToJSON (UnitDesc tag) -microarchitectureDesc :: forall tag v x t. (Typeable x) => BusNetwork tag v x t -> MicroarchitectureDesc tag +microarchitectureDesc :: forall tag v x t. Typeable x => BusNetwork tag v x t -> MicroarchitectureDesc tag microarchitectureDesc BusNetwork{bnName, bnPus, ioSync} = MicroarchitectureDesc { networks = diff --git a/src/NITTA/Model/Networks/Bus.hs b/src/NITTA/Model/Networks/Bus.hs index 7d80b7df4..ccbedba98 100644 --- a/src/NITTA/Model/Networks/Bus.hs +++ b/src/NITTA/Model/Networks/Bus.hs @@ -94,14 +94,14 @@ busNetwork name iosync = instance (Default t, IsString tag) => Default (BusNetwork tag v x t) where def = busNetwork "defaultBus" ASync -instance (Var v) => Variables (BusNetwork tag v x t) v where +instance Var v => Variables (BusNetwork tag v x t) v where variables BusNetwork{bnBinded} = unionsMap variables $ concat $ M.elems bnBinded bindedFunctions puTitle BusNetwork{bnBinded} | puTitle `M.member` bnBinded = bnBinded M.! puTitle | otherwise = [] -instance (Default x) => DefaultX (BusNetwork tag v x t) x +instance Default x => DefaultX (BusNetwork tag v x t) x instance WithFunctions (BusNetwork tag v x t) (F v x) where functions BusNetwork{bnRemains, bnBinded} = bnRemains ++ concat (M.elems bnBinded) @@ -232,13 +232,15 @@ instance (UnitTag tag, VarValTime v x t) => ProcessorUnit (BusNetwork tag v x t) -- Vertical relations between FB and Transport mapM_ - ( \Step{pID, pDesc = NestedStep{nStep = Step{pDesc = IntermediateStep f}}} -> - mapM_ - ( \v -> - when (v `M.member` v2transportStepKey) $ - establishVerticalRelations [pID] [v2transportStepKey M.! v] - ) - $ variables f + ( \case + Step{pID, pDesc = NestedStep{nStep = Step{pDesc = IntermediateStep f}}} -> + mapM_ + ( \v -> + when (v `M.member` v2transportStepKey) $ + establishVerticalRelations [pID] [v2transportStepKey M.! v] + ) + $ variables f + _ -> error "Bus: process: insternal error" ) $ filter isIntermediate steps in wholeProcess @@ -321,7 +323,7 @@ instance (UnitTag tag, VarValTime v x t) => BreakLoopProblem (BusNetwork tag v x breakLoopOptions BusNetwork{bnPus} = concatMap breakLoopOptions $ M.elems bnPus breakLoopDecision bn@BusNetwork{bnBinded, bnPus} bl@BreakLoop{} = - let Just (puTag, bindedToPU) = L.find (elem (recLoop bl) . snd) $ M.assocs bnBinded + let (puTag, bindedToPU) = fromJust $ L.find (elem (recLoop bl) . snd) $ M.assocs bnBinded bindedToPU' = recLoopIn bl : recLoopOut bl : (bindedToPU L.\\ [recLoop bl]) in bn { bnPus = M.adjust (`breakLoopDecision` bl) puTag bnPus @@ -407,9 +409,10 @@ instance (UnitTag tag, VarValTime v x t) => ResolveDeadlockProblem (BusNetwork t resolveDeadlockDecision bn@BusNetwork{bnRemains, bnBinded, bnPus, bnProcess} ref@ResolveDeadlock{newBuffer, changeset} = - let Just (tag, _) = - L.find - (\(_, f) -> not $ null $ S.intersection (outputs newBuffer) $ unionsMap outputs f) + let (tag, _) = + fromJust + $ L.find + (\(_, f) -> not $ null $ S.intersection (outputs newBuffer) $ unionsMap outputs f) $ M.assocs bnBinded in bn { bnRemains = newBuffer : patch changeset bnRemains @@ -419,7 +422,7 @@ instance (UnitTag tag, VarValTime v x t) => ResolveDeadlockProblem (BusNetwork t scheduleRefactoring (I.singleton $ nextTick bn) ref } -instance (UnitTag tag) => AllocationProblem (BusNetwork tag v x t) tag where +instance UnitTag tag => AllocationProblem (BusNetwork tag v x t) tag where allocationOptions BusNetwork{bnName, bnRemains, bnPUPrototypes} = map toOptions $ M.keys $ M.filter (\PUPrototype{pProto} -> any (`allowToProcess` pProto) bnRemains) bnPUPrototypes where @@ -574,7 +577,7 @@ instance (UnitTag tag, VarValTime v x t) => TargetSystemComponent (BusNetwork ta values (BusNetworkMC arr) = reverse $ map snd $ - L.sortOn ((\(Just [ix]) -> read ix :: Int) . matchRegex (mkRegex "([[:digit:]]+)") . T.unpack . signalTag . fst) $ + L.sortOn ((\ix -> read ix :: Int) . head . fromJust . matchRegex (mkRegex "([[:digit:]]+)") . T.unpack . signalTag . fst) $ M.assocs arr hardwareInstance tag BusNetwork{} UnitEnv{sigRst, sigClk, ioPorts = Just ioPorts} diff --git a/src/NITTA/Model/Networks/Types.hs b/src/NITTA/Model/Networks/Types.hs index 06b6cfcc4..32000a304 100644 --- a/src/NITTA/Model/Networks/Types.hs +++ b/src/NITTA/Model/Networks/Types.hs @@ -55,14 +55,14 @@ type PUClasses pu v x t = -- | Existential container for a processor unit . data PU v x t where PU :: - (PUClasses pu v x t) => + PUClasses pu v x t => { unit :: pu , diff :: Changeset v , uEnv :: UnitEnv pu } -> PU v x t -instance (Ord v) => EndpointProblem (PU v x t) v t where +instance Ord v => EndpointProblem (PU v x t) v t where endpointOptions PU{diff, unit} = map (patch diff) $ endpointOptions unit @@ -88,7 +88,7 @@ instance ResolveDeadlockProblem (PU v x t) v x where resolveDeadlockDecision PU{diff, unit, uEnv} d = PU{unit = resolveDeadlockDecision unit d, diff, uEnv} -instance (VarValTime v x t) => ProcessorUnit (PU v x t) v x t where +instance VarValTime v x t => ProcessorUnit (PU v x t) v x t where tryBind fb PU{diff, unit, uEnv} = case tryBind fb unit of Right unit' -> Right PU{unit = unit', diff, uEnv} @@ -98,7 +98,7 @@ instance (VarValTime v x t) => ProcessorUnit (PU v x t) v x t where in p{steps = map (patch diff) $ steps p} parallelismType PU{unit} = parallelismType unit -instance (Ord v) => Patch (PU v x t) (Changeset v) where +instance Ord v => Patch (PU v x t) (Changeset v) where patch diff' PU{unit, diff, uEnv} = PU { unit @@ -110,10 +110,10 @@ instance (Ord v) => Patch (PU v x t) (Changeset v) where , uEnv } -instance (Ord v) => Patch (PU v x t) (I v, I v) where +instance Ord v => Patch (PU v x t) (I v, I v) where patch (I v, I v') pu@PU{diff = diff@Changeset{changeI}} = pu{diff = diff{changeI = M.insert v v' changeI}} -instance (Ord v) => Patch (PU v x t) (O v, O v) where +instance Ord v => Patch (PU v x t) (O v, O v) where patch (O vs, O vs') pu@PU{diff = diff@Changeset{changeO}} = pu { diff = @@ -126,7 +126,7 @@ instance (Ord v) => Patch (PU v x t) (O v, O v) where } } -instance (Var v) => Locks (PU v x t) v where +instance Var v => Locks (PU v x t) v where locks PU{unit, diff = diff@Changeset{changeI, changeO}} | not $ M.null changeI = error $ "Locks (PU v x t) with non empty changeI: " <> show diff | otherwise = diff --git a/src/NITTA/Model/Problems/Allocation.hs b/src/NITTA/Model/Problems/Allocation.hs index 9ac355c73..c5e6fbd63 100644 --- a/src/NITTA/Model/Problems/Allocation.hs +++ b/src/NITTA/Model/Problems/Allocation.hs @@ -27,7 +27,7 @@ data Allocation tag = Allocation } deriving (Generic, Eq) -instance (ToString tag) => Show (Allocation tag) where +instance ToString tag => Show (Allocation tag) where show Allocation{networkTag, processUnitTag} = "Allocation of " <> toString processUnitTag <> " on " <> toString networkTag class AllocationProblem u tag | u -> tag where diff --git a/src/NITTA/Model/Problems/Bind.hs b/src/NITTA/Model/Problems/Bind.hs index a3f940e8f..f5a734bd6 100644 --- a/src/NITTA/Model/Problems/Bind.hs +++ b/src/NITTA/Model/Problems/Bind.hs @@ -22,12 +22,12 @@ data Bind tag v x = Bind (F v x) tag deriving (Generic, Eq) -instance (ToString tag) => Show (Bind tag v x) where +instance ToString tag => Show (Bind tag v x) where show (Bind f tag) = "Bind " <> show f <> " " <> toString tag class BindProblem u tag v x | u -> tag v x where bindOptions :: u -> [Bind tag v x] bindDecision :: u -> Bind tag v x -> u -instance (Var v) => Variables (Bind tab v x) v where +instance Var v => Variables (Bind tab v x) v where variables (Bind f _tag) = variables f diff --git a/src/NITTA/Model/Problems/Dataflow.hs b/src/NITTA/Model/Problems/Dataflow.hs index 340ccf59d..ff46fce60 100644 --- a/src/NITTA/Model/Problems/Dataflow.hs +++ b/src/NITTA/Model/Problems/Dataflow.hs @@ -44,7 +44,7 @@ instance (ToString tag, Show (EndpointSt v tp)) => Show (DataflowSt tag v tp) wh where show' (tag, ep) = "(" <> toString tag <> ", " <> show ep <> ")" -instance (Ord v) => Variables (DataflowSt tag v tp) v where +instance Ord v => Variables (DataflowSt tag v tp) v where variables DataflowSt{dfTargets} = unionsMap (variables . snd) dfTargets {- | Implemented for any things, which can send data between processor units over @@ -55,7 +55,7 @@ class DataflowProblem u tag v t | u -> tag v t where dataflowDecision :: u -> DataflowSt tag v (Interval t) -> u -- | Convert dataflow option to decision. -dataflowOption2decision :: (Time t) => DataflowSt tag v (TimeConstraint t) -> DataflowSt tag v (Interval t) +dataflowOption2decision :: Time t => DataflowSt tag v (TimeConstraint t) -> DataflowSt tag v (Interval t) dataflowOption2decision (DataflowSt (srcTag, srcEp) trgs) = let targetsAt = map (epAt . snd) trgs diff --git a/src/NITTA/Model/Problems/Endpoint.hs b/src/NITTA/Model/Problems/Endpoint.hs index e54a01ed9..e7a0a1943 100644 --- a/src/NITTA/Model/Problems/Endpoint.hs +++ b/src/NITTA/Model/Problems/Endpoint.hs @@ -51,7 +51,7 @@ instance (ToString v, Time t) => Show (EndpointSt v (TimeConstraint t)) where instance (ToString v, Time t) => Show (EndpointSt v (Interval t)) where show EndpointSt{epRole, epAt} = "!" <> show epRole <> "@(" <> show epAt <> ")" -instance (Ord v) => Patch (EndpointSt v tp) (Changeset v) where +instance Ord v => Patch (EndpointSt v tp) (Changeset v) where patch diff ep@EndpointSt{epRole} = ep{epRole = patch diff epRole} instance (ToJSON v, ToJSON tp) => ToJSON (EndpointSt v tp) @@ -76,11 +76,11 @@ data EndpointRole v Target v deriving (Eq, Ord, Generic) -instance (ToString v) => Show (EndpointRole v) where +instance ToString v => Show (EndpointRole v) where show (Source vs) = "Source " <> S.join "," (vsToStringList vs) show (Target v) = "Target " <> toString v -instance (Ord v) => Patch (EndpointRole v) (Changeset v) where +instance Ord v => Patch (EndpointRole v) (Changeset v) where patch Changeset{changeI} (Target v) = Target $ fromMaybe v $ changeI M.!? v patch Changeset{changeO} (Source vs) = Source $ S.unions $ map (\v -> fromMaybe (S.singleton v) $ changeO M.!? v) $ S.elems vs @@ -89,7 +89,7 @@ instance Variables (EndpointRole v) v where variables (Source vs) = vs variables (Target v) = S.singleton v -instance (ToJSON v) => ToJSON (EndpointRole v) +instance ToJSON v => ToJSON (EndpointRole v) isSubroleOf (Target a) (Target b) = a == b isSubroleOf (Source as) (Source bs) = as `S.isSubsetOf` bs diff --git a/src/NITTA/Model/Problems/Refactor/ConstantFolding.hs b/src/NITTA/Model/Problems/Refactor/ConstantFolding.hs index cde8bc04a..b673e0c5e 100644 --- a/src/NITTA/Model/Problems/Refactor/ConstantFolding.hs +++ b/src/NITTA/Model/Problems/Refactor/ConstantFolding.hs @@ -109,7 +109,10 @@ selectClusters fs = evalCluster [f] = [f] evalCluster fs = outputResult where - (consts, [f]) = L.partition isConst fs + (consts, fSingleton) = L.partition isConst fs + f = case fSingleton of + [f'] -> f' + _ -> error "evalCluster: internal error" cntx = CycleCntx $ HM.fromList $ concatMap (simulate def) consts outputResult | null $ outputs f = fs diff --git a/src/NITTA/Model/Problems/ViewHelper.hs b/src/NITTA/Model/Problems/ViewHelper.hs index a7429cc9b..f63b25985 100644 --- a/src/NITTA/Model/Problems/ViewHelper.hs +++ b/src/NITTA/Model/Problems/ViewHelper.hs @@ -24,7 +24,7 @@ import Numeric.Interval.NonEmpty newtype IntervalView = IntervalView T.Text deriving (Generic) -instance (Time t) => Viewable (Interval t) IntervalView where +instance Time t => Viewable (Interval t) IntervalView where view = IntervalView . T.replace (showText (maxBound :: t)) "INF" . showText instance ToJSON IntervalView @@ -62,14 +62,14 @@ data DecisionView } deriving (Generic) -instance (UnitTag tag) => Viewable (Bind tag v x) DecisionView where +instance UnitTag tag => Viewable (Bind tag v x) DecisionView where view (Bind f pu) = BindDecisionView { function = view f , pu = toText pu } -instance (UnitTag tag) => Viewable (Allocation tag) DecisionView where +instance UnitTag tag => Viewable (Allocation tag) DecisionView where view Allocation{networkTag, processUnitTag} = AllocationView { networkTag = toText networkTag @@ -114,7 +114,7 @@ instance Viewable (OptimizeAccum v x) DecisionView where , new = map view refNew } -instance (Var v) => Viewable (ResolveDeadlock v x) DecisionView where +instance Var v => Viewable (ResolveDeadlock v x) DecisionView where view ResolveDeadlock{newBuffer, changeset} = ResolveDeadlockView { newBuffer = showText newBuffer diff --git a/src/NITTA/Model/ProcessIntegrity.hs b/src/NITTA/Model/ProcessIntegrity.hs index 3d2aee86f..3366f5791 100644 --- a/src/NITTA/Model/ProcessIntegrity.hs +++ b/src/NITTA/Model/ProcessIntegrity.hs @@ -29,7 +29,7 @@ class ProcessIntegrity u where isProcessIntegrity u = isRight $ checkProcessIntegrity u -instance (ProcessorUnit (pu v x t) v x t) => ProcessIntegrity (pu v x t) where +instance ProcessorUnit (pu v x t) v x t => ProcessIntegrity (pu v x t) where checkProcessIntegrity pu = collectChecks [ checkVerticalRelations (up2down pu) (pid2intermediate pu) (pid2endpoint pu) "intermediate not related to endpoint" diff --git a/src/NITTA/Model/ProcessorUnits/Accum.hs b/src/NITTA/Model/ProcessorUnits/Accum.hs index 3f93ac6f1..88fd87845 100644 --- a/src/NITTA/Model/ProcessorUnits/Accum.hs +++ b/src/NITTA/Model/ProcessorUnits/Accum.hs @@ -66,7 +66,7 @@ data JobState taskVars lst = S.fromList $ map snd lst -instance (Var v) => Show (Job v x) where +instance Var v => Show (Job v x) where show Job{tasks, func, state} = [i|Job{tasks=#{ show' tasks }, func=#{ func }, state=#{ state }}|] where @@ -81,7 +81,7 @@ data Accum v x t = Accum -- ^ Process } -instance (VarValTime v x t) => Pretty (Accum v x t) where +instance VarValTime v x t => Pretty (Accum v x t) where pretty a = [__i| Accum: @@ -90,10 +90,10 @@ instance (VarValTime v x t) => Pretty (Accum v x t) where #{ indent 4 $ pretty $ process_ a } |] -instance (VarValTime v x t) => Show (Accum v x t) where +instance VarValTime v x t => Show (Accum v x t) where show = show . pretty -instance (VarValTime v x t) => Default (Accum v x t) where +instance VarValTime v x t => Default (Accum v x t) where def = Accum { remainJobs = [] @@ -118,8 +118,18 @@ actionGroups [] = [] actionGroups as = let (pushs, as') = span F.isPush as (pulls, as'') = span F.isPull as' - in [ map (\(F.Push sign (I v)) -> (sign == F.Minus, v)) pushs - , concatMap (\(F.Pull (O vs)) -> map (True,) $ S.elems vs) pulls + in [ map + ( \case + (F.Push sign (I v)) -> (sign == F.Minus, v) + _ -> error "actionGroups: internal error" + ) + pushs + , concatMap + ( \case + (F.Pull (O vs)) -> map (True,) $ S.elems vs + _ -> error "actionGroups: internal error" + ) + pulls ] : actionGroups as'' @@ -131,7 +141,7 @@ sourceTask tasks | odd $ length tasks = Just $ head tasks | otherwise = Nothing -instance (VarValTime v x t, Num x) => ProcessorUnit (Accum v x t) v x t where +instance VarValTime v x t => ProcessorUnit (Accum v x t) v x t where tryBind f pu | Just (F.Add a b c) <- castF f = Right $ registerAcc (F.Acc [F.Push F.Plus a, F.Push F.Plus b, F.Pull c]) pu @@ -145,7 +155,7 @@ instance (VarValTime v x t, Num x) => ProcessorUnit (Accum v x t) v x t where process = process_ -instance (VarValTime v x t, Num x) => EndpointProblem (Accum v x t) v t where +instance VarValTime v x t => EndpointProblem (Accum v x t) v t where endpointOptions pu@Accum{currentJob = Just Job{tasks, state}} | Just task <- targetTask tasks = let from = case state of @@ -178,7 +188,9 @@ instance (VarValTime v x t, Num x) => EndpointProblem (Accum v x t) v t where pu@Accum{currentJob = Just job@Job{tasks, state}} d@EndpointSt{epRole = Target v, epAt} | Just task <- targetTask tasks = - let ([(neg, _v)], task') = L.partition ((== v) . snd) task + let ((neg, _v), task') = case L.partition ((== v) . snd) task of + ([negAndVar], ts) -> (negAndVar, ts) + _ -> error "Accum: endpointDecision: internal error" instr = case state of Initialize -> ResetAndLoad neg _ -> Load neg @@ -257,7 +269,7 @@ instance UnambiguouslyDecode (Accum v x t) where decodeInstruction (Load neg) = def{resetAccSignal = False, loadSignal = True, negSignal = Just neg} decodeInstruction Out = def{oeSignal = True} -instance (Var v) => Locks (Accum v x t) v where +instance Var v => Locks (Accum v x t) v where locks Accum{currentJob = Nothing, remainJobs} = concatMap (locks . func) remainJobs locks Accum{currentJob = Just Job{tasks = []}} = error "Accum locks: internal error" locks Accum{currentJob = Just Job{tasks = t : ts}, remainJobs} = @@ -273,7 +285,7 @@ instance (Var v) => Locks (Accum v x t) v where ] in current ++ remain -instance (VarValTime v x t) => TargetSystemComponent (Accum v x t) where +instance VarValTime v x t => TargetSystemComponent (Accum v x t) where moduleName _ _ = "pu_accum" hardware _tag _pu = FromLibrary "pu_accum.v" software _ _ = Empty @@ -306,11 +318,11 @@ instance (VarValTime v x t) => TargetSystemComponent (Accum v x t) where |] hardwareInstance _title _pu _env = error "internal error" -instance (Ord t) => WithFunctions (Accum v x t) (F v x) where +instance Ord t => WithFunctions (Accum v x t) (F v x) where functions Accum{process_, remainJobs} = functions process_ ++ map func remainJobs -instance (VarValTime v x t) => Testable (Accum v x t) v x where +instance VarValTime v x t => Testable (Accum v x t) v x where testBenchImplementation prj@Project{pName, pUnit} = let tbcSignalsConst = ["resetAcc", "load", "oe", "neg"] diff --git a/src/NITTA/Model/ProcessorUnits/Broken.hs b/src/NITTA/Model/ProcessorUnits/Broken.hs index f5f75352a..1bafa8676 100644 --- a/src/NITTA/Model/ProcessorUnits/Broken.hs +++ b/src/NITTA/Model/ProcessorUnits/Broken.hs @@ -65,7 +65,7 @@ data Broken v x t = Broken , unknownDataOut :: Bool } -instance (VarValTime v x t) => Pretty (Broken v x t) where +instance VarValTime v x t => Pretty (Broken v x t) where pretty Broken{..} = [__i| Broken: @@ -85,7 +85,7 @@ instance (VarValTime v x t) => Pretty (Broken v x t) where #{ indent 4 $ pretty $ process_ } |] -instance (Var v) => Locks (Broken v x t) v where +instance Var v => Locks (Broken v x t) v where locks Broken{remain, sources, targets} = [ Lock{lockBy, locked} | locked <- sources @@ -101,7 +101,7 @@ instance ConstantFoldingProblem (Broken v x t) v x instance OptimizeAccumProblem (Broken v x t) v x instance ResolveDeadlockProblem (Broken v x t) v x -instance (VarValTime v x t) => ProcessorUnit (Broken v x t) v x t where +instance VarValTime v x t => ProcessorUnit (Broken v x t) v x t where tryBind f pu@Broken{remain} | Just F.BrokenBuffer{} <- castF f = Right pu{remain = f : remain} | otherwise = Left $ "The function is unsupported by Broken: " ++ show f @@ -117,7 +117,7 @@ execution pu@Broken{targets = [], sources = [], remain, process_} f } execution _ _ = error "Broken: internal execution error." -instance (VarValTime v x t) => EndpointProblem (Broken v x t) v t where +instance VarValTime v x t => EndpointProblem (Broken v x t) v t where endpointOptions Broken{targets = [_], lostEndpointTarget = True} = [] endpointOptions pu@Broken{targets = [v]} = let start = nextTick pu `withShift` 1 ... maxBound @@ -244,7 +244,7 @@ instance Default (Microcode (Broken v x t)) where , oeSignal = False } -instance (Time t) => Default (Broken v x t) where +instance Time t => Default (Broken v x t) where def = Broken { remain = [] @@ -284,7 +284,7 @@ instance IOConnected (Broken v x t) where data IOPorts (Broken v x t) = BrokenIO deriving (Show) -instance (VarValTime v x t) => TargetSystemComponent (Broken v x t) where +instance VarValTime v x t => TargetSystemComponent (Broken v x t) where moduleName _title _pu = "pu_broken" software _ _ = Empty hardware _tag _pu = Aggregate Nothing [FromLibrary "pu_broken.v"] @@ -323,7 +323,7 @@ instance (VarValTime v x t) => TargetSystemComponent (Broken v x t) where instance IOTestBench (Broken v x t) v x -instance (Ord t) => WithFunctions (Broken v x t) (F v x) where +instance Ord t => WithFunctions (Broken v x t) (F v x) where functions Broken{process_, remain, currentWork} = functions process_ ++ remain @@ -331,7 +331,7 @@ instance (Ord t) => WithFunctions (Broken v x t) (F v x) where Just (_, f) -> [f] Nothing -> [] -instance (VarValTime v x t) => Testable (Broken v x t) v x where +instance VarValTime v x t => Testable (Broken v x t) v x where testBenchImplementation prj@Project{pName, pUnit} = Immediate (toString $ moduleName pName pUnit <> "_tb.v") $ snippetTestBench diff --git a/src/NITTA/Model/ProcessorUnits/Divider.hs b/src/NITTA/Model/ProcessorUnits/Divider.hs index e4ad1bcea..68369f7ef 100644 --- a/src/NITTA/Model/ProcessorUnits/Divider.hs +++ b/src/NITTA/Model/ProcessorUnits/Divider.hs @@ -67,12 +67,12 @@ divider pipeline mock = , mock } -instance (Time t) => Default (Divider v x t) where +instance Time t => Default (Divider v x t) where def = divider 4 True instance Default x => DefaultX (Divider v x t) x -instance (Ord t) => WithFunctions (Divider v x t) (F v x) where +instance Ord t => WithFunctions (Divider v x t) (F v x) where functions Divider{process_, remains, jobs} = functions process_ ++ remains @@ -91,7 +91,7 @@ data Job v x t } deriving (Eq, Show) -instance (Ord v) => Variables (Job v x t) v where +instance Ord v => Variables (Job v x t) v where variables WaitArguments{arguments} = S.fromList $ map snd arguments variables WaitResults{results} = S.unions $ map snd results @@ -101,7 +101,7 @@ isWaitArguments _ = False isWaitResults WaitResults{} = True isWaitResults _ = False -instance (VarValTime v x t) => ProcessorUnit (Divider v x t) v x t where +instance VarValTime v x t => ProcessorUnit (Divider v x t) v x t where tryBind f pu@Divider{remains} | Just (F.Division (I _n) (I _d) (O _q) (O _r)) <- castF f = Right pu{remains = f : remains} @@ -163,7 +163,7 @@ firstWaitResults jobs = then Nothing else Just $ minimumOn readyAt jobs' -instance (VarValTime v x t) => EndpointProblem (Divider v x t) v t where +instance VarValTime v x t => EndpointProblem (Divider v x t) v t where endpointOptions pu@Divider{remains, jobs} = let executeNewFunction | any isWaitArguments jobs = [] @@ -189,7 +189,9 @@ instance (VarValTime v x t) => EndpointProblem (Divider v x t) v t where } in endpointDecision pu' d | ([WaitArguments{function, arguments}], jobs') <- partition (S.member v . variables) jobs = - let ([(tag, _v)], arguments') = partition ((== v) . snd) arguments + let (tag, arguments') = case partition ((== v) . snd) arguments of + ([(tag', _v)], other) -> (tag', other) + _ -> error "Divider: endpointDecision: internal error" nextTick' = sup epAt + 1 in case arguments' of [] -> @@ -214,7 +216,9 @@ instance (VarValTime v x t) => EndpointProblem (Divider v x t) v t where } endpointDecision pu@Divider{jobs} d@EndpointSt{epRole = Source vs, epAt} | ([job@WaitResults{results, function}], jobs') <- partition ((vs `S.isSubsetOf`) . variables) jobs = - let ([(tag, allVs)], results') = partition ((vs `S.isSubsetOf`) . snd) results + let ((tag, allVs), results') = case partition ((vs `S.isSubsetOf`) . snd) results of + ([(tag_, allVs_)], other) -> ((tag_, allVs_), other) + _ -> error "Divider: endpointDecision: internal error" allVs' = allVs S.\\ vs results'' = filterEmptyResults $ (tag, allVs') : results' jobs'' = @@ -322,7 +326,7 @@ instance (Val x, Show t) => TargetSystemComponent (Divider v x t) where instance IOTestBench (Divider v x t) v x -instance (VarValTime v x t) => Testable (Divider v x t) v x where +instance VarValTime v x t => Testable (Divider v x t) v x where testBenchImplementation prj@Project{pName, pUnit} = Immediate (toString $ moduleName pName pUnit <> "_tb.v") $ snippetTestBench diff --git a/src/NITTA/Model/ProcessorUnits/Fram.hs b/src/NITTA/Model/ProcessorUnits/Fram.hs index 5f37d7818..38a979af0 100644 --- a/src/NITTA/Model/ProcessorUnits/Fram.hs +++ b/src/NITTA/Model/ProcessorUnits/Fram.hs @@ -57,7 +57,7 @@ framWithSize size = , process_ = def } -instance (VarValTime v x t) => Pretty (Fram v x t) where +instance VarValTime v x t => Pretty (Fram v x t) where pretty Fram{memory} = [__i| Fram: @@ -77,11 +77,11 @@ instance (Default t, Default x) => Default (Fram v x t) where instance Default x => DefaultX (Fram v x t) x -instance (VarValTime v x t) => WithFunctions (Fram v x t) (F v x) where +instance VarValTime v x t => WithFunctions (Fram v x t) (F v x) where functions Fram{remainBuffers, memory} = map (packF . fst) remainBuffers ++ concatMap functions (A.elems memory) -instance (VarValTime v x t) => Variables (Fram v x t) v where +instance VarValTime v x t => Variables (Fram v x t) v where variables fram = S.unions $ map variables $ functions fram -- | Memory cell @@ -112,7 +112,7 @@ instance WithFunctions (Cell v x t) (F v x) where functions Cell{history, job = Just Job{function}} = function : history functions Cell{history} = history -instance (Default x) => Default (Cell v x t) where +instance Default x => Default (Cell v x t) where def = Cell { state = NotUsed @@ -158,7 +158,7 @@ data CellState v x t | DoLoopTarget v deriving (Eq) -instance (VarValTime v x t) => Pretty (CellState v x t) where +instance VarValTime v x t => Pretty (CellState v x t) where pretty NotUsed = "NotUsed" pretty Done = "Done" pretty (DoConstant vs) = "DoConstant " <> viaShow (map toString vs) @@ -195,7 +195,7 @@ addrWidth Fram{memory} = log2 $ numElements memory where log2 = ceiling . (logBase 2 :: Double -> Double) . fromIntegral -instance (VarValTime v x t) => ProcessorUnit (Fram v x t) v x t where +instance VarValTime v x t => ProcessorUnit (Fram v x t) v x t where tryBind f fram | not $ null (variables f `S.intersection` variables fram) = Left "can not bind (self transaction)" @@ -246,24 +246,26 @@ instance (VarValTime v x t) => ProcessorUnit (Fram v x t) v x t where process Fram{process_} = process_ parallelismType _ = Full -instance (Var v) => Locks (Fram v x t) v where +instance Var v => Locks (Fram v x t) v where -- FIXME: locks _ = [] -instance (VarValTime v x t) => BreakLoopProblem (Fram v x t) v x where +instance VarValTime v x t => BreakLoopProblem (Fram v x t) v x where breakLoopOptions Fram{memory} = [ BreakLoop x o i_ | (_, Cell{state = NotBrokenLoop, job = Just Job{function}}) <- A.assocs memory - , let Just (Loop (X x) (O o) (I i_)) = castF function + , let (Loop (X x) (O o) (I i_)) = fromJust $ castF function ] breakLoopDecision fram@Fram{memory} bl@BreakLoop{loopO} = - let Just (addr, cell@Cell{history, job = Just Job{binds}}) = - L.find - ( \case - (_, Cell{job = Just Job{function}}) -> function == recLoop bl - _ -> False - ) + let (addr, cell@Cell{history, job}) = + fromJust + $ L.find + ( \case + (_, Cell{job = Just Job{function}}) -> function == recLoop bl + _ -> False + ) $ A.assocs memory + Job{binds} = fromJust job ((iPid, oPid), process_) = runSchedule fram $ do revoke <- scheduleFunctionRevoke $ recLoop bl f1 <- scheduleFunctionBind $ recLoopOut bl @@ -289,7 +291,7 @@ instance ConstantFoldingProblem (Fram v x t) v x instance OptimizeAccumProblem (Fram v x t) v x instance ResolveDeadlockProblem (Fram v x t) v x -instance (VarValTime v x t) => EndpointProblem (Fram v x t) v t where +instance VarValTime v x t => EndpointProblem (Fram v x t) v t where endpointOptions pu@Fram{remainBuffers, memory} = let target v = EndpointSt (Target v) $ TimeConstraint (a ... maxBound) (1 ... maxBound) where @@ -490,7 +492,7 @@ instance UnambiguouslyDecode (Fram v x t) where decodeInstruction (PrepareRead addr) = Microcode True False $ Just addr decodeInstruction (Write addr) = Microcode False True $ Just addr -instance (VarValTime v x t) => Testable (Fram v x t) v x where +instance VarValTime v x t => Testable (Fram v x t) v x where testBenchImplementation prj@Project{pName, pUnit} = let tbcSignalsConst = ["oe", "wr", "[3:0] addr"] showMicrocode Microcode{oeSignal, wrSignal, addrSignal} = @@ -513,7 +515,7 @@ instance (VarValTime v x t) => Testable (Fram v x t) v x where softwareFile tag pu = moduleName tag pu <> "." <> tag <> ".dump" -instance (VarValTime v x t) => TargetSystemComponent (Fram v x t) where +instance VarValTime v x t => TargetSystemComponent (Fram v x t) where moduleName _ _ = "pu_fram" hardware _tag _pu = FromLibrary "pu_fram.v" software tag fram@Fram{memory} = diff --git a/src/NITTA/Model/ProcessorUnits/IO/I2C.hs b/src/NITTA/Model/ProcessorUnits/IO/I2C.hs index e04cf82d5..44ce4c47b 100644 --- a/src/NITTA/Model/ProcessorUnits/IO/I2C.hs +++ b/src/NITTA/Model/ProcessorUnits/IO/I2C.hs @@ -37,7 +37,7 @@ instance SimpleIOInterface I2Cinterface type I2C v x t = SimpleIO I2Cinterface v x t -i2cUnit :: (Time t) => Int -> I2C v x t +i2cUnit :: Time t => Int -> I2C v x t i2cUnit bounceFilter = SimpleIO { bounceFilter diff --git a/src/NITTA/Model/ProcessorUnits/IO/SPI.hs b/src/NITTA/Model/ProcessorUnits/IO/SPI.hs index a4664adda..7d8c0fabc 100644 --- a/src/NITTA/Model/ProcessorUnits/IO/SPI.hs +++ b/src/NITTA/Model/ProcessorUnits/IO/SPI.hs @@ -44,7 +44,7 @@ instance SimpleIOInterface SPIinterface type SPI v x t = SimpleIO SPIinterface v x t -anySPI :: (Time t) => Int -> Maybe Int -> SPI v x t +anySPI :: Time t => Int -> Maybe Int -> SPI v x t anySPI bounceFilter bufferSize = SimpleIO { bounceFilter @@ -95,7 +95,7 @@ spiSlavePorts tag = , slave_cs = InputPortTag $ tag <> "_cs" } -instance (Time t) => Default (SPI v x t) where +instance Time t => Default (SPI v x t) where def = anySPI 0 $ Just 6 instance (ToJSON v, VarValTime v x t) => TargetSystemComponent (SPI v x t) where @@ -170,7 +170,7 @@ instance (ToJSON v, VarValTime v x t) => TargetSystemComponent (SPI v x t) where |] hardwareInstance _title _pu _env = error "internal error" -instance (VarValTime v x t, Num x) => IOTestBench (SPI v x t) v x where +instance VarValTime v x t => IOTestBench (SPI v x t) v x where testEnvironmentInitFlag tag _pu = Just $ tag <> "_env_init_flag" testEnvironment @@ -224,7 +224,9 @@ instance (VarValTime v x t, Num x) => IOTestBench (SPI v x t) v x where end |] - Just envInitFlagName = testEnvironmentInitFlag tag sio + envInitFlagName = + fromMaybe (error "SPI: testEnvironment: internal error") $ + testEnvironmentInitFlag tag sio in case ioPorts of SPISlave{..} -> let receiveCycle transmit = diff --git a/src/NITTA/Model/ProcessorUnits/IO/SimpleIO.hs b/src/NITTA/Model/ProcessorUnits/IO/SimpleIO.hs index 623a1f74c..4b24ecfd9 100644 --- a/src/NITTA/Model/ProcessorUnits/IO/SimpleIO.hs +++ b/src/NITTA/Model/ProcessorUnits/IO/SimpleIO.hs @@ -45,7 +45,7 @@ import Numeric.Interval.NonEmpty ((...)) import Numeric.Interval.NonEmpty qualified as I import Prettyprinter -class (Typeable i) => SimpleIOInterface i +class Typeable i => SimpleIOInterface i data SimpleIO i v x t = SimpleIO { bounceFilter :: Int @@ -223,7 +223,7 @@ instance Connected (SimpleIO i v x t) where } deriving (Show) -instance (Var v) => Locks (SimpleIO i v x t) v where +instance Var v => Locks (SimpleIO i v x t) v where locks SimpleIO{} = [] data ProtocolDescription v = ProtocolDescription @@ -235,7 +235,7 @@ data ProtocolDescription v = ProtocolDescription } deriving (Generic) -instance (ToJSON v) => ToJSON (ProtocolDescription v) +instance ToJSON v => ToJSON (ProtocolDescription v) protocolDescription :: forall i v x t. diff --git a/src/NITTA/Model/ProcessorUnits/Multiplier.hs b/src/NITTA/Model/ProcessorUnits/Multiplier.hs index a6b2e776b..ca469398c 100644 --- a/src/NITTA/Model/ProcessorUnits/Multiplier.hs +++ b/src/NITTA/Model/ProcessorUnits/Multiplier.hs @@ -383,7 +383,7 @@ data Multiplier v x t = Multiplier -- IP-core. } -instance (VarValTime v x t) => Pretty (Multiplier v x t) where +instance VarValTime v x t => Pretty (Multiplier v x t) where pretty Multiplier{remain, targets, sources, currentWork, process_, isMocked} = [__i| Multiplier: @@ -410,7 +410,7 @@ multiplier mock = } -- | Default initial state of multiplier PU model. -instance (Time t) => Default (Multiplier v x t) where +instance Time t => Default (Multiplier v x t) where def = multiplier True instance Default x => DefaultX (Multiplier v x t) x @@ -419,7 +419,7 @@ instance Default x => DefaultX (Multiplier v x t) x implementation: we take process description (all planned functions), and function in progress, if it is. -} -instance (Ord t) => WithFunctions (Multiplier v x t) (F v x) where +instance Ord t => WithFunctions (Multiplier v x t) (F v x) where functions Multiplier{process_, remain, currentWork} = functions process_ ++ remain @@ -432,7 +432,7 @@ instance (Ord t) => WithFunctions (Multiplier v x t) (F v x) where - dependencies of all remain functions from the currently evaluated function (if it is). -} -instance (Var v) => Locks (Multiplier v x t) v where +instance Var v => Locks (Multiplier v x t) v where locks Multiplier{remain, sources, targets} = [ Lock{lockBy, locked} | locked <- sources @@ -470,7 +470,7 @@ From the CAD point of view, bind looks like: Binding can be done either gradually due synthesis process at the start. -} -instance (VarValTime v x t) => ProcessorUnit (Multiplier v x t) v x t where +instance VarValTime v x t => ProcessorUnit (Multiplier v x t) v x t where tryBind f pu@Multiplier{remain} | Just F.Multiply{} <- castF f = Right pu{remain = f : remain} | otherwise = Left $ "The function is unsupported by Multiplier: " ++ show f @@ -528,7 +528,7 @@ It includes three cases: find the selected function, 'execute' it, and do a recursive call with the same decision. -} -instance (VarValTime v x t) => EndpointProblem (Multiplier v x t) v t where +instance VarValTime v x t => EndpointProblem (Multiplier v x t) v t where endpointOptions pu@Multiplier{targets} | not $ null targets = let at = nextTick pu ... maxBound @@ -667,7 +667,7 @@ instance IOConnected (Multiplier v x t) where - Hardware instance in the upper structure element. -} -instance (VarValTime v x t) => TargetSystemComponent (Multiplier v x t) where +instance VarValTime v x t => TargetSystemComponent (Multiplier v x t) where moduleName _title _pu = "pu_multiplier" hardware _tag Multiplier{isMocked} = @@ -724,7 +724,7 @@ process. You can see tests in @test/Spec.hs@. Testbench contains: - The sequence of bus state checks in which we compare actual values with the results of the functional simulation. -} -instance (VarValTime v x t) => Testable (Multiplier v x t) v x where +instance VarValTime v x t => Testable (Multiplier v x t) v x where testBenchImplementation prj@Project{pName, pUnit} = Immediate (toString $ moduleName pName pUnit <> "_tb.v") $ snippetTestBench diff --git a/src/NITTA/Model/ProcessorUnits/Shift.hs b/src/NITTA/Model/ProcessorUnits/Shift.hs index 77510c809..b72ffbf89 100644 --- a/src/NITTA/Model/ProcessorUnits/Shift.hs +++ b/src/NITTA/Model/ProcessorUnits/Shift.hs @@ -54,7 +54,7 @@ data Shift v x t = Shift -- ^ description of target computation process } -instance (Var v) => Locks (Shift v x t) v where +instance Var v => Locks (Shift v x t) v where locks Shift{sources, target = Just t} = [ Lock{lockBy = t, locked} | locked <- sources @@ -81,7 +81,7 @@ instance ConstantFoldingProblem (Shift v x t) v x instance OptimizeAccumProblem (Shift v x t) v x instance ResolveDeadlockProblem (Shift v x t) v x -instance (VarValTime v x t) => ProcessorUnit (Shift v x t) v x t where +instance VarValTime v x t => ProcessorUnit (Shift v x t) v x t where tryBind f pu@Shift{remain} | Just f' <- castF f = case f' of @@ -109,7 +109,7 @@ execution pu@Shift{target = Nothing, sources = [], remain} f } execution _ _ = error "Not right arguments in execution function in shift module" -instance (VarValTime v x t) => EndpointProblem (Shift v x t) v t where +instance VarValTime v x t => EndpointProblem (Shift v x t) v t where endpointOptions pu@Shift{target = Just t} = [EndpointSt (Target t) $ TimeConstraint (nextTick pu ... maxBound) (singleton 1)] endpointOptions pu@Shift{sources, byteShiftDiv, byteShiftMod} @@ -264,7 +264,7 @@ instance Connected (Shift v x t) where instance IOConnected (Shift v x t) where data IOPorts (Shift v x t) = ShiftIO -instance (Val x) => TargetSystemComponent (Shift v x t) where +instance Val x => TargetSystemComponent (Shift v x t) where moduleName _ _ = "pu_shift" hardware _tag _pu = FromLibrary "pu_shift.v" software _ _ = Empty diff --git a/src/NITTA/Model/ProcessorUnits/Types.hs b/src/NITTA/Model/ProcessorUnits/Types.hs index da99c7285..f9fbe0edc 100644 --- a/src/NITTA/Model/ProcessorUnits/Types.hs +++ b/src/NITTA/Model/ProcessorUnits/Types.hs @@ -112,7 +112,7 @@ intermediate representation: 3. other features implemented by different type classes (see above and in "NITTA.Model.Problems"). -} -class (VarValTime v x t) => ProcessorUnit u v x t | u -> v x t where +class VarValTime v x t => ProcessorUnit u v x t | u -> v x t where -- If the processor unit can execute a function, then it will return the PU -- model with already bound function (only registeration, actual scheduling -- will be happening later). If not, it will return @Left@ value with a @@ -143,7 +143,7 @@ allowToProcess f pu = isRight $ tryBind f pu class NextTick u t | u -> t where nextTick :: u -> t -instance (ProcessorUnit u v x t) => NextTick u t where +instance ProcessorUnit u v x t => NextTick u t where nextTick = nextTick . process --------------------------------------------------------------------- @@ -182,13 +182,13 @@ instance (Time t, Show i) => Pretty (Process t i) where instance (ToJSON t, ToJSON i) => ToJSON (Process t i) -instance (Default t) => Default (Process t i) where +instance Default t => Default (Process t i) where def = Process{steps = [], relations = [], nextTick_ = def, nextUid = def} instance {-# OVERLAPS #-} NextTick (Process t si) t where nextTick = nextTick_ -instance (Ord t) => WithFunctions (Process t (StepInfo v x t)) (F v x) where +instance Ord t => WithFunctions (Process t (StepInfo v x t)) (F v x) where functions Process{steps} = mapMaybe get $ L.sortOn (I.inf . pInterval) steps where get Step{pDesc} | IntermediateStep f <- descent pDesc = Just f @@ -210,7 +210,7 @@ data Step t i = Step instance (ToJSON t, ToJSON i) => ToJSON (Step t i) -instance (Ord v) => Patch (Step t (StepInfo v x t)) (Changeset v) where +instance Ord v => Patch (Step t (StepInfo v x t)) (Changeset v) where patch diff step@Step{pDesc} = step{pDesc = patch diff pDesc} -- | Informative process step description at a specific process level. @@ -229,7 +229,7 @@ data StepInfo v x t where Instruction pu -> StepInfo v x t -- | wrapper for nested process unit step (used for networks) - NestedStep :: (UnitTag tag) => {nTitle :: tag, nStep :: Step t (StepInfo v x t)} -> StepInfo v x t + NestedStep :: UnitTag tag => {nTitle :: tag, nStep :: Step t (StepInfo v x t)} -> StepInfo v x t -- | Process unit allocation step AllocationStep :: (Typeable a, Show a, Eq a) => a -> StepInfo v x t @@ -251,7 +251,7 @@ instance (Var v, Show (Step t (StepInfo v x t))) => Show (StepInfo v x t) where show (InstructionStep instr) = "Instruction: " <> show instr show NestedStep{nTitle, nStep = Step{pDesc}} = "@" <> toString nTitle <> " " <> show pDesc -instance (Ord v) => Patch (StepInfo v x t) (Changeset v) where +instance Ord v => Patch (StepInfo v x t) (Changeset v) where patch diff (IntermediateStep f) = IntermediateStep $ patch diff f patch diff (EndpointRoleStep ep) = EndpointRoleStep $ patch diff ep patch diff (NestedStep tag nStep) = NestedStep tag $ patch diff nStep @@ -283,7 +283,7 @@ whatsHappen t Process{steps} = filter (atSameTime t . pInterval) steps extractInstructionAt pu t = mapMaybe (inst pu) $ whatsHappen t $ process pu where - inst :: (Typeable (Instruction pu)) => pu -> Step t (StepInfo v x t) -> Maybe (Instruction pu) + inst :: Typeable (Instruction pu) => pu -> Step t (StepInfo v x t) -> Maybe (Instruction pu) inst _ Step{pDesc = InstructionStep instr} = cast instr inst _ _ = Nothing @@ -341,7 +341,6 @@ instance , Default (Microcode pu) , ProcessorUnit pu v x t , UnambiguouslyDecode pu - , Time t , Typeable pu ) => ByTime pu t diff --git a/src/NITTA/Model/TargetSystem.hs b/src/NITTA/Model/TargetSystem.hs index 58e9b10a8..a18339af1 100644 --- a/src/NITTA/Model/TargetSystem.hs +++ b/src/NITTA/Model/TargetSystem.hs @@ -35,37 +35,32 @@ data TargetSystem u tag v x t = TargetSystem } deriving (Generic) -instance (Default u) => Default (TargetSystem u tag v x t) where +instance Default u => Default (TargetSystem u tag v x t) where def = TargetSystem def def -instance (WithFunctions u (F v x)) => WithFunctions (TargetSystem u tag v x t) (F v x) where +instance WithFunctions u (F v x) => WithFunctions (TargetSystem u tag v x t) (F v x) where functions TargetSystem{mUnit, mDataFlowGraph} = assert (S.fromList (functions mUnit) == S.fromList (functions mDataFlowGraph)) $ -- inconsistent TargetSystem functions mUnit processDuration TargetSystem{mUnit} = nextTick mUnit - 1 -isSynthesisComplete :: (ProcessorUnit u v x t) => TargetSystem u tag v x t -> Bool +isSynthesisComplete :: ProcessorUnit u v x t => TargetSystem u tag v x t -> Bool isSynthesisComplete TargetSystem{mUnit, mDataFlowGraph} = transferred mUnit == variables mDataFlowGraph -instance - ( VarValTime v x t - , ProcessorUnit u v x t - ) => - ProcessorUnit (TargetSystem u tag v x t) v x t - where +instance ProcessorUnit u v x t => ProcessorUnit (TargetSystem u tag v x t) v x t where tryBind f ts@TargetSystem{mUnit} = (\u -> ts{mUnit = u}) <$> tryBind f mUnit process TargetSystem{mUnit} = process mUnit parallelismType TargetSystem{mUnit} = parallelismType mUnit puSize TargetSystem{mUnit} = puSize mUnit -instance (BindProblem u tag v x) => BindProblem (TargetSystem u tag v x t) tag v x where +instance BindProblem u tag v x => BindProblem (TargetSystem u tag v x t) tag v x where bindOptions TargetSystem{mUnit} = bindOptions mUnit bindDecision ts@TargetSystem{mUnit} d = ts{mUnit = bindDecision mUnit d} -instance (DataflowProblem u tag v t) => DataflowProblem (TargetSystem u tag v x t) tag v t where +instance DataflowProblem u tag v t => DataflowProblem (TargetSystem u tag v x t) tag v t where dataflowOptions TargetSystem{mUnit} = dataflowOptions mUnit dataflowDecision f@TargetSystem{mUnit} d = f{mUnit = dataflowDecision mUnit d} @@ -112,7 +107,7 @@ instance (Var v, ResolveDeadlockProblem u v x) => ResolveDeadlockProblem (Target , mUnit = resolveDeadlockDecision mUnit d } -instance (AllocationProblem u tag) => AllocationProblem (TargetSystem u tag v x t) tag where +instance AllocationProblem u tag => AllocationProblem (TargetSystem u tag v x t) tag where allocationOptions TargetSystem{mUnit} = allocationOptions mUnit allocationDecision f@TargetSystem{mUnit} d = f{mUnit = allocationDecision mUnit d} diff --git a/src/NITTA/Model/Time.hs b/src/NITTA/Model/Time.hs index f8630b855..52fd8574d 100644 --- a/src/NITTA/Model/Time.hs +++ b/src/NITTA/Model/Time.hs @@ -31,7 +31,7 @@ type VarValTime v x t = (Var v, Val x, Time t) -- | Shortcut for time type constrain. type Time t = (Default t, Num t, Bounded t, Ord t, Show t, Typeable t, Enum t, Integral t) -instance (ToJSON t) => ToJSON (Interval t) +instance ToJSON t => ToJSON (Interval t) -- | Time constrain for processor activity. data TimeConstraint t = TimeConstraint @@ -52,7 +52,7 @@ instance (Show t, Eq t, Bounded t) => Show (TimeConstraint t) where then show a ++ "..INF" else show a ++ ".." ++ show b -instance (ToJSON tp) => ToJSON (TimeConstraint tp) +instance ToJSON tp => ToJSON (TimeConstraint tp) -- | Forgoten implementation of tagged time for speculative if statement. Current - dead code. data TaggedTime tag t = TaggedTime @@ -61,26 +61,26 @@ data TaggedTime tag t = TaggedTime } deriving (Typeable, Generic) -instance (Default t) => Default (TaggedTime tag t) where +instance Default t => Default (TaggedTime tag t) where def = TaggedTime Nothing def instance (Time t, Show tag) => Show (TaggedTime tag t) where show (TaggedTime tag t) = show t ++ maybe "" (("!" ++) . show) tag -instance {-# OVERLAPS #-} (Time t) => Show (TaggedTime String t) where +instance {-# OVERLAPS #-} Time t => Show (TaggedTime String t) where show (TaggedTime tag t) = show t ++ maybe "" ("!" ++) tag -instance (Eq t) => Eq (TaggedTime tag t) where +instance Eq t => Eq (TaggedTime tag t) where (TaggedTime _ a) == (TaggedTime _ b) = a == b -instance (Ord t) => Ord (TaggedTime tag t) where +instance Ord t => Ord (TaggedTime tag t) where (TaggedTime _ a) `compare` (TaggedTime _ b) = a `compare` b -instance (Enum t) => Enum (TaggedTime tag t) where +instance Enum t => Enum (TaggedTime tag t) where toEnum i = TaggedTime Nothing $ toEnum i fromEnum (TaggedTime _ i) = fromEnum i -instance (Num t) => Bounded (TaggedTime tag t) where +instance Num t => Bounded (TaggedTime tag t) where minBound = TaggedTime Nothing 0 maxBound = TaggedTime Nothing 1000 diff --git a/src/NITTA/Project/Template.hs b/src/NITTA/Project/Template.hs index a86a7fb98..f7311341c 100644 --- a/src/NITTA/Project/Template.hs +++ b/src/NITTA/Project/Template.hs @@ -63,7 +63,7 @@ instance Default TemplateConf where , ignore = Just [templateConfFileName] } -instance (Eq k, Hashable k) => Default (M.HashMap k v) where +instance Hashable k => Default (M.HashMap k v) where def = M.fromList [] instance FromJSON TemplateConf @@ -94,15 +94,16 @@ readTemplateConfDef fn = do let conf = either (error . show) id $ parseTomlDoc (fn <> ": parse error: ") text return Conf - { template = confLookup "template" conf - , signals = confLookup "signals" conf + { template = confLookup fn "template" conf + , signals = confLookup fn "signals" conf } + +confLookup fn sec conf = + maybe + def + (unwrap (fn <> " in section [" <> T.unpack sec <> "]: ") . fromJSON . toJSON) + $ M.lookup sec conf where - confLookup sec conf = - maybe - def - (unwrap (fn <> " in section [" <> T.unpack sec <> "]: ") . fromJSON . toJSON) - $ M.lookup sec conf unwrap _prefix (Success a) = a unwrap prefix (Error msg) = error $ prefix <> msg diff --git a/src/NITTA/Project/TestBench.hs b/src/NITTA/Project/TestBench.hs index f12af7125..f62d1f23c 100644 --- a/src/NITTA/Project/TestBench.hs +++ b/src/NITTA/Project/TestBench.hs @@ -1,6 +1,7 @@ {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TypeFamilies #-} {- | Module : NITTA.Project.TestBench @@ -33,7 +34,7 @@ import GHC.Generics (Generic) import NITTA.Intermediate.Types import NITTA.Model.Problems import NITTA.Model.ProcessorUnits.Types -import NITTA.Model.Time + import NITTA.Project.Types import NITTA.Project.VerilogSnippets import NITTA.Utils @@ -131,8 +132,7 @@ data SnippetTestBenchConf m = SnippetTestBenchConf -- | Function for testBench PU test snippetTestBench :: forall m v x t. - ( VarValTime v x t - , WithFunctions m (F v x) + ( WithFunctions m (F v x) , ProcessorUnit m v x t , TargetSystemComponent m , UnambiguouslyDecode m @@ -146,7 +146,7 @@ snippetTestBench :: snippetTestBench Project{pName, pUnit, pTestCntx = Cntx{cntxProcess}, pUnitEnv} SnippetTestBenchConf{tbcSignals, tbcPorts, tbcMC2verilogLiteral} = - let cycleCntx : _ = cntxProcess + let cycleCntx = head cntxProcess name = moduleName pName pUnit p@Process{steps} = process pUnit fs = functions pUnit diff --git a/src/NITTA/Project/Types.hs b/src/NITTA/Project/Types.hs index d98cbb2c7..735ef9a3f 100644 --- a/src/NITTA/Project/Types.hs +++ b/src/NITTA/Project/Types.hs @@ -64,7 +64,7 @@ defProjectTemplates = , "templates/DE0-Nano" ] -instance (Default x) => DefaultX (Project m v x) x +instance Default x => DefaultX (Project m v x) x -- | Type class for target components. Target -- a target system project or a testbench. class TargetSystemComponent pu where diff --git a/src/NITTA/Synthesis.hs b/src/NITTA/Synthesis.hs index b341f50f8..6574fbcdb 100644 --- a/src/NITTA/Synthesis.hs +++ b/src/NITTA/Synthesis.hs @@ -96,7 +96,6 @@ import NITTA.Model.ProcessorUnits.Types import NITTA.Model.TargetSystem import NITTA.Model.Time import NITTA.Project (Project (..), collectNittaPath, defProjectTemplates, runTestbench, writeProject) -import NITTA.Synthesis.Analysis import NITTA.Synthesis.Explore import NITTA.Synthesis.Method import NITTA.Synthesis.Steps @@ -133,7 +132,7 @@ data TargetSynthesis tag v x t = TargetSynthesis -- ^ source code format type } -instance (SynthesisMethodConstraints tag v x t) => Default (TargetSynthesis tag v x t) where +instance (UnitTag tag, VarValTime v x t) => Default (TargetSynthesis tag v x t) where def = TargetSynthesis { tName = undefined @@ -141,7 +140,7 @@ instance (SynthesisMethodConstraints tag v x t) => Default (TargetSynthesis tag , tSourceCode = Nothing , tDFG = undefined , tReceivedValues = def - , tSynthesisMethod = stateOfTheArtSynthesisIO def + , tSynthesisMethod = stateOfTheArtSynthesisIO () , tLibPath = "hdl" , tTemplates = defProjectTemplates , tPath = joinPath ["gen"] diff --git a/src/NITTA/Synthesis/Explore.hs b/src/NITTA/Synthesis/Explore.hs index a48dc733a..d4c862804 100644 --- a/src/NITTA/Synthesis/Explore.hs +++ b/src/NITTA/Synthesis/Explore.hs @@ -1,4 +1,5 @@ {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE NoMonomorphismRestriction #-} @@ -16,15 +17,15 @@ module NITTA.Synthesis.Explore ( getTreePathIO, subForestIO, positiveSubForestIO, + isComplete, + isLeaf, ) where import Control.Concurrent.STM -import Control.Exception import Control.Monad (forM, unless, when) import Data.Default import Data.Map.Strict qualified as M import Data.Set qualified as S -import Data.Text qualified as T import NITTA.Intermediate.Analysis (buildProcessWaves, estimateVarWaves) import NITTA.Intermediate.Types import NITTA.Model.Networks.Bus @@ -33,13 +34,9 @@ import NITTA.Model.Problems.Bind import NITTA.Model.Problems.Dataflow import NITTA.Model.Problems.Refactor import NITTA.Model.TargetSystem -import NITTA.Synthesis.MlBackend.Api -import NITTA.Synthesis.MlBackend.ServerInstance +import NITTA.Synthesis.Steps () import NITTA.Synthesis.Types -import NITTA.UIBackend.Types -import NITTA.UIBackend.ViewHelper import NITTA.Utils -import Network.HTTP.Simple import System.Log.Logger -- | Make synthesis tree @@ -56,87 +53,74 @@ rootSynthesisTreeSTM model = do } -- | Get specific by @nId@ node from a synthesis tree. -getTreeIO _ tree (Sid []) = return tree -getTreeIO ctx tree (Sid (i : is)) = do - subForest <- subForestIO ctx tree +getTreeIO tree (Sid []) = return tree +getTreeIO tree (Sid (i : is)) = do + subForest <- subForestIO tree unless (i < length subForest) $ error "getTreeIO - wrong Sid" - getTreeIO ctx (subForest !! i) (Sid is) + getTreeIO (subForest !! i) (Sid is) -- | Get list of all nodes from root to selected. -getTreePathIO _ _ (Sid []) = return [] -getTreePathIO ctx tree (Sid (i : is)) = do - h <- getTreeIO ctx tree $ Sid [i] - t <- getTreePathIO ctx h $ Sid is +getTreePathIO _ (Sid []) = return [] +getTreePathIO tree (Sid (i : is)) = do + h <- getTreeIO tree $ Sid [i] + t <- getTreePathIO h $ Sid is return $ h : t {- | Get all available edges for the node. Edges calculated only for the first call. -} subForestIO - ctx tree@Tree { sSubForestVar , sID , sDecision - } = - do - (firstTime, subForest) <- - atomically $ - tryReadTMVar sSubForestVar >>= \case - Just subForest -> return (False, subForest) - Nothing -> do - subForest <- exploreSubForestVar tree - putTMVar sSubForestVar subForest - return (True, subForest) - when firstTime $ do - debugM "NITTA.Synthesis" $ - "explore: " - <> show sID - <> " score: " - <> ( case sDecision of - SynthesisDecision{score} -> show score - _ -> "-" - ) - <> " decision: " - <> ( case sDecision of - SynthesisDecision{decision} -> show decision - _ -> "-" - ) + } = do + (firstTime, subForest) <- + atomically $ + tryReadTMVar sSubForestVar >>= \case + Just subForest -> return (False, subForest) + Nothing -> do + subForest <- exploreSubForestVar tree + putTMVar sSubForestVar subForest + return (True, subForest) + when firstTime $ do + debugM "NITTA.Synthesis" $ + "explore: " + <> show sID + <> " score: " + <> ( case sDecision of + SynthesisDecision{scores} -> show scores + _ -> "-" + ) + <> " decision: " + <> ( case sDecision of + SynthesisDecision{decision} -> show decision + _ -> "-" + ) - if null subForest - then return subForest - else - ( case mlScoringModel ctx of - Nothing -> return subForest - Just modelNameStr -> do - mlBackend <- mlBackendGetter ctx - let mlBackendBaseUrl = baseUrl mlBackend - case mlBackendBaseUrl of - Nothing -> return subForest - Just onlineUrl -> do - let modelName = T.pack modelNameStr - mapSubforestScoreViaMlBackendIO subForest onlineUrl modelName - `catch` \e -> do - errorM "NITTA.Synthesis" $ - "ML backend error: " - <> ( case e of - JSONConversionException _ resp _ -> show resp - _ -> show e - ) - return subForest - ) - -mapSubforestScoreViaMlBackendIO subForest mlBackendBaseUrl modelName = do - let input = ScoringInput{scoringTarget = ScoringTargetAll, nodes = [view node | node <- subForest]} - allInputsScores <- predictScoresIO modelName mlBackendBaseUrl [input] - let scores = map (+ 20) $ head allInputsScores -- +20 shifts "useless node" threshold, since model outputs negative values much more often - return $ map (\(node@Tree{sDecision = sDes}, score) -> node{sDecision = sDes{score}}) (zip subForest scores) + return subForest {- | For synthesis method is more usefull, because throw away all useless trees in subForest (objective function value less than zero). -} -positiveSubForestIO ctx tree = filter ((> 0) . score . sDecision) <$> subForestIO ctx tree +positiveSubForestIO tree = filter ((> 0) . defScore . sDecision) <$> subForestIO tree + +isLeaf + Tree + { sState = + SynthesisState + { sAllocationOptions = [] + , sBindOptions = [] + , sDataflowOptions = [] + , sBreakLoopOptions = [] + , sResolveDeadlockOptions = [] + , sOptimizeAccumOptions = [] + , sConstantFoldingOptions = [] + } + } = True +isLeaf _ = False +isComplete = isSynthesisComplete . sTarget . sState -- * Internal exploreSubForestVar parent@Tree{sID, sState} = @@ -164,7 +148,7 @@ decisionAndContext parent@Tree{sState = ctx} o = [ (SynthesisDecision o d p e, nodeCtx (Just parent) model) | (d, model) <- decisions ctx o , let p = parameters ctx o d - e = estimate ctx o d p + e = M.singleton "default" $ estimate ctx o d p ] nodeCtx parent nModel = diff --git a/src/NITTA/Synthesis/Method.hs b/src/NITTA/Synthesis/Method.hs index 51eb1a944..cdc1f6fa8 100644 --- a/src/NITTA/Synthesis/Method.hs +++ b/src/NITTA/Synthesis/Method.hs @@ -20,21 +20,16 @@ module NITTA.Synthesis.Method ( stateOfTheArtSynthesisIO, allBindsAndRefsIO, bestStepIO, - SynthesisMethodConstraints, ) where -import Data.Aeson (ToJSON) import Data.List qualified as L import Data.Typeable import Debug.Trace import NITTA.Model.ProcessorUnits import NITTA.Model.TargetSystem -import NITTA.Synthesis.Analysis import NITTA.Synthesis.Explore import NITTA.Synthesis.Steps import NITTA.Synthesis.Types -import NITTA.UIBackend.Types -import NITTA.UIBackend.ViewHelper import NITTA.Utils (maximumOn, minimumOn) import Safe import System.Log.Logger @@ -45,47 +40,47 @@ the endless synthesis process. stepLimit = 750 :: Int -- | The most complex synthesis method, which embedded all another. That all. -stateOfTheArtSynthesisIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> SynthesisMethod tag v x t -stateOfTheArtSynthesisIO ctx tree = do +stateOfTheArtSynthesisIO :: (VarValTime v x t, UnitTag tag) => () -> SynthesisMethod tag v x t +stateOfTheArtSynthesisIO () tree = do infoM "NITTA.Synthesis" $ "stateOfTheArtSynthesisIO: " <> show (sID tree) - l1 <- simpleSynthesisIO ctx tree - l2 <- smartBindSynthesisIO ctx tree - l3 <- bestThreadIO ctx stepLimit tree - l4 <- bestThreadIO ctx stepLimit =<< allBindsAndRefsIO ctx tree + l1 <- simpleSynthesisIO tree + l2 <- smartBindSynthesisIO tree + l3 <- bestThreadIO stepLimit tree + l4 <- bestThreadIO stepLimit =<< allBindsAndRefsIO tree return $ bestLeaf tree [l1, l2, l3, l4] -- | Schedule process by simple synthesis. -simpleSynthesisIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> SynthesisMethod tag v x t -simpleSynthesisIO ctx root = do +simpleSynthesisIO :: (VarValTime v x t, UnitTag tag) => SynthesisMethod tag v x t +simpleSynthesisIO root = do infoM "NITTA.Synthesis" $ "simpleSynthesisIO: " <> show (sID root) - lastObliviousNode <- obviousBindThreadIO ctx root - allBestThreadIO ctx 1 lastObliviousNode + lastObliviousNode <- obviousBindThreadIO root + allBestThreadIO 1 lastObliviousNode -smartBindSynthesisIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> SynthesisMethod tag v x t -smartBindSynthesisIO ctx tree = do +smartBindSynthesisIO :: (VarValTime v x t, UnitTag tag) => SynthesisMethod tag v x t +smartBindSynthesisIO tree = do infoM "NITTA.Synthesis" $ "smartBindSynthesisIO: " <> show (sID tree) - tree' <- smartBindThreadIO ctx tree - allBestThreadIO ctx 1 tree' + tree' <- smartBindThreadIO tree + allBestThreadIO 1 tree' -bestThreadIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> Int -> SynthesisMethod tag v x t -bestThreadIO _ 0 node = return $ trace "bestThreadIO reach step limit!" node -bestThreadIO ctx limit tree = do - subForest <- positiveSubForestIO ctx tree +bestThreadIO :: (VarValTime v x t, UnitTag tag) => Int -> SynthesisMethod tag v x t +bestThreadIO 0 node = return $ trace "bestThreadIO reach step limit!" node +bestThreadIO limit tree = do + subForest <- positiveSubForestIO tree case subForest of [] -> return tree - _ -> bestThreadIO ctx (limit - 1) $ maximumOn (score . sDecision) subForest + _ -> bestThreadIO (limit - 1) $ maximumOn (defScore . sDecision) subForest -bestStepIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> SynthesisMethod tag v x t -bestStepIO ctx tree = do - subForest <- positiveSubForestIO ctx tree +bestStepIO :: (VarValTime v x t, UnitTag tag) => SynthesisMethod tag v x t +bestStepIO tree = do + subForest <- positiveSubForestIO tree case subForest of [] -> error "all step is over" - _ -> return $ maximumOn (score . sDecision) subForest + _ -> return $ maximumOn (defScore . sDecision) subForest -obviousBindThreadIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> SynthesisMethod tag v x t -obviousBindThreadIO ctx tree = do - subForest <- positiveSubForestIO ctx tree - maybe (return tree) (obviousBindThreadIO ctx) $ +obviousBindThreadIO :: (VarValTime v x t, UnitTag tag) => SynthesisMethod tag v x t +obviousBindThreadIO tree = do + subForest <- positiveSubForestIO tree + maybe (return tree) obviousBindThreadIO $ L.find ( ( \case Just BindMetrics{pPossibleDeadlock = True} -> False @@ -97,37 +92,37 @@ obviousBindThreadIO ctx tree = do ) subForest -allBindsAndRefsIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> SynthesisMethod tag v x t -allBindsAndRefsIO ctx tree = do +allBindsAndRefsIO :: (VarValTime v x t, UnitTag tag) => SynthesisMethod tag v x t +allBindsAndRefsIO tree = do subForest <- filter ((\d -> isBind d || isRefactor d) . sDecision) - <$> positiveSubForestIO ctx tree + <$> positiveSubForestIO tree case subForest of [] -> return tree - _ -> allBindsAndRefsIO ctx $ minimumOn (score . sDecision) subForest + _ -> allBindsAndRefsIO $ minimumOn (defScore . sDecision) subForest -refactorThreadIO ctx tree = do - subForest <- positiveSubForestIO ctx tree - maybe (return tree) (refactorThreadIO ctx) $ +refactorThreadIO tree = do + subForest <- positiveSubForestIO tree + maybe (return tree) refactorThreadIO $ L.find (isRefactor . sDecision) subForest -smartBindThreadIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> SynthesisMethod tag v x t -smartBindThreadIO ctx tree = do +smartBindThreadIO :: (VarValTime v x t, UnitTag tag) => SynthesisMethod tag v x t +smartBindThreadIO tree = do subForest <- filter ((\d -> isBind d || isRefactor d) . sDecision) - <$> (positiveSubForestIO ctx =<< refactorThreadIO ctx tree) + <$> (positiveSubForestIO =<< refactorThreadIO tree) case subForest of [] -> return tree - _ -> smartBindThreadIO ctx $ maximumOn (score . sDecision) subForest + _ -> smartBindThreadIO $ maximumOn (defScore . sDecision) subForest -allBestThreadIO :: (SynthesisMethodConstraints tag v x t) => BackendCtx tag v x t -> Int -> SynthesisMethod tag v x t -allBestThreadIO ctx (0 :: Int) tree = bestThreadIO ctx stepLimit tree -allBestThreadIO ctx n tree = do - subForest <- positiveSubForestIO ctx tree - leafs <- mapM (allBestThreadIO ctx (n - 1)) subForest +allBestThreadIO :: (VarValTime v x t, UnitTag tag) => Int -> SynthesisMethod tag v x t +allBestThreadIO (0 :: Int) tree = bestThreadIO stepLimit tree +allBestThreadIO n tree = do + subForest <- positiveSubForestIO tree + leafs <- mapM (allBestThreadIO (n - 1)) subForest return $ bestLeaf tree leafs -bestLeaf :: (SynthesisMethodConstraints tag v x t) => DefTree tag v x t -> [DefTree tag v x t] -> DefTree tag v x t +bestLeaf :: (VarValTime v x t, UnitTag tag) => DefTree tag v x t -> [DefTree tag v x t] -> DefTree tag v x t bestLeaf tree leafs = let successLeafs = filter (\node -> isComplete node && isLeaf node) leafs in case successLeafs of @@ -136,28 +131,3 @@ bestLeaf tree leafs = minimumOn (\Tree{sState = SynthesisState{sTarget}} -> (processDuration sTarget, puSize sTarget)) successLeafs - -{- | Shortcut for constraints in signatures of synthesis method functions. -This used to be (VarValTime v x t, UnitTag tag). See below for more info. --} -type SynthesisMethodConstraints tag v x t = (VarValTimeJSON v x t, ToJSON tag, UnitTag tag) - --- FIXME: Validate the type above, its usages and meaning in the context of changes described below. --- --- Ilya Burakov is not sure why signatures of synthesis method functions were explicitly defined --- (not inferred) and why they are what they are, but introduction of JSON body formatting --- for ML backend node scoring requests in NITTA.Synthesis.Explore module forced to add JSON-related --- constraints to them. --- --- Also, it has spilled to Default interface in NITTA.Synthesis. See usages of --- SynthesisMethodConstraints for all related changes. --- --- Effectvely, those constraints were added: --- - ToJSONKey v, ToJSON v, ToJSON x, ToJSON t (via ValValTime -> ValValTimeJSON) --- - ToJSON tag (explicitly) --- --- Related chain of dependencies: --- stateOfTheArtSynthesisIO -> bestThreadIO (or others) -> positiveSubForestIO -> subForestIO -> --- predictScoresIO -> ScoringInput -> NodeView --- --- Not sure if it's the right way to do it, but it works for now. Please, validate and fix if needed. diff --git a/src/NITTA/Synthesis/Steps/Allocation.hs b/src/NITTA/Synthesis/Steps/Allocation.hs index 7113597a3..ad0cc43d9 100644 --- a/src/NITTA/Synthesis/Steps/Allocation.hs +++ b/src/NITTA/Synthesis/Steps/Allocation.hs @@ -55,7 +55,7 @@ data AllocationMetrics = AllocationMetrics instance ToJSON AllocationMetrics instance - (UnitTag tag) => + UnitTag tag => SynthesisDecisionCls (SynthesisState (TargetSystem (BusNetwork tag v x t) tag v x t) tag v x t) (TargetSystem (BusNetwork tag v x t) tag v x t) diff --git a/src/NITTA/Synthesis/Types.hs b/src/NITTA/Synthesis/Types.hs index 18c3da3c2..52a27a63a 100644 --- a/src/NITTA/Synthesis/Types.hs +++ b/src/NITTA/Synthesis/Types.hs @@ -1,6 +1,7 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE NoMonomorphismRestriction #-} {- | @@ -35,14 +36,17 @@ module NITTA.Synthesis.Types ( (), targetUnit, targetDFG, + defScore, ) where import Control.Concurrent.STM (TMVar) import Data.Aeson (ToJSON, toJSON) import Data.Default import Data.List.Split +import Data.Map.Strict (Map) import Data.Map.Strict qualified as M import Data.Set qualified as S +import Data.Text (Text) import Data.Typeable import NITTA.Intermediate.Analysis (ProcessWave) import NITTA.Intermediate.Types @@ -123,9 +127,11 @@ data SynthesisDecision ctx m where Root :: SynthesisDecision ctx m SynthesisDecision :: (Typeable p, SynthesisDecisionCls ctx m o d p, Show d, ToJSON p, Viewable d DecisionView) => - {option :: o, decision :: d, metrics :: p, score :: Float} -> + {option :: o, decision :: d, metrics :: p, scores :: Map Text Float} -> SynthesisDecision ctx m +defScore = (M.! "default") . scores + class SynthesisDecisionCls ctx m o d p | ctx o -> m d p where decisions :: ctx -> o -> [(d, m)] parameters :: ctx -> o -> d -> p diff --git a/src/NITTA/UIBackend/REST.hs b/src/NITTA/UIBackend/REST.hs index 14bb816e8..8091de17d 100644 --- a/src/NITTA/UIBackend/REST.hs +++ b/src/NITTA/UIBackend/REST.hs @@ -41,7 +41,6 @@ import NITTA.Project.TestBench import NITTA.Synthesis import NITTA.Synthesis.Analysis import NITTA.UIBackend.Timeline -import NITTA.UIBackend.Types import NITTA.UIBackend.ViewHelper import NITTA.UIBackend.VisJS (VisJS, algToVizJS) import NITTA.Utils @@ -51,6 +50,14 @@ import Servant.Docs import System.Directory import System.FilePath +data BackendCtx tag v x t = BackendCtx + { root :: DefTree tag v x t + -- ^ root synthesis node + , receivedValues :: [(v, [x])] + -- ^ lists of received by IO values + , outputPath :: String + } + type SynthesisAPI tag v x t = ( Description "Get whole synthesis tree" :> "synthesisTree" @@ -96,10 +103,10 @@ type SynthesisTreeNavigationAPI tag v x t = ) ) -synthesisTreeNavigation ctx@BackendCtx{root} sid = - liftIO (map view <$> getTreePathIO ctx root sid) - :<|> liftIO (fmap view . sParent . sState <$> getTreeIO ctx root sid) - :<|> liftIO (map view <$> (subForestIO ctx =<< getTreeIO ctx root sid)) +synthesisTreeNavigation BackendCtx{root} sid = + liftIO (map view <$> getTreePathIO root sid) + :<|> liftIO (fmap view . sParent . sState <$> getTreeIO root sid) + :<|> liftIO (map view <$> (subForestIO =<< getTreeIO root sid)) type NodeInspectionAPI tag v x t = Summary "Synthesis node inspection" @@ -137,13 +144,13 @@ type NodeInspectionAPI tag v x t = ) nodeInspection ctx@BackendCtx{root} sid = - liftIO (view <$> getTreeIO ctx root sid) - :<|> liftIO (algToVizJS . functions . targetDFG <$> getTreeIO ctx root sid) - :<|> liftIO (processTimelines . process . targetUnit <$> getTreeIO ctx root sid) - :<|> liftIO (view . process . targetUnit <$> getTreeIO ctx root sid) - :<|> (\tag -> liftIO (view . process . (M.! tag) . bnPus . targetUnit <$> getTreeIO ctx root sid)) + liftIO (view <$> getTreeIO root sid) + :<|> liftIO (algToVizJS . functions . targetDFG <$> getTreeIO root sid) + :<|> liftIO (processTimelines . process . targetUnit <$> getTreeIO root sid) + :<|> liftIO (view . process . targetUnit <$> getTreeIO root sid) + :<|> (\tag -> liftIO (view . process . (M.! tag) . bnPus . targetUnit <$> getTreeIO root sid)) :<|> liftIO (dbgEndpointOptions <$> debug ctx sid) - :<|> liftIO (microarchitectureDesc . targetUnit <$> getTreeIO ctx root sid) + :<|> liftIO (microarchitectureDesc . targetUnit <$> getTreeIO root sid) :<|> debug ctx sid type SynthesisMethodsAPI tag v x t = @@ -159,10 +166,10 @@ type SynthesisMethodsAPI tag v x t = :<|> "smartBindSynthesisIO" :> Post '[JSON] Sid ) -synthesisMethods ctx@BackendCtx{root} sid = - liftIO (sID <$> (stateOfTheArtSynthesisIO ctx =<< getTreeIO ctx root sid)) - :<|> liftIO (sID <$> (simpleSynthesisIO ctx =<< getTreeIO ctx root sid)) - :<|> liftIO (sID <$> (smartBindSynthesisIO ctx =<< getTreeIO ctx root sid)) +synthesisMethods BackendCtx{root} sid = + liftIO (sID <$> (stateOfTheArtSynthesisIO () =<< getTreeIO root sid)) + :<|> liftIO (sID <$> (simpleSynthesisIO =<< getTreeIO root sid)) + :<|> liftIO (sID <$> (smartBindSynthesisIO =<< getTreeIO root sid)) type SynthesisPracticesAPI tag v x t = Summary "SynthesisPractice is a set of small elements of the synthesis process." @@ -187,11 +194,11 @@ type SynthesisPracticesAPI tag v x t = ) ) -synthesisPractices ctx@BackendCtx{root} sid = - liftIO (sID <$> (bestStepIO ctx =<< getTreeIO ctx root sid)) - :<|> liftIO (sID <$> (obviousBindThreadIO ctx =<< getTreeIO ctx root sid)) - :<|> liftIO (sID <$> (allBindsAndRefsIO ctx =<< getTreeIO ctx root sid)) - :<|> (\deep -> liftIO (sID <$> (allBestThreadIO ctx deep =<< getTreeIO ctx root sid))) +synthesisPractices BackendCtx{root} sid = + liftIO (sID <$> (bestStepIO =<< getTreeIO root sid)) + :<|> liftIO (sID <$> (obviousBindThreadIO =<< getTreeIO root sid)) + :<|> liftIO (sID <$> (allBindsAndRefsIO =<< getTreeIO root sid)) + :<|> (\deep -> liftIO (sID <$> (allBestThreadIO deep =<< getTreeIO root sid))) type TestBenchAPI v x = Summary "Get the report of testbench execution for the current node." @@ -200,8 +207,8 @@ type TestBenchAPI v x = :> QueryParam' '[Required] "loopsNumber" Int :> Post '[JSON] (TestbenchReport v x) -testBench ctx@BackendCtx{root, receivedValues, outputPath} sid pName loopsNumber = liftIO $ do - tree <- getTreeIO ctx root sid +testBench BackendCtx{root, receivedValues, outputPath} sid pName loopsNumber = liftIO $ do + tree <- getTreeIO root sid pInProjectNittaPath <- either (error . T.unpack) id <$> collectNittaPath defProjectTemplates unless (isComplete tree) $ error "test bench not allow for non complete synthesis" pwd <- getCurrentDirectory @@ -257,10 +264,10 @@ instance ToSample (EndpointSt String (TimeConstraint Int)) where toSamples _ = n instance ToSample Char where toSamples _ = noSamples -instance (UnitTag tag) => ToSample (Lock tag) where +instance UnitTag tag => ToSample (Lock tag) where toSamples _ = singleSample Lock{locked = "b", lockBy = "a"} -instance {-# OVERLAPS #-} (UnitTag tag) => ToSample [(T.Text, [Lock tag])] where +instance {-# OVERLAPS #-} UnitTag tag => ToSample [(T.Text, [Lock tag])] where toSamples _ = singleSample [("PU or function tag", [Lock{locked = "b", lockBy = "a"}])] type DebugAPI tag v t = @@ -269,8 +276,8 @@ type DebugAPI tag v t = \(see NITTA.UIBackend.REST.Debug)" :> Get '[JSON] (Debug tag v t) -debug ctx@BackendCtx{root} sid = liftIO $ do - tree <- getTreeIO ctx root sid +debug BackendCtx{root} sid = liftIO $ do + tree <- getTreeIO root sid let dbgFunctionLocks = map (\f -> (f, locks f)) $ functions $ targetUnit tree already = transferred $ targetUnit tree return @@ -316,7 +323,7 @@ instance ToParam (QueryParam' mods "loopsNumber" Int) where instance ToSample Sid where toSamples _ = [("The synthesis node path from the root by edge indexes.", Sid [1, 1, 3])] -instance (Time t) => ToSample (Process t StepInfoView) where +instance Time t => ToSample (Process t StepInfoView) where toSamples _ = [ ( "for process unit" @@ -344,7 +351,7 @@ instance (Time t) => ToSample (Process t StepInfoView) where ) ] -instance (UnitTag tag) => ToSample (MicroarchitectureDesc tag) where +instance UnitTag tag => ToSample (MicroarchitectureDesc tag) where toSamples _ = let bn :: BusNetwork tag String (IntX 32) Int = defineNetwork "net1" Sync $ do addCustom "fram1" (framWithSize 16) FramIO diff --git a/src/NITTA/UIBackend/Timeline.hs b/src/NITTA/UIBackend/Timeline.hs index 1c4353651..26c3aae99 100644 --- a/src/NITTA/UIBackend/Timeline.hs +++ b/src/NITTA/UIBackend/Timeline.hs @@ -41,7 +41,7 @@ data TimelineWithViewPoint t = TimelineWithViewPoint } deriving (Generic) -instance (Time t) => Show (ProcessTimelines t) where +instance Time t => Show (ProcessTimelines t) where show ProcessTimelines{timelines} = let vpLength = maximum $ map (length . show . timelineViewpoint) timelines normalizeVP s = s ++ replicate (vpLength - length s) ' ' @@ -64,7 +64,7 @@ data TimelinePoint t = TimelinePoint } deriving (Generic) -instance {-# OVERLAPS #-} (Time t) => Show [TimelinePoint t] where +instance {-# OVERLAPS #-} Time t => Show [TimelinePoint t] where show [] = "." -- show [TimelinePoint{ pInfo }] | EndpointRoleStep Source{} <- descent pDesc = "^" -- show ( Single Step{ pDesc } ) | EndpointRoleStep Target{} <- descent pDesc = "v" diff --git a/src/NITTA/UIBackend/ViewHelper.hs b/src/NITTA/UIBackend/ViewHelper.hs index 6db04742a..1d08250fc 100644 --- a/src/NITTA/UIBackend/ViewHelper.hs +++ b/src/NITTA/UIBackend/ViewHelper.hs @@ -64,7 +64,7 @@ data TreeView a = TreeNodeView } deriving (Generic, Show) -instance (ToJSON a) => ToJSON (TreeView a) +instance ToJSON a => ToJSON (TreeView a) instance ToSample (TreeView ShortNodeView) where toSamples _ = @@ -202,7 +202,8 @@ instance (UnitTag tag, VarValTimeJSON v x t) => Viewable (DefTree tag v x t) (No sDecision , score = ( \case - SynthesisDecision{score} -> score + -- TODO: show all avaialable scores + sd@SynthesisDecision{} -> defScore sd _ -> 0 ) sDecision diff --git a/stack.yaml b/stack.yaml index 0fa4581b7..8aac7e95d 100644 --- a/stack.yaml +++ b/stack.yaml @@ -15,7 +15,7 @@ # resolver: # name: custom-snapshot # location: "./custom-snapshot.yaml" -resolver: lts-19.18 +resolver: lts-20.21 # User packages to be built. # Various formats can be used as shown in the example below. @@ -36,16 +36,17 @@ resolver: lts-19.18 # non-dependency (i.e. a user package), and its test suites and benchmarks # will not be run. This is useful for tweaking upstream packages. packages: -- '.' + - "." # Dependency packages to be pulled from upstream that are not in the resolver # (e.g., acme-missiles-0.3) extra-deps: -- alex-tools-0.4@sha256:7f24cb60ba88b04196965e78d7944d638b1a6034f0c9284bdf7d95e05c7be7c3 -- git: https://github.com/mirokuratczyk/htoml.git - commit: 33971287445c5e2531d9605a287486dfc3cbe1da -- language-lua-0.11.0.1@sha256:b3bdf864965279f830323d5ed7d38073b2687b3a07b7fa73aa13782560f48f89 -- servant-js-0.9.4.2@sha256:b2d973b43bfa69f35900bf46e989348b7300c398b599ddf74b90a06173c023f2 + - alex-tools-0.4@sha256:7f24cb60ba88b04196965e78d7944d638b1a6034f0c9284bdf7d95e05c7be7c3 + - git: https://github.com/mirokuratczyk/htoml.git + commit: 33971287445c5e2531d9605a287486dfc3cbe1da + - language-lua-0.11.0.1@sha256:b3bdf864965279f830323d5ed7d38073b2687b3a07b7fa73aa13782560f48f89 + - servant-js-0.9.4.2@sha256:b2d973b43bfa69f35900bf46e989348b7300c398b599ddf74b90a06173c023f2 + - ginger-0.10.4.0 allow-newer: true diff --git a/stack.yaml.lock b/stack.yaml.lock index 0f66c038d..f958211e9 100644 --- a/stack.yaml.lock +++ b/stack.yaml.lock @@ -7,38 +7,45 @@ packages: - completed: hackage: alex-tools-0.4@sha256:7f24cb60ba88b04196965e78d7944d638b1a6034f0c9284bdf7d95e05c7be7c3,995 pantry-tree: - size: 262 sha256: 98715aed391e9e76b61cc0491843b2a860d0aba0d062a4a696574632b8ad6753 + size: 262 original: hackage: alex-tools-0.4@sha256:7f24cb60ba88b04196965e78d7944d638b1a6034f0c9284bdf7d95e05c7be7c3 - completed: - name: htoml - version: 1.0.0.3 + commit: 33971287445c5e2531d9605a287486dfc3cbe1da git: https://github.com/mirokuratczyk/htoml.git + name: htoml pantry-tree: - size: 10869 sha256: 7f4578a5e8c97ff32f6e136750447b2769e80e8ccc2eb9b92bb2b3a02acccf9d - commit: 33971287445c5e2531d9605a287486dfc3cbe1da + size: 10869 + version: 1.0.0.3 original: - git: https://github.com/mirokuratczyk/htoml.git commit: 33971287445c5e2531d9605a287486dfc3cbe1da + git: https://github.com/mirokuratczyk/htoml.git - completed: hackage: language-lua-0.11.0.1@sha256:b3bdf864965279f830323d5ed7d38073b2687b3a07b7fa73aa13782560f48f89,2919 pantry-tree: - size: 3473 sha256: e4f6fcc91fbde9c6a6e3c9281a51dee49706d5868866695cc09f9bdf031799d4 + size: 3473 original: hackage: language-lua-0.11.0.1@sha256:b3bdf864965279f830323d5ed7d38073b2687b3a07b7fa73aa13782560f48f89 - completed: hackage: servant-js-0.9.4.2@sha256:b2d973b43bfa69f35900bf46e989348b7300c398b599ddf74b90a06173c023f2,3522 pantry-tree: - size: 961 sha256: ba201c0f2aa274afa40dfd3c31abe6e5b0791f14320f7167e750c57508dc70f6 + size: 961 original: hackage: servant-js-0.9.4.2@sha256:b2d973b43bfa69f35900bf46e989348b7300c398b599ddf74b90a06173c023f2 +- completed: + hackage: ginger-0.10.4.0@sha256:7ca26896395fda951443ed3bba5f24fd6036c44167d424e744d1e0c8c6b138fd,3855 + pantry-tree: + sha256: c3d93b51ba0e76853b217debb662cdcf5eca78288d49c809f823be27a072bfbf + size: 1430 + original: + hackage: ginger-0.10.4.0 snapshots: - completed: - size: 619164 - url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/19/18.yaml - sha256: 65b9809265860e085b4f61d4eb00d5d73e41190693620385a69cc9d9df7a901d - original: lts-19.18 + sha256: 401a0e813162ba62f04517f60c7d25e93a0f867f94a902421ebf07d1fb5a8c46 + size: 650044 + url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/21.yaml + original: lts-20.21 diff --git a/test/NITTA/Frontends/Lua/Tests.hs b/test/NITTA/Frontends/Lua/Tests.hs index 577482f73..490f6ffe1 100644 --- a/test/NITTA/Frontends/Lua/Tests.hs +++ b/test/NITTA/Frontends/Lua/Tests.hs @@ -1,10 +1,12 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE NoMonomorphismRestriction #-} +{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} {- | Module : NITTA.Frontends.Lua.Tests @@ -583,9 +585,10 @@ test_trace_features = | 4 | 4 | | 5 | 5 |\n |] - -- TODO: traceLuaSimulationTestCase pInt "variable before and after changing" ] +-- TODO: traceLuaSimulationTestCase pInt "variable before and after changing" + test_examples = [ unitTestCase "teacup io wait" def $ do setNetwork $ microarch Sync SlaveSPI diff --git a/test/NITTA/Intermediate/Tests/Functions.hs b/test/NITTA/Intermediate/Tests/Functions.hs index 2025a9f54..caa8d6c83 100644 --- a/test/NITTA/Intermediate/Tests/Functions.hs +++ b/test/NITTA/Intermediate/Tests/Functions.hs @@ -30,7 +30,7 @@ inputVarGen = I . T.pack <$> vectorOf varNameSize (elements ['a' .. 'z']) -- singleton. uniqueVars fb = S.null (inputs fb `intersection` outputs fb) -instance (Arbitrary x) => Arbitrary (Loop T.Text x) where +instance Arbitrary x => Arbitrary (Loop T.Text x) where arbitrary = suchThat (Loop <$> (X <$> arbitrary) <*> outputVarsGen <*> inputVarGen) uniqueVars instance Arbitrary (Buffer T.Text x) where @@ -57,5 +57,5 @@ instance Arbitrary (Acc T.Text x) where instance Arbitrary (IntX m) where arbitrary = IntX <$> choose (0, 256) -instance (Arbitrary x) => Arbitrary (Attr x) where +instance Arbitrary x => Arbitrary (Attr x) where arbitrary = Attr <$> arbitrary <*> arbitrary diff --git a/test/NITTA/Model/ProcessorUnits/Broken/Tests.hs b/test/NITTA/Model/ProcessorUnits/Broken/Tests.hs index 2f9f1a8b0..380d571b0 100644 --- a/test/NITTA/Model/ProcessorUnits/Broken/Tests.hs +++ b/test/NITTA/Model/ProcessorUnits/Broken/Tests.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE QuasiQuotes #-} diff --git a/test/NITTA/Model/ProcessorUnits/Divider/Tests.hs b/test/NITTA/Model/ProcessorUnits/Divider/Tests.hs index a8b49f9dc..7d09ae603 100644 --- a/test/NITTA/Model/ProcessorUnits/Divider/Tests.hs +++ b/test/NITTA/Model/ProcessorUnits/Divider/Tests.hs @@ -185,12 +185,13 @@ tests = end f(1024, 1024) |] - -- FIXME: Auto text can't work correctly, because processGen don't take into account the - -- facts that some variables may go out. - -- , testProperty "isUnitSynthesisFinish" $ isUnitSynthesisFinish <$> dividerGen - -- , testProperty "coSimulation" $ fmap (coSimulation "prop_simulation_divider") $ initialCycleCntxGen =<< dividerGen ] where + -- FIXME: Auto text can't work correctly, because processGen don't take into account the + -- facts that some variables may go out. + -- , testProperty "isUnitSynthesisFinish" $ isUnitSynthesisFinish <$> dividerGen + -- , testProperty "coSimulation" $ fmap (coSimulation "prop_simulation_divider") $ initialCycleCntxGen =<< dividerGen + u2 = def :: Divider String (Attr (IntX 16)) Int -- where diff --git a/test/NITTA/Model/ProcessorUnits/Tests/DSL.hs b/test/NITTA/Model/ProcessorUnits/Tests/DSL.hs index 702ad58b7..19e978cb0 100644 --- a/test/NITTA/Model/ProcessorUnits/Tests/DSL.hs +++ b/test/NITTA/Model/ProcessorUnits/Tests/DSL.hs @@ -314,7 +314,7 @@ assertBindFullness = do fInps = unionsMap inputs fs show' = show . S.map toString -assertAllEndpointRoles :: (Var v) => [EndpointRole v] -> PUStatement pu v x t () +assertAllEndpointRoles :: Var v => [EndpointRole v] -> PUStatement pu v x t () assertAllEndpointRoles roles = do UnitTestState{unit} <- get let opts = S.fromList $ map epRole $ endpointOptions unit @@ -329,7 +329,7 @@ assertEndpoint a b role = do Nothing -> lift $ assertFailure $ "assertEndpoint: '" <> show ep <> "' not defined in: " <> show opts Just _ -> return () -assertLocks :: (Locks pu v) => [Lock v] -> PUStatement pu v x t () +assertLocks :: Locks pu v => [Lock v] -> PUStatement pu v x t () assertLocks expectLocks = do UnitTestState{unit} <- get let actualLocks0 = locks unit @@ -515,7 +515,12 @@ mkConstantFolding old new = return $ ConstantFolding old new assertRefactor :: (Typeable ref, Eq ref, Show ref) => ref -> TSStatement x () assertRefactor ref = do refactors <- filter isRefactorStep . map (descent . pDesc) . steps . process . unit <$> get - case L.find (\(RefactorStep r) -> Just ref == cast r) refactors of + case L.find + ( \case + (RefactorStep r) -> Just ref == cast r + _ -> error "assertRefactor: impossible" + ) + refactors of Nothing -> lift $ assertFailure $ "Refactor not present: " <> show ref <> " in " <> show refactors Just _ -> return () @@ -535,7 +540,14 @@ mkAllocationOptions networkTag puTags = assertAllocation :: (Typeable a, Eq a, Show a) => Int -> a -> TSStatement x () assertAllocation number alloc = do allocations <- filter isAllocationStep . map (descent . pDesc) . steps . process . unit <$> get - let matched = length $ filter (\(AllocationStep a) -> Just alloc == cast a) allocations + let matched = + length $ + filter + ( \case + (AllocationStep a) -> Just alloc == cast a + _ -> error "assertAllocation: internal error" + ) + allocations when (matched /= number) $ lift $ assertFailure @@ -564,7 +576,7 @@ assertAllocationOptions options = do #{ showArray actual } |] -assertPU :: (Typeable a) => T.Text -> Proxy a -> TSStatement x () +assertPU :: Typeable a => T.Text -> Proxy a -> TSStatement x () assertPU tag puProxy = do UnitTestState{unit = TargetSystem{mUnit}} <- get let pu = M.lookup tag $ bnPus mUnit @@ -620,7 +632,7 @@ assertSuccessReport report@TestbenchReport{tbStatus} = synthesizeAndCoSim :: TSStatement x () synthesizeAndCoSim = do - synthesis $ stateOfTheArtSynthesisIO def + synthesis $ stateOfTheArtSynthesisIO () assertSynthesisComplete assertTargetSystemCoSimulation @@ -639,7 +651,7 @@ traceEndpoints = do UnitTestState{unit} <- get lift $ putListLn "Endpoints: " $ endpointOptions unit -traceProcess :: (ProcessorUnit u v x Int) => Statement u v x () +traceProcess :: ProcessorUnit u v x Int => Statement u v x () traceProcess = do UnitTestState{unit} <- get lift $ putStrLn $ "Process: " <> show (pretty $ process unit) diff --git a/test/NITTA/Model/Tests/Microarchitecture.hs b/test/NITTA/Model/Tests/Microarchitecture.hs index 90e3d2639..cff73c1c9 100644 --- a/test/NITTA/Model/Tests/Microarchitecture.hs +++ b/test/NITTA/Model/Tests/Microarchitecture.hs @@ -63,7 +63,7 @@ pFX42_64 = Proxy :: Proxy (FX 42 64) pFX48_64 = Proxy :: Proxy (FX 48 64) -basic :: (Integral x, Val x) => Proxy x -> BusNetwork T.Text T.Text x Int +basic :: Val x => Proxy x -> BusNetwork T.Text T.Text x Int basic _proxy = defineNetwork "net1" ASync $ do add "fram1" FramIO add "fram2" FramIO @@ -75,7 +75,7 @@ basic _proxy = defineNetwork "net1" ASync $ do march = basic pInt -- | Simple microarchitecture with broken PU for negative tests -maBroken :: (Integral x, Val x) => Broken T.Text x Int -> BusNetwork T.Text T.Text x Int +maBroken :: Val x => Broken T.Text x Int -> BusNetwork T.Text T.Text x Int maBroken brokenPU = defineNetwork "net1" ASync $ do add "fram1" FramIO add "fram2" FramIO diff --git a/test/NITTA/Tests.hs b/test/NITTA/Tests.hs index 703536142..56b542b5c 100644 --- a/test/NITTA/Tests.hs +++ b/test/NITTA/Tests.hs @@ -137,16 +137,16 @@ test_manual = f1 = F.add "a" "b" ["c", "d"] :: F T.Text Int -patchP :: (Patch a (T.Text, T.Text)) => (T.Text, T.Text) -> a -> a +patchP :: Patch a (T.Text, T.Text) => (T.Text, T.Text) -> a -> a patchP = patch -patchI :: (Patch a (I T.Text, I T.Text)) => (I T.Text, I T.Text) -> a -> a +patchI :: Patch a (I T.Text, I T.Text) => (I T.Text, I T.Text) -> a -> a patchI = patch -patchO :: (Patch a (O T.Text, O T.Text)) => (O T.Text, O T.Text) -> a -> a +patchO :: Patch a (O T.Text, O T.Text) => (O T.Text, O T.Text) -> a -> a patchO = patch -patchC :: (Patch a (Changeset T.Text)) => Changeset T.Text -> a -> a +patchC :: Patch a (Changeset T.Text) => Changeset T.Text -> a -> a patchC = patch test_patchFunction = @@ -176,16 +176,15 @@ test_patchFunction = @?= "a + b' = c = d'" ] -pu = - let Right pu' = - tryBind - f1 - PU - { diff = def - , unit = def :: Accum T.Text Int Int - , uEnv = undefined - } - in pu' +pu = case tryBind + f1 + PU + { diff = def + , unit = def :: Accum T.Text Int Int + , uEnv = undefined + } of + Right pu_ -> pu_ + Left err -> error $ show err test_patchEndpointOptions = [ testCase "non-patched function options" $