38 #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
42 #define DEBUG_TYPE "enable-arm-streaming"
46 #define GEN_PASS_DEF_ENABLEARMSTREAMING
47 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
55 constexpr StringLiteral
56 kEnableArmStreamingIgnoreAttr(
"enable_arm_streaming_ignore");
58 template <
typename... Ops>
59 constexpr
auto opList() {
60 return std::array{TypeID::get<Ops>()...};
63 bool isScalableVector(
Type type) {
64 if (
auto vectorType = dyn_cast<VectorType>(type))
65 return vectorType.isScalable();
69 struct EnableArmStreamingPass
70 :
public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
71 EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72 bool ifRequiredByOps,
bool ifScalableAndSupported) {
73 this->streamingMode = streamingMode;
74 this->zaMode = zaMode;
75 this->ifRequiredByOps = ifRequiredByOps;
76 this->ifScalableAndSupported = ifScalableAndSupported;
78 void runOnOperation()
override {
79 auto function = getOperation();
81 if (ifRequiredByOps && ifScalableAndSupported) {
82 function->emitOpError(
83 "enable-arm-streaming: `if-required-by-ops` and "
84 "`if-scalable-and-supported` are mutually exclusive");
85 return signalPassFailure();
88 if (ifRequiredByOps) {
89 bool foundTileOp =
false;
91 if (llvm::isa<ArmSMETileOpInterface>(op)) {
101 if (ifScalableAndSupported) {
106 auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107 bool isCompatibleScalableFunction =
false;
109 if (llvm::is_contained(disallowedOperations,
111 isCompatibleScalableFunction = false;
112 return WalkResult::interrupt();
114 if (!isCompatibleScalableFunction &&
117 isCompatibleScalableFunction = true;
121 if (!isCompatibleScalableFunction)
125 if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126 streamingMode == ArmStreamingMode::Disabled)
131 function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
137 if (zaMode != ArmZaMode::Disabled)
138 function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
144 const ArmStreamingMode streamingMode,
const ArmZaMode zaMode,
145 bool ifRequiredByOps,
bool ifScalableAndSupported) {
146 return std::make_unique<EnableArmStreamingPass>(
147 streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
static MLIRContext * getContext(OpFoldResult val)
TypeID getTypeID() const
Return the unique identifier of the derived Op class, or null if not registered.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static WalkResult advance()
static WalkResult interrupt()
std::unique_ptr< Pass > createEnableArmStreamingPass(const ArmStreamingMode=ArmStreamingMode::Streaming, const ArmZaMode=ArmZaMode::Disabled, bool ifRequiredByOps=false, bool ifContainsScalableVectors=false)
Pass to enable Armv9 Streaming SVE mode.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...