17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
27 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
28 patterns, ReductionKind::Sum);
29 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30 patterns, ReductionKind::Sum);
32 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
33 patterns, ReductionKind::Min);
34 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
35 patterns, ReductionKind::Min);
36 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37 patterns, ReductionKind::Min);
39 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
40 patterns, ReductionKind::Max);
41 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
42 patterns, ReductionKind::Max);
43 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
44 patterns, ReductionKind::Max);
58 struct MeshShapeFolder
59 : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
62 LogicalResult matchAndRewrite(MeshShapeOp op,
66 op.getOperation(), op.getMeshAttr());
70 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71 SmallVector<MeshAxis> opAxesIota;
72 if (opMeshAxes.empty()) {
73 opAxesIota.resize(mesh.getRank());
74 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
75 opMeshAxes = opAxesIota;
77 if (llvm::all_of(opMeshAxes, [&mesh](
MeshAxis axis) {
78 return ShapedType::isDynamic(mesh.getShape()[axis]);
84 SmallVector<Value> newResults(op->getResults().size());
85 SmallVector<MeshAxis> newShapeOpMeshAxes;
86 SmallVector<size_t> newToOldResultsIndexMap;
88 for (
size_t i = 0; i < opMeshAxes.size(); ++i) {
89 auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90 if (ShapedType::isDynamic(meshAxisSize)) {
91 newToOldResultsIndexMap.push_back(i);
92 newShapeOpMeshAxes.push_back(opMeshAxes[i]);
95 newResults[i] = builder.create<arith::ConstantOp>(
96 builder.getIndexAttr(meshAxisSize));
101 if (!newShapeOpMeshAxes.empty()) {
102 MeshShapeOp newShapeOp =
103 builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104 for (
size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
118 patterns.
add<MeshShapeFolder>(symbolTableCollection, patterns.
getContext());
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)