18 #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19 #include "mlir/Conversion/Passes.h.inc"
22 #define DEBUG_TYPE "arith-to-arm-sme"
32 if (llvm::isa<FloatType>(elemType))
34 if (llvm::isa<IntegerType>(elemType))
46 struct ConstantOpToArmSMELowering :
public OpRewritePattern<arith::ConstantOp> {
49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
51 auto tileType = dyn_cast<VectorType>(constantOp.getType());
55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56 if (!denseAttr || !denseAttr.isSplat())
59 auto tileElementType = tileType.getElementType();
63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
69 auto loc = constantOp.getLoc();
76 tileSliceType, denseAttr.getSplatValue<
Attribute>());
77 auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
79 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
84 auto nextTile = b.
create<arm_sme::InsertTileSliceOp>(
85 loc, tileType, constantOp1D, currentTile, tileSliceIndex);
89 rewriter, loc, initTile, makeLoopBody);
90 rewriter.replaceOp(constantOp, forOp.getResult(0));
104 patterns.
add<ConstantOpToArmSMELowering>(patterns.
getContext());
112 struct ArithToArmSMEConversionPass final
113 : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
114 using impl::ArithToArmSMEConversionPassBase<
115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
117 void runOnOperation()
override {
122 return signalPassFailure();
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...