38#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
42#define DEBUG_TYPE "enable-arm-streaming"
46#define GEN_PASS_DEF_ENABLEARMSTREAMING
47#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
55constexpr StringLiteral
56 kEnableArmStreamingIgnoreAttr(
"llvm.enable_arm_streaming_ignore");
58template <
typename... Ops>
59constexpr auto opList() {
63bool isScalableVector(
Type type) {
64 if (
auto vectorType = dyn_cast<VectorType>(type))
65 return vectorType.isScalable();
69struct EnableArmStreamingPass
70 :
public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
71 EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72 bool ifRequiredByOps,
bool ifScalableAndSupported) {
73 this->streamingMode = streamingMode;
74 this->zaMode = zaMode;
75 this->ifRequiredByOps = ifRequiredByOps;
76 this->ifScalableAndSupported = ifScalableAndSupported;
78 void runOnOperation()
override {
79 auto function = getOperation();
81 if (ifRequiredByOps && ifScalableAndSupported) {
82 function->emitOpError(
83 "enable-arm-streaming: `if-required-by-ops` and "
84 "`if-scalable-and-supported` are mutually exclusive");
85 return signalPassFailure();
88 if (ifRequiredByOps) {
89 bool foundTileOp =
false;
90 function.walk([&](Operation *op) {
91 if (llvm::isa<ArmSMETileOpInterface>(op)) {
101 if (ifScalableAndSupported) {
106 auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107 bool isCompatibleScalableFunction =
false;
108 function.walk([&](Operation *op) {
109 if (llvm::is_contained(disallowedOperations,
111 isCompatibleScalableFunction =
false;
114 if (!isCompatibleScalableFunction &&
117 isCompatibleScalableFunction =
true;
121 if (!isCompatibleScalableFunction)
125 if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126 streamingMode == ArmStreamingMode::Disabled)
131 function->setDiscardableAttr(
132 (Twine(
"llvm.") + stringifyArmStreamingMode(streamingMode)).str(),
139 if (zaMode != ArmZaMode::Disabled)
140 function->setAttr((Twine(
"llvm.") + stringifyArmZaMode(zaMode)).str(),
147 const ArmStreamingMode streamingMode,
const ArmZaMode zaMode,
148 bool ifRequiredByOps,
bool ifScalableAndSupported) {
149 return std::make_unique<EnableArmStreamingPass>(
150 streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
TypeID getTypeID() const
Return the unique identifier of the derived Op class, or null if not registered.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
static TypeID get()
Construct a type info object for the given type T.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static WalkResult advance()
static WalkResult interrupt()
std::unique_ptr< Pass > createEnableArmStreamingPass(const ArmStreamingMode=ArmStreamingMode::Streaming, const ArmZaMode=ArmZaMode::Disabled, bool ifRequiredByOps=false, bool ifContainsScalableVectors=false)
Pass to enable Armv9 Streaming SVE mode.
Include the generated interface declarations.