Shortcuts

Program Listing for File NPUStream.h

Return to documentation for file (csrc/backend/NPUStream.h)

#pragma once

#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/util/SmallVector.h>
#include <cstdint>
#include <mutex>

#include "csrc/aten/generated/NPUNativeFunctions.h"
#include "csrc/core/Macros.h"

// TODO(FFFrog):
// Remove later
#include "acl/include/acl/acl.h"
#include "acl/include/acl/acl_op.h"
#include "core/NPUException.h"

/*
 * Stream pool note.
 *
 * A NPUStream is an abstraction of an actual cuStream on the GPU. NPUStreams
 * are backed by cuStreams, but they use several pools to minimize the costs
 * associated with creating, retaining, and destroying cuStreams.
 *
 * There are three pools per device, and a device's pools are lazily created.
 *
 * The first pool contains only the default stream. When the default stream
 * is requested it's returned.
 *
 * The second pool is the "low priority" or "default priority" streams. In
 * HIP builds there is no distinction between streams in this pool and streams
 * in the third pool (below). There are 32 of these streams per device, and
 * when a stream is requested one of these streams is returned round-robin.
 * That is, the first stream requested is at index 0, the second at index 1...
 * to index 31, then index 0 again.
 *
 * This means that if 33 low priority streams are requested, the first and
 * last streams requested are actually the same stream (under the covers)
 * and kernels enqueued on them cannot run concurrently.
 *
 * The third pool is the "high priority" streams. The third pool acts like
 * the second pool except the streams are created with a higher priority.
 *
 * These pools suggest that stream users should prefer many short-lived streams,
 * as the cost of acquiring and releasing streams is effectively zero. If
 * many longer-lived streams are required in performance critical scenarios
 * then the functionality here may need to be extended to allow, for example,
 * "reserving" a subset of the pool so that other streams do not accidentally
 * overlap the performance critical streams.
 *
 * Note: although the notion of "current stream for device" is thread local
 * (every OS thread has a separate current stream, as one might expect),
 * the stream pool is global across all threads; stream 0 is always stream 0
 * no matter which thread you use it on.  Multiple threads can synchronize
 * on the same stream.  Although the NPU documentation is not very clear
 * on the matter, streams are thread safe; e.g., it is safe to enqueue
 * a kernel on the same stream from two different threads.
 */

namespace c10::backend {

static constexpr int max_compile_time_stream_priorities = 2;

// Value object representing a NPU stream.  This is just a wrapper
// around c10::Stream, but it comes with a little extra NPU-specific
// functionality (conversion to aclrtStream), and a guarantee that
// the wrapped c10::Stream really is a NPU stream.
class C10_BACKEND_API NPUStream {
 public:
  enum Unchecked { UNCHECKED };

  explicit NPUStream(c10::Stream stream) : stream_(stream) {
    TORCH_CHECK(
        stream_.device_type() == c10::DeviceType::PrivateUse1,
        PTA_ERROR(ErrCode::TYPE));
  }

  explicit NPUStream(Unchecked, c10::Stream stream) : stream_(stream) {}

  bool operator==(const NPUStream& other) const noexcept {
    return unwrap() == other.unwrap();
  }

  bool operator!=(const NPUStream& other) const noexcept {
    return unwrap() != other.unwrap();
  }

  // Implicit conversion to aclrtStream.
  operator aclrtStream() const {
    return stream();
  }

  // Implicit conversion to pytorch Stream. (a.k.a., forget that the stream is a
  operator c10::Stream() const {
    return unwrap();
  }

  // Used to avoid baking in device type explicitly to Python-side API.
  c10::DeviceType device_type() const {
    return c10::DeviceType::PrivateUse1;
  }

  // Get the NPU device index that this stream is associated with.
  c10::DeviceIndex device_index() const {
    return stream_.device_index();
  }

  // Get the full Device that this stream is associated with.  The Device
  // is guaranteed to be a NPU device.
  c10::Device device() const {
    return c10::Device(c10::DeviceType::PrivateUse1, device_index());
  }

  c10::StreamId id() const {
    return stream_.id();
  }

  bool query() const {
    c10::DeviceGuard guard{stream_.device()};
    aclrtStreamStatus status = ACL_STREAM_STATUS_RESERVED;
    NPU_CHECK_ERROR(aclrtStreamQuery(stream(), &status));
    if (status == ACL_STREAM_STATUS_COMPLETE) {
      return true;
    }
    return false;
  }

  void synchronize() const {
    c10::DeviceGuard guard{stream_.device()};
    NPU_CHECK_ERROR(aclrtSynchronizeStreamWithTimeout(stream(), -1));
  }

  // Explicit conversion to aclrtStream.
  aclrtStream stream() const;

  // Explicit conversion to Stream.
  c10::Stream unwrap() const {
    return stream_;
  }

  struct c10::StreamData3 pack3() const {
    return stream_.pack3();
  }

  // Unpack a NPUStream from the 3 fields generated by pack().
  static NPUStream unpack3(
      c10::StreamId stream_id,
      c10::DeviceIndex device_index,
      c10::DeviceType device_type) {
    return NPUStream(
        c10::Stream::unpack3(stream_id, device_index, device_type));
  }

 private:
  c10::Stream stream_;
};

C10_BACKEND_API NPUStream getStreamFromPool(
    const bool isHighPriority = false,
    c10::DeviceIndex device = -1);

C10_BACKEND_API NPUStream
getStreamFromPool(const int priority, c10::DeviceIndex device = -1);

C10_BACKEND_API NPUStream
getStreamFromExternal(aclrtStream ext_stream, c10::DeviceIndex device_index);

C10_BACKEND_API NPUStream
getDefaultNPUStream(c10::DeviceIndex device_index = -1);

C10_BACKEND_API NPUStream
getCurrentNPUStream(c10::DeviceIndex device_index = -1);

C10_BACKEND_API void setCurrentNPUStream(NPUStream stream);

std::ostream& operator<<(std::ostream& stream, const NPUStream& s);

aclError DestroyUsedStreams();
} // namespace c10::backend

namespace std {
template <>
struct hash<c10::backend::NPUStream> {
  size_t operator()(c10::backend::NPUStream s) const noexcept {
    return std::hash<c10::Stream>{}(s.unwrap());
  }
};
} // namespace std

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources