18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
28 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
29 patterns, ReductionKind::Sum);
30 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
31 patterns, ReductionKind::Sum);
33 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
34 patterns, ReductionKind::Min);
35 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
36 patterns, ReductionKind::Min);
37 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
38 patterns, ReductionKind::Min);
40 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
41 patterns, ReductionKind::Max);
42 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
43 patterns, ReductionKind::Max);
44 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
45 patterns, ReductionKind::Max);
59 struct MeshShapeFolder
60 : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
67 op.getOperation(), op.getMeshAttr());
71 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
72 SmallVector<MeshAxis> opAxesIota;
73 if (opMeshAxes.empty()) {
74 opAxesIota.resize(mesh.getRank());
75 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
76 opMeshAxes = opAxesIota;
78 if (llvm::all_of(opMeshAxes, [&mesh](
MeshAxis axis) {
79 return ShapedType::isDynamic(mesh.getShape()[axis]);
85 SmallVector<Value> newResults(op->getResults().size());
86 SmallVector<MeshAxis> newShapeOpMeshAxes;
87 SmallVector<size_t> newToOldResultsIndexMap;
89 for (
size_t i = 0; i < opMeshAxes.size(); ++i) {
90 auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
91 if (ShapedType::isDynamic(meshAxisSize)) {
92 newToOldResultsIndexMap.push_back(i);
93 newShapeOpMeshAxes.push_back(opMeshAxes[i]);
96 newResults[i] = builder.create<arith::ConstantOp>(
97 builder.getIndexAttr(meshAxisSize));
102 if (!newShapeOpMeshAxes.empty()) {
103 MeshShapeOp newShapeOp =
104 builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
105 for (
size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
106 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
119 patterns.
add<MeshShapeFolder>(symbolTableCollection, patterns.
getContext());
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)