26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Casting.h"
39 #define DEBUG_TYPE "mesh-ops"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
45 #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
49 struct DimensionSize {
50 static DimensionSize dynamic() {
return DimensionSize(ShapedType::kDynamic); }
51 DimensionSize(int64_t val) : val(val) {}
52 int64_t value()
const {
return val; }
53 operator int64_t()
const {
return val; }
54 bool isDynamic()
const {
return ShapedType::isDynamic(val); }
62 static DimensionSize
operator/(DimensionSize lhs, DimensionSize rhs) {
63 if (lhs.isDynamic() || rhs.isDynamic()) {
64 return DimensionSize::dynamic();
66 return lhs.value() / rhs.value();
69 static DimensionSize
operator*(DimensionSize lhs, DimensionSize rhs) {
70 if (lhs.isDynamic() || rhs.isDynamic()) {
71 return DimensionSize::dynamic();
73 return lhs.value() * rhs.value();
80 void MeshDialect::initialize() {
83 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
86 #define GET_ATTRDEF_LIST
87 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
93 return arith::ConstantOp::materialize(builder, value, type, loc);
105 return op->
emitError() <<
"Undefined required mesh symbol \""
112 template <
typename It>
117 It next = std::next(begin);
121 for (; next != end; ++next, ++begin) {
122 if (*begin == *next) {
133 if (!
isUnique(sorted.begin(), sorted.end())) {
134 return emitError(loc) <<
"Mesh axes contains duplicate elements.";
138 for (
auto axis : axes) {
139 if (axis >= rank || axis < 0) {
141 <<
"0-based mesh axis index " << axis
142 <<
" is out of bounds. The referenced mesh \"" << mesh.getSymName()
143 <<
"\" is of rank " << rank <<
".";
150 template <
typename InShape,
typename MeshShape,
typename SplitAxes,
152 static void shardShape(
const InShape &inShape,
const MeshShape &meshShape,
153 const SplitAxes &splitAxes, OutShape &outShape) {
154 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
155 llvm::adl_begin(outShape));
165 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
167 shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
169 return shape.clone(resShapeArr);
173 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
174 if (rankedTensorType) {
189 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
190 if (shardOp && shardOp.getShard() == sharding &&
191 !shardOp.getAnnotateForUsers()) {
197 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, sharding,
200 rewriter.replaceUsesWithIf(
201 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
202 return use.
getOwner() == operandOp && use.
get() == operandValue;
205 if (!shardOp || shardOp.getAnnotateForUsers()) {
209 auto newShardOp2 = builder.
create<ShardOp>(
210 operandValue.
getLoc(), newShardOp, sharding,
true);
211 rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
217 for (
auto &use : llvm::make_early_inc_range(result.
getUses())) {
229 bool isBlockArg = !operandSrcOp;
230 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
232 if (shardOp && shardOp.getShard() == sharding &&
233 shardOp.getAnnotateForUsers()) {
240 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, sharding,
243 rewriter.replaceUsesWithIf(
244 operandValue, newShardOp, [operandOp, operandValue](
OpOperand &use) {
245 return use.
getOwner() == operandOp && use.
get() == operandValue;
248 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
254 auto newPreceedingShardOp =
255 builder.
create<ShardOp>(operandValue.
getLoc(), operandValue, sharding,
260 newShardOp.getOperation();
269 int64_t rank = getRank();
272 return emitOpError(
"rank of mesh is expected to be a positive integer");
274 for (int64_t dimSize :
getShape()) {
275 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
276 return emitOpError(
"dimension size of a mesh is expected to be "
277 "non-negative or dynamic");
297 size_t expectedResultsCount =
298 getAxes().empty() ? mesh->getRank() : getAxes().size();
299 if (getResult().size() != expectedResultsCount) {
300 return emitError() <<
"Unexpected number of results " << getResult().size()
301 <<
". Expected " << expectedResultsCount <<
".";
314 build(odsBuilder, odsState,
322 assert(!axes.empty());
323 build(odsBuilder, odsState,
328 void MeshShapeOp::getAsmResultNames(
330 setNameFn(getResults()[0],
"mesh_shape");
344 llvm::SmallSet<MeshAxis, 4> visitedAxes;
349 return emitError() <<
"mesh axis is expected to be non-negative";
350 if (!visitedAxes.insert(axis).second)
351 return emitError() <<
"mesh axis duplicated";
358 if (failed(checkMeshAxis(subAxesArray)))
361 if (failed(checkMeshAxis(partialAxes)))
368 mlir::dyn_cast<MeshShardingAttr>(rhs);
369 return rhsAsMeshShardingAttr && *
this == rhsAsMeshShardingAttr;
373 return !(*
this == rhs);
377 if (
getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
381 if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
385 auto minSize =
std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
386 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
387 getSplitAxes().begin() + minSize),
388 llvm::make_range(rhs.getSplitAxes().begin(),
389 rhs.getSplitAxes().begin() + minSize))) {
393 return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
394 getSplitAxes().end()),
395 std::mem_fn(&MeshAxesAttr::empty)) &&
396 llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
397 rhs.getSplitAxes().end()),
398 std::mem_fn(&MeshAxesAttr::empty));
402 return !(*
this == rhs);
409 void ShardOp::getAsmResultNames(
411 setNameFn(getResult(),
"sharding_annotated");
428 size_t expectedResultsCount =
429 getAxes().empty() ? mesh->getRank() : getAxes().size();
430 if (getResult().size() != expectedResultsCount) {
431 return emitError() <<
"Unexpected number of results " << getResult().size()
432 <<
". Expected " << expectedResultsCount <<
".";
440 build(odsBuilder, odsState,
447 build(odsBuilder, odsState,
452 void ProcessMultiIndexOp::getAsmResultNames(
454 setNameFn(getResults()[0],
"proc_linear_idx");
470 void ProcessLinearIndexOp::build(
OpBuilder &odsBuilder,
472 build(odsBuilder, odsState, mesh.getSymName());
475 void ProcessLinearIndexOp::getAsmResultNames(
477 setNameFn(getResult(),
"proc_linear_idx");
486 template <
typename Op>
489 LogicalResult matchAndRewrite(
Op op,
491 auto meshAxes = op.getMeshAxes();
492 if (!meshAxes.empty()) {
500 rewriter.
eraseOp(op.getOperation());
512 if (device.size() != meshAxes.size()) {
513 return emitError(loc) <<
"In-group device \"" << deviceName
514 <<
"\" has unexpected multi-index size "
515 << device.size() <<
". Expected " << meshAxes.size()
519 for (
size_t i = 0; i < device.size(); ++i) {
520 if (!ShapedType::isDynamic(device[i]) &&
521 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
522 meshShape[meshAxes[i]] <= device[i]) {
524 <<
"Out of bounds coordinate " << i <<
" for in-group device \""
525 << deviceName <<
"\"."
526 <<
" Got " << device[i] <<
", but expected value in the range [0, "
527 << (meshShape[meshAxes[i]] - 1) <<
"].";
533 template <
typename Op>
534 static FailureOr<MeshOp>
547 template <
typename It>
549 using ElementType = std::decay_t<decltype(*begin)>;
550 return std::accumulate(begin, end,
static_cast<ElementType
>(1),
551 std::multiplies<ElementType>());
554 template <
typename R>
556 return product(adl_begin(range), adl_end(range));
560 int64_t expectedDimSize,
561 int64_t resultDimSize,
562 int64_t resultAxis) {
563 if (!ShapedType::isDynamic(resultDimSize) &&
564 expectedDimSize != resultDimSize) {
565 return emitError(loc) <<
"Dimension size mismatch for result axis "
566 << resultAxis <<
". Expected "
567 << (ShapedType::isDynamic(expectedDimSize)
569 : Twine(expectedDimSize))
570 <<
", but got " << resultDimSize <<
".";
577 Value operand,
Value result, int64_t gatherAxis,
579 auto resultRank = cast<ShapedType>(result.
getType()).getRank();
580 if (gatherAxis < 0 || gatherAxis >= resultRank) {
582 <<
"Gather axis " << gatherAxis <<
" is out of bounds [0, "
583 << resultRank <<
").";
586 ShapedType operandType = cast<ShapedType>(operand.
getType());
587 ShapedType resultType = cast<ShapedType>(result.
getType());
588 auto deviceGroupSize =
590 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
591 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
592 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
593 auto expectedResultDimSize =
594 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
596 result.
getLoc(), expectedResultDimSize, resultDimSize, axis))) {
604 Value operand,
Value result, int64_t splitAxis, int64_t concatAxis,
606 ShapedType operandType = cast<ShapedType>(operand.
getType());
607 ShapedType resultType = cast<ShapedType>(result.
getType());
608 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
609 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
611 result.
getLoc(), operandType.getDimSize(axis),
612 resultType.getDimSize(axis), axis))) {
618 if (splitAxis == concatAxis) {
622 auto deviceGroupSize =
624 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
625 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
626 DimensionSize expectedResultConcatDimSize =
627 operandConcatDimSize * deviceGroupSize;
628 DimensionSize expectedResultSplitDimSize =
629 operandSplitDimSize / deviceGroupSize;
630 if (!expectedResultSplitDimSize.isDynamic() &&
631 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
632 expectedResultSplitDimSize = DimensionSize::dynamic();
635 result.
getLoc(), expectedResultConcatDimSize.value(),
636 resultType.getDimSize(concatAxis), concatAxis))) {
640 result.
getLoc(), expectedResultSplitDimSize.value(),
641 resultType.getDimSize(splitAxis), splitAxis))) {
649 Value operand,
Value result, int64_t tensorAxis,
651 ShapedType operandType = cast<ShapedType>(operand.
getType());
652 ShapedType resultType = cast<ShapedType>(result.
getType());
653 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
654 if (axis != tensorAxis) {
656 result.
getLoc(), operandType.getDimSize(axis),
657 resultType.getDimSize(axis), axis))) {
663 auto deviceGroupSize =
665 auto operandScatterDimSize =
666 DimensionSize(operandType.getDimSize(tensorAxis));
667 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
668 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
670 <<
"Operand dimension size " << int64_t(operandScatterDimSize)
671 <<
" is not divisible by collective device group size "
672 << int64_t(deviceGroupSize) <<
" for tensor axis " << tensorAxis
675 DimensionSize expectedResultTensorDimSize =
676 operandScatterDimSize / deviceGroupSize;
678 result.
getLoc(), expectedResultTensorDimSize.value(),
679 resultType.getDimSize(tensorAxis), tensorAxis))) {
689 RankedTensorType operandRankedTensorType =
690 cast<RankedTensorType>(operandType);
691 DimensionSize operandSliceAxisSize =
692 operandRankedTensorType.getShape()[sliceAxis];
694 llvm::to_vector(operandRankedTensorType.getShape());
696 resultShape[sliceAxis] =
697 operandSliceAxisSize /
699 return operandRankedTensorType.clone(resultShape);
712 auto gatherAxis = getGatherAxis().getSExtValue();
714 gatherAxis, getMeshAxes(),
715 mesh.value().getShape());
720 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
723 void AllGatherOp::getAsmResultNames(
725 setNameFn(getResult(),
"all_gather");
739 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
743 Value input, StringRef mesh,
745 build(odsBuilder, odsState, input.
getType(), mesh, meshAxes, input,
749 void AllReduceOp::getAsmResultNames(
751 setNameFn(getResult(),
"all_reduce");
764 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
765 mesh.value().getShape());
770 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
777 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
782 Type resultType,
Value input, StringRef mesh,
784 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
785 APInt(
sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
788 void AllSliceOp::getAsmResultNames(
790 setNameFn(getResult(),
"all_slice");
804 getOperand(), getResult(), getSplitAxis().getSExtValue(),
805 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
810 patterns.
add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
813 void AllToAllOp::getAsmResultNames(
815 setNameFn(getResult(),
"all_to_all");
829 getRootDynamic(), getMeshAxes(),
830 mesh.value().getShape()))) {
839 patterns.
add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
842 void BroadcastOp::getAsmResultNames(
844 setNameFn(getResult(),
"broadcast");
857 getRootDynamic(), getMeshAxes(),
858 mesh.value().getShape()))) {
862 auto gatherAxis = getGatherAxis().getSExtValue();
865 mesh.value().getShape());
870 patterns.
add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
873 void GatherOp::getAsmResultNames(
875 setNameFn(getResult(),
"gather");
889 getSource().value(), getSourceDynamic(),
890 getMeshAxes(), mesh.value().getShape()))) {
898 patterns.
add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
902 setNameFn(getResult(),
"recv");
915 getRootDynamic(), getMeshAxes(),
916 mesh.value().getShape()))) {
925 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
928 void ReduceOp::getAsmResultNames(
930 setNameFn(getResult(),
"reduce");
945 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
946 mesh.value().getShape());
951 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
954 void ReduceScatterOp::getAsmResultNames(
956 setNameFn(getResult(),
"reduce_scatter");
969 getRootDynamic(), getMeshAxes(),
970 mesh.value().getShape()))) {
974 auto scatterAxis = getScatterAxis().getSExtValue();
976 scatterAxis, getMeshAxes(),
977 mesh.value().getShape());
982 patterns.
add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
985 void ScatterOp::getAsmResultNames(
987 setNameFn(getResult(),
"scatter");
1000 getDestination(), getDestinationDynamic(),
1001 getMeshAxes(), mesh.value().getShape()))) {
1009 patterns.
add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1012 void SendOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1013 setNameFn(getResult(),
"send");
1026 auto meshAxes = getMeshAxes();
1027 auto shiftAxis = getShiftAxis().getZExtValue();
1028 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
1029 return emitError() <<
"Invalid shift axis " << shiftAxis
1030 <<
". It must be one of the grouping mesh axes.";
1042 void ShiftOp::getAsmResultNames(
1044 setNameFn(getResult(),
"shift");
1051 #define GET_OP_CLASSES
1052 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1054 #define GET_ATTRDEF_CLASSES
1055 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1057 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs)
static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, int64_t resultAxis)
static FailureOr< MeshOp > getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable)
static FailureOr< MeshOp > getMeshAndVerify(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTable)
static LogicalResult verifyScatterOrSliceOperandAndResultShape(Value operand, Value result, int64_t tensorAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result, int64_t gatherAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static auto product(It begin, It end)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape)
bool isUnique(It begin, It end)
static LogicalResult verifyMeshAxes(Location loc, ArrayRef< MeshAxis > axes, MeshOp mesh)
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef< int64_t > device, Operation::operand_range deviceDynamic, ArrayRef< MeshAxis > meshAxes, ArrayRef< int64_t > meshShape)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this operation with the provided values if the given callback returns true...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
bool operator==(const Fraction &x, const Fraction &y)
bool operator!=(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineExpr operator*(int64_t val, AffineExpr expr)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.