package arrayjit

  1. Overview
  2. Docs
type symbol =
  1. | Symbol of Base.int
val compare_symbol : symbol -> symbol -> Base.int
val equal_symbol : symbol -> symbol -> Base.bool
val symbol_of_sexp : Sexplib0.Sexp.t -> symbol
val sexp_of_symbol : symbol -> Sexplib0.Sexp.t
val hash_fold_symbol : Ppx_hash_lib.Std.Hash.state -> symbol -> Ppx_hash_lib.Std.Hash.state
val hash_symbol : symbol -> Ppx_hash_lib.Std.Hash.hash_value
val symbol : Base.int -> symbol
val is_symbol : symbol -> bool
val symbol_val : symbol -> Base.int Stdlib.Option.t
module Variants_of_symbol : sig ... end
val unique_id : int Base.ref
val get_symbol : unit -> symbol
module CompareSymbol : sig ... end
module Symbol : sig ... end
val symbol_ident : symbol -> Base.String.t
type 'a environment = 'a Base.Map.M(Symbol).t
val environment_of_sexp : 'a. (Sexplib0.Sexp.t -> 'a) -> Sexplib0.Sexp.t -> 'a environment
val sexp_of_environment : 'a. ('a -> Sexplib0.Sexp.t) -> 'a environment -> Sexplib0.Sexp.t
val empty_env : 'a environment
type static_symbol = {
  1. static_symbol : symbol;
  2. mutable static_range : Base.int Base.option;
}
val compare_static_symbol : static_symbol -> static_symbol -> Base.int
val equal_static_symbol : static_symbol -> static_symbol -> Base.bool
val static_symbol_of_sexp : Sexplib0.Sexp.t -> static_symbol
val sexp_of_static_symbol : static_symbol -> Sexplib0.Sexp.t
val hash_fold_static_symbol : Ppx_hash_lib.Std.Hash.state -> static_symbol -> Ppx_hash_lib.Std.Hash.state
val hash_static_symbol : static_symbol -> Ppx_hash_lib.Std.Hash.hash_value
type 'a bindings =
  1. | Empty
  2. | Bind of static_symbol * (Base.int -> 'a) bindings
val sexp_of_bindings : 'a. ('a -> Sexplib0.Sexp.t) -> 'a bindings -> Sexplib0.Sexp.t
val bound_symbols : 'a bindings -> static_symbol Base.List.t
type ('r, 'idcs, 'p1, 'p2) variadic =
  1. | Result of 'r
  2. | Param_idx of Base.int Base.ref * (Base.int -> 'r, Base.int -> 'idcs, 'p1, 'p2) variadic
  3. | Param_1 of 'p1 Base.option Base.ref * ('p1 -> 'r, 'idcs, 'p1, 'p2) variadic
  4. | Param_2 of 'p2 Base.option Base.ref * ('p2 -> 'r, 'idcs, 'p1, 'p2) variadic

Helps jitting the bindings.

type unit_bindings = (Base.unit -> Base.unit) bindings
val sexp_of_unit_bindings : unit_bindings -> Sexplib0.Sexp.t
type jitted_bindings = (static_symbol, Base.int Base.ref) Base.List.Assoc.t
val sexp_of_jitted_bindings : jitted_bindings -> Sexplib0.Sexp.t
val apply : 'r 'idcs 'p1 'p2. ('r, 'idcs, 'p1, 'p2) variadic -> 'r
val jitted_bindings : 'a bindings -> ('b, 'a, 'p1, 'p2) variadic -> (static_symbol * Base.int Base.ref) Base.List.t
val find_exn : jitted_bindings -> static_symbol -> Base.int Base.ref
val get_static_symbol : ?static_range:Base.int -> (Base.int -> 'a) bindings -> static_symbol * 'a bindings
val dims_to_string : ?with_axis_numbers:bool -> Base.Int.t Base.Array.t -> Base.String.t

Dimensions to string, "x"-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2. Outputs "-" for empty dimensions.

type axis_index =
  1. | Fixed_idx of Base.int
    (*

    The specific position along an axis.

    *)
  2. | Iterator of symbol
    (*

    The given member of the product_space corresponding to some product_iterators.

    *)
val compare_axis_index : axis_index -> axis_index -> Base.int
val equal_axis_index : axis_index -> axis_index -> Base.bool
val axis_index_of_sexp : Sexplib0.Sexp.t -> axis_index
val sexp_of_axis_index : axis_index -> Sexplib0.Sexp.t
val fixed_idx : Base.int -> axis_index
val iterator : symbol -> axis_index
val is_fixed_idx : axis_index -> bool
val is_iterator : axis_index -> bool
val fixed_idx_val : axis_index -> Base.int Stdlib.Option.t
val iterator_val : axis_index -> symbol Stdlib.Option.t
module Variants_of_axis_index : sig ... end
type str_osym_map = (Base.string, symbol Base.option, Base.String.comparator_witness) Base.Map.t
val sexp_of_str_osym_map : str_osym_map -> Base.Sexp.t
type projections_debug = {
  1. spec : Base.string;
  2. derived_for : Base.Sexp.t;
  3. trace : (Base.string * Base.int) Base.list;
}
val projections_debug_of_sexp : Sexplib0.Sexp.t -> projections_debug
val sexp_of_projections_debug : projections_debug -> Sexplib0.Sexp.t
val unique_debug_id : unit -> int
type projections = {
  1. product_space : Base.int Base.array;
    (*

    The product space dimensions that an operation should parallelize (map-reduce) over.

    *)
  2. lhs_dims : Base.int Base.array;
    (*

    The dimensions of the LHS array.

    *)
  3. rhs_dims : Base.int Base.array Base.array;
    (*

    The dimensions of the RHS arrays, needed for deriving projections from other projections.

    *)
  4. product_iterators : symbol Base.array;
    (*

    The product space iterators (concatentation of the relevant batch, output, input axes) for iterating over the product_space axes, where same axes are at same array indices.

    *)
  5. project_lhs : axis_index Base.array;
    (*

    A projection that takes an product_space-bound index and produces an index into the result of an operation.

    *)
  6. project_rhs : axis_index Base.array Base.array;
    (*

    project_rhs.(i) Produces an index into the i+1th argument of an operation.

    *)
  7. debug_info : projections_debug;
}

All the information relevant for code generation.

val compare_projections : projections -> projections -> Base.int
val equal_projections : projections -> projections -> Base.bool
val projections_of_sexp : Sexplib0.Sexp.t -> projections
val sexp_of_projections : projections -> Sexplib0.Sexp.t
val iterated : int -> bool
val opt_symbol : int -> symbol option
val opt_iterator : symbol option -> axis_index
val is_bijective : projections -> bool
val identity_projections : debug_info:projections_debug -> lhs_dims:Base.int Base.Array.t -> projections

Projections for a pointwise unary operator.

val derive_index : product_syms:Symbol.t Base.Array.t -> projection:axis_index Base.array -> product:axis_index Base.Array.t -> axis_index Base.Array.t
OCaml

Innovation. Community. Security.