Library
Module
Module type
Parameter
Class
Class type
val train_batch :
?device:Device.t ->
?augmentation:[ `flip | `crop_with_pad of int | `cutout of int ] list ->
t ->
batch_size:int ->
batch_idx:int ->
Tensor.t * Tensor.t
train_batch ?device ?augmentation t ~batch_size ~batch_idx
returns two tensors corresponding to a training batch. The first tensor is for images and the second for labels. Each batch has a first dimension of size batch_size
, batch_idx
= 0 returns the first batch, batch_idx
= 1 returns the second one and so on. The tensors are located on device
if provided. Random data augmentation is performed as specified via augmentation
.
val batch_accuracy :
?device:Device.t ->
?samples:int ->
t ->
[ `test | `train ] ->
batch_size:int ->
predict:(Tensor.t -> Tensor.t) ->
float
batch_accuracy ?device ?samples t test_or_train ~batch_size ~predict
computes the accuracy of applying predict
to test or train images as specified by test_or_train
. Computations are done using batch of length batch_size
.
read_with_cache ~cache_file ~read
either returns the content of cache_file
if present or regenerate the file using read
if not.
val batches_per_epoch : t -> batch_size:int -> int
batches_per_epoch t ~batch_size
returns the total number of batches of size batch_size
in t
.
val iter :
?device:Device.t ->
?augmentation:[ `flip | `crop_with_pad of int | `cutout of int ] list ->
?shuffle:bool ->
t ->
f:(int -> batch_images:Tensor.t -> batch_labels:Tensor.t -> unit) ->
batch_size:int ->
unit
iter ?device ?augmentation ?shuffle t ~f ~batch_size
iterates function f
on all the batches from t
taken with a size batch_size
. Random shuffling and augmentation can be specified.
random_flip t
applies some random flips to a tensor of dimension N; H; C; W
. The last dimension related to width can be flipped.
random_crop t ~pad
performs some data augmentation by padding t
with zeros on the two last dimensions with pad
new values on each side, then performs some random crop to go back to the original shape.
val print_summary : t -> unit
val read_char_tensor : string -> Tensor.t
read_char_tensor filename
returns a tensor of char containing the specified file.