18#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19#include "mlir/Conversion/Passes.h.inc"
22#define DEBUG_TYPE "arith-to-arm-sme"
32 if (llvm::isa<FloatType>(elemType))
34 if (llvm::isa<IntegerType>(elemType))
46struct ConstantOpToArmSMELowering :
public OpRewritePattern<arith::ConstantOp> {
49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50 PatternRewriter &rewriter)
const final {
51 auto tileType = dyn_cast<VectorType>(constantOp.getType());
55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56 if (!denseAttr || !denseAttr.isSplat())
59 auto tileElementType = tileType.getElementType();
63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
69 auto loc = constantOp.getLoc();
74 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
76 tileSliceType, denseAttr.getSplatValue<Attribute>());
77 auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
79 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
80 auto makeLoopBody = [&](OpBuilder &
b, Location loc, Value tileSliceIndex,
84 auto nextTile = arm_sme::InsertTileSliceOp::create(
85 b, loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86 return nextTile.getResult();
89 rewriter, loc, initTile, makeLoopBody);
90 rewriter.replaceOp(constantOp, forOp.getResult(0));
104 patterns.
add<ConstantOpToArmSMELowering>(patterns.
getContext());
112struct ArithToArmSMEConversionPass final
115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
117 void runOnOperation()
override {
121 return signalPassFailure();
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...