package torch

  1. Overview
  2. Docs
module Tensor_id : sig ... end
type t

A VarStore is used to store all the variables used by a given model. The model creates variables by calling Var_store.new_var for which it has to provide a name. Var_store.sub creates a sub-directory in the var store which is useful to group some variables together.

val create : ?frozen:Base.bool -> ?device:Device.t -> name:Base.string -> Base.unit -> t

create ?frozen ?device ~name () creates a new variable store on the specified device (defaulting to cpu).

val sub : t -> Base.string -> t

sub t subname returns a var-store corresponding to path subname in t.

val subi : t -> Base.int -> t

subi is similar to sub but uses an integer for the subname.

val (/) : t -> Base.string -> t

Same as sub.

val (//) : t -> Base.int -> t

Same as subi.

val num_trainable_vars : t -> Base.int

num_trainable_vars t returns the number trainable variables stored in t.

val iter_trainable_vars : t -> f:(Tensor_id.t -> Tensor.t -> Base.unit) -> Base.unit

iter_trainable_vars t ~f applies f to all trainable variables stored in t.

val all_vars : t -> (Base.string * Tensor.t) Base.list

all_vars t returns all the variables stored in t.

val name : t -> Base.string

name t returns the var-store name.

val device : t -> Device.t

device t returns the device used to store variables hold by t.

module Init : sig ... end
val new_var : ?trainable:Base.bool -> t -> shape:Base.int Base.list -> init:Init.t -> name:Base.string -> Tensor.t

new_var ?trainable t ~shape ~init ~name creates a new variable in t with shape shape and the specified initialization. The tensor associated with the variable is returned.

val new_var_copy : ?trainable:Base.bool -> t -> src:Tensor.t -> name:Base.string -> Tensor.t

new_var_copy ?trainable t ~src ~name creates a new variable in t by copying tensor src, so using the same shape, element kind and element values.

val freeze : t -> Base.unit

freeze t freezes all variables in the var-store, none of the variables are trainable anymore and their gradients are not tracked.

val unfreeze : t -> Base.unit

unfreeze t unfreezes the 'trainable' variables from t.

val copy : src:t -> dst:t -> Base.unit