上传所有文件
This commit is contained in:
269
utils/utils.py
Normal file
269
utils/utils.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Description : some utils for project
|
||||
Author : Ruidi Qiu (ruidi.qiu@tum.de)
|
||||
Time : 2022/11/08 13:00:00
|
||||
LastEdited : 2024/9/3 17:33:26
|
||||
"""
|
||||
import time
|
||||
import datetime
|
||||
import collections
|
||||
import os
|
||||
import tiktoken
|
||||
import threading
|
||||
from functools import wraps
|
||||
from itertools import repeat
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
def str_list(list, precision=4) -> str:
|
||||
"""
|
||||
convert a list of string/number to a string;
|
||||
|
||||
to show the list in the way what we see it in the code
|
||||
|
||||
if string, add '' around it; if number, do nothing
|
||||
|
||||
Example:
|
||||
::
|
||||
|
||||
str_list(['a', 2, '3']) -> "['a', 2, '3']"
|
||||
"""
|
||||
if len(list) == 0:
|
||||
return '[]'
|
||||
str_list = '['
|
||||
for i in list:
|
||||
if isinstance(i, str):
|
||||
str_list += "'%s', " % (i)
|
||||
elif isinstance(i, int): # number
|
||||
str_list += "%d, " % (i)
|
||||
else: # number
|
||||
str_list += "%.*f, " % (precision, i)
|
||||
str_list = str_list[:-2] + ']'
|
||||
return str_list
|
||||
|
||||
|
||||
###################### decorators ######################
|
||||
def print_time(en=True):
|
||||
"""
|
||||
print the running time of a function
|
||||
|
||||
For example:
|
||||
::
|
||||
|
||||
@print_time()
|
||||
def forward(self, input):
|
||||
return self.top_level(input)
|
||||
"""
|
||||
def decorator_nopara(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if en:
|
||||
old_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
func_name = str(func).split(' ')[1]
|
||||
run_time = time.time() - old_time
|
||||
print('{} use time: {}s'.format(func_name, run_time))
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator_nopara
|
||||
|
||||
|
||||
def raise_error(func):
|
||||
"""
|
||||
decorator
|
||||
|
||||
raise error after a function
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
raise Exception('this error is raised by debug decorator "raise_error"')
|
||||
return wrapper
|
||||
|
||||
##########################################################
|
||||
class Timer:
|
||||
"""
|
||||
print the running time of a code block
|
||||
|
||||
Args:
|
||||
|
||||
- code_name (str): the name of the code block; default: None
|
||||
|
||||
- print_en (bool): whether to print the running time; default: True
|
||||
|
||||
Example 1 (print time on the console):
|
||||
::
|
||||
|
||||
with Timer('test') as t:
|
||||
loss.backward() # some code
|
||||
# this will print 'test: time cost = 0.1s' on the console
|
||||
|
||||
Example 2 (get time of a code block):
|
||||
::
|
||||
|
||||
with Timer(print_en=False) as t:
|
||||
loss.backward() # some code
|
||||
time_cost = t.interval # time_cost = 0.1
|
||||
"""
|
||||
def __init__(self, code_name=None, print_en=True):
|
||||
self.code_name = code_name
|
||||
self.print_en = print_en
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self , *args):
|
||||
self.end = time.time()
|
||||
self.interval_time = self.end - self.start
|
||||
print_line = 'time cost = %.4fs'%(self.interval_time)
|
||||
if self.code_name is not None:
|
||||
print_line = self.code_name + ': ' + print_line
|
||||
if self.print_en:
|
||||
print(print_line)
|
||||
self.print_line = print_line
|
||||
|
||||
@property
|
||||
def interval(self):
|
||||
return self.interval_time
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.code_name
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
return self.print_line
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self.print_line
|
||||
|
||||
def get_time(compact=False):
|
||||
"""
|
||||
get the string of current time, format: '%H:%M:%S %Y-%m-%d'
|
||||
"""
|
||||
if compact:
|
||||
return get_time_compact()
|
||||
else:
|
||||
return time.strftime('%H:%M:%S %Y-%m-%d', time.localtime(time.time()))
|
||||
|
||||
def get_time_compact():
|
||||
now = datetime.now()
|
||||
time_str = now.strftime("%Y%m%d_%H%M%S")
|
||||
return time_str
|
||||
|
||||
class run_in_dir:
|
||||
"""
|
||||
change the current directory to a new directory, and then change it back after the code block
|
||||
|
||||
Args:
|
||||
dir (str): the new directory (relative path to the current directory)
|
||||
"""
|
||||
def __init__(self, dir):
|
||||
self.new_dir_relative = dir
|
||||
def __enter__(self):
|
||||
self.old_dir = os.getcwd()
|
||||
self.new_dir = os.path.join(self.old_dir, self.new_dir_relative)
|
||||
os.chdir(self.new_dir)
|
||||
def __exit__(self, *args):
|
||||
os.chdir(self.old_dir)
|
||||
|
||||
################# utils from pytorch ###############
|
||||
def _ntuple(n, name="parse"):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
parse.__name__ = name
|
||||
return parse
|
||||
|
||||
_single = _ntuple(1, "_single")
|
||||
_pair = _ntuple(2, "_pair")
|
||||
_triple = _ntuple(3, "_triple")
|
||||
_quadruple = _ntuple(4, "_quadruple")
|
||||
|
||||
|
||||
################# some tools #################
|
||||
def clean_wave_vcd(clean_dir, cnt_en=False):
|
||||
"""
|
||||
remove all the "wave.vcd" files in the directory
|
||||
"""
|
||||
cnt = 0
|
||||
for root, dirs, files in os.walk(clean_dir):
|
||||
for file in files:
|
||||
# must be strictly equal to "wave.vcd"
|
||||
if file == "wave.vcd":
|
||||
os.remove(os.path.join(root, file))
|
||||
if cnt_en:
|
||||
cnt += 1
|
||||
if cnt % 100 == 0:
|
||||
print("%d files cleaned" % (cnt))
|
||||
|
||||
def get_week_range(start_day:str|int="Monday", today=None)->str:
|
||||
"""
|
||||
- function:
|
||||
- return the week range of the current week, the start day can be any day of the week
|
||||
- for example, if today is 20240807, which is wednesday, if the start_day is "Monday", the output will be "0805~0811"; if the start day is "Tuesday", the output will be "0806~0812"; if the start day is "Thursday", the output will be "0801~0807"
|
||||
- input:
|
||||
- start_day: the start day of the week, can be a string or an integer
|
||||
- string: the name of the day, for example, "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"
|
||||
- integer: the index of the day, 0 is Monday, 1 is Tuesday, 2 is Wednesday, 3 is Thursday, 4 is Friday, 5 is Saturday, 6 is Sunday, invalid index will be mod 7
|
||||
- today: the date of the day, if None, the current date will be used;
|
||||
- formart: "%Y%m%d", e.g. "20240807"
|
||||
"""
|
||||
weekday_map = {"Monday": 0, "Tuesday": 1, "Wednesday": 2, "Thursday": 3, "Friday": 4, "Saturday": 5, "Sunday": 6}
|
||||
start_day = weekday_map[start_day] if isinstance(start_day, str) else start_day % 7
|
||||
# Get the current date
|
||||
# today = datetime.today()
|
||||
if today is None:
|
||||
today = datetime.today()
|
||||
else:
|
||||
today = datetime.strptime(today, "%Y%m%d")
|
||||
|
||||
# Calculate the current day of the week (0 is Monday, 6 is Sunday)
|
||||
current_weekday = today.weekday()
|
||||
|
||||
# Calculate the number of days to subtract to get to the start day
|
||||
days_to_subtract = (current_weekday - start_day) % 7
|
||||
|
||||
# Subtract the days to get to the start day
|
||||
start = today - timedelta(days=days_to_subtract)
|
||||
|
||||
# the output format is like "0805~0811"
|
||||
end = start + timedelta(days=6)
|
||||
return start.strftime("%m%d") + "~" + end.strftime("%m%d")
|
||||
|
||||
def run_with_timeout(timeout):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Define a thread target function to run the target code
|
||||
def target(result):
|
||||
try:
|
||||
result.append(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
result.append(e)
|
||||
|
||||
# List used to store the function result
|
||||
result = []
|
||||
# Create a thread
|
||||
thread = threading.Thread(target=target, args=(result,))
|
||||
# Start the thread
|
||||
thread.start()
|
||||
# Wait for the thread to complete, with a specified timeout
|
||||
thread.join(timeout)
|
||||
|
||||
# If the thread is still alive, it timed out
|
||||
if thread.is_alive():
|
||||
raise TimeoutError(f"Function call timed out after {timeout} seconds")
|
||||
|
||||
# If the thread finished, check if there was an exception
|
||||
if isinstance(result[0], Exception):
|
||||
raise result[0]
|
||||
|
||||
# Return the function result
|
||||
return result[0]
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
Reference in New Issue
Block a user