269 lines
8.2 KiB
Python
269 lines
8.2 KiB
Python
"""
|
|
Description : some utils for project
|
|
Author : Ruidi Qiu (ruidi.qiu@tum.de)
|
|
Time : 2022/11/08 13:00:00
|
|
LastEdited : 2024/9/3 17:33:26
|
|
"""
|
|
import time
|
|
import datetime
|
|
import collections
|
|
import os
|
|
import tiktoken
|
|
import threading
|
|
from functools import wraps
|
|
from itertools import repeat
|
|
from datetime import datetime, timedelta
|
|
|
|
def str_list(list, precision=4) -> str:
|
|
"""
|
|
convert a list of string/number to a string;
|
|
|
|
to show the list in the way what we see it in the code
|
|
|
|
if string, add '' around it; if number, do nothing
|
|
|
|
Example:
|
|
::
|
|
|
|
str_list(['a', 2, '3']) -> "['a', 2, '3']"
|
|
"""
|
|
if len(list) == 0:
|
|
return '[]'
|
|
str_list = '['
|
|
for i in list:
|
|
if isinstance(i, str):
|
|
str_list += "'%s', " % (i)
|
|
elif isinstance(i, int): # number
|
|
str_list += "%d, " % (i)
|
|
else: # number
|
|
str_list += "%.*f, " % (precision, i)
|
|
str_list = str_list[:-2] + ']'
|
|
return str_list
|
|
|
|
|
|
###################### decorators ######################
|
|
def print_time(en=True):
|
|
"""
|
|
print the running time of a function
|
|
|
|
For example:
|
|
::
|
|
|
|
@print_time()
|
|
def forward(self, input):
|
|
return self.top_level(input)
|
|
"""
|
|
def decorator_nopara(func):
|
|
def wrapper(*args, **kwargs):
|
|
if en:
|
|
old_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
func_name = str(func).split(' ')[1]
|
|
run_time = time.time() - old_time
|
|
print('{} use time: {}s'.format(func_name, run_time))
|
|
else:
|
|
result = func(*args, **kwargs)
|
|
return result
|
|
return wrapper
|
|
return decorator_nopara
|
|
|
|
|
|
def raise_error(func):
|
|
"""
|
|
decorator
|
|
|
|
raise error after a function
|
|
"""
|
|
def wrapper(*args, **kwargs):
|
|
func(*args, **kwargs)
|
|
raise Exception('this error is raised by debug decorator "raise_error"')
|
|
return wrapper
|
|
|
|
##########################################################
|
|
class Timer:
|
|
"""
|
|
print the running time of a code block
|
|
|
|
Args:
|
|
|
|
- code_name (str): the name of the code block; default: None
|
|
|
|
- print_en (bool): whether to print the running time; default: True
|
|
|
|
Example 1 (print time on the console):
|
|
::
|
|
|
|
with Timer('test') as t:
|
|
loss.backward() # some code
|
|
# this will print 'test: time cost = 0.1s' on the console
|
|
|
|
Example 2 (get time of a code block):
|
|
::
|
|
|
|
with Timer(print_en=False) as t:
|
|
loss.backward() # some code
|
|
time_cost = t.interval # time_cost = 0.1
|
|
"""
|
|
def __init__(self, code_name=None, print_en=True):
|
|
self.code_name = code_name
|
|
self.print_en = print_en
|
|
|
|
def __enter__(self):
|
|
self.start = time.time()
|
|
return self
|
|
|
|
def __exit__(self , *args):
|
|
self.end = time.time()
|
|
self.interval_time = self.end - self.start
|
|
print_line = 'time cost = %.4fs'%(self.interval_time)
|
|
if self.code_name is not None:
|
|
print_line = self.code_name + ': ' + print_line
|
|
if self.print_en:
|
|
print(print_line)
|
|
self.print_line = print_line
|
|
|
|
@property
|
|
def interval(self):
|
|
return self.interval_time
|
|
|
|
@property
|
|
def name(self):
|
|
return self.code_name
|
|
|
|
@property
|
|
def info(self):
|
|
return self.print_line
|
|
|
|
@property
|
|
def message(self):
|
|
return self.print_line
|
|
|
|
def get_time(compact=False):
|
|
"""
|
|
get the string of current time, format: '%H:%M:%S %Y-%m-%d'
|
|
"""
|
|
if compact:
|
|
return get_time_compact()
|
|
else:
|
|
return time.strftime('%H:%M:%S %Y-%m-%d', time.localtime(time.time()))
|
|
|
|
def get_time_compact():
|
|
now = datetime.now()
|
|
time_str = now.strftime("%Y%m%d_%H%M%S")
|
|
return time_str
|
|
|
|
class run_in_dir:
|
|
"""
|
|
change the current directory to a new directory, and then change it back after the code block
|
|
|
|
Args:
|
|
dir (str): the new directory (relative path to the current directory)
|
|
"""
|
|
def __init__(self, dir):
|
|
self.new_dir_relative = dir
|
|
def __enter__(self):
|
|
self.old_dir = os.getcwd()
|
|
self.new_dir = os.path.join(self.old_dir, self.new_dir_relative)
|
|
os.chdir(self.new_dir)
|
|
def __exit__(self, *args):
|
|
os.chdir(self.old_dir)
|
|
|
|
################# utils from pytorch ###############
|
|
def _ntuple(n, name="parse"):
|
|
def parse(x):
|
|
if isinstance(x, collections.abc.Iterable):
|
|
return tuple(x)
|
|
return tuple(repeat(x, n))
|
|
|
|
parse.__name__ = name
|
|
return parse
|
|
|
|
_single = _ntuple(1, "_single")
|
|
_pair = _ntuple(2, "_pair")
|
|
_triple = _ntuple(3, "_triple")
|
|
_quadruple = _ntuple(4, "_quadruple")
|
|
|
|
|
|
################# some tools #################
|
|
def clean_wave_vcd(clean_dir, cnt_en=False):
|
|
"""
|
|
remove all the "wave.vcd" files in the directory
|
|
"""
|
|
cnt = 0
|
|
for root, dirs, files in os.walk(clean_dir):
|
|
for file in files:
|
|
# must be strictly equal to "wave.vcd"
|
|
if file == "wave.vcd":
|
|
os.remove(os.path.join(root, file))
|
|
if cnt_en:
|
|
cnt += 1
|
|
if cnt % 100 == 0:
|
|
print("%d files cleaned" % (cnt))
|
|
|
|
def get_week_range(start_day:str|int="Monday", today=None)->str:
|
|
"""
|
|
- function:
|
|
- return the week range of the current week, the start day can be any day of the week
|
|
- for example, if today is 20240807, which is wednesday, if the start_day is "Monday", the output will be "0805~0811"; if the start day is "Tuesday", the output will be "0806~0812"; if the start day is "Thursday", the output will be "0801~0807"
|
|
- input:
|
|
- start_day: the start day of the week, can be a string or an integer
|
|
- string: the name of the day, for example, "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"
|
|
- integer: the index of the day, 0 is Monday, 1 is Tuesday, 2 is Wednesday, 3 is Thursday, 4 is Friday, 5 is Saturday, 6 is Sunday, invalid index will be mod 7
|
|
- today: the date of the day, if None, the current date will be used;
|
|
- formart: "%Y%m%d", e.g. "20240807"
|
|
"""
|
|
weekday_map = {"Monday": 0, "Tuesday": 1, "Wednesday": 2, "Thursday": 3, "Friday": 4, "Saturday": 5, "Sunday": 6}
|
|
start_day = weekday_map[start_day] if isinstance(start_day, str) else start_day % 7
|
|
# Get the current date
|
|
# today = datetime.today()
|
|
if today is None:
|
|
today = datetime.today()
|
|
else:
|
|
today = datetime.strptime(today, "%Y%m%d")
|
|
|
|
# Calculate the current day of the week (0 is Monday, 6 is Sunday)
|
|
current_weekday = today.weekday()
|
|
|
|
# Calculate the number of days to subtract to get to the start day
|
|
days_to_subtract = (current_weekday - start_day) % 7
|
|
|
|
# Subtract the days to get to the start day
|
|
start = today - timedelta(days=days_to_subtract)
|
|
|
|
# the output format is like "0805~0811"
|
|
end = start + timedelta(days=6)
|
|
return start.strftime("%m%d") + "~" + end.strftime("%m%d")
|
|
|
|
def run_with_timeout(timeout):
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
# Define a thread target function to run the target code
|
|
def target(result):
|
|
try:
|
|
result.append(func(*args, **kwargs))
|
|
except Exception as e:
|
|
result.append(e)
|
|
|
|
# List used to store the function result
|
|
result = []
|
|
# Create a thread
|
|
thread = threading.Thread(target=target, args=(result,))
|
|
# Start the thread
|
|
thread.start()
|
|
# Wait for the thread to complete, with a specified timeout
|
|
thread.join(timeout)
|
|
|
|
# If the thread is still alive, it timed out
|
|
if thread.is_alive():
|
|
raise TimeoutError(f"Function call timed out after {timeout} seconds")
|
|
|
|
# If the thread finished, check if there was an exception
|
|
if isinstance(result[0], Exception):
|
|
raise result[0]
|
|
|
|
# Return the function result
|
|
return result[0]
|
|
|
|
return wrapper
|
|
return decorator |