Table Batched Embedding Operators¶
- 
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>> get_unique_indices_cuda(const at::Tensor &linear_indices, const int64_t max_indices, const bool compute_count)¶
- Deduplicate indices. 
- 
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> get_unique_indices_with_inverse_cuda(const at::Tensor &linear_indices, const int64_t max_indices, const bool compute_count, const bool compute_inverse_indices)¶
- Deduplicate indices. 
- 
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>> lru_cache_find_uncached_cuda(at::Tensor unique_indices, at::Tensor unique_indices_length, int64_t max_indices, at::Tensor lxu_cache_state, int64_t time_stamp, at::Tensor lru_state, bool gather_cache_stats, at::Tensor uvm_cache_stats, bool lock_cache_line, at::Tensor lxu_cache_locking_counter, const bool compute_inverse_indices)¶
- Lookup LRU cache to find uncached indices, and then sort them based on the set. 
- 
int64_t host_lxu_cache_slot(int64_t h_in, int64_t C)¶
- Map index to cache_set. h_in: linear_indices; C: #cache_sets. 
- 
at::Tensor linearize_cache_indices_cuda(const at::Tensor &cache_hash_size_cumsum, const at::Tensor &indices, const at::Tensor &offsets, const std::optional<at::Tensor> &B_offsets, const int64_t max_B, const int64_t indices_base_offset)¶
- Linearize the indices of all tables to make it be unique 
- 
at::Tensor linearize_cache_indices_from_row_idx_cuda(at::Tensor cache_hash_size_cumsum, at::Tensor update_table_indices, at::Tensor update_row_indices)¶
- Linearize the indices of all tables to make it be unique. Note the update_table_indices and update_row_indices are from the row indices format for inplace update. 
- 
at::Tensor direct_mapped_lxu_cache_lookup_cuda(at::Tensor linear_cache_indices, at::Tensor lxu_cache_state, int64_t invalid_index, bool gather_cache_stats, std::optional<at::Tensor> uvm_cache_stats)¶
- LRU cache: fetch the rows corresponding to - linear_cache_indicesfrom- weights, and insert them into the cache at timestep- time_stamp- . void lru_cache_populate_cuda( - at::Tensor weights, - at::Tensor hash_size_cumsum, - int64_t total_cache_hash_size, - at::Tensor cache_index_table_map, - at::Tensor weights_offsets, - at::Tensor D_offsets, - at::Tensor linear_cache_indices, - at::Tensor lxu_cache_state, - at::Tensor lxu_cache_weights, - int64_t time_stamp, - at::Tensor lru_state, - bool stochastic_rounding, - bool gather_cache_stats, - std::optional<at::Tensor> uvm_cache_stats, - bool lock_cache_line, - std::optional<at::Tensor> lxu_cache_locking_counter); - / / LRU cache: fetch the rows corresponding to - linear_cache_indicesfrom /- weights, and insert them into the cache at timestep- time_stamp- . / weights and lxu_cache_weights have “uint8_t” byte elements void lru_cache_populate_byte_cuda( - at::Tensor weights, - at::Tensor hash_size_cumsum, - int64_t total_cache_hash_size, - at::Tensor cache_index_table_map, - at::Tensor weights_offsets, - at::Tensor weights_tys, - at::Tensor D_offsets, - at::Tensor linear_cache_indices, - at::Tensor lxu_cache_state, - at::Tensor lxu_cache_weights, - int64_t time_stamp, - at::Tensor lru_state, - int64_t row_alignment, - bool gather_cache_stats, - std::optional<at::Tensor> uvm_cache_stats); - / / Direct-mapped (assoc=1) variant of lru_cache_populate_byte_cuda void direct_mapped_lru_cache_populate_byte_cuda( - at::Tensor weights, - at::Tensor hash_size_cumsum, - int64_t total_cache_hash_size, - at::Tensor cache_index_table_map, - at::Tensor weights_offsets, - at::Tensor weights_tys, - at::Tensor D_offsets, - at::Tensor linear_cache_indices, - at::Tensor lxu_cache_state, - at::Tensor lxu_cache_weights, - int64_t time_stamp, - at::Tensor lru_state, - at::Tensor lxu_cache_miss_timestamp, - int64_t row_alignment, - bool gather_cache_stats, - std::optional<at::Tensor> uvm_cache_stats); - / / LFU cache: fetch the rows corresponding to - linear_cache_indicesfrom /- weights- , and insert them into the cache. void lfu_cache_populate_cuda( - at::Tensor weights, - at::Tensor cache_hash_size_cumsum, - int64_t total_cache_hash_size, - at::Tensor cache_index_table_map, - at::Tensor weights_offsets, - at::Tensor D_offsets, - at::Tensor linear_cache_indices, - at::Tensor lxu_cache_state, - at::Tensor lxu_cache_weights, - at::Tensor lfu_state, - bool stochastic_rounding); - / / LFU cache: fetch the rows corresponding to - linear_cache_indicesfrom /- weights- , and insert them into the cache. / weights and lxu_cache_weights have “uint8_t” byte elements void lfu_cache_populate_byte_cuda( - at::Tensor weights, - at::Tensor cache_hash_size_cumsum, - int64_t total_cache_hash_size, - at::Tensor cache_index_table_map, - at::Tensor weights_offsets, - at::Tensor weights_tys, - at::Tensor D_offsets, - at::Tensor linear_cache_indices, - at::Tensor lxu_cache_state, - at::Tensor lxu_cache_weights, - at::Tensor lfu_state, - int64_t row_alignment); - / / Lookup the LRU/LFU cache: find the cache weights location for all indices. / Look up the slots in the cache corresponding to - linear_cache_indices- , with / a sentinel value for missing. at::Tensor lxu_cache_lookup_cuda( - at::Tensor linear_cache_indices, - at::Tensor lxu_cache_state, - int64_t invalid_index, - bool gather_cache_stats, - std::optional<at::Tensor> uvm_cache_stats, - std::optional<at::Tensor> num_uniq_cache_indices, - std::optional<at::Tensor> lxu_cache_locations_output); - at::Tensor emulate_cache_miss( - at::Tensor lxu_cache_locations, - const int64_t enforced_misses_per_256, - const bool gather_cache_stats, - at::Tensor uvm_cache_stats); - / / Lookup the LRU/LFU cache: find the cache weights location for all indices. / Look up the slots in the cache corresponding to - linear_cache_indices, with a sentinel value for missing.
- 
void lxu_cache_flush_cuda(at::Tensor uvm_weights, at::Tensor cache_hash_size_cumsum, at::Tensor cache_index_table_map, at::Tensor weights_offsets, at::Tensor D_offsets, int64_t total_D, at::Tensor lxu_cache_state, at::Tensor lxu_cache_weights, bool stochastic_rounding)¶
- Flush the cache: store the weights from the cache to the backing storage. 
- 
void reset_weight_momentum_cuda(at::Tensor dev_weights, at::Tensor uvm_weights, at::Tensor lxu_cache_weights, at::Tensor weights_placements, at::Tensor weights_offsets, at::Tensor momentum1_dev, at::Tensor momentum1_uvm, at::Tensor momentum1_placements, at::Tensor momentum1_offsets, at::Tensor D_offsets, at::Tensor pruned_indices, at::Tensor pruned_indices_offsets, at::Tensor logical_table_ids, at::Tensor buffer_ids, at::Tensor cache_hash_size_cumsum, at::Tensor lxu_cache_state, int64_t total_cache_hash_size)¶