MLIR 23.0.0git
VectorContractToAMXDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
17#include "mlir/IR/Dominance.h"
19#include "llvm/Support/Casting.h"
20
21#include "mlir/Pass/Pass.h"
23
24using namespace mlir;
25using namespace mlir::vector;
26using namespace mlir::x86;
27
28namespace {
29
30// Recursively follows single-use values through scf.yield operations
31// and returns the first non-yield user result in the contraction chain.
32static Value contractionUsersAfterYield(Value v) {
33 if (v.getNumUses() != 1)
34 return nullptr;
35
36 OpOperand &use = *v.use_begin();
37 Operation *user = use.getOwner();
38
39 if (!isa<scf::YieldOp>(user))
40 return v;
41
42 auto yield = cast<scf::YieldOp>(user);
43 Operation *parent = yield->getParentOp();
44 unsigned idx = use.getOperandNumber();
45
46 return contractionUsersAfterYield(parent->getResult(idx));
47}
48
49// Function to collapse the last two dimension (vnni and k) to help the
50// amx.tile_load to correctly load the packed element type.
52 Value input) {
53 ShapedType inputType = cast<ShapedType>(input.getType());
54 int64_t firstDimToCollapse = inputType.getRank() - 2;
55
56 if (inputType.getRank() == 1)
57 return input;
58
60 for (int64_t i = 0; i < firstDimToCollapse; ++i)
61 reassociation.push_back(ReassociationIndices{i});
62
63 ReassociationIndices collapsedIndices;
64 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
65 collapsedIndices.push_back(i);
66
67 reassociation.push_back(collapsedIndices);
68 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
69}
70
71// Get the MemRef source and offset index for the operands of
72// vector.contract.
73static FailureOr<std::pair<Value, SmallVector<Value>>>
74getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
75 bool isNotAcc) {
76 Operation *defOp = operand.getDefiningOp();
77 if (!defOp)
78 return failure();
79
80 Value srcBuff;
83 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
84 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
85 readOp.getIndices().end());
86 srcBuff = readOp.getOperand(0);
87 });
88
89 if (!srcBuff)
90 return failure();
91
92 if (isNotAcc)
93 indexVals.pop_back();
94
96 indices.reserve(indexVals.size());
97
98 for (OpFoldResult ofr : indexVals) {
99 indices.push_back(
100 mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101 }
102
103 if (isNotAcc) {
104 srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
105 }
106
107 return std::make_pair(srcBuff, indices);
108}
109
110// Function to validate the loop step value.
111static LogicalResult validateLoopStep(OpBuilder &rewriter, Value step,
112 int64_t value) {
113
114 auto cst = step.getDefiningOp<arith::ConstantIndexOp>();
115 if (!cst)
116 return failure();
117
118 if (cst.value() != value && cst.value() != 1)
119 return failure();
120
121 return success();
122}
123
124// Function to validate the vector.contract operation.
125static LogicalResult validateContractOps(OpBuilder &rewriter,
126 vector::ContractionOp contractOp,
127 unsigned int blockingFactor,
128 Value srcBuffLhs, Value srcBuffRhs,
129 bool srcValidate) {
130
131 if (srcValidate) {
132 // Get the MemRef buffer of LHS operand.
133 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
134 contractOp.getLhs(), false);
135 if (failed(srcIndxLhs))
136 return failure();
137 auto [buffLhs, indicesLhs] = *srcIndxLhs;
138
139 // Get the MemRef buffer of RHS operand.
140 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
141 contractOp.getRhs(), false);
142 if (failed(srcIndxRhs))
143 return failure();
144 auto [buffRhs, indicesRhs] = *srcIndxRhs;
145
146 // Return failure if the Memref buff didn't match.
147 if (buffLhs != srcBuffLhs)
148 return failure();
149
150 if (buffRhs != srcBuffRhs)
151 return failure();
152 }
153
154 if (!contractionUsersAfterYield(contractOp.getResult()))
155 return failure();
156
157 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
158 if (!accTy)
159 return failure();
160
161 // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
162 ArrayRef<int64_t> accShape = accTy.getShape();
163 llvm::SmallVector<int64_t> nonUnitDimAcc;
164 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
165 [](int64_t dim) { return (dim != 16 && dim != 1); });
166
167 if (nonUnitDimAcc.size() != 0)
168 return failure();
169
170 // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
171 // <16x16x4>. The vnni dims should be 2 or 4.
172 VectorType lhsTy = contractOp.getLhsType();
173 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
174 llvm::SmallVector<int64_t> nonUnitDimLhs;
175 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
176 [](int64_t dim) { return (dim != 16 && dim != 1); });
177
178 if (nonUnitDimLhs.size() != 1)
179 return failure();
180
181 if (nonUnitDimLhs[0] != blockingFactor)
182 return failure();
183
184 // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
185 // <16x16x4>. The vnni dims should be 2 or 4.
186 VectorType rhsTy = contractOp.getRhsType();
187 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
188 llvm::SmallVector<int64_t> nonUnitDimRhs;
189 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
190 [](int64_t dim) { return (dim != 16 && dim != 1); });
191
192 if (nonUnitDimRhs.size() != 1)
193 return failure();
194
195 if (nonUnitDimRhs[0] != blockingFactor)
196 return failure();
197
198 return success();
199}
200
201// Returns the loop index position to get mapped during the
202// MemRef type clone.
203static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
204 Value iv = loop.getInductionVar();
205
206 Value srcBuff;
208 .Case<TransferReadOp, LoadOp>(
209 [&](auto readOp) { srcBuff = readOp.getOperand(0); });
210
211 auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
212 if (!subview)
213 return 0;
214
215 auto offsets = subview.getOffsets();
216
217 for (auto it : llvm::enumerate(offsets)) {
218 if (it.value() == iv)
219 return it.index();
220 }
221
222 return 0;
223}
224
225// Creates amx.tile_loads.
226static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
227 Value operand, Value mat, Type ipType,
228 bool rhs, unsigned int offset,
229 bool isVnni) {
230
231 auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
232 auto [srcBuff, indices] = *srcIndx;
233 if (isVnni) {
234 indices.pop_back();
235 }
236
237 if (rhs && isVnni) {
238 auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
239 indices[indices.size() - 1] = arith::MulIOp::create(
240 rewriter, loc, indices[indices.size() - 1], cOffset);
241 }
242
243 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
244 return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
245}
246
247static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
248 Type ipType, unsigned int offset, Value packedBuffer,
249 Value indxToStoreInBuffer) {
250
251 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
252 Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
253 SmallVector<Value> subviewOffset(
254 llvm::cast<MemRefType>(matB.getType()).getRank(), c0);
255
256 Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
257 Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
258 Value offsetIndx =
259 arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
260
261 scf::ForOp::create(
262 rewriter, loc, c0, cBound, cStep, ValueRange{},
263 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
264 ValueRange iterArgs) {
265 subviewOffset[subviewOffset.size() - 2] = iv;
266
267 // Retrieve two rows of vector (32) for int8 and f8 type. For bf16,
268 // retrieve one row of vector (32).
269 auto vectorType = VectorType::get({2, (16 * (offset / 2))}, ipType);
270 if (ipType.isBF16())
271 vectorType = VectorType::get((16 * offset), ipType);
272
273 int64_t srcRank = (dyn_cast<ShapedType>(matB.getType())).getRank();
274 Value padding = ub::PoisonOp::create(rewriter, loc, ipType);
275 auto map = AffineMap::getMinorIdentityMap(srcRank, vectorType.getRank(),
276 rewriter.getContext());
277 SmallVector<bool> inBounds(vectorType.getRank(), true);
278 Value vec1 = vector::TransferReadOp::create(
279 rewriter, loc, vectorType, matB, ValueRange(subviewOffset), padding,
280 map, inBounds);
281
282 if (!ipType.isBF16())
283 vec1 = vector::ShapeCastOp::create(
284 rewriter, loc, VectorType::get((16 * offset), ipType), vec1);
285
286 // Increment the iv by 1 or 2 based on the type to load the next 32/64
287 // elements
288 Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
289 subviewOffset[subviewOffset.size() - 2] = incIV;
290
291 Value vec2 = vector::TransferReadOp::create(
292 rewriter, loc, vectorType, matB, ValueRange(subviewOffset), padding,
293 map, inBounds);
294 if (!ipType.isBF16())
295 vec2 = vector::ShapeCastOp::create(
296 rewriter, loc, VectorType::get((16 * offset), ipType), vec2);
297
298 vector::ShuffleOp shuffle1;
299 vector::ShuffleOp shuffle2;
300
301 if (ipType.isBF16()) {
302
303 shuffle1 = vector::ShuffleOp::create(
304 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
305 vec2,
306 ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
307 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
308 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
309
310 shuffle2 = vector::ShuffleOp::create(
311 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
312 vec2,
313 ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
314 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
315 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
316 }
317
318 if (ipType.isSignlessInteger(8) || ipType.isF8E5M2() ||
319 ipType.isF8E4M3FN()) {
320
321 shuffle1 = vector::ShuffleOp::create(
322 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
323 vec2,
325 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
326 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42,
327 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81,
328 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120,
329 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123});
330
331 shuffle2 = vector::ShuffleOp::create(
332 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
333 vec2,
335 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39,
336 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110,
337 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54,
338 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125,
339 30, 62, 94, 126, 31, 63, 95, 127});
340 }
341
342 // iv to store the shuffled elements
343 Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
344 Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
345
346 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
347 ValueRange{indxToStoreInBuffer, ivShuff1, c0});
348 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
349 ValueRange{indxToStoreInBuffer, ivShuff2, c0});
350
351 scf::YieldOp::create(nestedBuilder, loc);
352 });
353}
354
356packInputs(OpBuilder &rewriter, Location loc,
358 unsigned int offset, Value packedBuffer, bool pack,
359 Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
360
362 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
363 Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
364
365 for (size_t j = 0; j < ops.size(); j++) {
366 for (size_t i = 0; i < ops.size(); i++) {
367
368 if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
369
370 Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
371 auto itRhs = readsToTileLoads.find(readOpRhs);
372 if (itRhs != readsToTileLoads.end()) {
373 continue;
374 }
375
376 if (pack) {
377 performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
378 indxToStoreInBuffer);
379 }
380
381 amx::TileType tileType =
382 amx::TileType::get({16, (16 * offset)}, ipType);
383 auto loadRow1 =
384 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
385 ValueRange{indxToLoadFromMatB, c0, c0});
386
387 auto loadRow2 =
388 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
389 ValueRange{indxToLoadFromMatB, c16, c0});
390
391 readsToTileLoads.try_emplace(readOpRhs, loadRow1);
392 readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
393 }
394 }
395 }
396
397 return readsToTileLoads;
398}
399
400// Creates tiled amx dot-products.
402createTiledDp(OpBuilder &rewriter, Location loc,
404 Type ipType, Type opType, ValueRange accIterArgs,
405 unsigned int offset, bool isVnni, Value packedBuffer, bool pack,
406 Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
407
408 if (isVnni) {
409 matA = collapseInnerDims(rewriter, loc, matA);
410 matB = collapseInnerDims(rewriter, loc, matB);
411 }
412
413 SmallVector<Value> accumulators;
414 // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
415 // or load operations.
417
418 // function call to online pack the input B matrix
419 if (!isVnni) {
420 readsToTileLoads =
421 packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
422 indxToStoreInBuffer, indxToLoadFromMatB);
423 }
424
425 // Iterate over the contraction operations and compute the tiled dot-product.
426 for (size_t i = 0; i < ops.size(); i++) {
427
428 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
429 amx::TileLoadOp tilesLhs;
430 auto itLhs = readsToTileLoads.find(readOpLhs);
431 if (itLhs != readsToTileLoads.end()) {
432 tilesLhs = itLhs->second;
433 } else {
434 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
435 false, offset, isVnni);
436 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
437 }
438
439 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
440 amx::TileLoadOp tilesRhs;
441 auto itRhs = readsToTileLoads.find(readOpRhs);
442 if (itRhs != readsToTileLoads.end()) {
443 tilesRhs = itRhs->second;
444 } else {
445 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
446 true, offset, isVnni);
447 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
448 }
449
450 auto accTileType = amx::TileType::get({16, 16}, opType);
451
452 Value dp;
453 if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
454 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
455 tilesRhs, accIterArgs[i]);
456
457 if (ipType.isSignlessInteger(8))
458 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
459 tilesRhs, accIterArgs[i]);
460
461 accumulators.push_back(dp);
462 }
463 return accumulators;
464}
465
466static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
467 Type opType, scf::ForOp outerLoop,
468 int64_t size) {
469 rewriter.setInsertionPoint(outerLoop);
470
471 SmallVector<Value> loopItrArgs;
472 auto zeroTileType = amx::TileType::get({16, 16}, opType);
473
474 for (int i = 0; i < size; i++) {
475 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
476 loopItrArgs.push_back(zeroTile);
477 }
478 return loopItrArgs;
479}
480
481static Value getIndxToLoadStoreFromPckBuffer(
482 OpBuilder &rewriter, Location loc, Value ivInnerLoop, Value ivOuterLoop,
483 bool isInnerLoopUBHasOddQuot, bool isInnerLoopUBLarger, bool pack,
484 unsigned int blockingFactor) {
485
486 Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
487 Value packOffset =
488 arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
489
490 Value quotientInnerLoop =
491 arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
492 Value remInnerLoop = arith::RemUIOp::create(
493 rewriter, loc, rewriter.getIndexType(), quotientInnerLoop, c2);
494
495 if (!isInnerLoopUBLarger && !pack) {
496 remInnerLoop = arith::RemUIOp::create(
497 rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
498 }
499
500 if (isInnerLoopUBHasOddQuot) {
501 auto remOuterLoop = arith::RemUIOp::create(
502 rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
503 auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
504 remInnerLoop, remOuterLoop);
505 remInnerLoop = arith::RemUIOp::create(rewriter, loc,
506 rewriter.getIndexType(), remAdd, c2);
507 }
508
509 return remInnerLoop;
510}
511
512static scf::ForOp
513createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
514 Value upperBound, Value step, SmallVector<Value> loopItrArgs,
515 Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
516 Operation *vectorOpLhs, Operation *vectorOpRhs,
517 vector::ContractionOp contractOp, scf::ForOp outerLoop,
518 scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
519 Value ivOuterLoop, Value packedBuffer, bool pack,
520 arith::ConstantIndexOp innerLoopIndex, bool isInnerLoopUBLarger,
521 bool isInnerLoopUBHasOddQuot) {
522
523 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
524 Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
525 Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
526
527 int64_t offset = 16 * blockingFactor;
528 if (auto cst = step.getDefiningOp<arith::ConstantIndexOp>())
529 offset = cst.value();
530
531 auto newLoop = scf::ForOp::create(
532 rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
533 [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
534 Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
535 IRMapping mapping;
536 if (outerLoop)
537 mapping.map(vectorOpLhs->getOperand(
538 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
539 ivOuterLoop);
540
541 mapping.map(vectorOpLhs->getOperand(
542 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
543 ivNewInnerLoop);
544 auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
545
546 Value indxToStoreInBuffer = c0;
547 Value indxToLoadFromBuffer = c0;
548 if (!isVnni) {
549 if (outerLoop) {
550 if (innerLoopIndex.value() == 0) {
551 if (pack) {
552 ivNewInnerLoop = c0;
553 ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
554 c1, ivOuterLoop);
555
556 if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
557 indxToStoreInBuffer = arith::RemUIOp::create(
558 rewriter, locNewInnerLoop, rewriter.getIndexType(),
559 ivOuterLoop, c2);
560 }
561
562 Value indxToLoadFromMatB = arith::AddIOp::create(
563 rewriter, loc, indxToStoreInBuffer, c1);
564 indxToLoadFromBuffer = arith::RemUIOp::create(
565 rewriter, loc, rewriter.getIndexType(), indxToLoadFromMatB,
566 c2);
567 }
568
569 } else {
571 rewriter, locNewInnerLoop, offset);
572 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
573 nLoadIndx, ivNewInnerLoop);
574 indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
575 rewriter, loc, ivNewInnerLoop, ivOuterLoop,
576 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
577 blockingFactor);
578 Value indxToLoadFromMatB =
579 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
580 indxToLoadFromBuffer =
581 arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
582 indxToLoadFromMatB, c2);
583 }
584 } else {
585 if (pack) {
587 rewriter, locNewInnerLoop, offset);
588 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
589 nLoadIndx, ivNewInnerLoop);
590 Value quotient_K = arith::DivUIOp::create(
591 rewriter, loc, ivNewInnerLoop, nLoadIndx);
592 indxToStoreInBuffer = arith::RemUIOp::create(
593 rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
594
595 Value indxToLoadFromMatB =
596 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
597 indxToLoadFromBuffer =
598 arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
599 indxToLoadFromMatB, c2);
600 }
601 }
602 }
603 IRMapping rhsMapping;
604
605 Value matB;
606 Operation *rhsOp = vectorOpRhs;
607
608 // Clone for the subview type operations
609 if (rhsOp->getNumOperands() > 0) {
610
611 if (outerLoop) {
612 int64_t outerPos = getIndexPosition(contractOp.getRhs(), outerLoop);
613
614 if (outerPos >= 0) {
615 unsigned operandIdx = static_cast<unsigned>(outerPos + 1);
616
617 if (operandIdx < rhsOp->getNumOperands())
618 rhsMapping.map(rhsOp->getOperand(operandIdx), ivOuterLoop);
619 }
620 }
621
622 int64_t innerPos = getIndexPosition(contractOp.getRhs(), innerLoop);
623
624 if (innerPos >= 0) {
625 unsigned operandIdx = static_cast<unsigned>(innerPos + 1);
626
627 if (operandIdx < rhsOp->getNumOperands())
628 rhsMapping.map(rhsOp->getOperand(operandIdx), ivNewInnerLoop);
629 }
630
631 auto rhsClone = rewriterNewInnerLoop.clone(*rhsOp, rhsMapping);
632 matB = rhsClone->getResult(0);
633
634 } else {
635 // The mat B is of kind 'memref.get_global @__constant'
636 matB = rhsOp->getResult(0);
637 }
638
639 if (!isVnni) {
640 if (outerLoop) {
641 if (!pack) {
643 rewriter, locNewInnerLoop, offset);
644 matB = Value();
645 indxToLoadFromBuffer = c0;
646 indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
647 rewriter, loc, nLoadIndx, ivOuterLoop,
648 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
649 blockingFactor);
650 }
651 } else {
652 if (!pack) {
654 rewriter, locNewInnerLoop, offset);
655 matB = Value();
656 Value quotient_K = arith::DivUIOp::create(
657 rewriter, loc, ivNewInnerLoop, nLoadIndx);
658 indxToLoadFromBuffer = arith::RemUIOp::create(
659 rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
660 }
661 }
662 }
663 // compute tiled dot-product
664 SmallVector<Value> accumulators = createTiledDp(
665 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
666 ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
667 packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
668
669 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
670 accumulators);
671 });
672
673 return newLoop;
674}
675
676// Implements tiled dot-product operation for a vector.contract operation or a
677// sequence of vector.contracts inside the reduction loops.
678//
679// For example:
680// Case 1: register blocked vector.contract with prepacked input
681// ```
682// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
683// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
684// vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
685// vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
686// ```
687// to
688// ```
689// amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
690// amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
691// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
692// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
693// ```
694//
695//
696// Case2: vector.contract with register blocked
697//
698// Output IR with online packing (with s/w pipeline advantage):
699// s/w pipeline: load, pack to VNNI, and store the B sub matrix
700// of the 0th batch-reduce and K iteration.
701// scf.for (0 to 31) {
702// - load 0th and 1st vector<32xbf16>, pack into VNNI, store the
703// first shuffle in 0th and 2nd shuffle in 16th index of the
704// buffer.
705// }
706// scf.for (0 to br-2) { batch-reduce loop
707// scf.for (0 to k-2) { K loop
708// - load A matrix
709// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
710// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
711// iteration from the buffer (d) compute the tiled dot-product
712// }
713// Last iteration of the the K Loop (k-1) {
714// - load A matrix
715// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
716// matrix for the next batch-reduce + K loop iteration (c) load VNNI pack B
717// matrix of K iteration from the buffer (d) compute the tiled dot-product
718// }
719// }
720// Last iteration of the batch-reduce loop (br-1) {
721// scf.for (0 to k-2) { K loop
722// - load A matrix
723// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
724// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
725// iteration from the buffer (d) compute the tiled dot-product
726// }
727// Last iteration of the the K Loop (k-1) {
728// - load A matrix
729// - load VNNI pack B matrix of K iteration from the buffer
730// - compute the tiled dot-product
731// }
732// }
733//
734// scf.for (0 to M)
735// scf.for (0 to N)
736// - Load the ith and i+1th acc
737// - Shuffle them as we packed using vpunpack
738// - Load C matrix and do arith.add with the shuffle
739// - Store back into C matrix
740struct VectorContractToAMXDotProduct
741 : public OpRewritePattern<vector::ContractionOp> {
742 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
743
744 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
745 PatternRewriter &rewriter) const override {
746
747 if (contractOp.getKind() != vector::CombiningKind::ADD)
748 return rewriter.notifyMatchFailure(contractOp,
749 "Expects add combining kind.");
750
751 unsigned int blockingFactor =
752 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
753 bool isVnni =
754 isInVnniLayout(contractOp.getOperation(),
755 contractOp.getIndexingMapsArray(), blockingFactor);
756
757 VectorType lhsTy = contractOp.getLhsType();
758 if (!lhsTy.getElementType().isBF16() &&
759 !lhsTy.getElementType().isSignlessInteger(8) &&
760 !lhsTy.getElementType().isF8E4M3FN() &&
761 !lhsTy.getElementType().isF8E5M2())
762 return rewriter.notifyMatchFailure(
763 contractOp, "Only BF16/Int8/F8 lowering is supported.");
764
765 if (lhsTy.getElementType() != contractOp.getRhsType().getElementType())
766 return rewriter.notifyMatchFailure(
767 contractOp, "Contraction should have same lhs and rhs type.");
768
769 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
770 if (!accTy)
771 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
772
773 if (((lhsTy.getElementType().isBF16() ||
774 lhsTy.getElementType().isF8E4M3FN() ||
775 lhsTy.getElementType().isF8E5M2()) &&
776 !accTy.getElementType().isF32()) ||
777 (lhsTy.getElementType().isSignlessInteger(8) &&
778 !accTy.getElementType().isSignlessInteger(32)))
779 return rewriter.notifyMatchFailure(contractOp,
780 "Only F32 for BF16 or Int32 for Int8 "
781 "accumulation type is supported.");
782
783 Operation *accReadOp =
784 traceToVectorReadLikeParentOperation(contractOp.getAcc());
785
786 Operation *resultWriteOp =
787 traceToVectorWriteLikeUserOperation(contractOp.getResult());
788
789 if (!accReadOp || !resultWriteOp)
790 return rewriter.notifyMatchFailure(
791 contractOp, "The ACC operand of the vector.contract should be a "
792 "transfer_read or a load. And, the result should be "
793 "stored using transfer_write or store.");
794
795 Type ipType = rewriter.getBF16Type();
796 Type opType = rewriter.getF32Type();
797
798 if (lhsTy.getElementType().isSignlessInteger(8)) {
799 ipType = rewriter.getIntegerType(8);
800 opType = rewriter.getIntegerType(32);
801 }
802
803 if (lhsTy.getElementType().isF8E4M3FN())
804 ipType = rewriter.getF8E4M3FNType();
805
806 if (lhsTy.getElementType().isF8E5M2())
807 ipType = rewriter.getF8E5M2Type();
808
809 if (accReadOp->getBlock() == contractOp->getBlock() &&
810 resultWriteOp->getBlock() != contractOp->getBlock())
811 return rewriter.notifyMatchFailure(
812 contractOp, "The accumulator store is in different block.");
813
814 if (accReadOp->getBlock() != contractOp->getBlock() &&
815 resultWriteOp->getBlock() == contractOp->getBlock())
816 return rewriter.notifyMatchFailure(
817 contractOp, "The accumulator read is in different block.");
818
819 unsigned int dimValue = blockingFactor;
820 if (!isVnni)
821 dimValue = 16 * blockingFactor;
822
823 // Case 1: For just one VC rewrite. Where all accumulator read/write
824 // within the same block.
825 if (accReadOp->getBlock() == contractOp->getBlock() &&
826 resultWriteOp->getBlock() == contractOp->getBlock()) {
827
828 bool collapse = false;
829 if (isVnni)
830 collapse = true;
831
832 LogicalResult validate = validateContractOps(
833 rewriter, contractOp, dimValue, Value(), Value(), false);
834
835 if (failed(validate))
836 return rewriter.notifyMatchFailure(
837 contractOp, "The contract operation doesn't satisfy the operands "
838 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
839 "The rest dims should be 1. Op should have one user.");
840
841 Location loc = contractOp.getLoc();
842
843 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
844 contractOp.getLhs(), collapse);
845 if (failed(srcIndxLhs))
846 return rewriter.notifyMatchFailure(contractOp,
847 "The LHS src is not a MemRef type.");
848 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
849
850 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
851 contractOp.getRhs(), collapse);
852 if (failed(srcIndxRhs))
853 return rewriter.notifyMatchFailure(contractOp,
854 "The RHS src is not a MemRef type.");
855 auto rhsSrc = *srcIndxRhs;
856 auto srcBuffRhs = rhsSrc.first;
857 auto indicesRhs = rhsSrc.second;
858
859 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
860 contractOp.getAcc(), false);
861 if (failed(srcIndxAcc))
862 return rewriter.notifyMatchFailure(contractOp,
863 "The ACC src is not a MemRef type.");
864 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
865
866 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
867
868 // amx.tile_loads
869 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
870 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
871 srcBuffLhs, indicesLhs);
872
873 // Create the subview and then load.
874 amx::TileLoadOp loadRhs;
875 if (!isVnni) {
876 VectorType vecTy;
877 SmallVector<OpFoldResult> indexVals;
878 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
879 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
880 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
881 readOp.getIndices().end());
882 vecTy = readOp.getType();
883 });
884 auto one = rewriter.getIndexAttr(1);
885 SmallVector<OpFoldResult> strides(indexVals.size(), one);
886 SmallVector<OpFoldResult> sizes = getAsIndexOpFoldResult(
887 contractOp.getRhs().getDefiningOp()->getContext(),
888 vecTy.getShape());
889 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
890 indexVals, sizes, strides);
891 auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
892 auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
893
894 // create a loop that does online packing.
895 Value step =
896 arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
897 Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
898 (blockingFactor * 16));
899 Value nextLoadIndx =
900 arith::ConstantIndexOp::create(rewriter, loc, (blockingFactor / 2));
901 Value nextStoreIndx = arith::ConstantIndexOp::create(
902 rewriter, loc, 16 * (blockingFactor / 2));
903
904 scf::ForOp::create(
905 rewriter, loc, c0, uBound, step, ValueRange{},
906 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
907 ValueRange iterArgs) {
908 Value i1_load =
909 arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
910
911 indicesRhs[indicesRhs.size() - 2] = iv;
912 indicesRhs[indicesRhs.size() - 1] = c0;
913 ValueRange range1(indicesRhs);
914 auto vec1 = vector::LoadOp::create(
915 rewriter, loc,
916 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
917 range1);
918
919 indicesRhs[indicesRhs.size() - 2] = i1_load;
920 ValueRange range2(indicesRhs);
921 auto vec2 = vector::LoadOp::create(
922 rewriter, loc,
923 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
924 range2);
925
926 vector::ShuffleOp shuffle1;
927 vector::ShuffleOp shuffle2;
928
929 if (blockingFactor == 2) {
930
931 shuffle1 = vector::ShuffleOp::create(
932 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
933 ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
934 6, 22, 7, 23});
935
936 shuffle2 = vector::ShuffleOp::create(
937 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
938 ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
939 29, 14, 30, 15, 31});
940 }
941
942 if (blockingFactor == 4) {
943 shuffle1 = vector::ShuffleOp::create(
944 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
945 ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
946 2, 18, 34, 50, 3, 19, 35, 51,
947 4, 20, 36, 52, 5, 21, 37, 53,
948 6, 22, 38, 54, 7, 23, 39, 55});
949
950 shuffle2 = vector::ShuffleOp::create(
951 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
952 ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
953 10, 26, 42, 58, 11, 27, 43, 59,
954 12, 28, 44, 60, 13, 29, 45, 61,
955 14, 30, 46, 62, 15, 31, 47, 63});
956 }
957
958 auto rem = arith::DivUIOp::create(
959 rewriter, loc, rewriter.getIndexType(), iv, step);
960
961 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
962 ValueRange{rem, c0});
963 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
964 ValueRange{rem, nextStoreIndx});
965
966 scf::YieldOp::create(nestedBuilder, loc);
967 });
968 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
969 ValueRange{c0, c0});
970 } else {
971
972 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
973 indicesRhs);
974 }
975
976 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
977 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
978 srcBuffAcc, indicesAcc);
979
980 // Tiled dot-product.
981 Value dp;
982 if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
983 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
984 loadRhs, loadAcc);
985
986 if (ipType.isSignlessInteger(8))
987 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
988 loadRhs, loadAcc);
989
990 auto bufferType = MemRefType::get({16, 16}, opType);
991 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
992
993 amx::TileStoreOp::create(rewriter, loc, resultBuffer, ValueRange{c0, c0},
994 dp);
995
996 auto vectorType = mlir::VectorType::get({16, 16}, opType);
997 int64_t srcRank =
998 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
999 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1000 auto map = AffineMap::getMinorIdentityMap(srcRank, vectorType.getRank(),
1001 rewriter.getContext());
1002 SmallVector<bool> inBounds(vectorType.getRank(), true);
1003
1004 Value vecRow = vector::TransferReadOp::create(
1005 rewriter, loc, vectorType, resultBuffer, ValueRange{c0, c0}, padding,
1006 map, inBounds);
1007
1008 Value resultOp = contractionUsersAfterYield(contractOp.getResult());
1009 if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType()))
1010 vecRow = vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
1011
1012 rewriter.replaceAllUsesWith(resultOp, vecRow);
1013 return success();
1014 }
1015
1016 // Case 2: The acc are passed as iter args through the reduction loop.
1017 // We support, reduction loop depth until 2. TODO: Support for n-depth
1018 // reduction loop.
1019 // TODOs: Re-factor 2a and 2b.
1020 SmallVector<scf::ForOp> loopLists;
1021 Operation *current = contractOp;
1022 while (true) {
1023 Operation *parent = current->getParentOfType<scf::ForOp>();
1024
1025 if (!parent)
1026 return rewriter.notifyMatchFailure(
1027 contractOp,
1028 "Accumulator read and contract op not within scf.for op");
1029
1030 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
1031
1032 if (accReadOp->getBlock() == parent->getBlock()) {
1033 break;
1034 }
1035
1036 current = parent;
1037 }
1038 if (loopLists.size() > 2 || loopLists.size() == 0)
1039 return rewriter.notifyMatchFailure(
1040 contractOp, "Rewrite is supported until reduction loop depth of 2.");
1041
1042 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1043 contractOp.getLhs(), false);
1044 if (failed(srcIndxLhs))
1045 return rewriter.notifyMatchFailure(contractOp,
1046 "The LHS src is not a MemRef type.");
1047 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
1048
1049 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1050 contractOp.getRhs(), false);
1051 if (failed(srcIndxRhs))
1052 return rewriter.notifyMatchFailure(contractOp,
1053 "The RHS src is not a MemRef type.");
1054 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
1055 Operation *vectorOpLhs;
1056 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
1057 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
1058 vectorOpLhs = readOp.getBase().getDefiningOp();
1059 });
1060
1061 Operation *vectorOpRhs;
1062 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
1063 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
1064 vectorOpRhs = readOp.getBase().getDefiningOp();
1065 });
1066
1067 // Retrive all the contaction operation within the loop.
1068 SmallVector<vector::ContractionOp> ops;
1069 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
1070
1071 if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
1072
1073 LogicalResult validate = validateContractOps(
1074 rewriter, contract, dimValue, srcBuffLhs, srcBuffRhs, true);
1075
1076 if (failed(validate))
1077 return rewriter.notifyMatchFailure(
1078 contractOp,
1079 "The associated contract operations doesn't satisfy "
1080 "the re-write conditions either the dimensions are "
1081 "wrong or MemRef source are different or many users.");
1082
1083 ops.push_back(contract);
1084 }
1085 }
1086
1087 if (!isVnni) {
1088 unsigned int pairCount = 0;
1089 for (size_t j = 0; j < ops.size(); j++) {
1090 for (size_t i = j; i < ops.size(); i++) {
1091 if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16))
1092 pairCount = pairCount + 2;
1093 }
1094 }
1095
1096 if (pairCount != ops.size())
1097 return rewriter.notifyMatchFailure(
1098 contractOp, "Coudn't find the pair vector contract ");
1099 }
1100
1101 scf::ForOp innerLoop;
1102 scf::ForOp outerLoop;
1103
1104 scf::ForOp newLoop;
1105 // Case 2a: Reduction loop depth is 2.
1106 if (loopLists.size() == 2) {
1107 outerLoop = loopLists[1];
1108 innerLoop = loopLists[0];
1109
1110 LogicalResult validateOuterLoopStep =
1111 validateLoopStep(rewriter, outerLoop.getStep(), 1);
1112 if (failed(validateOuterLoopStep))
1113 return rewriter.notifyMatchFailure(contractOp, "Invalid loop step.");
1114
1115 int64_t stepValue = 16;
1116 if (!isVnni)
1117 stepValue = stepValue * blockingFactor;
1118 LogicalResult validateInnerLoopStep =
1119 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1120 if (failed(validateInnerLoopStep))
1121 return rewriter.notifyMatchFailure(
1122 contractOp, "Invalid loop step. The step should be 32 for BF16 and "
1123 "64 for Int8/F8.");
1124
1125 SmallVector<Value> loopItrArgs = createTileZeros(
1126 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
1127
1128 if (isVnni) {
1129 newLoop = scf::ForOp::create(
1130 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1131 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
1132 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1133 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1134 auto newInnerLoop = createLoops(
1135 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1136 innerLoop.getUpperBound(), innerLoop.getStep(),
1137 iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
1138 vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
1139 ops, ivOuterLoop, nullptr, true, nullptr, false, false);
1140
1141 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1142 newInnerLoop.getResults());
1143 });
1144
1145 } else {
1146
1147 bool isInnerLoopUBLarger = false;
1148 bool isInnerLoopUBHasOddQuot = false;
1149
1150 int64_t ubVal = 16 * blockingFactor;
1151 mlir::Value ub = innerLoop.getUpperBound();
1152 if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
1153 if (auto intAttr =
1154 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1155 ubVal = intAttr.getInt();
1156 }
1157 }
1158
1159 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1160 isInnerLoopUBHasOddQuot =
1161 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1162
1163 rewriter.setInsertionPoint(outerLoop);
1164
1165 auto c0 =
1166 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
1167 auto c1 =
1168 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
1169 auto spillLoopBound = arith::ConstantIndexOp::create(
1170 rewriter, outerLoop.getLoc(), 16 * blockingFactor);
1171
1172 Value spillOuterLoop = arith::SubIOp::create(
1173 rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
1174 Value spillInnerLoop =
1175 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1176 innerLoop.getUpperBound(), spillLoopBound);
1177 auto bufferType =
1178 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1179 auto packedBuffer =
1180 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1181
1182 // First Shuffling outside the reduction loops
1183 IRMapping rhsMapping;
1184 rhsMapping.map(
1185 vectorOpRhs->getOperand(
1186 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
1187 outerLoop.getLowerBound());
1188 rhsMapping.map(
1189 vectorOpRhs->getOperand(
1190 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1191 innerLoop.getLowerBound());
1192 auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
1193
1194 Value quotient_batch = arith::DivUIOp::create(
1195 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1196 outerLoop.getStep());
1197 Value quotient_k = arith::DivUIOp::create(rewriter, outerLoop.getLoc(),
1198 innerLoop.getLowerBound(),
1199 innerLoop.getStep());
1200
1201 Value quotient_add = arith::AddIOp::create(rewriter, outerLoop.getLoc(),
1202 quotient_batch, quotient_k);
1203 Value c2 =
1204 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 2);
1205 Value rem = arith::RemUIOp::create(rewriter, outerLoop.getLoc(),
1206 quotient_add, c2);
1207
1208 performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
1209 ipType, blockingFactor, packedBuffer, rem);
1210
1211 // First Set of Loops
1212 auto newLoopNonSpill = scf::ForOp::create(
1213 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1214 spillOuterLoop, outerLoop.getStep(), loopItrArgs,
1215 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1216 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1217 auto newInnerLoop1 = createLoops(
1218 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1219 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1220 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1221 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1222 ivOuterLoop, packedBuffer, true, spillLoopBound,
1223 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1224
1225 auto newInnerLoop = createLoops(
1226 rewriter, innerLoop.getLoc(), spillInnerLoop,
1227 innerLoop.getUpperBound(), innerLoop.getStep(),
1228 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1229 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1230 innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
1231 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1232
1233 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1234 newInnerLoop.getResults());
1235 });
1236
1237 // Last set of Loops
1238 newLoop = scf::ForOp::create(
1239 rewriter, outerLoop.getLoc(), spillOuterLoop,
1240 outerLoop.getUpperBound(), outerLoop.getStep(),
1241 newLoopNonSpill.getResults(),
1242 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1243 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1244 auto newInnerLoop1 = createLoops(
1245 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1246 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1247 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1248 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1249 ivOuterLoop, packedBuffer, true, spillLoopBound,
1250 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1251
1252 auto newInnerLoop = createLoops(
1253 rewriter, innerLoop.getLoc(), spillInnerLoop,
1254 innerLoop.getUpperBound(), innerLoop.getStep(),
1255 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1256 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1257 innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
1258 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1259
1260 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1261 newInnerLoop.getResults());
1262 });
1263 }
1264 }
1265
1266 // Case 2b: Reduction loop depth is 1.
1267 if (loopLists.size() == 1) {
1268
1269 innerLoop = loopLists[0];
1270 int64_t stepValue = 16;
1271 if (!isVnni)
1272 stepValue = stepValue * blockingFactor;
1273
1274 LogicalResult validateInnerLoopStep =
1275 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1276 if (failed(validateInnerLoopStep))
1277 return rewriter.notifyMatchFailure(
1278 contractOp,
1279 "Invalid loop step. The step should be 32 for BF16 and "
1280 "64 for Int8/F8 or 1 if it is rduction loop other than K.");
1281
1282 SmallVector<Value> loopItrArgs = createTileZeros(
1283 rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
1284
1285 if (isVnni) {
1286 newLoop = createLoops(
1287 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1288 innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
1289 opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1290 contractOp, nullptr, innerLoop, ops, nullptr, nullptr, true,
1291 nullptr, false, false);
1292
1293 } else {
1294
1295 bool isInnerLoopUBLarger = false;
1296 bool isInnerLoopUBHasOddQuot = false;
1297
1298 int64_t ubVal = 16 * blockingFactor;
1299 mlir::Value ub = innerLoop.getUpperBound();
1300 if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
1301 if (auto intAttr =
1302 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1303 ubVal = intAttr.getInt();
1304 }
1305 }
1306
1307 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1308 isInnerLoopUBHasOddQuot =
1309 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1310
1311 rewriter.setInsertionPoint(innerLoop);
1312
1313 auto c0 =
1314 arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
1315 int64_t offset = 16 * blockingFactor;
1316 if (auto cst =
1317 innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>())
1318 offset = cst.value();
1319
1320 auto spillLoopBound = arith::ConstantIndexOp::create(
1321 rewriter, innerLoop.getLoc(), offset);
1322 Value spillInnerLoop =
1323 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1324 innerLoop.getUpperBound(), spillLoopBound);
1325
1326 auto bufferType =
1327 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1328 auto packedBuffer =
1329 memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
1330
1331 // First Shuffling outside the reduction loops
1332 IRMapping rhsMapping;
1333 rhsMapping.map(
1334 vectorOpRhs->getOperand(
1335 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1336 innerLoop.getLowerBound());
1337 auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
1338
1339 Value quotient_k = arith::DivUIOp::create(rewriter, innerLoop.getLoc(),
1340 innerLoop.getLowerBound(),
1341 innerLoop.getStep());
1342 Value c2 =
1343 arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 2);
1344 Value rem = arith::RemUIOp::create(rewriter, innerLoop.getLoc(),
1345 quotient_k, c2);
1346
1347 performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
1348 ipType, blockingFactor, packedBuffer, rem);
1349
1350 auto newLoopNonSpill = createLoops(
1351 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1352 spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
1353 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
1354 nullptr, innerLoop, ops, nullptr, packedBuffer, true,
1355 spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1356
1357 newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
1358 innerLoop.getUpperBound(), innerLoop.getStep(),
1359 newLoopNonSpill.getResults(), ipType, opType,
1360 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1361 contractOp, nullptr, innerLoop, ops, nullptr,
1362 packedBuffer, false, c0, isInnerLoopUBLarger,
1363 isInnerLoopUBHasOddQuot);
1364 }
1365
1366 // This helps the final store back to the acc uses the same code for
1367 // the both reduction loop depth 1 or 2.
1368 outerLoop = innerLoop;
1369 }
1370
1371 // Copy the amx tile accumulation results to a MemRef buffer, add the
1372 // initial accumulation value, and store back to the C-Matrix
1373 Location loc = outerLoop.getLoc();
1374 Value srcBuffAcc;
1375 SmallVector<Value> indicesAcc;
1376
1377 llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
1378 [&](auto readOp) {
1379 srcBuffAcc = readOp.getOperand(0);
1380
1381 auto indices = readOp.getIndices();
1382 indicesAcc.reserve(indices.size());
1383
1384 llvm::transform(indices, std::back_inserter(indicesAcc),
1385 [&](OpFoldResult ofr) {
1387 rewriter, loc, ofr);
1388 });
1389 });
1390
1391 auto outputShapes =
1392 mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
1393 unsigned int M = outputShapes[outputShapes.size() - 2];
1394 unsigned int N = outputShapes[outputShapes.size() - 1];
1395
1396 SmallVector<Value> dps = newLoop.getResults();
1397 auto bufferType = MemRefType::get({M, N}, opType);
1398 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
1399
1400 // Store the amx tiled-dot product output into an MxN memref.
1401 for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
1402 for (unsigned int j = 0; j < N; j = j + 16) {
1403 Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
1404 Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
1405 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
1406 ValueRange{indexOp_i, indexOp_j}, dps[k]);
1407 k++;
1408 }
1409 }
1410 auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1411 auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
1412 auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
1413 auto nBound = arith::ConstantIndexOp::create(rewriter, loc, N);
1414
1415 // Create a loop that iterates over the MxN memerf, retrives two rows +
1416 // shuffle them, add up the C element values and stores them to temp buffer.
1417 scf::ForOp::create(
1418 rewriter, loc, c0, nBound, one, ValueRange{},
1419 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
1420 ValueRange iterArgs) {
1421 auto row =
1422 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1423 resultBuffer, ValueRange{iv, c0});
1424
1425 auto row2 =
1426 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1427 resultBuffer, ValueRange{iv, c16});
1428
1429 Value shuffle1 = row;
1430 Value shuffle2 = row2;
1431
1432 if (!isVnni) {
1433 shuffle1 = vector::ShuffleOp::create(
1434 rewriter, loc, VectorType::get(16, opType), row, row2,
1435 ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
1436 21, 22, 23});
1437
1438 shuffle2 = vector::ShuffleOp::create(
1439 rewriter, loc, VectorType::get(16, opType), row, row2,
1440 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
1441 28, 29, 30, 31});
1442 }
1443 indicesAcc[indicesAcc.size() - 2] = iv;
1444 indicesAcc[indicesAcc.size() - 1] = c0;
1445
1446 Value valueCRow1 =
1447 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1448 srcBuffAcc, indicesAcc);
1449 indicesAcc[indicesAcc.size() - 1] = c16;
1450
1451 Value valueCRow2 =
1452 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1453 srcBuffAcc, indicesAcc);
1454
1455 Value addOp;
1456 Value addOp2;
1457
1458 if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN()) {
1459 addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
1460
1461 addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
1462 }
1463
1464 if (ipType.isSignlessInteger(8)) {
1465 addOp = arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
1466
1467 addOp2 = arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
1468 }
1469
1470 vector::StoreOp::create(rewriter, loc, addOp, resultBuffer,
1471 ValueRange{iv, c0});
1472 vector::StoreOp::create(rewriter, loc, addOp2, resultBuffer,
1473 ValueRange{iv, c16});
1474
1475 scf::YieldOp::create(nestedBuilder, loc);
1476 });
1477
1478 SmallVector<Value> writeResults;
1479 for (unsigned int i = 0; i < M; i = i + 16) {
1480 for (unsigned int j = 0; j < N; j = j + 16) {
1481 Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
1482 Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
1483
1484 auto vectorType = mlir::VectorType::get({16, 16}, opType);
1485
1486 int64_t srcRank =
1487 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
1488 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1489 auto map = AffineMap::getMinorIdentityMap(srcRank, vectorType.getRank(),
1490 rewriter.getContext());
1491 SmallVector<bool> inBounds(vectorType.getRank(), true);
1492
1493 auto vec1 = vector::TransferReadOp::create(
1494 rewriter, loc, vectorType, resultBuffer,
1495 ValueRange{indexOp_i, indexOp_j}, padding, map, inBounds);
1496 writeResults.push_back(vec1);
1497 }
1498 }
1499
1500 // Replace use of vector.contract with dot-products.
1501 for (size_t i = 0; i < ops.size(); i++) {
1502 vector::ContractionOp contOp = ops[i];
1503 Value vecRow = writeResults[i];
1504
1505 Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
1506 if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType()))
1507 vecRow = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
1508 writeResults[i]);
1509
1510 rewriter.replaceAllUsesWith(resultWriteOp, vecRow);
1511 }
1512
1513 return success();
1514 }
1515};
1516
1517} // namespace
1518
1520 RewritePatternSet &patterns) {
1521 patterns.add<VectorContractToAMXDotProduct>(patterns.getContext());
1522}
return success()
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
#define rem(a, b)
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
FloatType getF32Type()
Definition Builders.cpp:47
FloatType getF8E5M2Type()
Definition Builders.cpp:39
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatType getBF16Type()
Definition Builders.cpp:41
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
FloatType getF8E4M3FNType()
Definition Builders.cpp:37
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
unsigned getNumOperands()
Definition Operation.h:371
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF8E5M2() const
Definition Types.cpp:45
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
use_iterator use_begin() const
Definition Value.h:184
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
mlir::x86::AMXTileType TileType
Definition X86Dialect.h:40
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:194
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:154
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:352
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.