38#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
42#define DEBUG_TYPE "enable-arm-streaming"
46#define GEN_PASS_DEF_ENABLEARMSTREAMING
47#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
55constexpr StringLiteral
56 kEnableArmStreamingIgnoreAttr(
"enable_arm_streaming_ignore");
58template <
typename... Ops>
59constexpr auto opList() {
63bool isScalableVector(
Type type) {
64 if (
auto vectorType = dyn_cast<VectorType>(type))
65 return vectorType.isScalable();
69struct EnableArmStreamingPass
71 EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72 bool ifRequiredByOps,
bool ifScalableAndSupported) {
73 this->streamingMode = streamingMode;
78 void runOnOperation()
override {
79 auto function = getOperation();
82 function->emitOpError(
83 "enable-arm-streaming: `if-required-by-ops` and "
84 "`if-scalable-and-supported` are mutually exclusive");
88 if (ifRequiredByOps) {
89 bool foundTileOp =
false;
91 if (llvm::isa<ArmSMETileOpInterface>(op)) {
101 if (ifScalableAndSupported) {
106 auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107 bool isCompatibleScalableFunction =
false;
109 if (llvm::is_contained(disallowedOperations,
111 isCompatibleScalableFunction =
false;
114 if (!isCompatibleScalableFunction &&
117 isCompatibleScalableFunction =
true;
121 if (!isCompatibleScalableFunction)
125 if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126 streamingMode == ArmStreamingMode::Disabled)
131 function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
137 if (zaMode != ArmZaMode::Disabled)
138 function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
144 const ArmStreamingMode streamingMode,
const ArmZaMode zaMode,
145 bool ifRequiredByOps,
bool ifScalableAndSupported) {
146 return std::make_unique<EnableArmStreamingPass>(
147 streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
TypeID getTypeID() const
Return the unique identifier of the derived Op class, or null if not registered.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void signalPassFailure()
Signal that some invariant was broken when running.
static TypeID get()
Construct a type info object for the given type T.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static WalkResult advance()
static WalkResult interrupt()
::mlir::Pass::Option< bool > ifScalableAndSupported
::mlir::Pass::Option< bool > ifRequiredByOps
::mlir::Pass::Option< mlir::arm_sme::ArmZaMode > zaMode
std::unique_ptr< Pass > createEnableArmStreamingPass(const ArmStreamingMode=ArmStreamingMode::Streaming, const ArmZaMode=ArmZaMode::Disabled, bool ifRequiredByOps=false, bool ifContainsScalableVectors=false)
Pass to enable Armv9 Streaming SVE mode.
Include the generated interface declarations.