38#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
42#define DEBUG_TYPE "enable-arm-streaming"
46#define GEN_PASS_DEF_ENABLEARMSTREAMING
47#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
55constexpr StringLiteral
56 kEnableArmStreamingIgnoreAttr(
"enable_arm_streaming_ignore");
58template <
typename... Ops>
59constexpr auto opList() {
63bool isScalableVector(
Type type) {
64 if (
auto vectorType = dyn_cast<VectorType>(type))
65 return vectorType.isScalable();
69struct EnableArmStreamingPass
70 :
public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
71 EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72 bool ifRequiredByOps,
bool ifScalableAndSupported) {
73 this->streamingMode = streamingMode;
74 this->zaMode = zaMode;
75 this->ifRequiredByOps = ifRequiredByOps;
76 this->ifScalableAndSupported = ifScalableAndSupported;
78 void runOnOperation()
override {
79 auto function = getOperation();
81 if (ifRequiredByOps && ifScalableAndSupported) {
82 function->emitOpError(
83 "enable-arm-streaming: `if-required-by-ops` and "
84 "`if-scalable-and-supported` are mutually exclusive");
85 return signalPassFailure();
88 if (ifRequiredByOps) {
89 bool foundTileOp =
false;
90 function.walk([&](Operation *op) {
91 if (llvm::isa<ArmSMETileOpInterface>(op)) {
101 if (ifScalableAndSupported) {
106 auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107 bool isCompatibleScalableFunction =
false;
108 function.walk([&](Operation *op) {
109 if (llvm::is_contained(disallowedOperations,
111 isCompatibleScalableFunction =
false;
114 if (!isCompatibleScalableFunction &&
117 isCompatibleScalableFunction =
true;
121 if (!isCompatibleScalableFunction)
125 if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126 streamingMode == ArmStreamingMode::Disabled)
131 function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
137 if (zaMode != ArmZaMode::Disabled)
138 function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
144 const ArmStreamingMode streamingMode,
const ArmZaMode zaMode,
145 bool ifRequiredByOps,
bool ifScalableAndSupported) {
146 return std::make_unique<EnableArmStreamingPass>(
147 streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
TypeID getTypeID() const
Return the unique identifier of the derived Op class, or null if not registered.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
static TypeID get()
Construct a type info object for the given type T.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static WalkResult advance()
static WalkResult interrupt()
std::unique_ptr< Pass > createEnableArmStreamingPass(const ArmStreamingMode=ArmStreamingMode::Streaming, const ArmZaMode=ArmZaMode::Disabled, bool ifRequiredByOps=false, bool ifContainsScalableVectors=false)
Pass to enable Armv9 Streaming SVE mode.
Include the generated interface declarations.