38#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
42#define DEBUG_TYPE "enable-arm-streaming"
46#define GEN_PASS_DEF_ENABLEARMSTREAMING
47#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
55constexpr StringLiteral
56 kEnableArmStreamingIgnoreAttr(
"llvm.enable_arm_streaming_ignore");
58template <
typename... Ops>
59constexpr auto opList() {
63bool isScalableVector(
Type type) {
64 if (
auto vectorType = dyn_cast<VectorType>(type))
65 return vectorType.isScalable();
69struct EnableArmStreamingPass
71 EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72 bool ifRequiredByOps,
bool ifScalableAndSupported) {
73 this->streamingMode = streamingMode;
78 void runOnOperation()
override {
79 auto function = getOperation();
82 function->emitOpError(
83 "enable-arm-streaming: `if-required-by-ops` and "
84 "`if-scalable-and-supported` are mutually exclusive");
88 if (ifRequiredByOps) {
89 bool foundTileOp =
false;
91 if (llvm::isa<ArmSMETileOpInterface>(op)) {
101 if (ifScalableAndSupported) {
106 auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107 bool isCompatibleScalableFunction =
false;
109 if (llvm::is_contained(disallowedOperations,
111 isCompatibleScalableFunction =
false;
114 if (!isCompatibleScalableFunction &&
117 isCompatibleScalableFunction =
true;
121 if (!isCompatibleScalableFunction)
125 if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126 streamingMode == ArmStreamingMode::Disabled)
131 function->setDiscardableAttr(
132 (Twine(
"llvm.") + stringifyArmStreamingMode(streamingMode)).str(),
139 if (zaMode != ArmZaMode::Disabled)
140 function->setAttr((Twine(
"llvm.") + stringifyArmZaMode(zaMode)).str(),
147 const ArmStreamingMode streamingMode,
const ArmZaMode zaMode,
148 bool ifRequiredByOps,
bool ifScalableAndSupported) {
149 return std::make_unique<EnableArmStreamingPass>(
150 streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
TypeID getTypeID() const
Return the unique identifier of the derived Op class, or null if not registered.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void signalPassFailure()
Signal that some invariant was broken when running.
static TypeID get()
Construct a type info object for the given type T.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static WalkResult advance()
static WalkResult interrupt()
::mlir::Pass::Option< bool > ifScalableAndSupported
::mlir::Pass::Option< bool > ifRequiredByOps
::mlir::Pass::Option< mlir::arm_sme::ArmZaMode > zaMode
std::unique_ptr< Pass > createEnableArmStreamingPass(const ArmStreamingMode=ArmStreamingMode::Streaming, const ArmZaMode=ArmZaMode::Disabled, bool ifRequiredByOps=false, bool ifContainsScalableVectors=false)
Pass to enable Armv9 Streaming SVE mode.
Include the generated interface declarations.