18#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19#include "mlir/Conversion/Passes.h.inc"
22#define DEBUG_TYPE "arith-to-arm-sme"
32 if (llvm::isa<FloatType>(elemType))
34 if (llvm::isa<IntegerType>(elemType))
46struct ConstantOpToArmSMELowering :
public OpRewritePattern<arith::ConstantOp> {
49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50 PatternRewriter &rewriter)
const final {
51 auto tileType = dyn_cast<VectorType>(constantOp.getType());
55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56 if (!denseAttr || !denseAttr.isSplat())
59 auto tileElementType = tileType.getElementType();
63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
69 auto loc = constantOp.getLoc();
74 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
76 tileSliceType, denseAttr.getSplatValue<Attribute>());
77 auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
79 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
80 auto makeLoopBody = [&](OpBuilder &
b, Location loc, Value tileSliceIndex,
84 auto nextTile = arm_sme::InsertTileSliceOp::create(
85 b, loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86 return nextTile.getResult();
89 rewriter, loc, initTile, makeLoopBody);
90 rewriter.replaceOp(constantOp, forOp.getResult(0));
112struct ArithToArmSMEConversionPass final
115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
117 void runOnOperation()
override {
121 return signalPassFailure();
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...