MLIR 23.0.0git
VectorContractToAMXDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
17#include "mlir/IR/Dominance.h"
19#include "llvm/Support/Casting.h"
20
21#include "mlir/Pass/Pass.h"
23
24using namespace mlir;
25using namespace mlir::vector;
26using namespace mlir::x86;
27
28namespace {
29
30// Function to collapse the last two dimension (vnni and k) to help the
31// amx.tile_load to correctly load the packed element type.
33 Value input) {
34 ShapedType inputType = cast<ShapedType>(input.getType());
35 int64_t firstDimToCollapse = inputType.getRank() - 2;
36
37 if (inputType.getRank() == 1)
38 return input;
39
41 for (int64_t i = 0; i < firstDimToCollapse; ++i)
42 reassociation.push_back(ReassociationIndices{i});
43
44 ReassociationIndices collapsedIndices;
45 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
46 collapsedIndices.push_back(i);
47
48 reassociation.push_back(collapsedIndices);
49 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
50}
51
52// Get the MemRef source and offset index for the operands of
53// vector.contract.
54static FailureOr<std::pair<Value, SmallVector<Value>>>
55getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
56 bool isNotAcc) {
57 Operation *defOp = operand.getDefiningOp();
58 if (!defOp)
59 return failure();
60
61 Value srcBuff;
64 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
65 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
66 readOp.getIndices().end());
67 srcBuff = readOp.getOperand(0);
68 });
69
70 if (!srcBuff)
71 return failure();
72
73 if (isNotAcc)
74 indexVals.pop_back();
75
77 indices.reserve(indexVals.size());
78
79 for (OpFoldResult ofr : indexVals) {
80 indices.push_back(
81 mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
82 }
83
84 if (isNotAcc) {
85 srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
86 }
87
88 return std::make_pair(srcBuff, indices);
89}
90
91// Function to validate the vector.contract operation.
92static LogicalResult validateContractOps(OpBuilder &rewriter,
93 vector::ContractionOp contractOp,
94 unsigned int blockingFactor,
95 Value srcBuffLhs, Value srcBuffRhs,
96 bool srcValidate) {
97
98 if (srcValidate) {
99 // Get the MemRef buffer of LHS operand.
100 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
101 contractOp.getLhs(), false);
102 if (failed(srcIndxLhs))
103 return failure();
104 auto [buffLhs, indicesLhs] = *srcIndxLhs;
105
106 // Get the MemRef buffer of RHS operand.
107 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
108 contractOp.getRhs(), false);
109 if (failed(srcIndxRhs))
110 return failure();
111 auto [buffRhs, indicesRhs] = *srcIndxRhs;
112
113 // Return failure if the Memref buff didn't match.
114 if (buffLhs != srcBuffLhs)
115 return failure();
116
117 if (buffRhs != srcBuffRhs)
118 return failure();
119 }
120
121 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
122 if (!accTy)
123 return failure();
124
125 // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
126 ArrayRef<int64_t> accShape = accTy.getShape();
127 llvm::SmallVector<int64_t> nonUnitDimAcc;
128 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
129 [](int64_t dim) { return (dim != 16 && dim != 1); });
130
131 if (nonUnitDimAcc.size() != 0)
132 return failure();
133
134 // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
135 // <16x16x4>. The vnni dims should be 2 or 4.
136 VectorType lhsTy = contractOp.getLhsType();
137 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
138 llvm::SmallVector<int64_t> nonUnitDimLhs;
139 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
140 [](int64_t dim) { return (dim != 16 && dim != 1); });
141
142 if (nonUnitDimLhs.size() != 1)
143 return failure();
144
145 if (nonUnitDimLhs[0] != blockingFactor)
146 return failure();
147
148 // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
149 // <16x16x4>. The vnni dims should be 2 or 4.
150 VectorType rhsTy = contractOp.getRhsType();
151 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
152 llvm::SmallVector<int64_t> nonUnitDimRhs;
153 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
154 [](int64_t dim) { return (dim != 16 && dim != 1); });
155
156 if (nonUnitDimRhs.size() != 1)
157 return failure();
158
159 if (nonUnitDimRhs[0] != blockingFactor)
160 return failure();
161
162 return success();
163}
164
165// Returns the loop index position to get mapped during the
166// MemRef type clone.
167static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
168 Value iv = loop.getInductionVar();
169
170 Value srcBuff;
172 .Case<TransferReadOp, LoadOp>(
173 [&](auto readOp) { srcBuff = readOp.getOperand(0); });
174
175 auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
176 if (!subview)
177 return 0;
178
179 auto offsets = subview.getOffsets();
180
181 for (auto it : llvm::enumerate(offsets)) {
182 if (it.value() == iv)
183 return it.index();
184 }
185
186 return 0;
187}
188
189// Creates amx.tile_loads.
190static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
191 Value operand, Value mat, Type ipType,
192 bool rhs, unsigned int offset) {
193
194 auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
195 auto [srcBuff, indices] = *srcIndx;
196 indices.pop_back();
197
198 if (rhs) {
199 auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
200 indices[indices.size() - 1] = arith::MulIOp::create(
201 rewriter, loc, indices[indices.size() - 1], cOffset);
202 }
203
204 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
205 return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
206}
207
208// Creates tiled amx dot-products.
209static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
211 Value matA, Value matB, Type ipType,
212 Type opType, ValueRange accIterArgs,
213 unsigned int offset) {
214
215 auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
216 auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
217
218 SmallVector<Value> accumulators;
219 // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
220 // or load operations.
222
223 // Iterate over the contraction operations and compute the tiled dot-product.
224 for (size_t i = 0; i < ops.size(); i++) {
225
226 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
227 amx::TileLoadOp tilesLhs;
228 auto itLhs = readsToTileLoads.find(readOpLhs);
229 if (itLhs != readsToTileLoads.end()) {
230 tilesLhs = itLhs->second;
231 } else {
232 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
233 subviewCollapseLhs, ipType, false, offset);
234 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
235 }
236
237 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
238 amx::TileLoadOp tilesRhs;
239 auto itRhs = readsToTileLoads.find(readOpRhs);
240 if (itRhs != readsToTileLoads.end()) {
241 tilesRhs = itRhs->second;
242 } else {
243 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
244 subviewCollapseRhs, ipType, true, offset);
245 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
246 }
247
248 auto accTileType = amx::TileType::get({16, 16}, opType);
249
250 Value dp;
251 if (ipType.isBF16())
252 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
253 tilesRhs, accIterArgs[i]);
254
255 if (ipType.isSignlessInteger(8))
256 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
257 tilesRhs, accIterArgs[i]);
258
259 accumulators.push_back(dp);
260 }
261 return accumulators;
262}
263
264static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
265 Type opType, scf::ForOp outerLoop,
266 int64_t size) {
267 rewriter.setInsertionPoint(outerLoop);
268
269 SmallVector<Value> loopItrArgs;
270 auto zeroTileType = amx::TileType::get({16, 16}, opType);
271
272 for (int i = 0; i < size; i++) {
273 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
274 loopItrArgs.push_back(zeroTile);
275 }
276 return loopItrArgs;
277}
278
279// Implements tiled dot-product operation for a vector.contract operation or a
280// sequence of vector.contracts inside the reduction loops.
281//
282// For example - for F32 type:
283// ```
284// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
285// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
286// vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
287// vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
288// ```
289// to
290// ```
291// amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
292// amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
293// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
294// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
295// ```
296struct VectorContractToAMXDotProduct
297 : public OpRewritePattern<vector::ContractionOp> {
298 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
299
300 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
301 PatternRewriter &rewriter) const override {
302
303 if (contractOp.getKind() != vector::CombiningKind::ADD)
304 return rewriter.notifyMatchFailure(contractOp,
305 "Expects add combining kind.");
306
307 unsigned int blockingFactor =
308 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
309 bool isVnni =
310 isInVnniLayout(contractOp.getOperation(),
311 contractOp.getIndexingMapsArray(), blockingFactor);
312
313 VectorType lhsTy = contractOp.getLhsType();
314 if (!lhsTy.getElementType().isBF16() &&
315 !lhsTy.getElementType().isSignlessInteger(8))
316 return rewriter.notifyMatchFailure(
317 contractOp, "Only BF16/Int8 lowering is supported.");
318
319 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
320 if (!accTy)
321 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
322
323 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
324 (lhsTy.getElementType().isSignlessInteger(8) &&
325 !accTy.getElementType().isSignlessInteger(32)))
326 return rewriter.notifyMatchFailure(contractOp,
327 "Only F32 for BF16 or Int32 for Int8 "
328 "accumulation type is supported.");
329 if (!isVnni)
330 return rewriter.notifyMatchFailure(
331 contractOp, "Only VNNI-packed inputs are supported.");
332
333 Operation *accReadOp =
334 traceToVectorReadLikeParentOperation(contractOp.getAcc());
335
336 Operation *resultWriteOp =
337 traceToVectorWriteLikeUserOperation(contractOp.getResult());
338
339 if (!accReadOp || !resultWriteOp)
340 return rewriter.notifyMatchFailure(
341 contractOp, "The ACC operand of the vector.contract should be a "
342 "transfer_read or a load. And, the result should be "
343 "stored using transfer_write or store.");
344
345 Type ipType = rewriter.getBF16Type();
346 Type opType = rewriter.getF32Type();
347
348 if (lhsTy.getElementType().isSignlessInteger(8)) {
349 ipType = rewriter.getIntegerType(8);
350 opType = rewriter.getIntegerType(32);
351 }
352
353 if (accReadOp->getBlock() == contractOp->getBlock() &&
354 resultWriteOp->getBlock() != contractOp->getBlock())
355 return rewriter.notifyMatchFailure(
356 contractOp, "The accumulator store is in different block.");
357
358 if (accReadOp->getBlock() != contractOp->getBlock() &&
359 resultWriteOp->getBlock() == contractOp->getBlock())
360 return rewriter.notifyMatchFailure(
361 contractOp, "The accumulator read is in different block.");
362
363 // Case 1: For just one VC rewrite. Where all accumulator read/write
364 // within the same block.
365 if (accReadOp->getBlock() == contractOp->getBlock() &&
366 resultWriteOp->getBlock() == contractOp->getBlock()) {
367
368 LogicalResult validate = validateContractOps(
369 rewriter, contractOp, blockingFactor, Value(), Value(), false);
370
371 if (failed(validate))
372 return rewriter.notifyMatchFailure(
373 contractOp, "The contract operation doesn't satisfy the operands "
374 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
375 "The rest dims should be 1.");
376
377 Location loc = contractOp.getLoc();
378
379 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
380 contractOp.getLhs(), true);
381 if (failed(srcIndxLhs))
382 return rewriter.notifyMatchFailure(contractOp,
383 "The LHS src is not a MemRef type.");
384 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
385
386 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
387 contractOp.getRhs(), true);
388 if (failed(srcIndxRhs))
389 return rewriter.notifyMatchFailure(contractOp,
390 "The RHS src is not a MemRef type.");
391 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
392
393 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
394 contractOp.getAcc(), false);
395 if (failed(srcIndxAcc))
396 return rewriter.notifyMatchFailure(contractOp,
397 "The ACC src is not a MemRef type.");
398 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
399
400 // amx.tile_loads
401 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
402 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
403 srcBuffLhs, indicesLhs);
404 auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
405 srcBuffRhs, indicesRhs);
406
407 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
408 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
409 srcBuffAcc, indicesAcc);
410
411 // Tiled dot-product.
412 Value dp;
413 if (ipType.isBF16())
414 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
415 loadRhs, loadAcc);
416
417 if (ipType.isSignlessInteger(8))
418 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
419 loadRhs, loadAcc);
420
421 amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
422
423 rewriter.eraseOp(resultWriteOp);
424 return success();
425 }
426
427 // Case 2: The acc are passed as iter args through the reduction loop.
428 // We support, reduction loop depth until 2. TODO: Support for n-depth
429 // reduction loop.
430 SmallVector<scf::ForOp> loopLists;
431 Operation *current = contractOp;
432
433 while (true) {
434 Operation *parent = current->getParentOfType<scf::ForOp>();
435 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
436
437 if (accReadOp->getBlock() == parent->getBlock()) {
438 break;
439 }
440
441 current = parent;
442 }
443
444 if (loopLists.size() > 2 || loopLists.size() == 0)
445 return rewriter.notifyMatchFailure(
446 contractOp, "Rewrite is supported until reduction loop depth of 2.");
447
448 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
449 contractOp.getLhs(), false);
450 if (failed(srcIndxLhs))
451 return rewriter.notifyMatchFailure(contractOp,
452 "The LHS src is not a MemRef type.");
453 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
454
455 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
456 contractOp.getRhs(), false);
457 if (failed(srcIndxRhs))
458 return rewriter.notifyMatchFailure(contractOp,
459 "The RHS src is not a MemRef type.");
460 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
461
462 Operation *vectorOpLhs;
463 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
464 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
465 vectorOpLhs = readOp.getBase().getDefiningOp();
466 });
467
468 Operation *vectorOpRhs;
469 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
470 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
471 vectorOpRhs = readOp.getBase().getDefiningOp();
472 });
473
474 // Retrive all the contaction operation within the loop.
475 SmallVector<vector::ContractionOp> ops;
476 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
477
478 if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
479
480 LogicalResult validate = validateContractOps(
481 rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
482
483 if (failed(validate))
484 return rewriter.notifyMatchFailure(
485 contractOp, "The associated contract operations doesn't satisfy "
486 "the re-write conditions either the dimensions are "
487 "wrong or MemRef source are different.");
488
489 ops.push_back(contract);
490 }
491 }
492
493 scf::ForOp outerLoop;
494 scf::ForOp innerLoop;
495
496 scf::ForOp newLoop;
497 // Case 2a: Reduction loop depth is 2.
498 if (loopLists.size() == 2) {
499 outerLoop = loopLists[1];
500 innerLoop = loopLists[0];
501
502 SmallVector<Value> loopItrArgs = createTileZeros(
503 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
504
505 newLoop = scf::ForOp::create(
506 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
507 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
508 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
509 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
510 auto newInnerLoop = scf::ForOp::create(
511 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
512 innerLoop.getUpperBound(), innerLoop.getStep(),
513 iterArgsOuterLoop,
514 [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
515 Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
516 IRMapping mapping;
517 mapping.map(
518 vectorOpLhs->getOperand(
519 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
520 ivOuterLoop);
521 mapping.map(
522 vectorOpLhs->getOperand(
523 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
524 ivNewInnerLoop);
525 auto lhsClone =
526 rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
527
528 IRMapping rhsMapping;
529 rhsMapping.map(
530 vectorOpRhs->getOperand(
531 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
532 ivOuterLoop);
533 rhsMapping.map(
534 vectorOpRhs->getOperand(
535 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
536 ivNewInnerLoop);
537 auto rhsClone =
538 rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
539
540 SmallVector<Value> accumulators = createTiledDp(
541 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
542 rhsClone->getResult(0), ipType, opType,
543 iterArgsNewInnerLoop, blockingFactor);
544
545 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
546 accumulators);
547 });
548
549 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
550 newInnerLoop.getResults());
551 });
552 }
553
554 // Case 2b: Reduction loop depth is 1.
555 if (loopLists.size() == 1) {
556 outerLoop = loopLists[0];
557
558 SmallVector<Value> loopItrArgs = createTileZeros(
559 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
560 newLoop = scf::ForOp::create(
561 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
562 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
563 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
564 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
565 IRMapping mapping;
566 mapping.map(
567 vectorOpLhs->getOperand(
568 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
569 ivOuterLoop);
570
571 auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
572
573 IRMapping rhsMapping;
574 rhsMapping.map(
575 vectorOpRhs->getOperand(
576 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
577 ivOuterLoop);
578
579 auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
580
581 SmallVector<Value> accumulators = createTiledDp(
582 rewriter, locOuterLoop, ops, lhsClone->getResult(0),
583 rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
584 blockingFactor);
585
586 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
587 });
588 }
589
590 // post processing after the loop creation.
591 // Copy the amx tile accumulation results to a MemRef buffer, add the
592 // initial accumulation value, and store back to the C-Matrix
593 auto bufferType = MemRefType::get({16, 16}, opType);
594 auto bBuffer =
595 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
596
597 SmallVector<Value> dps = newLoop.getResults();
598 for (size_t i = 0; i < ops.size(); i++) {
599 vector::ContractionOp contOp = ops[i];
600 Operation *resultWriteOp =
601 traceToVectorWriteLikeUserOperation(contOp.getResult());
602 rewriter.setInsertionPoint(resultWriteOp);
603
604 Value indexOp_0 =
605 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
606
607 amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
608 ValueRange{indexOp_0, indexOp_0}, dps[i]);
609
610 auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
611 auto one =
612 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
613 auto mBound =
614 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
615
616 scf::ForOp::create(
617 rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
618 [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
619 auto resultAcc = vector::LoadOp::create(
620 rewriter, loc, VectorType::get(16, opType), bBuffer,
621 ValueRange{iv, c0});
622
623 Operation *accReadOp =
625
626 Value srcBuffAcc;
627 SmallVector<Value> indicesAcc;
628
629 llvm::TypeSwitch<Operation *>(accReadOp)
630 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
631 srcBuffAcc = readOp.getOperand(0);
632
633 auto indices = readOp.getIndices();
634 indicesAcc.reserve(indices.size());
635
636 llvm::transform(
637 indices, std::back_inserter(indicesAcc),
638 [&](OpFoldResult ofr) {
640 loc, ofr);
641 });
642 });
643
644 Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
645 indicesAcc[indicesAcc.size() - 2] = sum;
646
647 auto acc = vector::LoadOp::create(rewriter, loc,
648 VectorType::get(16, opType),
649 srcBuffAcc, indicesAcc);
650 Value addition;
651 if (ipType.isBF16())
652 addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
653
654 if (ipType.isSignlessInteger(8))
655 addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
656
657 vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
658 indicesAcc);
659
660 scf::YieldOp::create(builder, outerLoop.getLoc());
661 });
662
663 rewriter.eraseOp(resultWriteOp);
664 }
665
666 return success();
667 }
668};
669
670} // namespace
671
673 RewritePatternSet &patterns) {
674 patterns.add<VectorContractToAMXDotProduct>(patterns.getContext());
675}
return success()
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatType getBF16Type()
Definition Builders.cpp:41
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
mlir::x86::AMXTileType TileType
Definition X86Dialect.h:40
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:186
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:146
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...