Program Listing for File NPUGuard.h¶
↰ Return to documentation for file (csrc/backend/NPUGuard.h
)
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include "csrc/backend/NPUGuardImpl.h"
#include "csrc/core/guard/PrivateUse1Guard.h"
#include <cstddef>
namespace c10::backend {
// This code is kind of boilerplatey. See Note [Whither the DeviceGuard
// boilerplate]
struct NPUGuard : public Guard::PrivateUse1Guard<impl::NPUGuardImpl> {
using PrivateUse1Guard = Guard::PrivateUse1Guard<impl::NPUGuardImpl>;
using PrivateUse1Guard::PrivateUse1Guard;
// Copy is not allowed
NPUGuard(const NPUGuard&) = delete;
NPUGuard& operator=(const NPUGuard&) = delete;
// Move is not allowed (there is no uninitialized state)
NPUGuard(NPUGuard&& other) = delete;
NPUGuard& operator=(NPUGuard&& other) = delete;
};
struct OptionalNPUGuard
: public Guard::OptionalPrivateUse1Guard<impl::NPUGuardImpl> {
using OptionalPrivateUse1Guard =
Guard::OptionalPrivateUse1Guard<impl::NPUGuardImpl>;
using OptionalPrivateUse1Guard::OptionalPrivateUse1Guard;
// Copy is not allowed
OptionalNPUGuard(const OptionalNPUGuard&) = delete;
OptionalNPUGuard& operator=(const OptionalNPUGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
OptionalNPUGuard(OptionalNPUGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalNPUGuard& operator=(OptionalNPUGuard&& other) = delete;
};
struct NPUStreamGuard {
explicit NPUStreamGuard() = delete;
explicit NPUStreamGuard(c10::Stream stream) : guard_(stream) {}
NPUStreamGuard(const NPUStreamGuard&) = delete;
NPUStreamGuard& operator=(const NPUStreamGuard&) = delete;
NPUStreamGuard(NPUStreamGuard&& other) = delete;
NPUStreamGuard& operator=(NPUStreamGuard&& other) = delete;
void reset_stream(c10::Stream stream) {
guard_.reset_stream(stream);
}
NPUStream original_stream() const {
return NPUStream(NPUStream::UNCHECKED, guard_.original_stream());
}
NPUStream current_stream() const {
return NPUStream(NPUStream::UNCHECKED, guard_.current_stream());
}
c10::Device current_device() const {
return guard_.current_device();
}
c10::Device original_device() const {
return guard_.original_device();
}
private:
c10::impl::InlineStreamGuard<c10::backend::impl::NPUGuardImpl> guard_;
};
struct OptionalNPUStreamGuard {
explicit OptionalNPUStreamGuard() : guard_() {}
explicit OptionalNPUStreamGuard(c10::Stream stream) : guard_(stream) {}
explicit OptionalNPUStreamGuard(c10::optional<c10::Stream> stream_opt)
: guard_(stream_opt) {}
OptionalNPUStreamGuard(const OptionalNPUStreamGuard&) = delete;
OptionalNPUStreamGuard& operator=(const OptionalNPUStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
OptionalNPUStreamGuard(OptionalNPUStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalNPUStreamGuard& operator=(OptionalNPUStreamGuard&& other) = delete;
void reset_stream(c10::Stream stream) {
guard_.reset_stream(stream);
}
c10::optional<NPUStream> original_stream() const {
auto r = guard_.original_stream();
if (r.has_value()) {
return c10::make_optional(NPUStream(NPUStream::UNCHECKED, r.value()));
} else {
return c10::nullopt;
}
}
c10::optional<NPUStream> current_stream() const {
auto r = guard_.current_stream();
if (r.has_value()) {
return c10::make_optional(NPUStream(NPUStream::UNCHECKED, r.value()));
} else {
return c10::nullopt;
}
}
void reset() {
guard_.reset();
}
private:
c10::impl::InlineOptionalStreamGuard<c10::backend::impl::NPUGuardImpl> guard_;
};
struct NPUMultiStreamGuard {
explicit NPUMultiStreamGuard(at::ArrayRef<NPUStream> streams)
: guard_(unwrapStreams(streams)) {}
NPUMultiStreamGuard(const NPUMultiStreamGuard&) = delete;
NPUMultiStreamGuard& operator=(const NPUMultiStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
NPUMultiStreamGuard(NPUMultiStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
NPUMultiStreamGuard& operator=(NPUMultiStreamGuard&& other) = delete;
private:
c10::impl::InlineMultiStreamGuard<c10::backend::impl::NPUGuardImpl> guard_;
static std::vector<c10::Stream> unwrapStreams(
at::ArrayRef<NPUStream> NPUStreams) {
std::vector<c10::Stream> streams;
streams.reserve(NPUStreams.size());
for (const NPUStream& NPUStream : NPUStreams) {
streams.push_back(NPUStream);
}
return streams;
}
};
} // namespace c10::backend