package torch

  1. Overview
  2. Docs
type t = {
  1. train_images : Tensor.t;
  2. train_labels : Tensor.t;
  3. test_images : Tensor.t;
  4. test_labels : Tensor.t;
}
val train_batch : ?device:Device.t -> ?augmentation:[ `flip | `crop_with_pad of int | `cutout of int ] list -> t -> batch_size:int -> batch_idx:int -> Tensor.t * Tensor.t

train_batch ?device ?augmentation t ~batch_size ~batch_idx returns two tensors corresponding to a training batch. The first tensor is for images and the second for labels. Each batch has a first dimension of size batch_size, batch_idx = 0 returns the first batch, batch_idx = 1 returns the second one and so on. The tensors are located on device if provided. Random data augmentation is performed as specified via augmentation.

val batch_accuracy : ?device:Device.t -> ?samples:int -> t -> [ `test | `train ] -> batch_size:int -> predict:(Tensor.t -> Tensor.t) -> float

batch_accuracy ?device ?samples t test_or_train ~batch_size ~predict computes the accuracy of applying predict to test or train images as specified by test_or_train. Computations are done using batch of length batch_size.

val read_with_cache : cache_file:string -> read:(unit -> t) -> t

read_with_cache ~cache_file ~read either returns the content of cache_file if present or regenerate the file using read if not.

val batches_per_epoch : t -> batch_size:int -> int

batches_per_epoch t ~batch_size returns the total number of batches of size batch_size in t.

val iter : ?device:Device.t -> ?augmentation:[ `flip | `crop_with_pad of int | `cutout of int ] list -> ?shuffle:bool -> t -> f:(int -> batch_images:Tensor.t -> batch_labels:Tensor.t -> unit) -> batch_size:int -> unit

iter ?device ?augmentation ?shuffle t ~f ~batch_size iterates function f on all the batches from t taken with a size batch_size. Random shuffling and augmentation can be specified.

val random_flip : Tensor.t -> Tensor.t

random_flip t applies some random flips to a tensor of dimension N; H; C; W. The last dimension related to width can be flipped.

val random_crop : Tensor.t -> pad:int -> Tensor.t

random_crop t ~pad performs some data augmentation by padding t with zeros on the two last dimensions with pad new values on each side, then performs some random crop to go back to the original shape.

val shuffle : t -> t

shuffle t returns t where training images and labels have been shuffled.

val map : ?device:Device.t -> t -> f: (int -> batch_images:Tensor.t -> batch_labels:Tensor.t -> Tensor.t * Tensor.t) -> batch_size:int -> t
val print_summary : t -> unit
val read_char_tensor : string -> Tensor.t

read_char_tensor filename returns a tensor of char containing the specified file.