MLIR 22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
23
24namespace mlir {
25namespace linalg {
26namespace {
27/// Verify that the runtime sizes of the operands to linalg structured ops are
28/// compatible with the runtime sizes inferred by composing the loop ranges with
29/// the linalg op's indexing maps. This is similar to the verifier except that
30/// here we insert IR to perform the verification at runtime.
31template <typename T>
32struct StructuredOpInterface
33 : public RuntimeVerifiableOpInterface::ExternalModel<
34 StructuredOpInterface<T>, T> {
35 void
36 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
37 function_ref<std::string(Operation *, StringRef)>
38 generateErrorMessage) const {
39 auto linalgOp = llvm::cast<LinalgOp>(op);
40
41 SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
42 auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
43
44 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
45 auto one = arith::ConstantIndexOp::create(builder, loc, 1);
46
47 Value iterationDomainIsNonDegenerate;
48 for (auto [start, end] : llvm::zip(starts, ends)) {
49 auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
50 auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
51
52 // Loop Trip count > 0 iff start < end
53 Value dimensionHasNonZeroTripCount = index::CmpOp::create(
54 builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue);
55
56 if (!iterationDomainIsNonDegenerate) {
57 iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
58 } else {
59 // Iteration domain is non-degenerate iff all dimensions have loop trip
60 // count > 0
61 iterationDomainIsNonDegenerate =
62 arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate,
63 dimensionHasNonZeroTripCount);
64 }
65 }
66
67 if (!iterationDomainIsNonDegenerate)
68 return;
69
70 auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate,
71 /*withElseRegion=*/false);
72 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
73
74 // Subtract one from the loop ends before composing with the indexing map
75 transform(ends, ends.begin(), [&](OpFoldResult end) {
76 auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
77 return builder.createOrFold<index::SubOp>(loc, endValue, one);
78 });
79
80 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
81 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
83 builder, loc, indexingMap, starts);
85 builder, loc, indexingMap, ends);
86
87 for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
88 auto startIndex =
89 getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
90 auto endIndex =
91 getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
92
93 // Generate:
94 // minIndex = min(startIndex, endIndex)
95 // assert(minIndex >= 0)
96 // To ensure we do not generate a negative index. We take the minimum of
97 // the start and end indices in order to handle reverse loops such as
98 // `affine_map<(i) -> (3 - i)>`
99 auto min =
100 builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
101 auto cmpOp = builder.createOrFold<index::CmpOp>(
102 loc, index::IndexCmpPredicate::SGE, min, zero);
103 auto msg = generateErrorMessage(
104 linalgOp, "unexpected negative result on dimension #" +
105 std::to_string(dim) + " of input/output operand #" +
106 std::to_string(opOperand.getOperandNumber()));
107 builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
108
109 // Generate:
110 // inferredDimSize = max(startIndex, endIndex) + 1
111 // actualDimSize = dim(operand)
112 // assert(inferredDimSize <= actualDimSize)
113 // To ensure that we do not index past the bounds of the operands.
114 auto max =
115 builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
116
117 auto inferredDimSize =
118 builder.createOrFold<index::AddOp>(loc, max, one);
119
120 auto actualDimSize =
121 createOrFoldDimOp(builder, loc, opOperand.get(), dim);
122
123 // Similar to the verifier, when the affine expression in the indexing
124 // map is complicated, we just check that the inferred dimension sizes
125 // are in the boundary of the operands' size. Being more precise than
126 // that is difficult.
127 auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
128 ? index::IndexCmpPredicate::EQ
129 : index::IndexCmpPredicate::SLE;
130
131 cmpOp = builder.createOrFold<index::CmpOp>(
132 loc, predicate, inferredDimSize, actualDimSize);
133 msg = generateErrorMessage(
134 linalgOp, "dimension #" + std::to_string(dim) +
135 " of input/output operand #" +
136 std::to_string(opOperand.getOperandNumber()) +
137 " is incompatible with inferred dimension size");
138 builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
139 }
140 }
141 builder.setInsertionPointAfter(ifOp);
142 }
143};
144
145template <typename... OpTs>
146void attachInterface(MLIRContext *ctx) {
147 (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
148}
149} // namespace
150} // namespace linalg
151} // namespace mlir
152
154 DialectRegistry &registry) {
155 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
156 attachInterface<
157#define GET_OP_LIST
158#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
159 >(ctx);
160
161 // Load additional dialects of which ops may get created.
162 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
163 cf::ControlFlowDialect, index::IndexDialect,
164 tensor::TensorDialect>();
165 });
166}
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
Include the generated interface declarations.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152