Shortcuts

Program Listing for File NPUFunctions.h

Return to documentation for file (csrc/backend/NPUFunctions.h)

#pragma once

// This header provides C++ wrappers around commonly used AscendCL API
// functions. The benefit of using C++ here is that we can raise an exception in
// the event of an error, rather than explicitly pass around error codes.  This
// leads to more natural APIs.
//
// The naming convention used here matches the naming convention of torch.npu

#include <c10/core/Device.h>
#include <c10/macros/Macros.h>

#include <mutex>
#include <optional>
#include "csrc/backend/NPUDeviceProp.h"
#include "csrc/core/Macros.h"

// TODO(FFFrog):
// Remove later
#include "acl/include/acl/acl.h"
#include "core/NPUException.h"

namespace c10::backend {

C10_BACKEND_API c10::DeviceIndex device_count() noexcept;

// Version of device_count that throws is no devices are detected
C10_BACKEND_API c10::DeviceIndex device_count_ensure_non_zero();

C10_BACKEND_API c10::DeviceIndex current_device();

C10_BACKEND_API void set_device(c10::DeviceIndex device);

C10_BACKEND_API void device_synchronize();

// this function has to be called from callers performing npu synchronizing
// operations, to raise proper error or warning
C10_BACKEND_API void warn_or_error_on_sync();

// Raw CUDA device management functions
C10_BACKEND_API aclError GetDeviceCount(int* dev_count);

C10_BACKEND_API aclError InitDevice();

C10_BACKEND_API void FinalizeDevice();

C10_BACKEND_API aclError GetDevice(c10::DeviceIndex* device);

C10_BACKEND_API aclError SetDevice(c10::DeviceIndex device);

C10_BACKEND_API aclError MaybeSetDevice(c10::DeviceIndex device);

C10_BACKEND_API c10::DeviceIndex ExchangeDevice(c10::DeviceIndex device);

C10_BACKEND_API c10::DeviceIndex MaybeExchangeDevice(c10::DeviceIndex device);

C10_BACKEND_API void SetTargetDevice();

C10_BACKEND_API aclrtContext GetDeviceContext(c10::DeviceIndex device);

enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };

// it's used to store npu synchronization state
// through this global state to determine the synchronization debug mode
class WarningState {
 public:
  void set_sync_debug_mode(SyncDebugMode l) {
    sync_debug_mode = l;
  }

  SyncDebugMode get_sync_debug_mode() {
    return sync_debug_mode;
  }

 private:
  SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
};

C10_BACKEND_API __inline__ WarningState& warning_state() {
  static WarningState warning_state_;
  return warning_state_;
}

C10_BACKEND_API bool hasPrimaryContext(c10::DeviceIndex device_index);
C10_BACKEND_API std::optional<c10::DeviceIndex>
getDeviceIndexWithPrimaryContext();

C10_BACKEND_API std::mutex* getFreeMutex();

C10_BACKEND_API void get_device_properties(
    c10::backend::NPUDeviceProp* device_prop,
    c10::DeviceIndex device);

} // namespace c10::backend

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources