Comments (1)
I add two params in function read_video, [clip_idx, num_frame], to chage audio start_pts and end_pts.
torchvision\io\video.py
def read_video(
filename: str,
clip_idx: int,
num_frame: int,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames and the audio frames
Args:
filename (str): path to the video file
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The start presentation time of the video
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns:
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video)
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
from torchvision import get_video_backend
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
if get_video_backend() != "pyav":
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
else:
_check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
video_frames = []
audio_frames = []
audio_timebase = _video_opt.default_timebase
try:
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
start_audio_pts = int(clip_idx / math.ceil(container.streams.video[0].average_rate) * math.floor(container.streams.audio[0].rate))
end_audio_pts = int((num_frame + clip_idx) / math.ceil(container.streams.video[0].average_rate) * math.floor(container.streams.audio[0].rate))
audio_frames = _read_from_stream(
container,
start_audio_pts,
end_audio_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except av.AVError:
# TODO raise a warning?
pass
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]
if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
# if pts_unit == "sec":
# start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
# if end_pts != float("inf"):
# end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
# aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
if pts_unit == "sec":
start_audio_pts = int(math.floor(start_audio_pts * (1 / audio_timebase)))
if end_audio_pts != float("inf"):
end_audio_pts = int(math.ceil(end_audio_pts * (1 / audio_timebase)))
aframes = _align_audio_frames(aframes, audio_frames, start_audio_pts, end_audio_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info
torchvision\datasets\video_utils.py
def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
"""
Gets a subclip from a list of videos.
Args:
idx (int): index of the subclip. Must be between 0 and num_clips().
Returns:
video (Tensor)
audio (Tensor)
info (Dict)
video_idx (int): index of the video in `video_paths`
"""
if idx >= self.num_clips():
raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]
num_frame = len(clip_pts)
from torchvision import get_video_backend
backend = get_video_backend()
if backend == "pyav":
# check for invalid options
if self._video_width != 0:
raise ValueError("pyav backend doesn't support _video_width != 0")
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
if self._video_max_dimension != 0:
raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")
if backend == "pyav":
start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, clip_idx* self.frames_between_clips, num_frame, start_pts, end_pts)
from vision.
Related Issues (20)
- RandomPhotometricDistort has undocumented channel shuffle feature HOT 1
- Runaway mask_loss for MaskRCNN when using non-binary mask. HOT 2
- Inconsistent Behavior with transforms.v2 for Multiple Arguments HOT 1
- adjust_hue broken on ARM64 HOT 1
- [Bug] Unable to build documentation on local machine HOT 2
- procrustes alignment for pytorch HOT 3
- Windows CUDA unittests jobs are failing
- Cant find nms function in code? HOT 1
- encode_jpeg generates noise when processing 4k image HOT 5
- size mismatch for rpn HOT 4
- SetupTools update breaks Vision Nightly aarch64 builds are failing for CPU and GPU HOT 3
- prioritize batching for torchvision::nms HOT 4
- ImportError: cannot import name 'datapoints' from 'torchvision.prototype' HOT 1
- crossvit vs vision transformer HOT 2
- GPU accelerated video loading with optimizations for reading at specific timestamps or time intervals HOT 1
- add typing to torchvision.models.detection.faster_rcnn HOT 2
- ConvertImageDtype not converting properly from uint8 HOT 3
- loss_box_reg increasing while training mask rcnn HOT 1
- Torch Load Warning causing test suite to fail. HOT 5
- Prebuilt .whl and .conda of torchvision for aarch64 + cuda
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vision.