255 lines
9.4 KiB
Python
255 lines
9.4 KiB
Python
#
|
|
# Copyright (c) 2010-2024 Antmicro
|
|
#
|
|
# This file is licensed under the MIT License.
|
|
# Full license text is available in 'licenses/MIT.txt'.
|
|
#
|
|
|
|
import math
|
|
import random
|
|
import time
|
|
from typing import List
|
|
|
|
|
|
class CacheLine:
|
|
"""
|
|
Represents a cache line in a cache set.
|
|
|
|
tag (int): The tag of the cache line.
|
|
use_count (int): Used for replacement policies.
|
|
insertion_time (float): The time when the cache line was inserted.
|
|
last_access_time (float): The time when the cache line was last accessed.
|
|
free (bool): Indicates if the line contains valid data.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.init()
|
|
|
|
def init(self, tag: int = 0, free: bool = True):
|
|
self.tag = tag
|
|
self.free = free
|
|
self.use_count: int = 0
|
|
self.insertion_time: float = time.time()
|
|
self.last_access_time: float = time.time()
|
|
|
|
def __str__(self) -> str:
|
|
return f"[CacheLine]: tag: {self.tag:b}, free: {self.free}, use: {self.use_count}, insertion: {self.insertion_time}, last access: {self.last_access_time}"
|
|
|
|
|
|
class Cache:
|
|
"""
|
|
Cache memory model.
|
|
|
|
name (str): Cache name, used in the `printd` debug helpers.
|
|
cache_width (int): log2(cache_size).
|
|
block_width (int): log2(cache_block_size).
|
|
memory_width (int): log2(memory_size).
|
|
|
|
lines_per_set (int): cache mapping policy selection:
|
|
* -1 for fully associative
|
|
* 1 for direct mapping
|
|
* 2^n for n-way associativity
|
|
|
|
replacement_policy (str | None): Selected line eviction policy (defaults to None):
|
|
* FIFO: first in first out
|
|
* LRU: least recently used
|
|
* LFU: least frequently used
|
|
* None: random
|
|
|
|
debug (bool): print debug messages (defaults to False).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
cache_width: int,
|
|
block_width: int,
|
|
memory_width: int,
|
|
lines_per_set: int,
|
|
replacement_policy: str | None = None,
|
|
debug: bool = False
|
|
):
|
|
self.name = name
|
|
self.debug = debug
|
|
|
|
# Width of the memories
|
|
self._cache_width = cache_width
|
|
self._block_width = block_width
|
|
self._memory_width = memory_width
|
|
|
|
# Convert width to size in bytes
|
|
self._cache_size = 2 ** self._cache_width
|
|
self._block_size = 2 ** self._block_width
|
|
self._memory_size = 2 ** self._memory_width
|
|
|
|
self._num_lines = self._cache_size // self._block_size
|
|
self._lines = [CacheLine() for i in range(self._num_lines)]
|
|
|
|
if lines_per_set == -1:
|
|
# special configuration case for fully associative mapping
|
|
lines_per_set = self._num_lines
|
|
|
|
if not (lines_per_set & (lines_per_set - 1) == 0) or lines_per_set == 0:
|
|
raise Exception('Lines per set must be a power of two (1, 2, 4, 8, ...)')
|
|
|
|
self._lines_per_set = lines_per_set
|
|
self._sets = self._num_lines // lines_per_set
|
|
self._set_width = int(math.log(self._sets, 2))
|
|
|
|
self._replacement_policy = replacement_policy if replacement_policy is not None else 'RAND'
|
|
|
|
# Statistics
|
|
self.misses = 0
|
|
self.hits = 0
|
|
self.invalidations = 0
|
|
self.flushes = 0
|
|
|
|
def read(self, addr: int) -> None:
|
|
sset = self._addr_get_set(addr)
|
|
line = self._line_lookup(addr)
|
|
self.printd(f'[read] attempt to fetch {hex(addr)} (set {sset})')
|
|
|
|
if line and not line.free:
|
|
self.printd('[read] rhit')
|
|
self.hits += 1
|
|
line.use_count += 1
|
|
line.last_access_time = time.time()
|
|
else:
|
|
self.printd('[read] rmiss')
|
|
self.misses += 1
|
|
self._load(addr)
|
|
|
|
def write(self, addr: int) -> None:
|
|
sset = self._addr_get_set(addr)
|
|
line = self._line_lookup(addr)
|
|
self.printd(f'[write] attempted write to {hex(addr)} (set {sset})')
|
|
|
|
if line:
|
|
self.printd('[write] whit')
|
|
self.hits += 1
|
|
line.last_access_time = time.time()
|
|
else:
|
|
self.printd('[write] wmiss')
|
|
self.misses += 1
|
|
self._load(addr)
|
|
|
|
def flush(self) -> None:
|
|
self.printd('[flush] flushing all lines!')
|
|
self.flushes += 1
|
|
self._lines = [CacheLine() for i in range(self._num_lines)]
|
|
|
|
def _select_evicted_index(self, lines_in_set: list) -> int:
|
|
if self._replacement_policy == 'RAND':
|
|
return random.randint(0, self._lines_per_set - 1)
|
|
elif self._replacement_policy == 'LFU':
|
|
return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].use_count)
|
|
elif self._replacement_policy == 'FIFO':
|
|
return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].insertion_time)
|
|
elif self._replacement_policy == 'LRU':
|
|
return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].last_access_time)
|
|
else:
|
|
raise Exception(f"Unknown replacement policy: {self._replacement_policy}! Exiting!")
|
|
|
|
def _load(self, addr: int) -> None:
|
|
self.printd(f'[load] loading @ {hex(addr)} to cache from Main Memory')
|
|
tag = self._addr_get_tag(addr)
|
|
set_index = self._addr_get_set(addr)
|
|
lines_in_set = self._get_lines_in_set(set_index)
|
|
|
|
# Determine the index of the cache line to load into
|
|
free_line_index = next((index for index, obj in enumerate(lines_in_set) if obj.free), None)
|
|
if free_line_index is not None:
|
|
index = free_line_index
|
|
self.printd(f'[load] loaded new cache index: {free_line_index} in the set {set_index}')
|
|
else:
|
|
self.printd(f"[load] lines in set {set_index}:")
|
|
self.printd(' selecting a line to invalidate:\n', '\n'.join(f'{index}: {line}' for index, line in enumerate(lines_in_set)), sep='')
|
|
index = self._select_evicted_index(lines_in_set)
|
|
self.printd(f'[load] invalidated index: {index} in the set {set_index}')
|
|
self.invalidations += 1
|
|
|
|
lines_in_set[index].init(tag, False)
|
|
|
|
@staticmethod
|
|
def _extract_bits(value: int, start_bit: int, end_bit: int) -> int:
|
|
num_bits = end_bit - start_bit + 1
|
|
mask = ((1 << num_bits) - 1) << start_bit
|
|
extracted_bits = (value & mask) >> start_bit
|
|
return extracted_bits
|
|
|
|
def _addr_get_tag(self, addr: int) -> int:
|
|
start = self._block_width + self._set_width
|
|
end = self._memory_width
|
|
return self._extract_bits(addr, start, end)
|
|
|
|
def _addr_get_set(self, addr: int) -> int:
|
|
start = self._block_width
|
|
end = self._block_width + self._set_width - 1
|
|
return self._extract_bits(addr, start, end)
|
|
|
|
def _addr_get_offset(self, addr: int) -> int:
|
|
start = 0
|
|
end = self._block_width - 1
|
|
return self._extract_bits(addr, start, end)
|
|
|
|
def _get_lines_in_set(self, set_index: int) -> List[CacheLine]:
|
|
line_index = set_index * self._lines_per_set
|
|
return self._lines[
|
|
line_index:
|
|
line_index + self._lines_per_set
|
|
]
|
|
|
|
def _line_lookup(self, addr: int) -> CacheLine | None:
|
|
tag = self._addr_get_tag(addr)
|
|
lines_in_set = self._get_lines_in_set(self._addr_get_set(addr))
|
|
return next((line for line in lines_in_set if line.tag == tag), None)
|
|
|
|
def printd(self, *args, **kwargs):
|
|
if self.debug:
|
|
print(f'[{self.name}]', *args, **kwargs)
|
|
|
|
def print_addr_info(self, addr: int, format: str = 'hex') -> None:
|
|
convop = {'bin': bin, 'hex': hex, 'dec': int}.get(format, hex)
|
|
print(f'addr: {convop(addr)}')
|
|
print(f'tag : {convop(self._addr_get_tag(addr))}')
|
|
print(f'set : {convop(self._addr_get_set(addr))}')
|
|
print(f'off : {convop(self._addr_get_offset(addr))}')
|
|
|
|
def print_cache_info(self) -> None:
|
|
print(f'{self.name} configuration:')
|
|
print(f'Cache size: {self._cache_size} bytes')
|
|
print(f'Block size: {self._block_size} bytes')
|
|
print(f'Number of lines: {self._num_lines}')
|
|
print(f'Number of sets: {self._sets} ({self._lines_per_set} lines per set)')
|
|
print(f'Replacement policy: {self._replacement_policy if self._replacement_policy is not None else "RAND"}')
|
|
|
|
if self.debug:
|
|
print(f'Cache block width: {self._block_width} bits')
|
|
print(f'Addressable memory: {self._memory_size} bytes')
|
|
tag_width = self._memory_width - self._block_width - self._set_width
|
|
print('Addressing parameters:')
|
|
print(f'Tag: {tag_width} bits')
|
|
print(f'Set: {self._set_width} bits')
|
|
print(f'Block: {self._block_width} bits\n')
|
|
|
|
print()
|
|
|
|
def print_hmr(self) -> None:
|
|
ratio = (self.hits / ((self.hits + self.misses) if self.misses else 1)) * 100
|
|
print(f'Misses: {self.misses}')
|
|
print(f'Hits: {self.hits}')
|
|
print(f'Invalidations: {self.invalidations}')
|
|
print(f'Hit ratio: {round(ratio, 2)}%')
|
|
|
|
def print_debug_lines(self, include_empty_tags: bool = False) -> None:
|
|
tag_width = self._memory_width - self._block_width - self._set_width
|
|
print(f'tag: {tag_width} bits')
|
|
print(f'set: {self._set_width} bits')
|
|
print(f'block: {self._block_width} bits')
|
|
|
|
for id, line in enumerate(self._lines):
|
|
if line.tag or include_empty_tags:
|
|
print(line)
|
|
if self._lines_per_set and (id + 1) % self._lines_per_set == 0:
|
|
print()
|