Shortcuts

Program Listing for File NPUGuard.h

Return to documentation for file (csrc/backend/NPUGuard.h)

#pragma once

#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include "csrc/backend/NPUGuardImpl.h"
#include "csrc/core/guard/PrivateUse1Guard.h"

#include <cstddef>

namespace c10::backend {

// This code is kind of boilerplatey.  See Note [Whither the DeviceGuard
// boilerplate]

struct NPUGuard : public Guard::PrivateUse1Guard<impl::NPUGuardImpl> {
  using PrivateUse1Guard = Guard::PrivateUse1Guard<impl::NPUGuardImpl>;
  using PrivateUse1Guard::PrivateUse1Guard;
  // Copy is not allowed
  NPUGuard(const NPUGuard&) = delete;
  NPUGuard& operator=(const NPUGuard&) = delete;

  // Move is not allowed (there is no uninitialized state)
  NPUGuard(NPUGuard&& other) = delete;
  NPUGuard& operator=(NPUGuard&& other) = delete;
};

struct OptionalNPUGuard
    : public Guard::OptionalPrivateUse1Guard<impl::NPUGuardImpl> {
  using OptionalPrivateUse1Guard =
      Guard::OptionalPrivateUse1Guard<impl::NPUGuardImpl>;
  using OptionalPrivateUse1Guard::OptionalPrivateUse1Guard;

  // Copy is not allowed
  OptionalNPUGuard(const OptionalNPUGuard&) = delete;
  OptionalNPUGuard& operator=(const OptionalNPUGuard&) = delete;

  // See Note [Move construction for RAII guards is tricky]
  OptionalNPUGuard(OptionalNPUGuard&& other) = delete;
  // See Note [Move assignment for RAII guards is tricky]
  OptionalNPUGuard& operator=(OptionalNPUGuard&& other) = delete;
};

struct NPUStreamGuard {
  explicit NPUStreamGuard() = delete;

  explicit NPUStreamGuard(c10::Stream stream) : guard_(stream) {}

  NPUStreamGuard(const NPUStreamGuard&) = delete;
  NPUStreamGuard& operator=(const NPUStreamGuard&) = delete;

  NPUStreamGuard(NPUStreamGuard&& other) = delete;
  NPUStreamGuard& operator=(NPUStreamGuard&& other) = delete;

  void reset_stream(c10::Stream stream) {
    guard_.reset_stream(stream);
  }

  NPUStream original_stream() const {
    return NPUStream(NPUStream::UNCHECKED, guard_.original_stream());
  }

  NPUStream current_stream() const {
    return NPUStream(NPUStream::UNCHECKED, guard_.current_stream());
  }

  c10::Device current_device() const {
    return guard_.current_device();
  }

  c10::Device original_device() const {
    return guard_.original_device();
  }

 private:
  c10::impl::InlineStreamGuard<c10::backend::impl::NPUGuardImpl> guard_;
};

struct OptionalNPUStreamGuard {
  explicit OptionalNPUStreamGuard() : guard_() {}

  explicit OptionalNPUStreamGuard(c10::Stream stream) : guard_(stream) {}

  explicit OptionalNPUStreamGuard(c10::optional<c10::Stream> stream_opt)
      : guard_(stream_opt) {}

  OptionalNPUStreamGuard(const OptionalNPUStreamGuard&) = delete;
  OptionalNPUStreamGuard& operator=(const OptionalNPUStreamGuard&) = delete;

  // See Note [Move construction for RAII guards is tricky]
  OptionalNPUStreamGuard(OptionalNPUStreamGuard&& other) = delete;

  // See Note [Move assignment for RAII guards is tricky]
  OptionalNPUStreamGuard& operator=(OptionalNPUStreamGuard&& other) = delete;

  void reset_stream(c10::Stream stream) {
    guard_.reset_stream(stream);
  }

  c10::optional<NPUStream> original_stream() const {
    auto r = guard_.original_stream();
    if (r.has_value()) {
      return c10::make_optional(NPUStream(NPUStream::UNCHECKED, r.value()));
    } else {
      return c10::nullopt;
    }
  }

  c10::optional<NPUStream> current_stream() const {
    auto r = guard_.current_stream();
    if (r.has_value()) {
      return c10::make_optional(NPUStream(NPUStream::UNCHECKED, r.value()));
    } else {
      return c10::nullopt;
    }
  }

  void reset() {
    guard_.reset();
  }

 private:
  c10::impl::InlineOptionalStreamGuard<c10::backend::impl::NPUGuardImpl> guard_;
};

struct NPUMultiStreamGuard {
  explicit NPUMultiStreamGuard(at::ArrayRef<NPUStream> streams)
      : guard_(unwrapStreams(streams)) {}

  NPUMultiStreamGuard(const NPUMultiStreamGuard&) = delete;
  NPUMultiStreamGuard& operator=(const NPUMultiStreamGuard&) = delete;

  // See Note [Move construction for RAII guards is tricky]
  NPUMultiStreamGuard(NPUMultiStreamGuard&& other) = delete;

  // See Note [Move assignment for RAII guards is tricky]
  NPUMultiStreamGuard& operator=(NPUMultiStreamGuard&& other) = delete;

 private:
  c10::impl::InlineMultiStreamGuard<c10::backend::impl::NPUGuardImpl> guard_;

  static std::vector<c10::Stream> unwrapStreams(
      at::ArrayRef<NPUStream> NPUStreams) {
    std::vector<c10::Stream> streams;
    streams.reserve(NPUStreams.size());
    for (const NPUStream& NPUStream : NPUStreams) {
      streams.push_back(NPUStream);
    }
    return streams;
  }
};

} // namespace c10::backend

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources