MLIR 23.0.0git
VectorDropLeadUnitDim.cpp
Go to the documentation of this file.
1//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <numeric>
10
16#include "mlir/IR/Builders.h"
18#include "llvm/ADT/STLExtras.h"
19
20#define DEBUG_TYPE "vector-drop-unit-dim"
21
22using namespace mlir;
23using namespace mlir::vector;
24
25// Trims leading one dimensions from `oldType` and returns the result type.
26// Returns `vector<1xT>` if `oldType` only has one element.
27static VectorType trimLeadingOneDims(VectorType oldType) {
28 ArrayRef<int64_t> oldShape = oldType.getShape();
29 ArrayRef<int64_t> newShape = oldShape;
30
31 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
32 ArrayRef<bool> newScalableDims = oldScalableDims;
33
34 while (!newShape.empty() && newShape.front() == 1 &&
35 !newScalableDims.front()) {
36 newShape = newShape.drop_front(1);
37 newScalableDims = newScalableDims.drop_front(1);
38 }
39
40 // Make sure we have at least 1 dimension per vector type requirements.
41 if (newShape.empty()) {
42 newShape = oldShape.take_back();
43 newScalableDims = oldType.getScalableDims().take_back();
44 }
45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
46}
47
48/// Return a smallVector of size `rank` containing all zeros.
50 return SmallVector<int64_t>(rank, 0);
51}
52namespace {
53
54// Casts away leading one dimensions in vector.extract_strided_slice's vector
55// input by inserting vector.broadcast.
56struct CastAwayExtractStridedSliceLeadingOneDim
57 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
58 using Base::Base;
59
60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
61 PatternRewriter &rewriter) const override {
62 // vector.extract_strided_slice requires the input and output vector to have
63 // the same rank. Here we drop leading one dimensions from the input vector
64 // type to make sure we don't cause mismatch.
65 VectorType oldSrcType = extractOp.getSourceVectorType();
66 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
67
68 if (newSrcType.getRank() == oldSrcType.getRank())
69 return failure();
70
71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
72
73 VectorType oldDstType = extractOp.getType();
74 VectorType newDstType =
75 VectorType::get(oldDstType.getShape().drop_front(dropCount),
76 oldDstType.getElementType(),
77 oldDstType.getScalableDims().drop_front(dropCount));
78
79 Location loc = extractOp.getLoc();
80
81 Value newSrcVector = vector::ExtractOp::create(
82 rewriter, loc, extractOp.getSource(), splatZero(dropCount));
83
84 // The offsets/sizes/strides attribute can have a less number of elements
85 // than the input vector's rank: it is meant for the leading dimensions.
86 auto newOffsets = rewriter.getArrayAttr(
87 extractOp.getOffsets().getValue().drop_front(dropCount));
88 auto newSizes = rewriter.getArrayAttr(
89 extractOp.getSizes().getValue().drop_front(dropCount));
90 auto newStrides = rewriter.getArrayAttr(
91 extractOp.getStrides().getValue().drop_front(dropCount));
92
93 auto newExtractOp = vector::ExtractStridedSliceOp::create(
94 rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
95 newStrides);
96
97 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
98 newExtractOp);
99
100 return success();
101 }
102};
103
104// Casts away leading one dimensions in vector.insert_strided_slice's vector
105// inputs by inserting vector.broadcast.
106struct CastAwayInsertStridedSliceLeadingOneDim
107 : public OpRewritePattern<vector::InsertStridedSliceOp> {
108 using Base::Base;
109
110 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
111 PatternRewriter &rewriter) const override {
112 VectorType oldSrcType = insertOp.getSourceVectorType();
113 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
114 VectorType oldDstType = insertOp.getDestVectorType();
115 VectorType newDstType = trimLeadingOneDims(oldDstType);
116
117 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
118 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
119 if (srcDropCount == 0 && dstDropCount == 0)
120 return failure();
121
122 // Trim leading one dimensions from both operands.
123 Location loc = insertOp.getLoc();
124
125 Value newSrcVector = vector::ExtractOp::create(
126 rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
127 Value newDstVector = vector::ExtractOp::create(
128 rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
129
130 auto newOffsets = rewriter.getArrayAttr(
131 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
132 auto newStrides = rewriter.getArrayAttr(
133 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
134
135 auto newInsertOp = vector::InsertStridedSliceOp::create(
136 rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
137 newStrides);
138
139 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
140 newInsertOp);
141
142 return success();
143 }
144};
145
146// Casts away leading one dimensions in vector.insert's vector inputs by
147// inserting vector.broadcast.
148struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
149 using Base::Base;
150
151 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
152 PatternRewriter &rewriter) const override {
153 Type oldSrcType = insertOp.getValueToStoreType();
154 Type newSrcType = oldSrcType;
155 int64_t oldSrcRank = 0, newSrcRank = 0;
156 if (auto type = dyn_cast<VectorType>(oldSrcType)) {
157 newSrcType = trimLeadingOneDims(type);
158 oldSrcRank = type.getRank();
159 newSrcRank = cast<VectorType>(newSrcType).getRank();
160 }
161
162 VectorType oldDstType = insertOp.getDestVectorType();
163 VectorType newDstType = trimLeadingOneDims(oldDstType);
164
165 int64_t srcDropCount = oldSrcRank - newSrcRank;
166 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
167 if (srcDropCount == 0 && dstDropCount == 0)
168 return failure();
169
170 // Trim leading one dimensions from both operands.
171 Location loc = insertOp.getLoc();
172
173 Value newSrcVector = insertOp.getValueToStore();
174 if (oldSrcRank != 0) {
175 newSrcVector = vector::ExtractOp::create(
176 rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
177 }
178 Value newDstVector = vector::ExtractOp::create(
179 rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
180
181 // New position rank needs to be computed in two steps: (1) if destination
182 // type has leading unit dims, we also trim the position array accordingly,
183 // then (2) if source type also has leading unit dims, we need to append
184 // zeroes to the position array accordingly.
185 unsigned oldPosRank = insertOp.getNumIndices();
186 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
187 SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
188 SmallVector<OpFoldResult> newPosition =
189 llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
190 newPosition.resize(newDstType.getRank() - newSrcRank,
191 rewriter.getI64IntegerAttr(0));
192
193 auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
194 newDstVector, newPosition);
195
196 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
197 newInsertOp);
198
199 return success();
200 }
201};
202
203static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
204 VectorType newType, AffineMap newMap,
205 VectorType oldMaskType) {
206 // Infer the type of the new mask from the new map.
207 VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
208
209 // If the new mask is broadcastable to the old result type, we can safely
210 // use a `vector.extract` to get the new mask. Otherwise the best we can
211 // do is shape cast.
212 if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
214 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
215 return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim));
216 }
217 return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
218}
219
220// Turns vector.transfer_read on vector with leading 1 dimensions into
221// vector.shape_cast followed by vector.transfer_read on vector without leading
222// 1 dimensions.
223struct CastAwayTransferReadLeadingOneDim
224 : public OpRewritePattern<vector::TransferReadOp> {
225 using Base::Base;
226
227 LogicalResult matchAndRewrite(vector::TransferReadOp read,
228 PatternRewriter &rewriter) const override {
229 // TODO(#78787): Not supported masked op yet.
230 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
231 return failure();
232 // TODO: support 0-d corner case.
233 if (read.getTransferRank() == 0)
234 return failure();
235
236 auto shapedType = cast<ShapedType>(read.getBase().getType());
237 if (shapedType.getElementType() != read.getVectorType().getElementType())
238 return failure();
239
240 VectorType oldType = read.getVectorType();
241 VectorType newType = trimLeadingOneDims(oldType);
242
243 if (newType == oldType)
244 return failure();
245
246 AffineMap oldMap = read.getPermutationMap();
247 ArrayRef<AffineExpr> newResults =
248 oldMap.getResults().take_back(newType.getRank());
249 AffineMap newMap =
250 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
251 rewriter.getContext());
252
253 ArrayAttr inBoundsAttr;
254 if (read.getInBounds())
255 inBoundsAttr = rewriter.getArrayAttr(
256 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
257
258 Value mask = Value();
259 if (read.getMask()) {
260 VectorType maskType = read.getMaskType();
261 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
262 newType, newMap, maskType);
263 }
264
265 auto newRead = vector::TransferReadOp::create(
266 rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
267 AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
268 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
269
270 return success();
271 }
272};
273
274// Turns vector.transfer_write on vector with leading 1 dimensions into
275// vector.shape_cast followed by vector.transfer_write on vector without leading
276// 1 dimensions.
277struct CastAwayTransferWriteLeadingOneDim
278 : public OpRewritePattern<vector::TransferWriteOp> {
279 using Base::Base;
280
281 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
282 PatternRewriter &rewriter) const override {
283 // TODO(#78787): Not supported masked op yet.
284 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
285 return failure();
286 // TODO: support 0-d corner case.
287 if (write.getTransferRank() == 0)
288 return failure();
289
290 auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
291 if (shapedType.getElementType() != write.getVectorType().getElementType())
292 return failure();
293
294 VectorType oldType = write.getVectorType();
295 VectorType newType = trimLeadingOneDims(oldType);
296 if (newType == oldType)
297 return failure();
298 int64_t dropDim = oldType.getRank() - newType.getRank();
299
300 AffineMap oldMap = write.getPermutationMap();
301 ArrayRef<AffineExpr> newResults =
302 oldMap.getResults().take_back(newType.getRank());
303 AffineMap newMap =
304 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
305 rewriter.getContext());
306
307 ArrayAttr inBoundsAttr;
308 if (write.getInBounds())
309 inBoundsAttr = rewriter.getArrayAttr(
310 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
311
312 auto newVector = vector::ExtractOp::create(
313 rewriter, write.getLoc(), write.getVector(), splatZero(dropDim));
314
315 if (write.getMask()) {
316 VectorType maskType = write.getMaskType();
317 Value newMask = dropUnitDimsFromMask(
318 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
319 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
320 write, newVector, write.getBase(), write.getIndices(),
321 AffineMapAttr::get(newMap), newMask, inBoundsAttr);
322 return success();
323 }
324
325 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
326 write, newVector, write.getBase(), write.getIndices(),
327 AffineMapAttr::get(newMap), inBoundsAttr);
328 return success();
329 }
330};
331
332} // namespace
333
334FailureOr<Value>
335mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
336 MaskingOpInterface maskingOp,
337 RewriterBase &rewriter) {
338 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
339 if (oldAccType == nullptr)
340 return failure();
341 if (oldAccType.getRank() < 2)
342 return failure();
343 if (oldAccType.getShape()[0] != 1)
344 return failure();
345 // currently we support only dropping one dim but the pattern can be applied
346 // greedily to drop more.
347 int64_t dropDim = 1;
348
349 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
350 SmallVector<AffineMap> newIndexingMaps;
351
352 auto oldIteratorTypes = contractOp.getIteratorTypes();
353 SmallVector<Attribute> newIteratorTypes;
354
355 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
356
357 if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
358 // only parallel type iterators can be dropped.
359 return failure();
360
361 for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
362 int64_t currDim = it.index();
363 if (currDim == dimToDrop)
364 continue;
365 newIteratorTypes.push_back(it.value());
366 }
367
368 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
369 contractOp.getAcc()};
370 SmallVector<Value> newOperands;
371 auto loc = contractOp.getLoc();
372
373 for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
374 // Check if the dim to be dropped exists as a leading dim in the operand
375 // if it does then we use vector.extract to drop it.
376 bool validExtract = false;
378 auto map = it.value();
379 int64_t orginalZeroDim = it.value().getDimPosition(0);
380 if (orginalZeroDim != dimToDrop) {
381 // There are two reasons to be in this path, 1. We need to
382 // transpose the operand to make the dim to be dropped
383 // leading. 2. The dim to be dropped does not exist and in
384 // that case we dont want to add a unit transpose but we must
385 // check all the indices to make sure this is the case.
386 bool transposeNeeded = false;
388 SmallVector<AffineExpr> transposeResults;
389
390 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
391 int64_t currDim = map.getDimPosition(i);
392 if (currDim == dimToDrop) {
393 transposeNeeded = true;
394 perm.insert(perm.begin(), i);
395 auto targetExpr = rewriter.getAffineDimExpr(currDim);
396 transposeResults.insert(transposeResults.begin(), targetExpr);
397 } else {
398 perm.push_back(i);
399 auto targetExpr = rewriter.getAffineDimExpr(currDim);
400 transposeResults.push_back(targetExpr);
401 }
402 }
403
404 // Checks if only the outer, unit dimensions (of size 1) are permuted.
405 // Such transposes do not materially effect the underlying vector and can
406 // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
407 bool transposeNonOuterUnitDims = false;
408 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
409 for (auto [index, dim] :
410 llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
411 if (dim != static_cast<int64_t>(index) &&
412 operandShape.getDimSize(index) != 1) {
413 transposeNonOuterUnitDims = true;
414 break;
415 }
416 }
417
418 // Do the transpose now if needed so that we can drop the
419 // correct dim using extract later.
420 if (transposeNeeded) {
421 map = AffineMap::get(map.getNumDims(), 0, transposeResults,
422 contractOp.getContext());
423 if (transposeNonOuterUnitDims) {
424 operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
425 loc, operands[it.index()], perm);
426 }
427 }
428 }
429 // We have taken care to have the dim to be dropped be
430 // the leading dim. If its still not leading that means it
431 // does not exist in this operand and hence we do not need
432 // an extract.
433 if (map.getDimPosition(0) == dimToDrop)
434 validExtract = true;
435
436 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
437 int64_t currDim = map.getDimPosition(i);
438 if (currDim == dimToDrop)
439 // This is the dim we are dropping.
440 continue;
441 auto targetExpr = rewriter.getAffineDimExpr(
442 currDim < dimToDrop ? currDim : currDim - 1);
443 results.push_back(targetExpr);
444 }
445 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
446 contractOp.getContext()));
447 // Extract if its a valid extraction, otherwise use the operand
448 // without extraction.
449 newOperands.push_back(validExtract
450 ? vector::ExtractOp::create(rewriter, loc,
451 operands[it.index()],
452 splatZero(dropDim))
453 : operands[it.index()]);
454 }
455
456 // Depending on whether this vector.contract is masked, the replacing Op
457 // should either be a new vector.contract Op or vector.mask Op.
458 Operation *newOp = vector::ContractionOp::create(
459 rewriter, loc, newOperands[0], newOperands[1], newOperands[2],
460 rewriter.getAffineMapArrayAttr(newIndexingMaps),
461 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
462
463 if (maskingOp) {
464 auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
465 splatZero(dropDim));
466
467 newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
468 }
469
470 return vector::BroadcastOp::create(rewriter, loc,
471 contractOp->getResultTypes()[0],
472 newOp->getResults()[0])
473 .getResult();
474}
475
476namespace {
477
478/// Turns vector.contract on vector with leading 1 dimensions into
479/// vector.extract followed by vector.contract on vector without leading
480/// 1 dimensions. Also performs transpose of lhs and rhs operands if required
481/// prior to extract.
482struct CastAwayContractionLeadingOneDim
483 : public MaskableOpRewritePattern<vector::ContractionOp> {
484 using MaskableOpRewritePattern::MaskableOpRewritePattern;
485
486 FailureOr<Value>
487 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
488 MaskingOpInterface maskingOp,
489 PatternRewriter &rewriter) const override {
490 return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
491 }
492};
493
494/// Looks at elementwise operations on vectors with at least one leading
495/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
496/// and cast aways the leading one dimensions (_plural_) and then broadcasts
497/// the results.
498///
499/// Example before:
500/// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
501/// Example after:
502/// %2 = arith.mulf %0, %1 : vector<4x1xf32>
503/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
504///
505/// Does support scalable vectors.
506class CastAwayElementwiseLeadingOneDim : public RewritePattern {
507public:
508 CastAwayElementwiseLeadingOneDim(MLIRContext *context,
509 PatternBenefit benefit = 1)
510 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
511
512 LogicalResult matchAndRewrite(Operation *op,
513 PatternRewriter &rewriter) const override {
515 return failure();
516 auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
517 if (!vecType)
518 return failure();
519 VectorType newVecType = trimLeadingOneDims(vecType);
520 if (newVecType == vecType)
521 return failure();
522 int64_t dropDim = vecType.getRank() - newVecType.getRank();
523 SmallVector<Value, 4> newOperands;
524 for (Value operand : op->getOperands()) {
525 if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
526 newOperands.push_back(vector::ExtractOp::create(
527 rewriter, op->getLoc(), operand, splatZero(dropDim)));
528 } else {
529 newOperands.push_back(operand);
530 }
531 }
532 Operation *newOp =
533 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
534 newOperands, newVecType, op->getAttrs());
535 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
536 newOp->getResult(0));
537 return success();
538 }
539};
540
541// Drops leading 1 dimensions from vector.constant_mask and inserts a
542// vector.broadcast back to the original shape.
543struct CastAwayConstantMaskLeadingOneDim
544 : public OpRewritePattern<vector::ConstantMaskOp> {
545 using Base::Base;
546
547 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
548 PatternRewriter &rewriter) const override {
549 VectorType oldType = mask.getType();
550 VectorType newType = trimLeadingOneDims(oldType);
551
552 if (newType == oldType)
553 return failure();
554
555 int64_t dropDim = oldType.getRank() - newType.getRank();
556 ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
557
558 // If any of the dropped unit dims has a size of `0`, the entire mask is a
559 // zero mask, else the unit dim has no effect on the mask.
560 int64_t flatLeadingSize =
561 llvm::product_of(dimSizes.take_front(dropDim + 1));
562 SmallVector<int64_t> newDimSizes = {flatLeadingSize};
563 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
564
565 auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
566 newType, newDimSizes);
567 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
568 return success();
569 }
570};
571
572} // namespace
573
574void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
575 RewritePatternSet &patterns, PatternBenefit benefit) {
576 patterns
577 .add<CastAwayExtractStridedSliceLeadingOneDim,
578 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
579 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
580 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
581 CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
582}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.