MLIR  21.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 /// Generate a runtime check for lb <= value < ub.
25 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
26  Value lb, Value ub) {
27  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
28  loc, arith::CmpIPredicate::sge, value, lb);
29  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
30  loc, arith::CmpIPredicate::slt, value, ub);
31  Value inBounds =
32  builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
33  return inBounds;
34 }
35 
36 struct CastOpInterface
37  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
38  CastOp> {
39  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
40  Location loc) const {
41  auto castOp = cast<CastOp>(op);
42  auto srcType = cast<TensorType>(castOp.getSource().getType());
43 
44  // Nothing to check if the result is an unranked tensor.
45  auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
46  if (!resultType)
47  return;
48 
49  if (isa<UnrankedTensorType>(srcType)) {
50  // Check rank.
51  Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
52  Value resultRank =
53  builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
54  Value isSameRank = builder.create<arith::CmpIOp>(
55  loc, arith::CmpIPredicate::eq, srcRank, resultRank);
56  builder.create<cf::AssertOp>(
57  loc, isSameRank,
58  RuntimeVerifiableOpInterface::generateErrorMessage(op,
59  "rank mismatch"));
60  }
61 
62  // Check dimension sizes.
63  for (const auto &it : llvm::enumerate(resultType.getShape())) {
64  // Static dim size -> static/dynamic dim size does not need verification.
65  if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
66  if (!rankedSrcType.isDynamicDim(it.index()))
67  continue;
68 
69  // Static/dynamic dim size -> dynamic dim size does not need verification.
70  if (resultType.isDynamicDim(it.index()))
71  continue;
72 
73  Value srcDimSz =
74  builder.create<DimOp>(loc, castOp.getSource(), it.index());
75  Value resultDimSz =
76  builder.create<arith::ConstantIndexOp>(loc, it.value());
77  Value isSameSz = builder.create<arith::CmpIOp>(
78  loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
79  builder.create<cf::AssertOp>(
80  loc, isSameSz,
81  RuntimeVerifiableOpInterface::generateErrorMessage(
82  op, "size mismatch of dim " + std::to_string(it.index())));
83  }
84  }
85 };
86 
87 struct DimOpInterface
88  : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
89  DimOp> {
90  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
91  Location loc) const {
92  auto dimOp = cast<DimOp>(op);
93  Value rank = builder.create<RankOp>(loc, dimOp.getSource());
94  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
95  builder.create<cf::AssertOp>(
96  loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
97  RuntimeVerifiableOpInterface::generateErrorMessage(
98  op, "index is out of bounds"));
99  }
100 };
101 
102 /// Verifies that the indices on extract/insert ops are in-bounds of the
103 /// tensor's index space: 0 <= index#i < dim#i
104 template <typename OpTy>
105 struct ExtractInsertOpInterface
106  : public RuntimeVerifiableOpInterface::ExternalModel<
107  ExtractInsertOpInterface<OpTy>, OpTy> {
108  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
109  Location loc) const {
110  auto extractInsertOp = cast<OpTy>(op);
111 
112  Value tensor;
113  if constexpr (std::is_same_v<OpTy, ExtractOp>) {
114  tensor = extractInsertOp.getTensor();
115  } else if constexpr (std::is_same_v<OpTy, InsertOp>) {
116  tensor = extractInsertOp.getDest();
117  } else {
118  llvm_unreachable("invalid op");
119  }
120  auto tensorType = cast<RankedTensorType>(tensor.getType());
121  auto rank = tensorType.getRank();
122  if (rank == 0) {
123  // Nothing to check for 0-d tensors.
124  return;
125  }
126 
127  auto indices = extractInsertOp.getIndices();
128  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
129  Value assertCond;
130  for (auto i : llvm::seq<int64_t>(0, rank)) {
131  Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
132  Value inBounds =
133  generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
134  assertCond =
135  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
136  : inBounds;
137  }
138  builder.create<cf::AssertOp>(
139  loc, assertCond,
140  RuntimeVerifiableOpInterface::generateErrorMessage(
141  op, "out-of-bounds access"));
142  }
143 };
144 
145 struct ExtractSliceOpInterface
146  : public RuntimeVerifiableOpInterface::ExternalModel<
147  ExtractSliceOpInterface, ExtractSliceOp> {
148  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
149  Location loc) const {
150  auto extractSliceOp = cast<ExtractSliceOp>(op);
151  RankedTensorType sourceType = extractSliceOp.getSource().getType();
152 
153  // For each dimension, assert that:
154  // 0 <= offset < dim_size
155  // 0 <= offset + (size - 1) * stride < dim_size
156  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
157  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
158  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
160  builder, loc, extractSliceOp.getMixedOffsets()[i]);
162  builder, loc, extractSliceOp.getMixedSizes()[i]);
164  builder, loc, extractSliceOp.getMixedStrides()[i]);
165 
166  // Verify that offset is in-bounds.
167  Value dimSize = builder.createOrFold<tensor::DimOp>(
168  loc, extractSliceOp.getSource(), i);
169  Value offsetInBounds =
170  generateInBoundsCheck(builder, loc, offset, zero, dimSize);
171  builder.create<cf::AssertOp>(
172  loc, offsetInBounds,
173  RuntimeVerifiableOpInterface::generateErrorMessage(
174  op, "offset " + std::to_string(i) + " is out-of-bounds"));
175 
176  // Verify that slice does not run out-of-bounds.
177  Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
178  Value sizeMinusOneTimesStride =
179  builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
180  Value lastPos =
181  builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
182  Value lastPosInBounds =
183  generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
184  builder.create<cf::AssertOp>(
185  loc, lastPosInBounds,
186  RuntimeVerifiableOpInterface::generateErrorMessage(
187  op, "extract_slice runs out-of-bounds along dimension " +
188  std::to_string(i)));
189  }
190  }
191 };
192 } // namespace
193 } // namespace tensor
194 } // namespace mlir
195 
197  DialectRegistry &registry) {
198  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
199  CastOp::attachInterface<CastOpInterface>(*ctx);
200  DimOp::attachInterface<DimOpInterface>(*ctx);
201  ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
202  ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
203  InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
204 
205  // Load additional dialects of which ops may get created.
206  ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
207  });
208 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112