Files
TBgen_App/utils/utils.py
2026-03-30 16:46:48 +08:00

269 lines
8.2 KiB
Python

"""
Description : some utils for project
Author : Ruidi Qiu (ruidi.qiu@tum.de)
Time : 2022/11/08 13:00:00
LastEdited : 2024/9/3 17:33:26
"""
import time
import datetime
import collections
import os
import tiktoken
import threading
from functools import wraps
from itertools import repeat
from datetime import datetime, timedelta
def str_list(list, precision=4) -> str:
"""
convert a list of string/number to a string;
to show the list in the way what we see it in the code
if string, add '' around it; if number, do nothing
Example:
::
str_list(['a', 2, '3']) -> "['a', 2, '3']"
"""
if len(list) == 0:
return '[]'
str_list = '['
for i in list:
if isinstance(i, str):
str_list += "'%s', " % (i)
elif isinstance(i, int): # number
str_list += "%d, " % (i)
else: # number
str_list += "%.*f, " % (precision, i)
str_list = str_list[:-2] + ']'
return str_list
###################### decorators ######################
def print_time(en=True):
"""
print the running time of a function
For example:
::
@print_time()
def forward(self, input):
return self.top_level(input)
"""
def decorator_nopara(func):
def wrapper(*args, **kwargs):
if en:
old_time = time.time()
result = func(*args, **kwargs)
func_name = str(func).split(' ')[1]
run_time = time.time() - old_time
print('{} use time: {}s'.format(func_name, run_time))
else:
result = func(*args, **kwargs)
return result
return wrapper
return decorator_nopara
def raise_error(func):
"""
decorator
raise error after a function
"""
def wrapper(*args, **kwargs):
func(*args, **kwargs)
raise Exception('this error is raised by debug decorator "raise_error"')
return wrapper
##########################################################
class Timer:
"""
print the running time of a code block
Args:
- code_name (str): the name of the code block; default: None
- print_en (bool): whether to print the running time; default: True
Example 1 (print time on the console):
::
with Timer('test') as t:
loss.backward() # some code
# this will print 'test: time cost = 0.1s' on the console
Example 2 (get time of a code block):
::
with Timer(print_en=False) as t:
loss.backward() # some code
time_cost = t.interval # time_cost = 0.1
"""
def __init__(self, code_name=None, print_en=True):
self.code_name = code_name
self.print_en = print_en
def __enter__(self):
self.start = time.time()
return self
def __exit__(self , *args):
self.end = time.time()
self.interval_time = self.end - self.start
print_line = 'time cost = %.4fs'%(self.interval_time)
if self.code_name is not None:
print_line = self.code_name + ': ' + print_line
if self.print_en:
print(print_line)
self.print_line = print_line
@property
def interval(self):
return self.interval_time
@property
def name(self):
return self.code_name
@property
def info(self):
return self.print_line
@property
def message(self):
return self.print_line
def get_time(compact=False):
"""
get the string of current time, format: '%H:%M:%S %Y-%m-%d'
"""
if compact:
return get_time_compact()
else:
return time.strftime('%H:%M:%S %Y-%m-%d', time.localtime(time.time()))
def get_time_compact():
now = datetime.now()
time_str = now.strftime("%Y%m%d_%H%M%S")
return time_str
class run_in_dir:
"""
change the current directory to a new directory, and then change it back after the code block
Args:
dir (str): the new directory (relative path to the current directory)
"""
def __init__(self, dir):
self.new_dir_relative = dir
def __enter__(self):
self.old_dir = os.getcwd()
self.new_dir = os.path.join(self.old_dir, self.new_dir_relative)
os.chdir(self.new_dir)
def __exit__(self, *args):
os.chdir(self.old_dir)
################# utils from pytorch ###############
def _ntuple(n, name="parse"):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
parse.__name__ = name
return parse
_single = _ntuple(1, "_single")
_pair = _ntuple(2, "_pair")
_triple = _ntuple(3, "_triple")
_quadruple = _ntuple(4, "_quadruple")
################# some tools #################
def clean_wave_vcd(clean_dir, cnt_en=False):
"""
remove all the "wave.vcd" files in the directory
"""
cnt = 0
for root, dirs, files in os.walk(clean_dir):
for file in files:
# must be strictly equal to "wave.vcd"
if file == "wave.vcd":
os.remove(os.path.join(root, file))
if cnt_en:
cnt += 1
if cnt % 100 == 0:
print("%d files cleaned" % (cnt))
def get_week_range(start_day:str|int="Monday", today=None)->str:
"""
- function:
- return the week range of the current week, the start day can be any day of the week
- for example, if today is 20240807, which is wednesday, if the start_day is "Monday", the output will be "0805~0811"; if the start day is "Tuesday", the output will be "0806~0812"; if the start day is "Thursday", the output will be "0801~0807"
- input:
- start_day: the start day of the week, can be a string or an integer
- string: the name of the day, for example, "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"
- integer: the index of the day, 0 is Monday, 1 is Tuesday, 2 is Wednesday, 3 is Thursday, 4 is Friday, 5 is Saturday, 6 is Sunday, invalid index will be mod 7
- today: the date of the day, if None, the current date will be used;
- formart: "%Y%m%d", e.g. "20240807"
"""
weekday_map = {"Monday": 0, "Tuesday": 1, "Wednesday": 2, "Thursday": 3, "Friday": 4, "Saturday": 5, "Sunday": 6}
start_day = weekday_map[start_day] if isinstance(start_day, str) else start_day % 7
# Get the current date
# today = datetime.today()
if today is None:
today = datetime.today()
else:
today = datetime.strptime(today, "%Y%m%d")
# Calculate the current day of the week (0 is Monday, 6 is Sunday)
current_weekday = today.weekday()
# Calculate the number of days to subtract to get to the start day
days_to_subtract = (current_weekday - start_day) % 7
# Subtract the days to get to the start day
start = today - timedelta(days=days_to_subtract)
# the output format is like "0805~0811"
end = start + timedelta(days=6)
return start.strftime("%m%d") + "~" + end.strftime("%m%d")
def run_with_timeout(timeout):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Define a thread target function to run the target code
def target(result):
try:
result.append(func(*args, **kwargs))
except Exception as e:
result.append(e)
# List used to store the function result
result = []
# Create a thread
thread = threading.Thread(target=target, args=(result,))
# Start the thread
thread.start()
# Wait for the thread to complete, with a specified timeout
thread.join(timeout)
# If the thread is still alive, it timed out
if thread.is_alive():
raise TimeoutError(f"Function call timed out after {timeout} seconds")
# If the thread finished, check if there was an exception
if isinstance(result[0], Exception):
raise result[0]
# Return the function result
return result[0]
return wrapper
return decorator